/*******************************************************************************
* Copyright (C) 2023 Intel Corporation
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

//@HEADER
// ***************************************************
//
// HPCG: High Performance Conjugate Gradient Benchmark
//
// Contact:
// Michael A. Heroux ( maherou@sandia.gov)
// Jack Dongarra     (dongarra@eecs.utk.edu)
// Piotr Luszczek    (luszczek@eecs.utk.edu)
//
// ***************************************************
//@HEADER

/*!
 @file TestCustomKernel.cpp
 HPCG routine
 */


#include <sycl/sycl.hpp>
#include <fstream>
#include <iostream>
#include <vector>
#include "hpcg.hpp"
#include "WriteProblem.hpp"
#include "mytimer.hpp"

#include "TestCustomKernels.hpp"
#include "kernels/axpby_kernel.hpp"

#ifndef HPCG_NO_MPI
#include "ExchangeHalo.hpp"
#include <mpi.h>
#include <cstdlib>
#endif

#include <cmath>  // for convert_to_gflops()
#include <cstdio> // for convert_to_gflops()

#include "CustomKernels.hpp"
#include "ComputeDotProduct.hpp"
#include "ComputeSPMV.hpp"
#include "ComputeSYMGS.hpp"
#include "test_kernels/ComputeSYMGS_MKL.hpp"
#include "PrefixSum.hpp"

#include "VeryBasicProfiler.hpp"

#ifdef BASIC_PROFILING
#define BEGIN_PROFILE(n) optData->profiler->begin(n);
#define END_PROFILE(n) optData->profiler->end(n);
#define END_PROFILE_WAIT(n, event) event.wait(); optData->profiler->end(n);
#else
#define BEGIN_PROFILE(n)
#define END_PROFILE(n)
#define END_PROFILE_WAIT(n, event)
#endif


// =====================================================================
// ==========  Master Switches for Functional Test Runs  ===============
// =====================================================================

// select which functionality to run functional testing on:

#define TEST_AXPBY
#define TEST_DOT
#define TEST_PREFIX_SUM
#define TEST_SPGEMV
#define TEST_SPGEMV_DOT
#define TEST_SPTRMVL
#define TEST_SPTRMVU
#define TEST_SPTRSVL
#define TEST_SPTRSVU
#define TEST_SPTRSVL_FUSED
#define TEST_SPTRSVU_FUSED
#define TEST_SYMGS
#define TEST_SYMGS_MV

// =====================================================================
// =====================================================================



namespace {

    #define TOL  1.0e-9

    namespace sparse = oneapi::mkl::sparse;
    namespace mkl = oneapi::mkl;

    void check_arrays(sycl::queue &queue, local_int_t length, double *ref, double *test, local_int_t *tmp_dev, local_int_t *tmp_host, const std::vector<sycl::event> &dependencies)
    {
        auto ev_reset = queue.fill(tmp_dev, 0, 1, dependencies);
        auto ev_check = queue.submit([&](sycl::handler &cgh) {
            cgh.depends_on(ev_reset);
            auto kernel = [=](sycl::item<1> row) {
                double diff = (ref[row] - test[row]);
                if (diff > TOL || diff < -TOL) {
                    tmp_dev[0] = 1;
                }
            };
            cgh.parallel_for<class test_custom_kernel_check_arrays>(sycl::range<1>(length), kernel);
        });
        queue.memcpy(tmp_host, tmp_dev, 1*sizeof(local_int_t), ev_check).wait();
    }



    void check_and_print_arrays(sycl::queue &queue, local_int_t length, double *ref, double *test, local_int_t *tmp_dev, local_int_t *tmp_host, const std::vector<sycl::event> &dependencies)
    {
        auto ev_reset = queue.fill(tmp_dev, 0, 1, dependencies);
        auto ev_check = queue.submit([&](sycl::handler &cgh) {
            cgh.depends_on(ev_reset);
            auto kernel = [=](sycl::item<1> item) {
                int row = item.get_id(0);
                double diff = (ref[row] - test[row]);
                if (diff > TOL || diff < -TOL) {
                    tmp_dev[0] = 1;
                    if (row < 1000)
                        sycl::ext::oneapi::experimental::printf("row %d: diff = %3.7f,  ref = %g,  test = %g\n", row, diff, ref[row], test[row]);
                }

//                if (row < 10 )
//                    sycl::ext::oneapi::experimental::printf("row %d: diff = %3.7f,  ref = %g,  test = %g\n", row, diff, ref[row], test[row]);
            };
            cgh.parallel_for<class test_custom_kernel_check_and_print_arrays>(sycl::range<1>(length), kernel);
        });
        queue.memcpy(tmp_host, tmp_dev, 1*sizeof(local_int_t), ev_check).wait();
    }


    void print_header(const int rank)
    {
#ifndef HPCG_NO_MPI
        MPI_Barrier(MPI_COMM_WORLD);
#endif
        if (rank == 0) {
            printf("\t%4s %-4s%-13s%-6s %-5s%s\n", "rank", "", "Functionality", "", "", "Status");
            printf("\t=========================================================\n");
        }
#ifndef HPCG_NO_MPI
            MPI_Barrier(MPI_COMM_WORLD);
#endif
    }



    void print_status(const int rank, const int size,  const std::string &func, const std::string &status)
    {
#ifndef HPCG_NO_MPI
        MPI_Barrier(MPI_COMM_WORLD);
#endif
        for (int r = 0; r < size; ++r) {
            if (rank == r) {
                printf("\t %2d: %-22s verification %s\n", rank, func.c_str(), status.c_str()); // with headers
            }
#ifndef HPCG_NO_MPI
            MPI_Barrier(MPI_COMM_WORLD);
#endif
        }
    }


} // anonymous namespace




int TestCustomKernels(SparseMatrix &A, Vector &b, Vector &x, const int rank, const int size, TestCustomKernelsData &testck_data, sycl::queue &queue)
{
    sycl::event ev_test, ev_ref, ev_update, ev_run;

    double start_time = 0.0, wall_time = 0.0, ave_time = 0.0, gflops = 0.0, gmemops = 0.0;

    testck_data.count_fail = 0;

    const local_int_t nRows = A.localNumberOfRows;
    const local_int_t nCols = A.localNumberOfColumns;
    struct optData *optData = (struct optData *)A.optimizationData;

    custom::sparseMatrix *sparseM = (custom::sparseMatrix *)optData->esbM;

    Vector r, w, y, y1, z, Ay, Az;

    // vectors:  nCols x 1
    InitializeVectorDevice(r, nCols, queue);
    InitializeVectorDevice(w, nCols, queue);
    InitializeVectorDevice(y, nCols, queue);
    InitializeVectorDevice(y1, nCols, queue);
    InitializeVectorDevice(z, nCols, queue);

    // vectors: nRows x 1     = [nRows x nCols] [nCols x 1]
    // use nCols as nCols >= nRows to make sure we have room
    InitializeVectorDevice(Ay, nCols, queue);
    InitializeVectorDevice(Az, nCols, queue);

    double *rv = r.values;
    double *wv = w.values;
    double *yv = y.values;
    double *y1v = y1.values;
    double *zv = z.values;
    double *Ayv = Ay.values;
    double *Azv = Az.values;


#ifdef TEST_PREFIX_SUM
    // for prefix_sum
    local_int_t *int_nrows_dev = (local_int_t *)sparse_malloc_device(nRows * sizeof(local_int_t), queue);
    local_int_t *int_nrowsp1_dev = (local_int_t *)sparse_malloc_device((nRows + 1) * sizeof(local_int_t), queue);
    if (int_nrows_dev == NULL || int_nrowsp1_dev == NULL ) {
        std::cerr << "rank " << rank << ": error in TestCustomKernels local_int_t array allocation" << std::endl;
        return 1;
    }
#endif

    int ierr;
    double *fp_dev = (double *)sparse_malloc_device(1 * sizeof(double), queue);
    double *fp_host = (double *)sparse_malloc_host(1 * sizeof(double), queue);
    local_int_t *tmp_dev = (local_int_t *)sparse_malloc_device(1 * sizeof(local_int_t), queue);
    local_int_t *tmp_host = (local_int_t *)sparse_malloc_host(1 * sizeof(local_int_t), queue);
    double *tmp2_dev = (double *)sparse_malloc_device(nRows * sizeof(double), queue);
    if ( fp_dev == NULL || tmp_dev == NULL || tmp_host == NULL || tmp2_dev == NULL){
        std::cerr << "rank " << rank << ": error in TestCustomKernels allocation" << std::endl;
        return 1;
    }


    local_int_t nRows_b = optData->nrow_b;
    local_int_t *bmap = optData->bmap;

    using val_t = double;
    constexpr std::uint64_t seed = 0; // 777;
    oneapi::mkl::rng::philox4x32x10 engine(queue, seed);
    oneapi::mkl::rng::uniform<val_t> distribution(static_cast<val_t>(-1.0), static_cast<val_t>(1.0));

    oneapi::mkl::rng::generate(distribution, engine, nCols, rv, {}).wait();
    oneapi::mkl::rng::generate(distribution, engine, nCols, wv, {}).wait();
//    queue.fill(wv, 0.01, nRows).wait();

    queue.fill(yv, 0.0, nCols).wait();
    queue.fill(y1v, 0.0, nCols).wait();
    queue.fill(zv, 0.0, nCols).wait();
    queue.fill(Ayv, 0.0, nCols).wait();
    queue.fill(Azv, 0.0, nCols).wait();


#ifdef TEST_PREFIX_SUM
    queue.fill(int_nrows_dev, 1, nRows).wait();
    queue.fill(int_nrowsp1_dev, 1, nRows).wait();
#endif

    queue.fill(tmp_dev, 0, 1).wait();

    // copy over nnz for performance tests
    queue.memcpy(tmp_host, optData->ia + nRows, 1*sizeof(local_int_t)).wait();
    const std::int64_t nnz_a = static_cast<std::int64_t>(tmp_host[0]);

    bool test_passed = false;

    //
    // setup onemkl objects for comparisons
    //

    sparse::matrix_handle_t hMatrixA = nullptr;
    sparse::matrix_handle_t hMatrixB = nullptr;

    sparse::init_matrix_handle(&hMatrixA);
    sparse::init_matrix_handle(&hMatrixB);

    sparse::set_csr_data(queue, hMatrixA, nRows, nRows, nnz_a, mkl::index_base::zero,
                         optData->ia, optData->ja, optData->a, {}).wait();
    if (A.geom->size > 1) {
        // copy over nnz to set the data for matrix B
        queue.memcpy(tmp_host, optData->ib + nRows_b, 1*sizeof(local_int_t)).wait();
        const std::int64_t nnz_b = static_cast<std::int64_t>(tmp_host[0]);
        sparse::set_csr_data(queue, hMatrixB, optData->nrow_b, nCols, nnz_b, mkl::index_base::zero,
                             optData->ib, optData->jb, optData->b).wait();
    }

    sparse::set_matrix_property(hMatrixA, sparse::property::symmetric);
    sparse::set_matrix_property(hMatrixA, sparse::property::sorted);

    sparse::optimize_gemv(queue, mkl::transpose::nontrans, hMatrixA, {}).wait();
    if (A.geom->size > 1) {
        sparse::optimize_gemv(queue, mkl::transpose::nontrans, hMatrixB, {}).wait();
    }

    // call the optimize steps for everything related to onemkl algorithms being tested
    sparse::optimize_trsv(queue, oneapi::mkl::uplo::lower, oneapi::mkl::transpose::nontrans, oneapi::mkl::diag::nonunit, hMatrixA, {}).wait();
    sparse::optimize_trsv(queue, oneapi::mkl::uplo::upper, oneapi::mkl::transpose::nontrans, oneapi::mkl::diag::nonunit, hMatrixA, {}).wait();
    sparse::optimize_trmv(queue, oneapi::mkl::uplo::lower, oneapi::mkl::transpose::nontrans, oneapi::mkl::diag::nonunit, hMatrixA, {}).wait();
    sparse::optimize_trmv(queue, oneapi::mkl::uplo::upper, oneapi::mkl::transpose::nontrans, oneapi::mkl::diag::nonunit, hMatrixA, {}).wait();

#ifndef HPCG_NO_MPI
    MPI_Barrier(MPI_COMM_WORLD);
#endif

    // generic scalars for both functional and performance tests
    double alpha = 1.23, beta = 4.56;


    if (rank == 0) {
        std::cout << "Starting Test Custom Kernel Functional Validation Suite:" << std::endl;
    }


    print_header(rank);


#ifdef TEST_AXPBY

    oneapi::mkl::rng::generate(distribution, engine, nCols, yv, {}).wait();
    queue.memcpy(zv, yv, sizeof(double) * nCols).wait();

    ev_test = queue.submit([&](sycl::handler &cgh) {

        // AXPBY esimd kernel parameters
        constexpr local_int_t block_size = HPCG_BLOCK_SIZE;
        const local_int_t nWG = 8;
        constexpr local_int_t uroll = 4;
        local_int_t nBlocks = ceil_div(nCols, uroll * block_size);

        auto kernel = [=] (sycl::nd_item<1> item) SYCL_ESIMD_KERNEL {
            axpby_body<block_size, uroll>(item, wv, yv, alpha, beta, nCols, nBlocks);
        };
        cgh.parallel_for<class axbpy_esimd>(sycl::nd_range<1>(ceil_div(nBlocks, nWG) * nWG, nWG), kernel);
    });

    // reference code for axpby
    ev_ref = queue.submit([&](sycl::handler &cgh) {
        cgh.depends_on(ev_test);
        auto kernel = [=](sycl::item<1> item) {
            const local_int_t i = item.get_id(0);
            zv[i] = alpha * wv[i] +  beta * zv[i];
        };
        cgh.parallel_for<class test_axpby>( sycl::range<1>(nCols), kernel);
    });

    check_arrays(queue, nCols, zv, yv, tmp_dev, tmp_host, {ev_ref});

    bool failed = (tmp_host[0] == 1);
    print_status(rank, size, "AXPBY", failed ? "failed" : "passed");
    if (failed) testck_data.count_fail++;
    else        testck_data.count_pass++;

    queue.wait();
#ifndef HPCG_NO_MPI
    MPI_Barrier(MPI_COMM_WORLD);
#endif

#endif //TEST_AXPBY


#ifdef TEST_DOT
    // test ComputeDotProductLocal  rdotw_test = dot(r,w)
    ev_test = ComputeDotProductLocal(nRows, r, w, fp_dev, queue, {});
    queue.memcpy(fp_host, fp_dev, 1*sizeof(double), ev_test).wait();
    double rdotw_test = fp_host[0];

    // reference code for  rdotw_ref = dot(r,w)
    ev_ref = queue.submit([&](sycl::handler &cgh) {
        cgh.depends_on(ev_test);
        auto sumReducer = sycl::reduction(fp_dev, sycl::plus<>(), sycl::property::reduction::initialize_to_identity());
        auto kernel = [=](sycl::item<1> item, auto &sumDot){
            const local_int_t i = item.get_id(0);
            sumDot += wv[i] * rv[i];
        };
        cgh.parallel_for<class test_dotproduct>( sycl::range<1>(nRows), sumReducer, kernel);
    });
    queue.memcpy(fp_host, fp_dev, 1*sizeof(double), ev_ref).wait();
    double rdotw_ref = fp_host[0];

    failed = std::fabs(rdotw_test - rdotw_ref) > TOL;
    print_status(rank, size, "Dot Product", failed ? "failed" : "passed");
    if (failed) testck_data.count_fail++;
    else        testck_data.count_pass++;

    queue.wait();
#ifndef HPCG_NO_MPI
    MPI_Barrier(MPI_COMM_WORLD);
#endif

#endif //TEST_DOT


#ifdef TEST_PREFIX_SUM
    // test PrefixSum length nRows  (less common, but shows up in this TestCase for oneMKL B matrix creation)
    ev_test = prefix_sum(queue, nRows, int_nrows_dev, {});
    queue.memcpy(tmp_host, int_nrows_dev+nRows-1, 1*sizeof(local_int_t), ev_test).wait();
    failed = tmp_host[0] != nRows;

    // test PrefixSum length nRows+1  (most common usage)
    ev_test = prefix_sum(queue, nRows+1, int_nrowsp1_dev, {});
    queue.memcpy(tmp_host, int_nrowsp1_dev+nRows, 1*sizeof(local_int_t), ev_test).wait();
    failed &= tmp_host[0] != nRows+1;

    print_status(rank, size, "Prefix Sum", failed ? "failed" : "passed");
    if (failed) testck_data.count_fail++;
    else        testck_data.count_pass++;

    queue.wait();
#ifndef HPCG_NO_MPI
    MPI_Barrier(MPI_COMM_WORLD);
#endif


#endif // TEST_PREFIX_SUM


#ifdef TEST_SPGEMV
    //
    // Test GEMV with A via ComputeSPMV_DOT
    //

    // yv = (A + B) * wv; fp_dev = dot(yv, wv) custom
//    ev_test = custom::SpGEMV(queue, sparseM, wv, yv, {});
    //ev_test = ComputeSPMV(A, w, y, queue, ierr, {});
    // yv <- A*wv, fp_dev <- wv*A*wv
    ev_test = custom::SpGEMV_DOT(queue, sparseM, wv, yv, fp_dev, {ev_test});
    queue.memcpy(fp_host, fp_dev, 1*sizeof(double), ev_test).wait();
    double ydotw = fp_host[0];

    // zv = A * wv; fp_dev = dot(zv, wv) onemkl
    ev_ref = sparse::gemv(queue, oneapi::mkl::transpose::nontrans, 1.0, hMatrixA, wv, 0.0, zv, {ev_test});
    if (A.geom->size > 1) {
        // tmp2_dev = B * wv; onemkl
        ev_ref = sparse::gemv(queue, oneapi::mkl::transpose::nontrans, 1.0, hMatrixB, wv, 0.0, tmp2_dev, {ev_ref});
        // zv += tmp2_dev(bmap)
        ev_ref = queue.submit([&](sycl::handler &cgh) {
            cgh.depends_on(ev_ref);
            auto kernel = [=](sycl::item<1> item){
                const local_int_t row = item.get_id(0);
                zv[bmap[row]] += tmp2_dev[row];
            };
            cgh.parallel_for<class test_gemvB_update>( sycl::range<1>(nRows_b), kernel);
        });
    }
    ev_ref = ComputeDotProductLocal(nRows, z, w, fp_dev, queue, {ev_ref});
    queue.memcpy(fp_host, fp_dev, 1*sizeof(double), ev_ref).wait();
    double zdotw = fp_host[0];

//    check_and_print_arrays(queue, nRows, zv, yv, tmp_dev, tmp_host, {ev_ref});
    check_arrays(queue, nRows, zv, yv, tmp_dev, tmp_host, {ev_ref});

    failed = tmp_host[0] == 1;
    print_status(rank, size, "SpGEMV A+B", failed ? "failed" : "passed");
    if (failed) testck_data.count_fail++;
    else        testck_data.count_pass++;

    failed = std::fabs(ydotw - zdotw) > TOL * nRows;
    print_status(rank, size, "SpGEMV Dot", failed ? "failed" : "passed");
    if (failed) testck_data.count_fail++;
    else        testck_data.count_pass++;

    queue.wait();
#ifndef HPCG_NO_MPI
    MPI_Barrier(MPI_COMM_WORLD);
#endif

#endif // TEST_SPGEMV


#ifdef TEST_SPTRMVL
    //
    // Test SpTRMV Lower + Diagonal + Nonlocal + Update
    //

    ZeroVector(y, queue, {}).wait();
    ZeroVector(y1, queue, {}).wait();
    ZeroVector(z, queue, {}).wait();

    // yv = yv + (L+B) * wv;  custom, y1v not used
    ev_test = custom::SpTRMV(queue, sparseM, custom::uplo::lower_update, wv, rv, yv, y1v, {});
    // zv = zv + (L+I) * wv;  onemkl
    ev_ref = sparse::trmv(queue, oneapi::mkl::uplo::lower, oneapi::mkl::transpose::nontrans, oneapi::mkl::diag::unit,
                          1.0, hMatrixA, wv, 1.0, zv, {ev_test});

    // Subtract off unit diagonals
    ev_ref = queue.submit([&](sycl::handler &cgh) {
        cgh.depends_on(ev_ref);
        auto kernel = [=](sycl::item<1> item){
            const local_int_t row = item.get_id(0);
            zv[row] -= wv[row];
        };
        cgh.parallel_for<class test_trmvL_gemvB_update>( sycl::range<1>(nRows), kernel);
    });

    if (A.geom->size > 1) {
        // tmp2_dev = B * wv
        ev_ref = sparse::gemv(queue, oneapi::mkl::transpose::nontrans, 1.0, hMatrixB, wv, 0.0, tmp2_dev, {ev_ref});
        // zv += tmp2_dev(bmap)
        ev_ref = queue.submit([&](sycl::handler &cgh) {
            cgh.depends_on(ev_ref);
            auto kernel = [=](sycl::item<1> item){
                const local_int_t row = item.get_id(0);
                zv[bmap[row]] += tmp2_dev[row];
            };
            cgh.parallel_for<class test_trmvL_sub_diag>( sycl::range<1>(nRows_b), kernel);
        });
    }

//    check_and_print_arrays(queue, nRows, zv, yv, tmp_dev, tmp_host, {ev_ref});
    check_arrays(queue, nRows, zv, yv, tmp_dev, tmp_host, {ev_ref});

    failed = tmp_host[0] == 1;
    print_status(rank, size, "SpTRMV L+B+update", failed ? "failed" : "passed");
    if (failed) testck_data.count_fail++;
    else        testck_data.count_pass++;

    queue.wait();

#ifndef HPCG_NO_MPI
    MPI_Barrier(MPI_COMM_WORLD);
#endif

#endif // TEST_SPTRMVL



#ifdef TEST_SPTRMVU
    //
    // Test SpTRMV Upper + Nonlocal
    //

    // yv = (U + B) * wv;  custom
    ev_test = custom::SpTRMV(queue, sparseM, custom::uplo::upper_nonlocal, wv, rv, yv, y1v, {});
    // zv = (D+U) * wv;  onemkl
    ev_ref = sparse::trmv(queue, oneapi::mkl::uplo::upper, oneapi::mkl::transpose::nontrans, oneapi::mkl::diag::nonunit,
                          1.0, hMatrixA, wv, 0.0, zv, {ev_test});
    if (A.geom->size > 1) {
        // tmp2_dev = B * wv
        ev_ref = sparse::gemv(queue, oneapi::mkl::transpose::nontrans, 1.0, hMatrixB, wv, 0.0, tmp2_dev, {ev_ref});
        // zv += tmp2_dev(bmap)
        ev_ref = queue.submit([&](sycl::handler &cgh) {
            cgh.depends_on(ev_ref);
            auto kernel = [=](sycl::item<1> item){
                const local_int_t row = item.get_id(0);
                zv[bmap[row]] += tmp2_dev[row];
            };
            cgh.parallel_for<class test_trmvU_gemvB_update>( sycl::range<1>(nRows_b), kernel);
        });
    }
    // zv = zv - D * wv
    ev_update = queue.submit([&](sycl::handler &cgh) {
            cgh.depends_on(ev_ref);
            double *diags = sparseM->diags;
            auto kernel = [=](sycl::item<1> item){
                const local_int_t row = item.get_id(0);
                zv[row] -= diags[row] * wv[row];
                zv[row] = rv[row] - zv[row];
            };
            cgh.parallel_for<class test_trmvU_update_diag>(
                    sycl::range<1>(nRows), kernel);
        });

//    check_and_print_arrays(queue, nRows, zv, yv, tmp_dev, tmp_host, {ev_update});
    check_arrays(queue, nRows, zv, yv, tmp_dev, tmp_host, {ev_update});

    failed = tmp_host[0] == 1;
    print_status(rank, size, "SpTRMV r-(U+B)", failed ? "failed" : "passed");
    if (failed) testck_data.count_fail++;
    else        testck_data.count_pass++;

    queue.wait();

    //
    // Test SpTRMV Upper
    //
    if (A.geom->size > 1) { // Convert zv back to Ux
        // remove r component
        auto evt = queue.submit([&](sycl::handler &cgh) {
            auto kernel = [=](sycl::item<1> item){
                const local_int_t row = item.get_id(0);
                zv[row] = rv[row] - zv[row];
            };
            cgh.parallel_for<class test_remove_r>(
                    sycl::range<1>(nRows), kernel);
        });

        // remove Bx component
        ev_ref = queue.submit([&](sycl::handler &cgh) {
            cgh.depends_on(evt);
            auto kernel = [=](sycl::item<1> item){
                const local_int_t row = item.get_id(0);
                zv[bmap[row]] -= tmp2_dev[row];
            };
            cgh.parallel_for<class test_trmvU_update>( sycl::range<1>(nRows_b), kernel);
        });

        check_arrays(queue, nRows, zv, y1v, tmp_dev, tmp_host, {ev_ref});

        failed = tmp_host[0] == 1;
        print_status(rank, size, "SpTRMV U", failed ? "failed" : "passed");
        if (failed) testck_data.count_fail++;
        else        testck_data.count_pass++;

        queue.wait();
    }

#ifndef HPCG_NO_MPI
    MPI_Barrier(MPI_COMM_WORLD);
#endif

#endif // TEST_SPTRMVU


#ifdef TEST_SPTRSVL

    //
    // Test TRSV Lower
    //

    // solve (L+D) * yv = wv;  custom
    ev_test = custom::SpTRSV(queue, sparseM, custom::uplo::lower_diagonal, wv, yv, {});

    // solve (L+D) * zv = wv;  onemkl
    ev_ref = sparse::trsv(queue, oneapi::mkl::uplo::lower,  oneapi::mkl::transpose::nontrans, oneapi::mkl::diag::nonunit,
#if defined(INTEL_MKL_VERSION) && (INTEL_MKL_VERSION > 20240000)
                          1.0,
#endif
                          hMatrixA, wv, zv, {ev_test});

//    check_and_print_arrays(queue, nRows, zv, yv, tmp_dev, tmp_host, {ev_ref});
    check_arrays(queue, nRows, zv, yv, tmp_dev, tmp_host, {ev_ref});

    failed = tmp_host[0] == 1;
    print_status(rank, size, "SpTRSV FWD", failed ? "failed" : "passed");
    if (failed) testck_data.count_fail++;
    else        testck_data.count_pass++;

    queue.wait();

#ifndef HPCG_NO_MPI
    MPI_Barrier(MPI_COMM_WORLD);
#endif

#endif // TEST_SPTRSVL



#ifdef TEST_SPTRSVU

    //
    // Test TRSV Upper
    //
    // solve (L+D) * yv = wv;  custom
    ev_test = custom::SpTRSV(queue, sparseM, custom::uplo::upper_diagonal, wv, yv, {});

    // solve (L+D) * zv = wv;  onemkl
    ev_ref = sparse::trsv(queue, oneapi::mkl::uplo::upper, oneapi::mkl::transpose::nontrans, oneapi::mkl::diag::nonunit,
#if defined(INTEL_MKL_VERSION) && (INTEL_MKL_VERSION > 20240000)
                              1.0,
#endif
                          hMatrixA, wv, zv, {ev_test});

//    check_and_print_arrays(queue, nRows, zv, yv, tmp_dev, tmp_host, {ev_ref});
    check_arrays(queue, nRows, zv, yv, tmp_dev, tmp_host, {ev_ref});

    failed = tmp_host[0] == 1;
    print_status(rank, size, "SpTRSV BWD", failed ? "failed" : "passed");
    if (failed) testck_data.count_fail++;
    else        testck_data.count_pass++;

    queue.wait();

#ifndef HPCG_NO_MPI
    MPI_Barrier(MPI_COMM_WORLD);
#endif

#endif // TEST_SPTRSVU




#ifdef TEST_SPTRSVL_FUSED
    //
    // Test TRSV Lower Fused   (L+D)*y_out = w_in;   fused with w_out = y_in+D*y_out
    //
    queue.memcpy(Ayv, wv, sizeof(double) * nRows).wait();
    queue.memcpy(Azv, wv, sizeof(double) * nRows).wait();
    queue.fill(yv, 1.0, nCols).wait();
    queue.fill(zv, 1.0, nCols).wait();

    // Ayv = wv;  yv_in = 1.0
    // solve (L+D) * yv_out = Ayv_in; and  Ayv_out = yv_in + D*yv_out  custom
    ev_test = custom::SpTRSV_FUSED(queue, sparseM, custom::uplo::lower_diagonal, Ayv, yv, {});

    // Azv = wv;  zv_in = 1.0
    // solve (L+D) * zv_out = Azv_in;    onemkl
    ev_ref = sparse::trsv(queue, oneapi::mkl::uplo::lower,  oneapi::mkl::transpose::nontrans, oneapi::mkl::diag::nonunit,
#if defined(INTEL_MKL_VERSION) && (INTEL_MKL_VERSION > 20240000)
                              1.0,
#endif
                          hMatrixA, Azv, zv, {ev_test});

    // Azv_out = zv_in + D*zv_out = 1.0 + D * zv
    ev_update = queue.submit([&](sycl::handler &cgh) {
        cgh.depends_on(ev_ref);
        double *diags = sparseM->diags;
        auto kernel = [=](sycl::item<1> item){
            const local_int_t row = item.get_id(0);
            Azv[row] = 1.0 + diags[row] * zv[row];
        };
        cgh.parallel_for<class test_trsv_fused_lower_diag>(
                sycl::range<1>(nRows), kernel);
    });

//    check_and_print_arrays(queue, nRows, zv, yv, tmp_dev, tmp_host, {ev_ref});
    check_arrays(queue, nRows, zv, yv, tmp_dev, tmp_host, {ev_update});
    failed = (tmp_host[0] == 1);
    check_arrays(queue, nRows, Azv, Ayv, tmp_dev, tmp_host, {ev_update});
    failed &= (tmp_host[0] == 1);

    print_status(rank, size, "SpTRSV_FUSED FWD", failed ? "failed" : "passed");
    if (failed) testck_data.count_fail++;
    else        testck_data.count_pass++;

    queue.wait();

#ifndef HPCG_NO_MPI
    MPI_Barrier(MPI_COMM_WORLD);
#endif

#endif // TEST_SPTRSVL_FUSED





#ifdef TEST_SPTRSVU_FUSED
    //
    // Test TRSV Upper Fused   (L+D)* y_out = w_in;   fused with w_out = D*w_in
    //
    queue.memcpy(Ayv, wv, sizeof(double) * nRows).wait();
    queue.memcpy(Azv, wv, sizeof(double) * nRows).wait();

    // Ayv = wv
    // solve (D+U) * yv_out = Ayv_in; and  Ayv_out = D*Ayv_in
    ev_test = custom::SpTRSV_FUSED(queue, sparseM, custom::uplo::upper_diagonal, Ayv, yv, {});

    // Azv = wv;
    // solve (D+U) * zv_out = Azv_in;    onemkl
    ev_ref = sparse::trsv(queue, oneapi::mkl::uplo::upper,  oneapi::mkl::transpose::nontrans, oneapi::mkl::diag::nonunit,
#if defined(INTEL_MKL_VERSION) && (INTEL_MKL_VERSION > 20240000)
                              1.0,
#endif
                          hMatrixA, Azv, zv, {ev_test});

    // Azv_out = D*Azv_in
    ev_update = queue.submit([&](sycl::handler &cgh) {
        cgh.depends_on(ev_ref);
        double *diags = sparseM->diags;
        auto kernel = [=](sycl::item<1> item){
            const local_int_t row = item.get_id(0);
            Azv[row] = diags[row] * Azv[row];
        };
        cgh.parallel_for<class test_trsv_fused_upper_diag>(
                sycl::range<1>(nRows), kernel);
    });

//    check_and_print_arrays(queue, nRows, zv, yv, tmp_dev, tmp_host, {ev_ref});
    check_arrays(queue, nRows, zv, yv, tmp_dev, tmp_host, {ev_update});
    failed = (tmp_host[0] == 1);
    check_arrays(queue, nRows, Azv, Ayv, tmp_dev, tmp_host, {ev_update});
    failed &= (tmp_host[0] == 1);

    print_status(rank, size, "SpTRSV_FUSED BWD", failed ? "failed" : "passed");
    if (failed) testck_data.count_fail++;
    else        testck_data.count_pass++;

    queue.wait();

#ifndef HPCG_NO_MPI
    MPI_Barrier(MPI_COMM_WORLD);
#endif

#endif // TEST_SPTRSVU_FUSED


#ifdef TEST_SYMGS


    //
    // Test SYMGS
    //

    // SYMGS input scaling
    val_t max_scale = 0.01;
    oneapi::mkl::rng::uniform<val_t> distribution2(-max_scale, max_scale);
    oneapi::mkl::rng::generate(distribution2, engine, nRows, yv, {}).wait();
    queue.memcpy(zv, yv, nRows*sizeof(double)).wait();

    ev_test = run_SYMGS_custom(queue, A, optData, sparseM, w, y, {});
    ev_ref = run_SYMGS_onemkl(queue, A, optData, hMatrixA, hMatrixB, w, z, {ev_test});

//    check_and_print_arrays(queue, nRows, zv, yv, tmp_dev, tmp_host, {ev_ref});
    check_arrays(queue, nRows, zv, yv, tmp_dev, tmp_host, {ev_ref});

    failed = (tmp_host[0] == 1);
    print_status(rank, size, "SYMGS", failed ? "failed" : "passed");
    if (failed) testck_data.count_fail++;
    else        testck_data.count_pass++;

    queue.wait();

#ifndef HPCG_NO_MPI
    MPI_Barrier(MPI_COMM_WORLD);
#endif

#endif // TEST_SYMGS


#ifdef TEST_SYMGS_MV

    // reset y,z to zero
    //ZeroVector(y, queue, {}).wait();
    //ZeroVector(z, queue, {}).wait();
    oneapi::mkl::rng::generate(distribution2, engine, nRows, yv, {}).wait();
    queue.memcpy(zv, yv, nRows*sizeof(double)).wait();

    //
    // Test SYMGS_MV full permutation
    //
    ev_test = run_SYMGS_MV_custom(queue, A, optData, sparseM, w, y, Ay, {});
    ev_ref = run_SYMGS_MV_onemkl(queue, A, optData, hMatrixA, hMatrixB, w, z, Az, {ev_test});

//    check_and_print_arrays(queue, nRows, zv, yv, tmp_dev, tmp_host, {ev_ref});
    check_arrays(queue, nRows, zv, yv, tmp_dev, tmp_host, {ev_ref});

    failed = (tmp_host[0] == 1);
    print_status(rank, size, "SYMGS_MV y==z", failed ? "failed" : "passed");
    if (failed) testck_data.count_fail++;
    else        testck_data.count_pass++;

    queue.wait();

//    check_and_print_arrays(queue, nRows, Azv, Ayv, tmp_dev, tmp_host, {ev_ref});
    check_arrays(queue, nRows, Azv, Ayv, tmp_dev, tmp_host, {ev_ref});

    failed = (tmp_host[0] == 1);
    print_status(rank, size, "SYMGS_MV Ay==Az", failed ? "failed" : "passed");
    if (failed) testck_data.count_fail++;
    else        testck_data.count_pass++;

    queue.wait();

#ifndef HPCG_NO_MPI
    MPI_Barrier(MPI_COMM_WORLD);
#endif

#endif // TEST_SYMGS_MV

    // cleanup
    queue.wait();
    DeleteVector(r, queue);
    DeleteVector(w, queue);
    DeleteVector(y, queue);
    DeleteVector(y1, queue);
    DeleteVector(z, queue);
    DeleteVector(Ay, queue);
    DeleteVector(Az, queue);

    sycl::free(fp_dev, queue);
    sycl::free(fp_host, queue);
    sycl::free(tmp_dev, queue);
    sycl::free(tmp_host, queue);
    sycl::free(tmp2_dev, queue);

    sparse::release_matrix_handle(queue, &hMatrixA, {}).wait();
    sparse::release_matrix_handle(queue, &hMatrixB, {}).wait();

    return 0;
}
