// ************************************************************************** //
//                                                                            //
//    eses                   eses                                             //
//   eses                     eses                                            //
//  eses    eseses  esesese    eses   Embedded Systems Group                  //
//  ese    ese  ese ese         ese                                           //
//  ese    eseseses eseseses    ese   Department of Computer Science          //
//  eses   eses          ese   eses                                           //
//   eses   eseses  eseseses  eses    University of Kaiserslautern            //
//    eses                   eses                                             //
//                                                                            //
// ************************************************************************** //
// The following module implements a solution to the parallel sum problem:    //
// Given N operands z[0][0..N-1] an associative operation like "+" that has   //
// to be applied to all of these operands, construct a balanced binary tree   //
// of depth O(log(N)) to compute the result. In the example below, we compute //
// the inner product of two vectors x[0..N-1] and y[0..N-1].                  //
// To this end, we first compute z[0][j] = x[j] * y[j], and then a binary     //
// tree where z[i][j] are the values determine in the i-th level of the       //
// tree. In each level, we apply the operation to pairs starting with z[i][0] //
// and z[i][1]. Since N might be an odd number, an operand z[i][N_i] is left. //
// We then forward this operand to the next level. For example, the binary    //
// tree for N=6 looks as follows:                                             //
//                                                                            //
//                z[0][5] z[0][4] z[0][3] z[0][2] z[0][1] z[0][0]             //
//                   |       |       |       |       |       |                //
//                    -------         -------         -------                 //
//                       |               |               |                    //
//                    z[1][2]         z[1][1]         z[1][0]                 //
//                       |               |               |                    //
//                       |                ---------------                     //
//                       |                       |                            //
//                    z[2][1]                 z[2][0]                         //
//                       |                       |                            //
//                        -----------------------                             //
//                                    |                                       //
//                                 z[3][0]                                    //
//                                                                            //
// For computing the width of level i, note that floor(n/2) = ceil((n+1)/2).  //
// In contrast to module PrefixSum, the above tree is encoded in a single     //
// one-dimensional array by means of an index mapping idm(i,j) that maps the  //
// two-dimensional virtual array element z[i][j] to z[idm(i,j)].              //
// ************************************************************************** //


macro N = 121;
macro K = log(N);
macro WidthOfLevel(i,n) = (i==0 ? n : WidthOfLevel(i-1,(n+1)/2));
macro SumOfWidths(j) = sum(i=0..j) WidthOfLevel(i,N);
macro idm(i,j) = (i==0 ? j : j+SumOfWidths(i-1));


module PrefixSumOpt([N]nat ?x,[SumOfWidths(K)]nat z) {
    // first row determines the N operands z[idm(0,0..N-1)]
    for(j=0..N-1)
        z[idm(0,j)] = x[j];
    // deeper levels determine the balanced binary tree 
    for(i=0..K-1)
        let(w1 = WidthOfLevel(i,N))
        let(w2 = (w1+1)/2)
        for(j=0..w2-1)
            z[idm(i+1,j)] = (2*j+1==w1 
                            ? z[idm(i,2*j)] 
                            : z[idm(i,2*j)] + z[idm(i,2*j+1)]);
}
drivenby {
    for(i=0..N-1) {
        x[i] = i+1;
    }
    assert(z[idm(K,0)] == sum(j=0..N-1) x[j]);
}