package FloatingPoint;







// ************************************************************************** //
//                                                                            //
//    eses                   eses                                             //
//   eses                     eses                                            //
//  eses    eseses  esesese    eses   Embedded Systems Group                  //
//  ese    ese  ese ese         ese                                           //
//  ese    eseseses eseseses    ese   Department of Computer Science          //
//  eses   eses          ese   eses                                           //
//   eses   eseses  eseseses  eses    University of Kaiserslautern            //
//    eses                   eses                                             //
//                                                                            //
// ************************************************************************** //
// This module implements square root computation of fixed-point numbers.     //
// Consider this module as a helper function to implement square root         //
// computation of floating point numbers that conform to the IEEE standard.   //
// ************************************************************************** //

// BITWIDTH and FIXWIDTH are adapted to the floating point format of FloatSquareRoot.
// BITWIDTH is fixed to two bit: mantissa has one number left of the comma and may
// be shifted left to correct odd exponents (see FloatSquareRoot for explaination).
// The FIXWIDTH is computed as follows: width of mantissa plus additional precision
// bits to allow correct rounding and left shifts to get a normalized numbers.

// Set to use for test cases:
//macro BITWIDTH         = 8;  // use 8 bit int width for test cases
//macro FIXWIDTH         = 6;  // use 6 bit precision for test cases

// Set to use for call from FloatSquareRoot
macro BITWIDTH         = 2;   // use 2 bit for usage of FixedSquareRoot in FloatSquareRoot
macro FIXWIDTH         = 5;   // use width of mantissa + 1 (+ 1 to compute the r bit to determine rounding direction)

macro OVERALLWIDTH     = BITWIDTH + FIXWIDTH;
macro OVERALLWIDTH_TMP = BITWIDTH + FIXWIDTH*2;

macro MULT_CORR = exp(2,FIXWIDTH);

// fixed-point version
// a     = input number
// r     = result
// exact = true when r*r=a, false if r*r< a < (r+1)*(r+1)
// ready = result was computed, r and exact contain the result
module FixedSquareRoot(bv{OVERALLWIDTH} ?a, bv{OVERALLWIDTH} r, bool !exact, event !ready) {
    bv{OVERALLWIDTH_TMP} ns;         // =n^2
    bv{OVERALLWIDTH_TMP} n2i;        // =2^{2i}: used for iterative approximation of n.
    bv{OVERALLWIDTH_TMP} ni;         // =2^i: used for iterative approximation of n.
    bv{OVERALLWIDTH_TMP} remainingA; // aCopy - n^2

    r = {false::OVERALLWIDTH};
    ns = {false::OVERALLWIDTH_TMP};
    remainingA = a@{false::(OVERALLWIDTH_TMP - OVERALLWIDTH)};
    if(a=={false::OVERALLWIDTH}) {
        ni = {false::OVERALLWIDTH_TMP};
    } else {
        ni = {false::(BITWIDTH/2)}@true@{false::(BITWIDTH/2-1+FIXWIDTH*2)};
    }
    n2i = false@true@{false::(OVERALLWIDTH_TMP-2)};

    while ((n2i != {false::OVERALLWIDTH_TMP}) & (remainingA != {false::OVERALLWIDTH_TMP})) {
        if(bv2nat(n2i) + bv2nat(ns) <= bv2nat(remainingA)) {
            next(remainingA) = nat2bv(bv2nat(remainingA) - bv2nat(n2i) - bv2nat(ns), OVERALLWIDTH_TMP);
            next(r) = r | ni{-1:FIXWIDTH};
            next(ns) = nat2bv(bv2nat(false@ns{-1:1}) + bv2nat(n2i), OVERALLWIDTH_TMP);
        } else {
            next(ns) = false@ns{-1:1};  // ns/2
        }
        next(n2i) = {false::2}@(n2i{-1:2}); // n2i/4
        next(ni) = false@(ni{-1:1}); // ni/2
        pause;
    }

    // If remainingA!=0 => the result is not exact, r=floor(sqrt(a))<sqrt(a)
    exact = (remainingA=={false::OVERALLWIDTH_TMP});

    emit(ready);
}
satisfies {
    // Verification with model checker is not testet, yet!
    x : assert(
        A (
            ready ->
                ( bv2nat(r)*bv2nat(r) <= bv2nat(a)*MULT_CORR &
                  bv2nat(a)*MULT_CORR < (bv2nat(r)+1)*(bv2nat(r)+1)
                )
          )
        );
}
// input: int numbers
drivenby { // 1
    a = nat2bv(0, BITWIDTH)@{false::FIXWIDTH};
    immediate await(ready);
    assert( bv2nat(r)*bv2nat(r) <= bv2nat(a)*MULT_CORR & bv2nat(a)*MULT_CORR < (bv2nat(r)+1)*(bv2nat(r)+1) );
    assert( exact -> ( bv2nat(r)*bv2nat(r) == bv2nat(a)*MULT_CORR ) );
    assert( !exact -> ( bv2nat(r)*bv2nat(r) < bv2nat(a)*MULT_CORR ) );
}
drivenby { // 2
    a = nat2bv(1, BITWIDTH)@{false::FIXWIDTH};
    immediate await(ready);
    assert( bv2nat(r)*bv2nat(r) <= bv2nat(a)*MULT_CORR & bv2nat(a)*MULT_CORR < (bv2nat(r)+1)*(bv2nat(r)+1) );
    assert( exact -> ( bv2nat(r)*bv2nat(r) == bv2nat(a)*MULT_CORR ) );
    assert( !exact -> ( bv2nat(r)*bv2nat(r) < bv2nat(a)*MULT_CORR ) );
}
drivenby { // 3
    a = nat2bv(1, BITWIDTH)@{false::FIXWIDTH};
    immediate await(ready);
    assert( bv2nat(r)*bv2nat(r) <= bv2nat(a)*MULT_CORR & bv2nat(a)*MULT_CORR < (bv2nat(r)+1)*(bv2nat(r)+1) );
    assert( exact -> ( bv2nat(r)*bv2nat(r) == bv2nat(a)*MULT_CORR ) );
    assert( !exact -> ( bv2nat(r)*bv2nat(r) < bv2nat(a)*MULT_CORR ) );
}
drivenby { // 4
    a = nat2bv(3, BITWIDTH)@{false::FIXWIDTH};
    immediate await(ready);
    assert( bv2nat(r)*bv2nat(r) <= bv2nat(a)*MULT_CORR & bv2nat(a)*MULT_CORR < (bv2nat(r)+1)*(bv2nat(r)+1) );
    assert( exact -> ( bv2nat(r)*bv2nat(r) == bv2nat(a)*MULT_CORR ) );
    assert( !exact -> ( bv2nat(r)*bv2nat(r) < bv2nat(a)*MULT_CORR ) );
}
drivenby { // 5
    a = nat2bv(4, BITWIDTH)@{false::FIXWIDTH};
    immediate await(ready);
    assert( bv2nat(r)*bv2nat(r) <= bv2nat(a)*MULT_CORR & bv2nat(a)*MULT_CORR < (bv2nat(r)+1)*(bv2nat(r)+1) );
    assert( exact -> ( bv2nat(r)*bv2nat(r) == bv2nat(a)*MULT_CORR ) );
    assert( !exact -> ( bv2nat(r)*bv2nat(r) < bv2nat(a)*MULT_CORR ) );
}
drivenby { // 6
    a = nat2bv(7, BITWIDTH)@{false::FIXWIDTH};
    immediate await(ready);
    assert( bv2nat(r)*bv2nat(r) <= bv2nat(a)*MULT_CORR & bv2nat(a)*MULT_CORR < (bv2nat(r)+1)*(bv2nat(r)+1) );
    assert( exact -> ( bv2nat(r)*bv2nat(r) == bv2nat(a)*MULT_CORR ) );
    assert( !exact -> ( bv2nat(r)*bv2nat(r) < bv2nat(a)*MULT_CORR ) );
}
drivenby { // 7
    a = nat2bv(13, BITWIDTH)@{false::FIXWIDTH};
    immediate await(ready);
    assert( bv2nat(r)*bv2nat(r) <= bv2nat(a)*MULT_CORR & bv2nat(a)*MULT_CORR < (bv2nat(r)+1)*(bv2nat(r)+1) );
    assert( exact -> ( bv2nat(r)*bv2nat(r) == bv2nat(a)*MULT_CORR ) );
    assert( !exact -> ( bv2nat(r)*bv2nat(r) < bv2nat(a)*MULT_CORR ) );
}
drivenby { // 8
    a = nat2bv(23, BITWIDTH)@{false::FIXWIDTH};
    immediate await(ready);
    assert( bv2nat(r)*bv2nat(r) <= bv2nat(a)*MULT_CORR & bv2nat(a)*MULT_CORR < (bv2nat(r)+1)*(bv2nat(r)+1) );
    assert( exact -> ( bv2nat(r)*bv2nat(r) == bv2nat(a)*MULT_CORR ) );
    assert( !exact -> ( bv2nat(r)*bv2nat(r) < bv2nat(a)*MULT_CORR ) );
}
drivenby { // 9
    a = nat2bv(67, BITWIDTH)@{false::FIXWIDTH};
    immediate await(ready);
    assert( bv2nat(r)*bv2nat(r) <= bv2nat(a)*MULT_CORR & bv2nat(a)*MULT_CORR < (bv2nat(r)+1)*(bv2nat(r)+1) );
    assert( exact -> ( bv2nat(r)*bv2nat(r) == bv2nat(a)*MULT_CORR ) );
    assert( !exact -> ( bv2nat(r)*bv2nat(r) < bv2nat(a)*MULT_CORR ) );
}
drivenby { // 10
    a = nat2bv(99, BITWIDTH)@{false::FIXWIDTH};
    immediate await(ready);
    assert( bv2nat(r)*bv2nat(r) <= bv2nat(a)*MULT_CORR & bv2nat(a)*MULT_CORR < (bv2nat(r)+1)*(bv2nat(r)+1) );
    assert( exact -> ( bv2nat(r)*bv2nat(r) == bv2nat(a)*MULT_CORR ) );
    assert( !exact -> ( bv2nat(r)*bv2nat(r) < bv2nat(a)*MULT_CORR ) );
}
drivenby { // 11
    a = nat2bv(100, BITWIDTH)@{false::FIXWIDTH};
    immediate await(ready);
    assert( bv2nat(r)*bv2nat(r) <= bv2nat(a)*MULT_CORR & bv2nat(a)*MULT_CORR < (bv2nat(r)+1)*(bv2nat(r)+1) );
    assert( exact -> ( bv2nat(r)*bv2nat(r) == bv2nat(a)*MULT_CORR ) );
    assert( !exact -> ( bv2nat(r)*bv2nat(r) < bv2nat(a)*MULT_CORR ) );
}
drivenby { // 12
    a = nat2bv(101, BITWIDTH)@{false::FIXWIDTH};
    immediate await(ready);
    assert( bv2nat(r)*bv2nat(r) <= bv2nat(a)*MULT_CORR & bv2nat(a)*MULT_CORR < (bv2nat(r)+1)*(bv2nat(r)+1) );
    assert( exact -> ( bv2nat(r)*bv2nat(r) == bv2nat(a)*MULT_CORR ) );
    assert( !exact -> ( bv2nat(r)*bv2nat(r) < bv2nat(a)*MULT_CORR ) );
}
drivenby { // 13
    a = nat2bv(144, BITWIDTH)@{false::FIXWIDTH};
    immediate await(ready);
    assert( bv2nat(r)*bv2nat(r) <= bv2nat(a)*MULT_CORR & bv2nat(a)*MULT_CORR < (bv2nat(r)+1)*(bv2nat(r)+1) );
    assert( exact -> ( bv2nat(r)*bv2nat(r) == bv2nat(a)*MULT_CORR ) );
    assert( !exact -> ( bv2nat(r)*bv2nat(r) < bv2nat(a)*MULT_CORR ) );
}
drivenby { // 14
    a = nat2bv(255, BITWIDTH)@{false::FIXWIDTH};
    immediate await(ready);
    assert( bv2nat(r)*bv2nat(r) <= bv2nat(a)*MULT_CORR & bv2nat(a)*MULT_CORR < (bv2nat(r)+1)*(bv2nat(r)+1) );
    assert( exact -> ( bv2nat(r)*bv2nat(r) == bv2nat(a)*MULT_CORR ) );
    assert( !exact -> ( bv2nat(r)*bv2nat(r) < bv2nat(a)*MULT_CORR ) );
}
// input: real numbers
drivenby { // 15
    a = nat2bv(0, BITWIDTH)@{true::FIXWIDTH};
    immediate await(ready);
    assert( bv2nat(r)*bv2nat(r) <= bv2nat(a)*MULT_CORR & bv2nat(a)*MULT_CORR < (bv2nat(r)+1)*(bv2nat(r)+1) );
    assert( exact -> ( bv2nat(r)*bv2nat(r) == bv2nat(a)*MULT_CORR ) );
    assert( !exact -> ( bv2nat(r)*bv2nat(r) < bv2nat(a)*MULT_CORR ) );
}
drivenby { // 16
    a = nat2bv(0, BITWIDTH)@true@{false::(FIXWIDTH-1)};
    immediate await(ready);
    assert( bv2nat(r)*bv2nat(r) <= bv2nat(a)*MULT_CORR & bv2nat(a)*MULT_CORR < (bv2nat(r)+1)*(bv2nat(r)+1) );
    assert( exact -> ( bv2nat(r)*bv2nat(r) == bv2nat(a)*MULT_CORR ) );
    assert( !exact -> ( bv2nat(r)*bv2nat(r) < bv2nat(a)*MULT_CORR ) );
}
drivenby { // 17
    a = nat2bv(0, BITWIDTH)@false@true@{false::(FIXWIDTH-2)};
    immediate await(ready);
    assert( bv2nat(r)*bv2nat(r) <= bv2nat(a)*MULT_CORR & bv2nat(a)*MULT_CORR < (bv2nat(r)+1)*(bv2nat(r)+1) );
    assert( exact -> ( bv2nat(r)*bv2nat(r) == bv2nat(a)*MULT_CORR ) );
    assert( !exact -> ( bv2nat(r)*bv2nat(r) < bv2nat(a)*MULT_CORR ) );
}
drivenby { // 18
    a = nat2bv(1, BITWIDTH)@true@true@{false::(FIXWIDTH-2)}; // 1.75
    immediate await(ready);
    assert( bv2nat(r)*bv2nat(r) <= bv2nat(a)*MULT_CORR & bv2nat(a)*MULT_CORR < (bv2nat(r)+1)*(bv2nat(r)+1) );
    assert( exact -> ( bv2nat(r)*bv2nat(r) == bv2nat(a)*MULT_CORR ) );
    assert( !exact -> ( bv2nat(r)*bv2nat(r) < bv2nat(a)*MULT_CORR ) );
}
drivenby { // 19
    a = nat2bv(1, BITWIDTH)@false@false@true@{false::(FIXWIDTH-3)}; // 1.125
    immediate await(ready);
    assert( bv2nat(r)*bv2nat(r) <= bv2nat(a)*MULT_CORR & bv2nat(a)*MULT_CORR < (bv2nat(r)+1)*(bv2nat(r)+1) );
    assert( exact -> ( bv2nat(r)*bv2nat(r) == bv2nat(a)*MULT_CORR ) );
    assert( !exact -> ( bv2nat(r)*bv2nat(r) < bv2nat(a)*MULT_CORR ) );
}
drivenby { // 20
    a = nat2bv(1, BITWIDTH)@false@true@{false::(FIXWIDTH-2)}; // 1.25
    immediate await(ready);
    assert( bv2nat(r)*bv2nat(r) <= bv2nat(a)*MULT_CORR & bv2nat(a)*MULT_CORR < (bv2nat(r)+1)*(bv2nat(r)+1) );
    assert( exact -> ( bv2nat(r)*bv2nat(r) == bv2nat(a)*MULT_CORR ) );
    assert( !exact -> ( bv2nat(r)*bv2nat(r) < bv2nat(a)*MULT_CORR ) );
}

drivenby { // 21
    a = nat2bv(1, BITWIDTH)@{true::4}@{false::(FIXWIDTH-4)};
    immediate await(ready);
    assert( bv2nat(r)*bv2nat(r) <= bv2nat(a)*MULT_CORR & bv2nat(a)*MULT_CORR < (bv2nat(r)+1)*(bv2nat(r)+1) );
    assert( exact -> ( bv2nat(r)*bv2nat(r) == bv2nat(a)*MULT_CORR ) );
    assert( !exact -> ( bv2nat(r)*bv2nat(r) < bv2nat(a)*MULT_CORR ) );
}
drivenby { // 22
    a = nat2bv(3, BITWIDTH)@{true::3}@{false::(FIXWIDTH-3)};
    immediate await(ready);
    assert( bv2nat(r)*bv2nat(r) <= bv2nat(a)*MULT_CORR & bv2nat(a)*MULT_CORR < (bv2nat(r)+1)*(bv2nat(r)+1) );
    assert( exact -> ( bv2nat(r)*bv2nat(r) == bv2nat(a)*MULT_CORR ) );
    assert( !exact -> ( bv2nat(r)*bv2nat(r) < bv2nat(a)*MULT_CORR ) );
}