/*******************************************************************************
* Copyright 2022 Intel Corporation.
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
! Content:
!       Example of using fftw_plan_dft_r2c_1d and fftw_plan_dft_c2r_1d 
!       function on a (GPU) device using the OpenMP target (offload) interface
!
!*****************************************************************************/

#include <stdio.h>
#include <math.h>
#include <stdlib.h>
#include <float.h>
#include "fftw/fftw3.h"
#include "fftw/offload/fftw3_omp_offload.h"

static void init_r(double *x, int N, int H);
static int verify_c(double *x, int N, int H);
static void init_c(double *x, int N, int H);
static int verify_r(double *x, int N, int H);

int main(void)
{
    /* Size of 1D transform */
    int N = 64;
    const MKL_LONG halfNplus1 = N/2 + 1;

    /* Arbitrary harmonic used to verify FFT */
    int H = -N/2;

    /* FFTW plan handles */
    fftw_plan forward_plan = 0, backward_plan = 0;
    /* Pointer to input/output data */
    double *x = NULL;

    /* Execution status */
    int statusf = 0, statusb = 0, status = 0;

    const int devNum = 0;

    printf("Example dp_plan_dft_real_1d\n");
    printf("Forward and backward 1D real inplace transform\n");
    printf("Configuration parameters:\n");
    printf(" N = %d\n", N);
    printf(" H = %d\n", H);

    printf("Allocate array for input data\n");
    x  = (double *) fftw_malloc(2*halfNplus1*sizeof(double));
    if (0 == x) goto failed;

    printf("Initialize input for forward transform\n");
    init_r(x, N, H);

    printf("Create FFTW plan for 1D double-precision Real to Complex forward transform\n");
#pragma omp target data map(tofrom:x[0:halfNplus1*2]) device(devNum)
    {
#if defined(ONEMKL_USE_OPENMP_VERSION) && (ONEMKL_USE_OPENMP_VERSION >= 202011)
#pragma omp dispatch device(devNum)
#else
#pragma omp target variant dispatch use_device_ptr(x) device(devNum)
#endif
    forward_plan = fftw_plan_dft_r2c_1d(N, x, (fftw_complex *)x, FFTW_ESTIMATE);
    if (forward_plan == 0) printf("Call to fftw_plan_dft_r2c_1d has failed");

    printf("Compute forward FFT\n");
#if defined(ONEMKL_USE_OPENMP_VERSION) && (ONEMKL_USE_OPENMP_VERSION >= 202011)
#pragma omp dispatch device(devNum)
#else
#pragma omp target variant dispatch device(devNum)
#endif
    fftw_execute(forward_plan);
    
// Update the host with the results from the forward FFT
#pragma omp target update from(x[0:halfNplus1*2])
    
    printf("Verify the results of the forward FFT\n");
    statusf = verify_c(x,N,H);

    printf("Initialize input for Complex to Real backward transform\n");
    init_c(x, N, H);
#pragma omp target update to(x[0:halfNplus1*2])

    printf("Create FFTW plan for 1D double-precision Complex to Real backward transform\n");
#if defined(ONEMKL_USE_OPENMP_VERSION) && (ONEMKL_USE_OPENMP_VERSION >= 202011)
#pragma omp dispatch device(devNum)
#else
#pragma omp target variant dispatch use_device_ptr(x) device(devNum)
#endif
    backward_plan = fftw_plan_dft_c2r_1d(N, (fftw_complex *)x, x, FFTW_ESTIMATE);
    if (backward_plan == 0) printf("Call to fftw_plan_dft_c2r_1d has failed");

    printf("Compute backward FFT\n");
#if defined(ONEMKL_USE_OPENMP_VERSION) && (ONEMKL_USE_OPENMP_VERSION >= 202011)
#pragma omp dispatch device(devNum)
#else
#pragma omp target variant dispatch device(devNum)
#endif
    fftw_execute(backward_plan);

    } // target data map

    printf("Verify the result of backward FFT\n");
    verify_r(x, N, H);

    if(statusf != 0 || statusb != 0) goto failed;

 cleanup:

    printf("Destroy FFTW plans\n");
    fftw_destroy_plan(forward_plan);
    fftw_destroy_plan(backward_plan);

    printf("Free data array\n");
    fftw_free(x);

    printf("TEST %s\n",0==status ? "PASSED" : "FAILED");
    return status;

 failed:
    printf(" ERROR\n");
    status = 1;
    goto cleanup;
}

/* Compute (K*L)%M accurately */
static double moda(int K, int L, int M)
{
    return (double)(((long long)K * L) % M);
}

const double TWOPI = 6.2831853071795864769;

/* Initialize array x[N] with harmonic H */
static void init_r(double *x, int N, int H)
{
    double factor = (2*(N-H)%N == 0) ? 1.0 : 2.0;
    for (int n = 0; n < N; n++)
    {
        double phase  = moda(n,H,N) / N;
        x[n] = factor * cos( TWOPI * phase ) / N;
    }
}

/* Verify that x has unit peak at H */
static int verify_c(double *x, int N, int H)
{
    double errthr = 2.5 * log( (double)N ) / log(2.0) * DBL_EPSILON;
    printf(" Verify the result, errthr = %.3lg\n", errthr);

    double maxerr = 0.0;
    for (MKL_LONG n = 0; n < N/2+1; n++) {
        double re_exp = 0.0, im_exp = 0.0, re_got, im_got;

        if ((n-H)%N == 0 || (-n-H)%N == 0) re_exp = 1.0;

        re_got = x[2*n + 0];
        im_got = x[2*n + 1];
        double err  = fabs(re_got - re_exp) + fabs(im_got - im_exp);
        if (err > maxerr) maxerr = err;
        if (!(err < errthr))
        {
            printf(" x[%lld]: ", n);
            printf(" expected (%.17lg,%.17lg), ",re_exp,im_exp);
            printf(" got (%.17lg,%.17lg), ",re_got,im_got);
            printf(" err %.3lg\n", err);
            printf(" Verification FAILED\n");
            return 1;
        }
    }
    printf(" Verified, maximum error was %.3lg\n", maxerr);
    return 0;
}

static void init_c(double *x, int N, int H)
{
    for (size_t n = 0; n < N/2+1; n++) {
        double phase  = moda(n, H, N) / N;
        x[2*n + 0]  =  cos(TWOPI * phase) / N;
        x[2*n + 1]  = -sin(TWOPI * phase) / N;
    }
}

// Verify that x has unit peak at H
static int verify_r(double *x, int N, int H)
{
    const double errthr = 2.5 * log((double) N) / log(2.0) * DBL_EPSILON;
    printf(" Check if err is below errthr %.3lg\n", errthr);

    double maxerr = 0.0;
    for (MKL_LONG n = 0; n < N; n++) {
        double re_exp = 0.0, re_got;

        if ((n-H)%N == 0) re_exp = 1.0;

        re_got = x[n];
        double err  = fabs(re_got - re_exp);
        if (err > maxerr) maxerr = err;
        if (!(err < errthr)) {
            printf(" x[%lld]: ", n);
            printf(" expected %.7g, ", re_exp);
            printf(" got %.7g, ", re_got);
            printf(" err %.3lg\n", err);
            printf(" Verification FAILED\n");
            return 1;
        }
    }
    printf(" Verified,  maximum error was %.3lg\n", maxerr);
    return 0;
}
