package FloatingPoint;

import FloatingPoint.FixedSquareRoot;
/*import FloatingPoint.FloatMult;
import FloatingPoint.FloatAbs;
import FloatingPoint.FloatSubtract;*/




// ************************************************************************** //
//                                                                            //
//    eses                   eses                                             //
//   eses                     eses                                            //
//  eses    eseses  esesese    eses   Embedded Systems Group                  //
//  ese    ese  ese ese         ese                                           //
//  ese    eseseses eseseses    ese   Department of Computer Science          //
//  eses   eses          ese   eses                                           //
//   eses   eseses  eseseses  eses    University of Kaiserslautern            //
//    eses                   eses                                             //
//                                                                            //
// ************************************************************************** //
// This module implements square root computation of IEEE-like floating point //
// numbers. The algorithm below deals with the full IEEE standard including   //
// denormal and normal numbers, infinity and NaN values.                      //
//                                                                            //
// Let FP be the floating point number for which the square root has to be    //
// computed. In particular, let s,exp,mant be the parts of FP:                //
// FP = (-1)^s * 2^(exp-bias) * Mant(mant)                                    //
// With s=0, since sqrt(R->R) is only defined for positive numbers, and       //
// Mant(mant) is 1.mant for normalized and 0.mant*2 for denormalized numbers  //
// (*2 is the correction of exp for denormalized numbers)                     //
//                                                                            //
// It follows                                                                 //
// sqrt(FP) = sqrt( 2^(exp-bias) * Mant(mant) )                               //
//          = 2^( (exp-bias)/2 ) * sqrt( Mant(mant) )                         //
//                                                                            //
// To simplify the computation to integer arithmetic:                         //
// sqrt(FP) = 2^( (exp-bias)/2 ) * sqrt( Mant(mant) ) for exp-bias is even    //
//            2^( (exp-bias-1)/2 ) * sqrt( Mant(mant)*2 ) for exp-bias is odd //
//                                                                            //
// Hence, we required the computation of the square root of fixed point       //
// numbers. We therefore use FixedSquareRoot.                                 //
//
// NOTE: FixedSquareRoot is used to compute the square root of the mantissa.  //
//       Therefore, the bit width must be adapted to MantWidth                //
// ************************************************************************** //

// parameters of the floating-point format
macro MantWidth = 4;   // test cases 6-21
//macro MantWidth = 5;   // test cases 6b-7b
//macro MantWidth = 256;   // test cases 6c
//macro MantWidth = 257;   // test cases 6c
macro ExpWidth  = 3;
macro ExpMax    = exp(2,ExpWidth)-1;
macro ExpBias   = ExpMax / 2;
macro ExpMin    = +1 - ExpBias;
macro Width     = 1+ExpWidth+MantWidth;

macro MantMinNorm = exp(2,MantWidth-1);  // smallest normalized mantissa
macro MantMax     = exp(2,MantWidth)-1;  // largest mantissa value

// macros for extracting parts of the floating point number
macro Sign(x) = x{-1};
macro Exp(x)  = x{-2:MantWidth};
macro Mant(x) = x{MantWidth-1:0};
macro RmHBit(x) = Mant(x);            // remove hidden bit

// macros for interpretation of the floating point numbers
macro ValSign(s)    = (s?-1.0:1.0);
macro ValNMant(m)   = bv2nat(true@m)  * exp(2.0,-1*MantWidth);
macro ValDMant(m)   = bv2nat(false@m) * exp(2.0,-1*MantWidth);
macro ValMant(e,m)  = (e=={false::ExpWidth} ? ValDMant(m) : ValNMant(m));
macro ValNExp(e)    = (+bv2nat(e))-ExpBias;
macro ValDExp(e)    = +1-ExpBias;
macro ValExp(e)     = (e=={false::ExpWidth} ? ValDExp(e) : ValNExp(e));
macro Float2Real(x) = ValSign(Sign(x)) * ValMant(Exp(x),Mant(x)) * exp(2.0,ValExp(Exp(x)));

// special values
macro posZero = false@{false::ExpWidth}@{false::MantWidth};
macro negZero =  true@{false::ExpWidth}@{false::MantWidth};
macro posInf  = false@{ true::ExpWidth}@{false::MantWidth};
macro negInf  =  true@{ true::ExpWidth}@{false::MantWidth};
macro posNaN  = false@{ true::ExpWidth}@{ true::MantWidth};
macro negNaN  =  true@{ true::ExpWidth}@{ true::MantWidth};


// Definition of some operations: declaration as macro
macro isZero(x)      = ( ( Exp(x) == {false::ExpWidth} ) & ( Mant(x) == {false::MantWidth} ) );
macro isNormal(x)    = ( 1 <= bv2nat(Exp(x)) ) & ( bv2nat(Exp(x)) <= (exp(2,ExpWidth)-2) );
macro isSubnormal(x) = ( Exp(x) == {false::ExpWidth} ) & ( Mant(x) != {false::MantWidth} );
macro isFinite(x)    = ( isZero(x) | isSubnormal(x) | isNormal(x) );
macro isInfinite(x)  = ( Exp(x) == {true::ExpWidth} ) & ( Mant(x) == {false::MantWidth} );
macro isNaN(x)       = ( Exp(x) == {true::ExpWidth} ) & ( Mant(x) != {false::MantWidth} );
macro isSignMinus(x) = ( Sign(x) != false );

// Macros that return a float
macro fAbs(x)         = (false@Exp(x)@Mant(x));
macro fCopySign(x,y)  = (Sign(y)@Exp(x)@Mant(x));
macro fNegate(x)      = ((Sign(x) ? false : true)@Exp(x)@Mant(x));


// FloatSquareRoot: compute the square root of a floating point number
//
// x     = input
// z     = result (valid when ready is set)
// ready = signalizes that the result has been computed.
//
// NOTE: FloatSquareRoot uses FixedSquareRoot, which must have the correct
// settings of BITWIDTH and FIXWIDTH.
// macro BITWIDTH         = 2;   // use 2 bit for usage of FixedSquareRoot in FloatSquareRoot
// macro FIXWIDTH         = 5;   // use width of mantissa + 1 (+ 1 to compute the r bit to determine rounding direction)

module FloatSquareRoot(bv{Width} ?x,!z, event !ready) {
    bv{MantWidth+3} m_x;       // 1.Mant(x) yields 0,1,Mant(x),0
    bv{MantWidth+3} m_sqrt_x;
    bool sqrt_exact;
    bool r_bit;
    bool s_bit;
    bool fixSqrtReady;
    int  newExp;

    if( isNaN(x) ) {
        z = x;
    } else if ( isZero(x) ) {
        z = x;    // sqrt(-0)=-0 => special case according to IEEE-754 !!
    } else if ( Sign(x) ) {
        // TODO: set qNaN
        z = posNaN;
    } else if ( isInfinite(x) ) {
        z = x;
    } else {
        // Create mantissa that is always 01.xxxxx or 10.xxxxx.
        // Depending on the epxonent
        // normalized / denormalized numbers
        if( isSubnormal(x) ) {
            // We use x = (-1)^sign * 2^(exp - bias) * m{-1}.m{-2:0} instead of
            //        x = (-1)^sign * 2^(exp_min==1 + exp - bias) * 0.m
            // subnormal --> exp=0 --> exp-bias=-bias=-0b0111...11 --> exponent is always odd: 
            // To make actually used exponent even:
            // x = (-1)^sign * 2^(exp_min==exp - bias - 1) * m{-1:-2}.m{-3:0}

            // if( ( ValNExp(Exp(x)) % 2 ) == 1 ) { // exponent is always odd!
            m_x = (Mant(x))@false@false@false;
            newExp = ( ValNExp(Exp(x)) - 1 ) / 2;
            // }

            // We do not know the position of the most left '1' is ==>
            // Shift mantissa left until there is a '1' left from comma.
            // We will eventually get a one on the left-hand side from the comma,
            // because, subnormal FP means the m!=0, i.e. there must be a one
            // some where in the mantissa.
            while( m_x{-1:-2} == {false::2} ) {
                next(m_x) = m_x{-3:0}@{false::2};  // two bit shift => we always check two bits
                next(newExp) = newExp - 1; // actually two bit correction, but newExp is actual exponent / 2
                pause;
            }
        } else { // exponent is even
            if( ( ValExp(Exp(x)) % 2 ) == 1 ) { // exponent is odd
                m_x = true@(Mant(x))@false@false;
                newExp = ( ValNExp(Exp(x)) - 1 ) / 2;
            } else {
                m_x = false@true@(Mant(x))@false;
                newExp = ValNExp(Exp(x)) / 2;
            }
        }

        // Compute square root of mantissa, which is 1x.xxxx or 01.xxxxx.
        // The code for subnormal and normal numbers is shared ...
        FixedSquareRoot(m_x, m_sqrt_x, sqrt_exact, fixSqrtReady);
        await(fixSqrtReady);
        
        assert( m_sqrt_x{-1:-2} == false@true );
        r_bit = m_sqrt_x{0};
        s_bit = !sqrt_exact;

        // The following may only occur if the conjunction of the following
        // sub-conditions hold:
        // - The input number is a denormal number.
        // - The width of the mantissa is larger than abs(ExpMin):
        // - The actual value of the mantissa is <2^ExpMin
        // This basically result in value the is smaller than
        //    sqrt(2^ExpMin * 2^ExpMin) = sqrt(2^(2*ExpMin)) = 2^ExpMin
        // which means that the result is a denormalized number.
        //
        // In particular, the can only occur for specific configurations of
        // MantWidth and ExpWidth: The formula is derived as follows:
        // Let m=0.00...01 be the smallest expressable mantissa
        //   => m=2^(-MantWidth)
        //   sqrt( 2^ExpMin * 2^(-MantWidth) ) = sqrt( 2^(ExpMin-MantWidth) )
        // The result will be denormalized if
        //     2^((ExpMin-MantWidth)/2) < 2^ExpMin
        // ==>    (ExpMin-MantWidth)/2  < ExpMin
        // ==>    -MantWidth/2 < ExpMin-ExpMin/2=ExpMin/2
        // ==>    MantWidth > ExpMin
        if( isSubnormal(x) /* & MantWidth > ExpMin */) {
            // If the exponent can not be expressed (when exp is to small)
            // then right shift mantissa until we can express the exponent.
            while( newExp + ExpBias < 0 ) {
                next(m_sqrt_x) = false@m_sqrt_x{-1:1};
                next(newExp)   = newExp + 1;
                next(s_bit)    = s_bit | r_bit;
                next(r_bit)    = m_sqrt_x{1};
                pause;
            }

            // Is result a denormal number?
            if( newExp + ExpBias == 0 ) {
                // A denormal number looses one bit of precision
                // Hence, we have to correct s and r bit.
                // Moreover, we shift m_x one bit right => this allows
                // to use the same assignment to z 
                next(m_sqrt_x) = false@m_sqrt_x{-1:1};
                next(s_bit)    = s_bit | r_bit;
                next(r_bit)    = m_sqrt_x{1};
                pause;
            }
        }

        // if r_bit & s_bit require to round up, check
        // whether the rounding might cause an overflow.
        // An overflow would require to correct the
        // mantissa and exponent. => TODO: tests have
        // shown that this case might never happend (see
        // test case 6c using MantWidth = {256,257}
        if( ( r_bit & s_bit ) | ( r_bit & !s_bit & m_sqrt_x{1} ) ) { // round up / round to nearest even
            if( m_sqrt_x{-3:1} == {true::MantWidth} ) { // round up will cause an overflow
                z = false@(nat2bv(abs(newExp + ExpBias + 1), ExpWidth))@{false::MantWidth};
            } else {
                z = false@(nat2bv(abs(newExp + ExpBias), ExpWidth))@(nat2bv( bv2nat(m_sqrt_x{-3:1}) + 1, MantWidth ));
            }
        } else if ( !r_bit | ( r_bit & !s_bit & !m_sqrt_x{1} ) ) { // round down / round to nearest even
            z = false@(nat2bv(abs(newExp + ExpBias), ExpWidth))@(m_sqrt_x{-3:1});
        } else {
            assert(false);
        }

        pause;
    }

    emit( ready );
}
// check special cases
drivenby { // 0
    x = posNaN;
    assert(z==posNaN);
}
drivenby { // 1
    x = negNaN;
    assert(z==negNaN);
}
drivenby { // 2
    x = posZero;
    assert(z==posZero);
}
drivenby { // 3
    // sqrt(-0)=-0 => special case according to IEEE-754 !!
    x = negZero;
    assert(z==negZero);
}
drivenby { // 4
    x = posInf;
    assert(z==posInf);
}
drivenby { // 5
    // Actually x=negInf yields a qNaN!
    x = negInf;
    assert(z==posNaN);
}
// some negative numbers
drivenby { // 6
    x = (true@0b100@0b0000);
    assert(z==posNaN);
}
drivenby { // 7
    x = (true@0b010@0b0100);
    immediate await(ready);
    assert(z==posNaN);
}
drivenby { // 8
    x = (true@0b110@0b1101);
    immediate await(ready);
    assert(z==posNaN);
}
drivenby { // 9
    x = (true@0b101@0b0101);
    immediate await(ready);
    assert(z==posNaN);
}
// normal numbers (always yield a normal number as result)
drivenby { // 10
    // sqrt(1) = 1
    x = (false@0b011@0b0000);
    immediate await(ready);
    assert(z==false@0b011@0b0000);
}
drivenby { // 11
    // sqrt(100) = 10
    x = (false@0b101@0b0000);
    immediate await(ready);
    assert(z==false@0b100@0b0000);
}
drivenby { // 12
    // sqrt(0.01) = 0.1
    x = (false@0b001@0b0000);
    immediate await(ready);
    assert(z==false@0b010@0b0000);
}
// normal, even exponent
drivenby { // 13
    // sqrt(7) = 2,6457513110645905905
    // 7 = 1.11 * 2^(101-011) = 0 101 1100
    //   2,6457513110645905905
    // = 10.101001... (r=0,s=1 => round down)
    // =  1.0101 * 2^(100-011)
    // = 0 100 0101
    x = (false@0b101@0b1100);
    immediate await(ready);
    assert(z==false@0b100@0b0101);
}
drivenby { // 14
    // 3.875 = 11.111 = 1.1111 * 2^(100-011) = 0 100 1111
    //   1,96850196850295275492
    // = 2^011 * 1.1111 01... (r=0,s=1 => round down)
    // = 2^011 * 1.1111
    // = 1.0101 * 2^(011-011)
    // = 0 011 1111
    x = (false@0b100@0b1111);
    immediate await(ready);
    assert(z==false@0b011@0b1111);
}
drivenby { // 15
    // 0,328125 = 1.0101 * 2^(001-011) = 0 001 0101
    //   0,57282196186948000082
    // = 2^(-1) * 1.0010 0 11.. (r=0,s=1 => round down)
    // = 2^(010-011) * 1.0010
    // = 1.0010 * 2^(010-011)
    // = 0 010 0010
    x = (false@0b001@0b0101);
    immediate await(ready);
    assert(z==false@0b010@0b0010);
}
drivenby { // 16
    // 1,25 = 1.0100 * 2^(011-011) = 0 011 0100
    //   1,1180339887498948482
    // = 2^0 * 1.0001 1 11.. (r=1,s=1 => round up)
    // = 2^(011-011) * 1.0010
    // = 1.0010 * 2^(011-011)
    // = 0 011 0010
    x = (false@0b011@0b0100);
    immediate await(ready);
    assert(z==false@0b011@0b0010);
}
// normal, odd exponent
drivenby { // 17
    // 0,875 = 1.1100 * 2^(010-011) = 0 010 1100
    //   0,9354143466934853464
    // = 2^(-1) * 1.0001 1 01.. (r=1,s=1 => round up)
    // = 2^(010-011) * 1.1110
    // = 1.1110 * 2^(010-011)
    // = 0 010 1110
    x = (false@0b010@0b1100);
    immediate await(ready);
    assert(z==false@0b010@0b1110);
}
drivenby { // 18
    // 2.75 = 1.0110 * 2^(100-011) = 0 100 0110
    //   1,65831239517769992456
    // = 2^0 * 1.1010 1 0001.. (r=1,s=1 => round up)
    // = 2^(011-011) * 1.1011
    // = 1.1011 * 2^(011-011)
    // = 0 011 1011
    x = (false@0b100@0b0110);
    immediate await(ready);
    assert(z==false@0b011@0b1011);
}
// denormal
drivenby { // 19
    // 0,125 = 0.1000 * 2^(001-011) = 0 000 1000
    //   0,3535533905932737622
    // = 2^(-2) * 1.0110 1 01.. (r=1,s=1 => round up)
    // = 2^(001-011) * 1.0111
    // = 1.0111 * 2^(001-011)
    // = 0 001 0111
    x = (false@0b000@0b1000);
    immediate await(ready);
    assert(z==false@0b001@0b0111);
}
drivenby { // 20
    // 0,015625 = 0.0001 * 2^(001-011) = 0 000 0001
    //   0,125 = 0.1000 * 2^(001-011) = 0 000 1000
    x = (false@0b000@0b0001);
    immediate await(ready);
    assert(z==false@0b000@0b1000);
}
drivenby { // 21
    // 0,96875 = 1.1111 * 2^(010-011) = 0 010 1111
    // 0,98425098425147637746
    // = 2^(-1) * 1.1111 0 11... (r=0,s=1 => round down)
    // = 2^(010-011) * 1.1111
    // = 1.1111 * 2^(010-011)
    // = 0 010 1111
    x = (false@0b010@0b1111);
    immediate await(ready);
    assert(z==false@0b010@0b1111);
}
/*
drivenby { // 6b (MantWidth=5)
    // 0.0078125 = 0.00001 * 2^(001-011) = 0 000 00001
    //   sqrt(0.0078125) = 0,08838834764831844055
    // = 2^(001-011) * 0.01011 0 10... (r=0,s=1 => round down)
    // = 0 000 01011
    x = (false@0b000@0b00001);
    immediate await(ready);
    assert(z==false@0b000@0b01011);
}
drivenby { // 7b (MantWidth=5)
    // 0,984375 = 1.11111 * 2^(010-011) = 0 010 11111
    // sqrt(0,984375) = 0,99215674164922147144
    // = 2^(-1) * 1.11111 0 11... (r=0,s=1 => round down)
    // = 2^(010-011) * 1.11111
    // = 1.11111 * 2^(010-011)
    // = 0 010 11111
    x = (false@0b010@0b11111);
    immediate await(ready);
    assert(z==false@0b010@0b11111);
}
*/
/*
drivenby { // 6c: generic mantissa width
    x = (false@0b010@{true::MantWidth});
    immediate await(ready);
    assert(z==false@0b010@{true::MantWidth});
}
*/