// ************************************************************************** //
//                                                                            //
//    eses                   eses                                             //
//   eses                     eses                                            //
//  eses    eseses  esesese    eses   Embedded Systems Group                  //
//  ese    ese  ese ese         ese                                           //
//  ese    eseseses eseseses    ese   Department of Computer Science          //
//  eses   eses          ese   eses                                           //
//   eses   eseses  eseseses  eses    University of Kaiserslautern            //
//    eses                   eses                                             //
//                                                                            //
// ************************************************************************** //
// The following module implements an algorithm for subtraction of binary     //
// signed digit numbers. It is simply based on adding the negated number. Note//
// that changing the sign of a SD number is simply done by changing the sign  //
// of each digit in parallel. For the chosen encoding, this is simply done by //
// swapping the components. The depth of the circuit is also O(1).            //
// ************************************************************************** //

macro N = 4;

macro enc(x) = (x==-1 ? (true,false) : (x==0 ? (false,false) : (false,true)));
macro dec(x) = (x.0?-1:0) + (x.1?+1:0);
macro sdval(x,k) = sum(i=0..k-1) (dec(x[i]) * exp(2,i));
macro isSDnum(x,k) = forall(i=0..k-1) !(x[i].0 & x[i].1);

macro lin(i) = (i==0 ? false : lout[i-1]);
macro tin(i) = (i==0 ? (false,false) : tout[i-1]);


module SgnSub(event [N](bool*bool) ?x,?y, [N+1](bool*bool) s) {
    event [N]bool w1,w2,w3,w4,w,u1,u0;
    event [N]bool lout;
    event [N](bool*bool) tout;
    loop {
       pause;
       for(i=0..N-1) {
            w1[i] = !x[i].0 & !x[i].1 & y[i].0;
            w2[i] = !x[i].0 & !x[i].1 & y[i].1;
            w3[i] = !y[i].1 & !y[i].0 & x[i].1;
            w4[i] = !y[i].1 & !y[i].0 & x[i].0;
            w[i]  = w1[i] | w2[i] | w3[i] | w4[i];
            u1[i] = !lin(i) & w[i];
            u0[i] =  lin(i) & w[i];
            lout[i]   = x[i].0 | y[i].1;
            tout[i].0 = x[i].0 & y[i].1 |  lin(i) & (w2[i] | w4[i]);
            tout[i].1 = x[i].1 & y[i].0 | !lin(i) & (w1[i] | w3[i]);
            s[i].0 = tin(i).0 & !u0[i] | u1[i] & !tin(i).1;
            s[i].1 = tin(i).1 & !u1[i] | u0[i] & !tin(i).0;
       }
       s[N].0 = tout[N-1].0;
       s[N].1 = tout[N-1].1;
       assert(isSDnum(x,N) & isSDnum(y,N) 
              -> sdval(s,N+1) == sdval(x,N) - sdval(y,N)
               & isSDnum(s,N+1) );
   }
}
drivenby {
   nat{exp(2,2*N+1)} i,j;
    bv{2*N} vi,vj;
    pause;
    // enumerate all possible bitvectors for x and y to check the assertion
    do {
        vi = nat2bv(i,2*N);
        for(k=0..N-1) {
            x[k].0 = vi{2*k};
            x[k].1 = vi{2*k+1};
        }
        // enumerate all bitvectors for y
        do {
            next(j) = j+1;
            for(k=0..N-1) {
                y[k].0 = vj{2*k};
                y[k].1 = vj{2*k+1};
            }
            vj = nat2bv(j,2*N);
            pause;
        } while(j<exp(2,2*N));
        vj = nat2bv(j,2*N);
        next(i) = i+1;
        next(j) = 0;
        pause;
   } while(i<exp(2,2*N)); 
}