// ************************************************************************** //
//                                                                            //
//    eses                   eses                                             //
//   eses                     eses                                            //
//  eses    eseses  esesese    eses   Embedded Systems Group                  //
//  ese    ese  ese ese         ese                                           //
//  ese    eseseses eseseses    ese   Department of Computer Science          //
//  eses   eses          ese   eses                                           //
//   eses   eseses  eseseses  eses    University of Kaiserslautern            //
//    eses                   eses                                             //
//                                                                            //
// ************************************************************************** //
// The following module implements a solution to the parallel sum problem:    //
// Given N operands z[0][0..N-1] an associative operation like "+" that has   //
// to be applied to all of these operands, construct a balanced binary tree   //
// of depth O(log(N)) to compute the result. In the example below, we compute //
// the inner product of two vectors x[0..N-1] and y[0..N-1].                  //
// To this end, we first compute z[0][j] = x[j] * y[j], and then a binary     //
// tree where z[i][j] are the values determine in the i-th level of the       //
// tree. In each level, we apply the operation to pairs starting with z[i][0] //
// and z[i][1]. Since N might be an odd number, an operand z[i][N_i] is left. //
// We then forward this operand to the next level. For example, the binary    //
// tree for N=6 looks as follows:                                             //
//                                                                            //
//                z[0][5] z[0][4] z[0][3] z[0][2] z[0][1] z[0][0]             //
//                   |       |       |       |       |       |                //
//                    -------         -------         -------                 //
//                       |               |               |                    //
//                    z[1][2]         z[1][1]         z[1][0]                 //
//                       |               |               |                    //
//                       |                ---------------                     //
//                       |                       |                            //
//                    z[2][1]                 z[2][0]                         //
//                       |                       |                            //
//                        -----------------------                             //
//                                    |                                       //
//                                 z[3][0]                                    //
//                                                                            //
// For computing the width of level i, note that floor(n/2) = ceil((n+1)/2).  //
// Note that the module below is quite wasteful in that it declares in each   //
// level N values z[i][0..N-1] even though only a few are required. See the   //
// module PrefixSumOpt for an optimized version.                              //
// ************************************************************************** //


macro N = 121;
macro K = log(N);
macro WidthOfLevel(i,n) = (i==0 ? n : WidthOfLevel(i-1,(n+1)/2));

module PrefixSum([N]nat ?x,[K+1][N]nat z) {
    // first row determines the N operands z[0][0..N-1]
    for(j=0..N-1)
        z[0][j] = x[j];
    // deeper levels determine the balanced binary tree 
    for(i=0..K-1)
        let(w1 = WidthOfLevel(i,N))
        let(w2 = (w1+1)/2)
        for(j=0..w2-1)
            z[i+1][j] = (2*j+1==w1 
                        ? z[i][2*j] 
                        : z[i][2*j] + z[i][2*j+1]);
}
drivenby {
    for(i=0..N-1) {
        x[i] = i+1;
    }
    assert(z[K][0] == sum(j=0..N-1) x[j]);
}