// ************************************************************************** //
//                                                                            //
//    eses                   eses                                             //
//   eses                     eses                                            //
//  eses    eseses  esesese    eses   Embedded Systems Group                  //
//  ese    ese  ese ese         ese                                           //
//  ese    eseseses eseseses    ese   Department of Computer Science          //
//  eses   eses          ese   eses                                           //
//   eses   eseses  eseseses  eses    University of Kaiserslautern            //
//    eses                   eses                                             //
//                                                                            //
// ************************************************************************** //
// The discrete cosine transform (DCT) is used in many applications like the  //
// JPEG image encoding. In essence, the situation is similar to the discrete  //
// Fourier transform: The definition is also a matrix multiplication with the //
// NxN matrix C having indices c[i][j] as defined below. Simple matrix        //
// multiplication requires N^2 multiplications and N(N-1) additions, hence,   //
// O(N^2) sequential time, which can be easily reduced by N^2 processors to   //
// O(log(N)) time. In analogy to the fast Fourier transform FFT, there are    //
// also fast DCT algorithms [Rezn10] that reduce the complexity to O(Nlog(N)) //
// sequential time and with N processors to O(log(N)) parallel time.          //
//   The version below is a simple reduction of the DCT to the DFT, so that   //
// any FFT algorithm can be used to compute it in O(log(N)) parallel time.    //
// This reduction is well-known, and can be found, e.g. in [Khay03].          //
// ************************************************************************** //

macro K = 4;
macro N = exp(2,K);

// macros for DCT matrix
macro pi = 3.1415926535897932384626433832795;
macro sqrt(x) = exp(x,0.5);
macro alpha(i) = (i==0 ? sqrt(1.0/N) : sqrt(2.0/N));
// arithmetic on complex numbers
macro cadd(x,y) = (x.0+y.0,x.1+y.1);
macro csub(x,y) = (x.0-y.0,x.1-y.1);
macro cmul(x,y) = (x.0*y.0-x.1*y.1,x.1*y.0+x.0*y.1);
macro Re(x) = x.0;
macro Im(x) = x.1;
// numeric precision
macro eps = 1.0e-4;
macro almost_equal(x,y) = (x<y ? y-x : x-y) < eps ; 
macro rnd(x) = (almost_equal(x,0.0) ? 0.0 : (almost_equal(x,1.0) ? 1.0 : x)); 


module DiscreteCosineTransform([N]real ?x,!y,[N](real * real) z,u) {


    // permute the inputs so that the DCT y of x is expressed by z as follows:
    // y[j] = sum(k=0..N-1) (alpha(j) * cos((pi*j*(2*k+1))/(2*N)) * x[k]);
    //      = sum(k=0..N-1) (alpha(j) * cos((pi*j*(4*k+1))/(2*N)) * Re(z[k]))
    for(j=0..N/2-1) {
        z[j] = (x[2*j],0.0);
        z[N-1-j] = (x[2*j+1],0.0);
    }
    
    // perform FFT on permuted vector z with depth O(log(N)) using O(N) procs.
    // for(j=0..N-1)
    //     let(re = sum(k=0..N-1) Re(cmul((rnd(cos((2*pi*j*k)/N)),-rnd(sin((2*pi*j*k)/N))),z[k])))
    //     let(im = sum(k=0..N-1) Im(cmul((rnd(cos((2*pi*j*k)/N)),-rnd(sin((2*pi*j*k)/N))),z[k])))
    //     u[j] = (re,im);
    FastFourierTransform(z,u);

    // scale FFT u accordingly to obtain the DCT y
    // note that u[j] = sum(k=0..N-1) (exp(e,-(2*pi*i*j*k)/N) * z[k])
    // thus, alpha(j) * exp(e,-(pi*i*j)/(2*N)) * u[j]
    //       = sum(k=0..N-1) (alpha(j) * exp(e,-(pi*i*j*(4*k+1))/(2*N)) * z[k])
    for(j=0..N-1)
        let(phi_j = (pi*j)/(2.0*N))
        let(scale_j = (cos(phi_j),-sin(phi_j)))
        y[j] = alpha(j) * Re(cmul(scale_j,u[j]));

}
drivenby {
    // test vector
    for(j=0..N-1)
        x[j] = j;
    // perform inefficient matrix multiplication and compare
    // note c[i][j] = alpha(i) * cos((pi*i*(2*j+1))/(2*N))
    // are by definition the coefficients of the DCT matrix
    for(j=0..N-1) {
        let(dct_j = sum(k=0..N-1) (alpha(j) * cos((pi*j*(2*k+1))/(2.0*N)) * x[k]))
        assert(almost_equal(dct_j,y[j]));
    }
}