// ************************************************************************** //
//                                                                            //
//    eses                   eses                                             //
//   eses                     eses                                            //
//  eses    eseses  esesese    eses   Embedded Systems Group                  //
//  ese    ese  ese ese         ese                                           //
//  ese    eseseses eseseses    ese   Department of Computer Science          //
//  eses   eses          ese   eses                                           //
//   eses   eseses  eseseses  eses    University of Kaiserslautern            //
//    eses                   eses                                             //
//                                                                            //
// ************************************************************************** //
// The module below implements a matrix multiplication that requires          //
// O(D1*D2*D3) processors to perform the multiplication in O(log(D2)) depth   //
// which is optimal.                                                          //
// ************************************************************************** //

macro D1 = 3;
macro D2 = 8;
macro D3 = 4;
macro K = log(D2);
macro WidthOfLevel(i,n) = (i==0 ? n : WidthOfLevel(i-1,(n+1)/2));


module MatrixMultCombLogN([D1][D2]int ?a, [D2][D3]int ?b, [D1][D3]int c){
    [D1][D3][K+1][D2]int d;

    // first, we use D1*D2*D3 many processors to compute the leafs  
    // of the prefix trees
    for(i=0..D1-1)
        for(j=0..D3-1)
            for(l=0..D2-1)
                d[i][j][0][l] = a[i][l] * b[l][j];

    // compute prefix sums of the d[i][j][0..D2-1]
    // deeper levels determine the balanced binary tree 
    for(i=0..D1-1)
        for(j=0..D3-1) {
            for(k=0..K-1)
                let(w1 = WidthOfLevel(k,D2))
                let(w2 = (w1+1)/2)
                for(l=0..w2-1)
                    d[i][j][k+1][l] = (2*l+1==w1 
                                       ? d[i][j][k][2*l]
                                       : d[i][j][k][2*l] + d[i][j][k][2*l+1]);
                c[i][j] = d[i][j][K][0];
            }
}
drivenby {
    for(i=0..D1-1)
        for(j=0..D2-1)
            a[i][j] = i*D2+j;
    for(i=0..D2-1)
        for(j=0..D3-1)
            b[i][j] = i*D3+j;
    // check correctness of matrix multiplication
    for(i=0..D1-1)
        for(j=0..D3-1)
            assert(c[i][j] == sum(k=0..D2-1) (a[i][k] * b[k][j]));
}