/*******************************************************************************
* Copyright 2023 Intel Corporation.
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
*
*  Content:
*       This example demonstrates use of oneAPI Math Kernel Library (oneMKL)
*       API oneapi::mkl::dft to perform 2-D Single Precision Complex to Complex
*       Fast-Fourier Transform using external workspace on a SYCL GPU device.
*       The external workspace is not supported on CPU.
*
*       The supported floating point data types for data are:
*           float
*           std::complex<float>
*
*******************************************************************************/

#include <vector>
#include <iostream>
#include <CL/sycl.hpp>
#include "oneapi/mkl/dfti.hpp"

#include <stdexcept>
#include <cfloat>
#include <cstddef>
#include <limits>
#include <type_traits>
#include "mkl.h"

// local includes
#define NO_MATRIX_HELPERS
#include "common_for_examples.hpp"

typedef oneapi::mkl::dft::descriptor<oneapi::mkl::dft::precision::SINGLE, oneapi::mkl::dft::domain::COMPLEX> descriptor_t;

constexpr int SUCCESS = 0;
constexpr int FAILURE = 1;
constexpr float TWOPI = 6.2831853071795864769f;

// Compute (K*L)%M accurately
static float moda(int K, int L, int M)
{
    return (float)(((long long)K * L) % M);
}

static void init(float *data, int N1, int N2, int H1, int H2)
{
    // Generalized strides for row-major addressing of data
    int S1 = 1, S2 = N1;

    for (int n2 = 0; n2 < N2; ++n2) {
        for (int n1 = 0; n1 < N1; ++n1) {
            float phase = TWOPI * (moda(n1, H1, N1) / N1
                                       + moda(n2, H2, N2) / N2);
            int index = 2*(n2*S2 + n1*S1);
            data[index+0] = cosf(phase) / (N2*N1);
            data[index+1] = sinf(phase) / (N2*N1);
        }
    }
}

static int verify_bwd(const float* data, int N1, int N2, int H1, int H2) {
    // Note: this simple error bound doesn't take into account error of
    //       input data
    float errthr = 5.0f * logf((float) N2*N1) / logf(2.0f) * FLT_EPSILON;
    std::cout << "\t\tVerify the result, errthr = " << errthr << std::endl;

    // Generalized strides for row-major addressing of data
    int S1 = 1, S2 = N1;

    float maxerr = 0.0f;
    for (int n2 = 0; n2 < N2; n2++) {
        for (int n1 = 0; n1 < N1; n1++) {
            float phase = TWOPI * (moda(n1, H1, N1) / N1
                                       + moda(n2, H2, N2) / N2);
            float re_exp = cosf(phase) / (N2*N1);
            float im_exp = sinf(phase) / (N2*N1);

            int index = 2*(n2*S2 + n1*S1);
            float re_got = data[index+0];  // real component
            float im_got = data[index+1];  // imaginary component
            float err  = fabsf(re_got - re_exp) + fabsf(im_got - im_exp);
            if (err > maxerr) maxerr = err;
            if (!(err < errthr)) {
                std::cout << "\t\tdata[" << n2 << ", " << n1 << "]: "
                          << "expected (" << re_exp << "," << im_exp << "), "
                          << "got (" << re_got << "," << im_got << "), "
                          << "err " << err << std::endl;
                std::cout << "\t\tVerification FAILED" << std::endl;
                return FAILURE;
            }
        }
    }
    std::cout << "\t\tVerified, maximum error was " << maxerr << std::endl;
    return SUCCESS;
}

int run_dft_example(sycl::device &dev) {
    //
    // Initialize data for DFT
    //
    int N1 = 256, N2 = 512;
    int H1 = -1, H2 = -2;
    int result = FAILURE;

    float* in = (float*) mkl_malloc(N2*N1*2*sizeof(float), 64);
    init(in, N1, N2, H1, H2);

    //
    // Execute DFT
    //
    try {
        // Catch asynchronous exceptions
        auto exception_handler = [] (sycl::exception_list exceptions) {
            for (std::exception_ptr const& e : exceptions) {
                try {
                    std::rethrow_exception(e);
                } catch(sycl::exception const& e) {
                    std::cout << "Caught asynchronous SYCL exception:" << std::endl
                              << e.what() << std::endl;
                }
            }
        };

        // create execution queue with asynchronous error handling
        sycl::queue queue(dev, exception_handler);

        // Setting up USM and initialization
        float *in_usm = (float*) malloc_shared(N2*N1*2*sizeof(float), queue.get_device(), queue.get_context());
        init(in_usm, N1, N2, H1, H2);

        descriptor_t desc({N2, N1});
        desc.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, (1.0/(N1*N2)));

        // Letting the 'desc' know that the workspace will be provided
        // by the user.
        desc.set_value(oneapi::mkl::dft::config_param::WORKSPACE,
                       oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL);

        // Obtain the workspace estimate.
        size_t workspace_estimate_size_bytes = 0, workspace_size_bytes = 0;
        desc.get_value(oneapi::mkl::dft::config_param::WORKSPACE_ESTIMATE_BYTES,
                       &workspace_estimate_size_bytes);
        std::cout << "Estimated workspace size : " << workspace_estimate_size_bytes << std::endl;

        desc.commit(queue);

        // Get the exact amount of workspace that is required after 'commit'
        desc.get_value(oneapi::mkl::dft::config_param::WORKSPACE_BYTES,
                       &workspace_size_bytes);

        std::cout << "Exact workspace size : " << workspace_size_bytes << std::endl;

        // Allocate the USM workspace that will be used by the 'desc'
        float *workspace_usm = (float*) malloc_device(workspace_size_bytes, queue.get_device(), queue.get_context());

        // Set the workspace
        desc.set_workspace(workspace_usm);

        // Using USM
        std::cout<<"\tUsing USM"<<std::endl;
        auto fwd = oneapi::mkl::dft::compute_forward(desc, in_usm);

        auto bwd = oneapi::mkl::dft::compute_backward(desc, in_usm, {fwd});
        bwd.wait();
        result = verify_bwd(in_usm, N1, N2, H1, H2);

        free(in_usm, queue.get_context());
        free(workspace_usm, queue.get_context());
    }
    catch(sycl::exception const& e) {
        std::cout << "\t\tSYCL exception during FFT" << std::endl;
        std::cout << "\t\t" << e.what() << std::endl;
        std::cout << "\t\tError code: " << get_error_code(e) << std::endl;
    }
    catch(std::runtime_error const& e) {
        std::cout << "\t\truntime exception during FFT" << std::endl;
        std::cout << "\t\t" << e.what() << std::endl;
    }
    mkl_free(in);

    return result;
}

//
// Description of example setup, apis used and supported floating point type precisions
//
void print_example_banner() {
    std::cout << "" << std::endl;
    std::cout << "########################################################################" << std::endl;
    std::cout << "# 2D FFT Complex-Complex Single-Precision Workspace Example: " << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Using apis:" << std::endl;
    std::cout << "#   dft" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Supported floating point type precisions:" << std::endl;
    std::cout << "#   float" << std::endl;
    std::cout << "#   std::complex<float>" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "########################################################################" << std::endl;
    std::cout << std::endl;
}

//
// Main entry point for example.
//
// Dispatches to appropriate device types as set at build time with flag:
// -DSYCL_DEVICES_cpu -- Workspace is not supported on CPU device. This example is ignored.
// -DSYCL_DEVICES_gpu -- only runs SYCL GPU implementation
// -DSYCL_DEVICES_all (default) -- Runs on GPU, ignores CPU.
//
//  For each device selected and each supported data type, Basic_Sp_C2C_2D_FFTExample
//  runs is with all supported data types
//
int main() {
    print_example_banner();

    std::list<my_sycl_device_types> list_of_devices;
    set_list_of_devices(list_of_devices);

    int returnCode = SUCCESS;
    for (auto it = list_of_devices.begin(); it != list_of_devices.end(); ++it) {
        sycl::device my_dev;
        bool my_dev_is_found = false;
        get_sycl_device(my_dev, my_dev_is_found, *it);

        if (my_dev_is_found) {
            if (!my_dev.is_gpu()) {
                std::cout << "DFT with external workspace is only supported on GPU. Skipping tests on " << sycl_device_names[*it] << ".\n";
                continue;
            }
            std::cout << "Running tests on " << sycl_device_names[*it] << ".\n";

            std::cout << "\tRunning with single precision complex-to-complex 2-D FFT:" << std::endl;
            int status = run_dft_example(my_dev);
            if (status != SUCCESS) {
                std::cout << "\tTest Failed" << std::endl << std::endl;
                returnCode = status;
            } else {
                std::cout << "\tTest Passed" << std::endl << std::endl;
            }
        } else {
#ifdef FAIL_ON_MISSING_DEVICES
            std::cout << "No " << sycl_device_names[*it] << " devices found; Fail on missing devices is enabled." << std::endl;
            return 1;
#else
            std::cout << "No " << sycl_device_names[*it] << " devices found; skipping " << sycl_device_names[*it] << " tests." << std::endl << std::endl;
#endif
        }
    }

    mkl_free_buffers();
    return returnCode;
}
