diff --git a/pyrtl/positutils.py b/pyrtl/positutils.py new file mode 100644 index 00000000..6de448b6 --- /dev/null +++ b/pyrtl/positutils.py @@ -0,0 +1,407 @@ +"""Implements utility functions for posit operations.""" + +import math +import pyrtl +from pyrtl.corecircuits import shift_right_logical, shift_left_logical + + +def decode_posit( + x: pyrtl.WireVector, nbits: int, es: int +) -> tuple[ + pyrtl.WireVector, + pyrtl.WireVector, + pyrtl.WireVector, + pyrtl.WireVector, + pyrtl.WireVector, +]: + """Decode posit into its components and return them as a tuple. + + .. doctest only:: + + >>> import pyrtl + >>> pyrtl.reset_working_block() + + Example:: + >>> nbits = 8 + >>> es = 2 + + >>> a = pyrtl.Input(bitwidth=nbits, name='a') + >>> sign_out = pyrtl.Output(bitwidth=nbits, name='sign_out') + >>> k_out = pyrtl.Output(bitwidth=nbits, name='k_out') + >>> exp_out = pyrtl.Output(bitwidth=es, name='exp_out') + >>> frac_bits_out = pyrtl.Output(bitwidth=nbits, name='frac_bits_out') + >>> frac_len_out = pyrtl.Output(bitwidth=nbits, name='frac_len_out') + + >>> sign, k, exp, frac_bits, frac_len = decode_posit(a, nbits, es) + + >>> sign_out <<= sign + >>> k_out <<= k + >>> exp_out <<= exp + >>> frac_bits_out <<= frac_bits + >>> frac_len_out <<= frac_len + + >>> sim = pyrtl.Simulation() + >>> sim.step({'a': 0b01011100}) + + >>> sim.inspect('sign_out') + '0' + >>> sim.inspect('k_out') + '0' + >>> sim.inspect('exp_out') + '3' + >>> format(sim.inspect('frac_bits_out'), '08b') + '00000100' + >>> sim.inspect('frac_len_out') + '3' + + :param x: A WireVector that represents the posit. + :param nbits: An int that represents the bitwidth of the posit. + :param es: An int that represents the exponent size of the posit. + + :return: A tuple consisting of: + - WireVector for sign + - WireVector for k + - WireVector for exponent + - WireVector for fractional bits + - WireVector for length of fraction + """ + sign = x[nbits - 1] + rest = [x[nbits - 2 - i] for i in range(nbits - 1)] + regime_bit = rest[0] + run_len = pyrtl.Const(0, bitwidth=nbits) + active = pyrtl.Const(1, bitwidth=1) + + for i in range(nbits - 1): + bit = rest[i] + is_same = bit == regime_bit + run_len = run_len + pyrtl.select( + active & is_same, + pyrtl.Const(1), + pyrtl.Const(0), + ) + active = active & is_same + + k_pos = run_len - pyrtl.Const(1, bitwidth=nbits) + k_neg = (~run_len) + pyrtl.Const(1, bitwidth=nbits) + k = pyrtl.select(regime_bit, k_pos, k_neg) + + exp_bits = [] + for j in range(es): + bit_val = pyrtl.Const(0, bitwidth=1) + for i in range(nbits - 2): + cond = run_len == pyrtl.Const(i, bitwidth=nbits) + # exponent bit is at rest[i + 1 + j] + target_idx = i + 1 + j + if target_idx < (nbits - 1): + bit_val = pyrtl.select(cond, rest[target_idx], bit_val) + exp_bits.append(bit_val) + + exp = pyrtl.concat_list(exp_bits[::-1]) if es > 0 else pyrtl.Const(0) + + start_idx = run_len + pyrtl.Const(1 + es, bitwidth=nbits) + fraction_bits = [] + + for i in range(nbits - 1): + in_range = (i >= start_idx) & (i < (nbits - 1)) + bit_val = pyrtl.select(in_range, rest[i], pyrtl.Const(0)) + fraction_bits.append(bit_val) + + frac_result = pyrtl.concat_list(fraction_bits[::-1]) + fraction_length = ( + pyrtl.Const(nbits, bitwidth=nbits) + - run_len + - pyrtl.Const(es, bitwidth=nbits) + - pyrtl.Const(2, bitwidth=nbits) + ) + + return sign, k, exp, frac_result, fraction_length + + +def get_upto_regime( + k: pyrtl.WireVector, + n_val: int, + sign_final: pyrtl.WireVector, +) -> tuple[pyrtl.WireVector, pyrtl.WireVector]: + """Calculates the remaining bits and the regime bits. + + .. doctest only:: + + >>> import pyrtl + >>> pyrtl.reset_working_block() + + Example:: + + >>> nbits = 8 + >>> k_in = pyrtl.Input(bitwidth=nbits, name='k_in') + >>> sign_final = pyrtl.Input(bitwidth=1, name='sign_final') + + >>> rem_bits_out = pyrtl.Output(bitwidth=nbits, name='rem_bits_out') + >>> sign_w_regime_out = pyrtl.Output(bitwidth=nbits, name='sign_w_regime_out') + + >>> rem_bits, sign_w_regime = get_upto_regime(k_in, nbits, sign_final) + + >>> rem_bits_out <<= rem_bits + >>> sign_w_regime_out <<= sign_w_regime + + >>> sim = pyrtl.Simulation() + >>> sim.step({'k_in': 2, 'sign_final': 0}) + + >>> sim.inspect('rem_bits_out') + '3' + >>> format(sim.inspect('sign_w_regime_out'), '08b') + '01110000' + + :param k: A WireVector that represents the k value. + :param n_val: A int that represents the bitwidth of the posit. + :param sign_final: A WireVector that represents the final sign. + + :return: A tuple consisting of: + - WireVector representing the remaining bits. + - WireVector representing the regime bits with sign bit. + """ + precomputed_val = (1 << (n_val - 1)) - 1 + + n_c = pyrtl.Const(n_val) + n_minus_1 = pyrtl.Const(n_val - 1) + n_minus_2 = pyrtl.Const(n_val - 2) + n_minus_3 = pyrtl.Const(n_val - 3) + + rem_bits = pyrtl.WireVector(bitwidth=n_val) + sign_w_regime = pyrtl.WireVector(bitwidth=n_val) + + abs_k = pyrtl.WireVector(bitwidth=k.bitwidth) + abs_k <<= pyrtl.select( + k >= (1 << (n_val - 1)), + (~k + 1) & ((1 << n_val) - 1), + k, + ) + + large_neg_regime = abs_k >= n_minus_1 + large_pos_regime = abs_k >= n_minus_2 + + with pyrtl.conditional_assignment: + with k >= (1 << (n_val - 1)): + with large_neg_regime: + rem_bits |= 0 + sign_w_regime |= 0 + with ~large_neg_regime: + temp_rem = n_c + k - 2 + rem_bits |= temp_rem + sign_w_regime |= shift_right_logical( + pyrtl.Const(1 << (n_val - 2), bitwidth=n_val), abs_k + ) + + with k < (1 << (n_val - 1)): + with large_pos_regime: + rem_bits |= 0 + sign_w_regime |= pyrtl.Const(precomputed_val, bitwidth=n_val) + with ~large_pos_regime: + temp_rem = n_minus_3 - k + shift_amt = k + 2 + rem_bits |= temp_rem + ones = shift_left_logical( + pyrtl.Const(1, bitwidth=n_val), shift_amt + ) - pyrtl.Const(2, bitwidth=n_val) + shifted = shift_left_logical(ones, temp_rem) + sign_w_regime |= shifted + + sign_w_regime_trimmed = pyrtl.WireVector(bitwidth=n_val - 1) + sign_w_regime_trimmed <<= sign_w_regime[: n_val - 1] + sign_w_regime_final = pyrtl.concat(sign_final, sign_w_regime_trimmed) + + return rem_bits, sign_w_regime_final + + +def frac_with_hidden_one( + frac: pyrtl.WireVector, + frac_length: pyrtl.WireVector, + nbits: int, +) -> pyrtl.WireVector: + """Adds a hidden 1 to the fractional bits. + + .. doctest only:: + + >>> import pyrtl + >>> pyrtl.reset_working_block() + + Example:: + >>> nbits = 8 + >>> frac_in = pyrtl.Input(bitwidth=nbits-1, name='frac_in') + >>> frac_len_in = pyrtl.Input(bitwidth=nbits, name='frac_len_in') + >>> frac_out = pyrtl.Output(bitwidth=32, name='frac_out') + + >>> frac_out <<= frac_with_hidden_one(frac_in, frac_len_in, nbits) + + >>> sim = pyrtl.Simulation() + >>> sim.step({'frac_in': 0b0010101, 'frac_len_in': 5}) + + >>> format(sim.inspect('frac_out'), '08b') + '000110101' + + :param frac: A WireVector that represents the fractional bits. + :param frac_length: A WireVector that represents the length of the fractional bits. + :param nbits: An int that represents the bitwidth of the posit. + + :return: A WireVector that represents the fraction with the hidden 1. + """ + one_table = [pyrtl.Const(1 << i, bitwidth=nbits + 1) for i in range(nbits + 1)] + one_shifted = pyrtl.Const(0, bitwidth=32) + + for i in range(nbits + 1): + one_shifted = pyrtl.select( + frac_length == pyrtl.Const(i, bitwidth=8), + one_table[i], + one_shifted, + ) + + frac_32 = pyrtl.concat(pyrtl.Const(0, bitwidth=32 - (nbits - 1)), frac) + full = one_shifted + frac_32 + return full + + +def remove_first_one(val: pyrtl.WireVector) -> pyrtl.WireVector: + """Removes the leading hidden bit of 1. + + .. doctest only:: + + >>> import pyrtl + >>> pyrtl.reset_working_block() + + Example:: + >>> nbits = 8 + >>> frac_with_one = pyrtl.Input(bitwidth=nbits, name='frac_with_one') + >>> frac_removed = pyrtl.Output(bitwidth=nbits, name='frac_removed') + + >>> frac_removed <<= remove_first_one(frac_with_one) + + >>> sim = pyrtl.Simulation() + >>> sim.step({'frac_with_one': 0b10010110}) + + >>> format(sim.inspect('frac_removed'), '08b') + '00010110' + + :param val: A WireVector that represents the fractional bits. + + :return: A WireVector with the hidden bit of 1 removed. + """ + found = pyrtl.Const(0, bitwidth=1) + result_bits = [] + + for i in range(val.bitwidth): + bit = val[val.bitwidth - 1 - i] # MSB first + new_bit = pyrtl.select( + (found == 0) & (bit == 1), + pyrtl.Const(0, bitwidth=1), + bit, + ) + result_bits.append(new_bit) + + found = pyrtl.select( + (found == 0) & (bit == 1), + pyrtl.Const(1, bitwidth=1), + found, + ) + + return pyrtl.concat_list(result_bits[::-1]) + + +def decimal_to_posit(x: float, nbits: int, es: int) -> int: + """Convert a decimal float to Posit representation. + + .. doctest only:: + + >>> import math + + Example:: + + >>> nbits, es = 16, 2 + >>> format(decimal_to_posit(4992, nbits, es), '016b') + '0111100000111000' + + >>> nbits, es = 8, 1 + >>> format(decimal_to_posit(5000, nbits, es), '08b') + '01111111' + + :param x: The decimal float to be converted. + :param nbits: Total number of bits in the posit representation. + :param es: Maximum number of exponent bits. + :return: The integer representation of the posit encoding. + """ + if x == 0: + return 0 + + # Sign + sign = 0 + if x < 0: + sign = 1 + x = -x + + useed = 2 ** (2 ** es) + + k = int(math.floor(math.log(x, useed))) + regime_value = useed ** k + remaining = x / regime_value + + exponent = int(math.floor(math.log2(remaining))) if es > 0 else 0 + exponent = max(0, exponent) + remaining /= 2 ** exponent + + # Fraction bits + fraction = remaining - 1.0 + frac_bits = [] + for _ in range(nbits * 2): + fraction *= 2 + if fraction >= 1: + frac_bits.append("1") + fraction -= 1 + else: + frac_bits.append("0") + + # Regime bits + if k >= 0: + regime_bits = "1" * (k + 1) + "0" + else: + regime_bits = "0" * (-k) + "1" + + bits = "0" + regime_bits + + # Exponent bits + if es > 0: + exp_str = format(exponent & ((1 << es) - 1), f"0{es}b") + bits += exp_str + + bits += "".join(frac_bits) + + # Handle rounding if bits exceed nbits + if len(bits) > nbits: + main = bits[:nbits] + guard = bits[nbits] + roundb = bits[nbits + 1] if nbits + 1 < len(bits) else "0" + sticky = "1" if "1" in bits[nbits + 2:] else "0" + + increment = ( + (guard == "1") + and (roundb == "1" or sticky == "1" or main[-1] == "1") + ) + + if increment: + main_int = int(main, 2) + 1 + if main_int >= (1 << (nbits - 1)): + main_int = (1 << (nbits - 1)) - 1 + main = format(main_int, f"0{nbits}b") + + bits = main + else: + bits = bits.ljust(nbits, "0") + + ones_comp = "" + if sign: + for i in bits: + if i == "0": + ones_comp = ones_comp + "1" + else: + ones_comp = ones_comp + "0" + result = int(ones_comp, 2) + 1 + return result + + return int(bits, 2) diff --git a/pyrtl/rtllib/positadder.py b/pyrtl/rtllib/positadder.py new file mode 100644 index 00000000..d3322602 --- /dev/null +++ b/pyrtl/rtllib/positadder.py @@ -0,0 +1,225 @@ +import pyrtl +from pyrtl.corecircuits import shift_left_logical, shift_right_logical +from pyrtl.positutils import ( + decode_posit, + get_upto_regime, + frac_with_hidden_one, + remove_first_one, +) + + +def posit_add( + a: pyrtl.WireVector, b: pyrtl.WireVector, nbits: int, es: int +) -> pyrtl.WireVector: + """Adds two numbers in posit format and returns their sum. + + .. doctest only:: + + >>> import pyrtl + >>> pyrtl.reset_working_block() + + Example:: + + >>> nbits = 8 + >>> es = 1 + + >>> a = pyrtl.Input(bitwidth=nbits, name='a') + >>> b = pyrtl.Input(bitwidth=nbits, name='b') + >>> posit = pyrtl.Output(bitwidth=nbits, name='posit') + + >>> added_posit = posit_add(a, b, nbits, es) + + >>> posit <<= added_posit + + >>> sim = pyrtl.Simulation() + >>> sim.step({'a': 0b01011100, 'b': 0b01100000}) # 3.5 + 4 = 7.5 + >>> format(sim.inspect('posit'), '08b') + '01100111' + + :param a: A :class:`.WireVector` to add. Bitwidths need to match. + :param b: A :class:`WireVector` to add. Bitwidths need to match. + :param nbits: A :class:`.int` representing the total bitwidth of the posit. + :param es: A :class:`.int` representing the exponent size of the posit. + + :return: A :class:`WireVector` that represents the sum of the two posits. + """ + # Decode input posits into regime (k), exponent, fraction, and fraction length + _, k_a, exp_a, frac_a, frac_len_a = decode_posit(a, nbits, es) + _, k_b, exp_b, frac_b, frac_len_b = decode_posit(b, nbits, es) + + # Match bitwidths of fractional part + frac_a_aligned = pyrtl.WireVector(bitwidth=nbits) + frac_b_aligned = pyrtl.WireVector(bitwidth=nbits) + + frac_len_a_aligned = pyrtl.WireVector(bitwidth=nbits) + frac_len_b_aligned = pyrtl.WireVector(bitwidth=nbits) + + with pyrtl.conditional_assignment: + with frac_len_a > frac_len_b: + shift_amt = frac_len_a - frac_len_b + frac_b_aligned |= shift_left_logical(frac_b, shift_amt) + frac_a_aligned |= frac_a + frac_len_b_aligned |= frac_len_b + shift_amt + frac_len_a_aligned |= frac_len_a + + with frac_len_a < frac_len_b: + shift_amt = frac_len_b - frac_len_a + frac_a_aligned |= shift_left_logical(frac_a, shift_amt) + frac_b_aligned |= frac_b + frac_len_a_aligned |= frac_len_a + shift_amt + frac_len_b_aligned |= frac_len_b + + with pyrtl.otherwise: + frac_a_aligned |= frac_a + frac_b_aligned |= frac_b + frac_len_a_aligned |= frac_len_a + frac_len_b_aligned |= frac_len_b + + # Add hidden leading one to fractions + frac_a_full = frac_with_hidden_one(frac_a_aligned, frac_len_a_aligned, nbits) + frac_b_full = frac_with_hidden_one(frac_b_aligned, frac_len_b_aligned, nbits) + + # Compute scales (regime*k + exponent) + scale_a = (k_a if es == 0 else shift_left_logical(k_a, es)) + exp_a + scale_b = (k_b if es == 0 else shift_left_logical(k_b, es)) + exp_b + + offset = scale_a - scale_b + + is_neg_offset = pyrtl.select( + offset > pyrtl.Const(127, bitwidth=nbits), + truecase=pyrtl.Const(1, bitwidth=nbits), + falsecase=pyrtl.Const(0, bitwidth=nbits), + ) + + shifted_a = pyrtl.WireVector(bitwidth=frac_a_full.bitwidth) + shifted_b = pyrtl.WireVector(bitwidth=frac_b_full.bitwidth) + result_scale = pyrtl.WireVector(bitwidth=offset.bitwidth) + if es == 0: + result_exp = pyrtl.WireVector(bitwidth=1) # placeholder + else: + result_exp = pyrtl.WireVector(bitwidth=es) + + neg_offset = (~offset) + pyrtl.Const(1, bitwidth=offset.bitwidth) + + # Align fractions based on offset + with pyrtl.conditional_assignment: + with is_neg_offset == pyrtl.Const(0, bitwidth=nbits): + shifted_b |= frac_b_full + shifted_a |= shift_left_logical(frac_a_full, offset) + result_scale |= scale_a + result_exp |= exp_a + + with is_neg_offset == pyrtl.Const(1, bitwidth=nbits): + shifted_b |= shift_left_logical(frac_b_full, neg_offset) + shifted_a |= frac_a_full + result_scale |= scale_b + result_exp |= exp_b + + # Add shifted fractions + result_frac = shifted_a + shifted_b + + # Checking for overflow, if overflow, increase scale + result_scale = pyrtl.select( + offset == pyrtl.Const(0, bitwidth=offset.bitwidth), + result_scale + 1, + result_scale, + ) + result_k = ( + result_scale + if es == 0 + else shift_right_logical(result_scale, pyrtl.Const(es, bitwidth=nbits)) + ) + + # Extract regime bits + rem_bits, regime_bits = get_upto_regime(result_k, nbits, 0) + + # Extract exponent from scale + result_exp = result_scale - ( + result_k if es == 0 else shift_left_logical(result_k, es) + ) + shift_amt = rem_bits - pyrtl.Const(es, bitwidth=nbits) + result_exp = pyrtl.select( + shift_amt == pyrtl.Const(0), + result_exp, + shift_left_logical(result_exp, shift_amt), + ) + + # Remaining fraction length + frac_len = rem_bits - es + + # Handling rounding of fractional bits + count = pyrtl.Const(0, bitwidth=nbits) + rounded_frac = pyrtl.Const(0, bitwidth=nbits) + found = pyrtl.Const(0, bitwidth=nbits) + + for i in range(nbits): + bit = result_frac[nbits - 1 - i] + cond = pyrtl.select(bit == pyrtl.Const(1), 1, found) + found = found | cond + count = pyrtl.select(found == pyrtl.Const(1), count + 1, count) + + # Exclude leading one + count = count - 1 + bits_to_shift = count - frac_len + + # Normalize fraction + truncated_frac = shift_right_logical(result_frac, bits_to_shift) + + # Guard, Round, Sticky bits + ground_bit = result_frac & shift_right_logical( + pyrtl.Const(1, bitwidth=nbits), (bits_to_shift - 1) + ) + round_bit = result_frac & shift_right_logical( + pyrtl.Const(1, bitwidth=nbits), (bits_to_shift - 2) + ) + sticky_bit = result_frac & shift_right_logical( + pyrtl.Const(1, bitwidth=nbits), (bits_to_shift - 3) + ) + + cond = ground_bit & (round_bit | sticky_bit) + rounded_frac = pyrtl.WireVector(bitwidth=nbits) + + # Apply rounding rules + with pyrtl.conditional_assignment: + with ground_bit == pyrtl.Const(0): + rounded_frac |= truncated_frac + with cond == pyrtl.Const(1): + rounded_frac |= truncated_frac + 1 + with cond == pyrtl.Const(0): + with truncated_frac[0] == 1: + rounded_frac |= truncated_frac + 1 + with truncated_frac[0] == 0: + rounded_frac |= truncated_frac + + # Remove hidden one from rounded fraction + rounded_frac = remove_first_one(rounded_frac) + + # Combine regime, exponent, and fraction + added_posit = ( + pyrtl.Const(0, bitwidth=nbits) + + regime_bits + + result_exp + + rounded_frac + ) + result_posit = pyrtl.WireVector(bitwidth=nbits) + + # Checking for special cases (NaR and 0) + isNar = ( + pyrtl.select(a == pyrtl.Const(1 << nbits - 1, bitwidth=nbits), 1, 0) + | pyrtl.select(b == pyrtl.Const(1 << nbits - 1, bitwidth=nbits), 1, 0) + ) + + with pyrtl.conditional_assignment: + with isNar == pyrtl.Const(1, bitwidth=nbits): + result_posit |= pyrtl.Const(1 << nbits - 1, bitwidth=nbits) + + with a == pyrtl.Const(0, bitwidth=nbits): + result_posit |= b + + with b == pyrtl.Const(0, bitwidth=nbits): + result_posit |= a + + with pyrtl.otherwise: + result_posit |= added_posit + + return result_posit \ No newline at end of file diff --git a/pyrtl/rtllib/positmatmul.py b/pyrtl/rtllib/positmatmul.py new file mode 100644 index 00000000..833eb319 --- /dev/null +++ b/pyrtl/rtllib/positmatmul.py @@ -0,0 +1,83 @@ +import pyrtl +from pyrtl import PyrtlError +from pyrtl.rtllib.matrix import Matrix, matrix_wv_to_list +from pyrtl.rtllib.positadder import posit_add +from pyrtl.rtllib.positmul import posit_mul + + +def posit_matmul(x: Matrix, y: Matrix, nbits: int, es: int) -> Matrix: + """Performs matrix multiplication on posits. + + .. doctest only:: + + >>> import pyrtl + >>> from pyrtl.rtllib.matrix import Matrix, matrix_wv_to_list + >>> pyrtl.reset_working_block() + + Example:: + + >>> nbits = 8 + >>> es = 1 + + >>> matrix_x = [[0b01000000, 0b01010000], [0b01011000, 0b01100000]] + >>> test_x = Matrix(2, 2, bits=nbits, value=matrix_x) + + >>> matrix_y = [[0b01000000, 0b01010000], [0b01011000, 0b01100000]] + >>> test_y = Matrix(2, 2, bits=nbits, value=matrix_y) + + >>> result = posit_matmul(test_x, test_y, nbits, es) + + >>> output = pyrtl.Output(name='output') + >>> output <<= result.to_wirevector() + + >>> sim = pyrtl.Simulation() + >>> sim.step() + + >>> raw_matrix = matrix_wv_to_list( + ... sim.inspect("output"), result.rows, result.columns, result.bits + ... ) + + >>> pretty_matrix = [[format(val, '08b') for val in row] for row in raw_matrix] + >>> pretty_matrix + [['01100110', '01101010'], ['01101111', '01110001']] + + :param x: A :class:`Matrix` to be multiplied. + :param y: A :class:`Matrix` to be multiplied. + :param nbits: A :class:`int` representing the bitwidth of each cell of + the matrix. + :param es: A :class:`int` representing the exponent size of the posit. + + :return: A :class:`Matrix` that represents the product of two posit + matrices. + """ + if not isinstance(x, Matrix): + msg = f"error: expecting a Matrix, got {type(x)} instead" + raise PyrtlError(msg) + + if not isinstance(y, Matrix): + msg = f"error: expecting a Matrix, got {type(y)} instead" + raise PyrtlError(msg) + + if x.columns != y.rows: + msg = ( + f"error: rows and columns mismatch. " + f"Matrix a: {x.columns} columns, Matrix b: {y.rows} rows" + ) + raise PyrtlError(msg) + + result = Matrix( + x.rows, + y.columns, + nbits, + max_bits=x.max_bits, + ) + + for i in range(x.rows): + for j in range(y.columns): + acc = pyrtl.Const(0, bitwidth=nbits) + for k in range(x.columns): + prod = posit_mul(x[i, k], y[k, j], nbits, es) + acc = posit_add(acc, prod, nbits, es) + result[i, j] = acc + + return result \ No newline at end of file diff --git a/pyrtl/rtllib/positmul.py b/pyrtl/rtllib/positmul.py new file mode 100644 index 00000000..e3f462a8 --- /dev/null +++ b/pyrtl/rtllib/positmul.py @@ -0,0 +1,137 @@ +import pyrtl +from pyrtl.corecircuits import shift_right_logical, shift_left_logical +from pyrtl.positutils import decode_posit, get_upto_regime + + +def posit_mul( + a: pyrtl.WireVector, b: pyrtl.WireVector, nbits: int, es: int +) -> pyrtl.WireVector: + """Multiplies two numbers in posit format and returns their product. + + .. doctest only:: + + >>> import pyrtl + >>> pyrtl.reset_working_block() + >>> nbits, es = 8, 1 + >>> a = pyrtl.Input(bitwidth=nbits, name='a') + >>> b = pyrtl.Input(bitwidth=nbits, name='b') + >>> out = pyrtl.Output(bitwidth=nbits, name='out') + >>> out <<= posit_mul(a, b, nbits, es) + >>> sim = pyrtl.Simulation() + >>> sim.step({'a': 0b01011100, 'b': 0b01000000}) # 3.5 * 2 = 7.0 (approx) + >>> format(sim.inspect('out'), '08b') # doctest: +ELLIPSIS + '...' + + :param a: A WireVector posit multiplicand. + :param b: A WireVector posit multiplier. + :param nbits: Total bitwidth of the posit. + :param es: Exponent size of the posit. + :return: A WireVector representing the product of the two posits. + """ + + # Decode inputs + sign_a, k_a, exp_a, frac_a, fraclength_a = decode_posit(a, nbits, es) + sign_b, k_b, exp_b, frac_b, fraclength_b = decode_posit(b, nbits, es) + + # Handle multiplication of special cases + either_zero = (a == 0) | (b == 0) + either_inf = (a == (1 << (nbits - 1))) | (b == (1 << (nbits - 1))) + + final_value = pyrtl.WireVector(bitwidth=nbits) + result_zero = pyrtl.Const(0, bitwidth=nbits) + result_nar = pyrtl.Const(1 << (nbits - 1), bitwidth=nbits) + normal_case = ~(either_zero | either_inf) + + # Compute resultant sign + sign_final = sign_a ^ sign_b + + # Compute scale + scale_a = k_a * pyrtl.Const(2 ** es) + exp_a + scale_b = k_b * pyrtl.Const(2 ** es) + exp_b + scale_sum = scale_a + scale_b + + # Fraction multiplication with implicit 1 + one_table = [pyrtl.Const(1 << i, bitwidth=32) for i in range(nbits)] + one_shifted_a = pyrtl.Const(0, bitwidth=32) + one_shifted_b = pyrtl.Const(0, bitwidth=32) + for i in range(nbits): + one_shifted_a = pyrtl.select(fraclength_a == i, one_table[i], one_shifted_a) + one_shifted_b = pyrtl.select(fraclength_b == i, one_table[i], one_shifted_b) + + frac_a_32 = pyrtl.concat(pyrtl.Const(0, bitwidth=24), frac_a) + frac_b_32 = pyrtl.concat(pyrtl.Const(0, bitwidth=24), frac_b) + frac_a_full = one_shifted_a + frac_a_32 + frac_b_full = one_shifted_b + frac_b_32 + + frac_product = frac_a_full * frac_b_full + + # Normalize fraction + fraclen_total = fraclength_a + fraclength_b + threshold = shift_left_logical(pyrtl.Const(1, bitwidth=32), fraclen_total + 1) + frac = frac_product + scale = pyrtl.Const(0, bitwidth=8) + for _ in range(8): + shifted = shift_right_logical(frac, 1) + should_shift = frac >= threshold + frac = pyrtl.select(should_shift, shifted, frac) + scale = pyrtl.select(should_shift, scale + 1, scale) + normalized_frac = frac + normalized_scale = scale + + # Remove extra 1 + mask_table = [pyrtl.Const((1 << i) - 1, bitwidth=32) for i in range(33)] + mask_val = pyrtl.Const(0) + for i in range(1, 33): + mask_val = pyrtl.select(fraclen_total == i, mask_table[i], mask_val) + frac_result = normalized_frac & mask_val + + # Final scale + final_scale = scale_sum + normalized_scale + + # Extract k and exponent + resultk = shift_right_logical(final_scale, pyrtl.Const(es, bitwidth=8)) + mod_mask = pyrtl.Const((1 << es) - 1, bitwidth=final_scale.bitwidth) + resultExponent = final_scale & mod_mask + + # Get remaining bits and regime + rem_bits, sign_w_regime = get_upto_regime(resultk, nbits, sign_final) + + # Fraction bits with rounding + frac_bits = rem_bits - es + is_small = rem_bits <= es + + shift_amt_small = es - rem_bits + exp_shifted_small = shift_right_logical(resultExponent, shift_amt_small) + value_small = sign_w_regime + exp_shifted_small + + sum_fraclens = fraclength_a + fraclength_b + roundup_bit = pyrtl.Const(0, bitwidth=1) + cond_round = sum_fraclens > frac_bits + shift_amt1 = sum_fraclens - frac_bits - 1 + shift_amt2 = sum_fraclens - frac_bits + + roundup_candidate = shift_right_logical(frac_result, shift_amt1) & 1 + frac_shifted = shift_right_logical(frac_result, shift_amt2) + frac_shifted_else = shift_left_logical(frac_result, frac_bits - sum_fraclens) + + frac_final = pyrtl.WireVector(bitwidth=nbits) + frac_final <<= pyrtl.select(cond_round, frac_shifted, frac_shifted_else) + + roundup_bit = pyrtl.WireVector(bitwidth=1) + roundup_bit <<= pyrtl.select(cond_round, roundup_candidate, pyrtl.Const(0, bitwidth=1)) + + exp_shifted_large = shift_left_logical(resultExponent, frac_bits) + value_large = sign_w_regime + exp_shifted_large + frac_final + all_ones = pyrtl.Const((1 << nbits) - 1, bitwidth=nbits) + value_rounded = pyrtl.select( + (roundup_bit & (value_large != all_ones)), value_large + 1, value_large + ) + + computed_value = pyrtl.select(is_small, value_small, value_rounded) + + # Select between normal values and special computed value + final_value <<= pyrtl.select( + either_zero, result_zero, pyrtl.select(either_inf, result_nar, computed_value) + ) + + return final_value diff --git a/tests/rtllib/test_positadder.py b/tests/rtllib/test_positadder.py new file mode 100644 index 00000000..225ccfd3 --- /dev/null +++ b/tests/rtllib/test_positadder.py @@ -0,0 +1,54 @@ +import doctest +import random +import unittest + +import pyrtl +import pyrtl.rtllib.testingutils as utils +from pyrtl.rtllib.positadder import posit_add +from pyrtl.positutils import decimal_to_posit + +class TestDocTests(unittest.TestCase): + """Test documentation examples.""" + + def test_doctests(self): + failures, tests = doctest.testmod(m=pyrtl.rtllib.positadder) + self.assertGreater(tests, 0) + self.assertEqual(failures, 0) + +class TestPositAdder(unittest.TestCase): + @classmethod + def setUpClass(cls): + random.seed(42) + + def setUp(self): + pyrtl.reset_working_block() + + def tearDown(self): + pyrtl.reset_working_block() + + def test_posit_adder(self): + nbits_list = [8, 16, 32] + es_list = [0, 1, 2, 3, 4] + nbits, es = random.choice(nbits_list), random.choice(es_list) + a = pyrtl.Input(bitwidth=nbits, name="a") + b = pyrtl.Input(bitwidth=nbits, name="b") + out = pyrtl.Output(name="out") + + out <<= posit_add(a, b, nbits, es) + + useed = 2 ** (2 ** es) + maxpos = useed ** (nbits - 2) + + wires = [a, b] + vals_raw = [[random.randint(0, maxpos) for _ in range(7)] for _ in wires] + vals = [[decimal_to_posit(j, nbits, es) for j in i] for i in vals_raw] + + out_vals = utils.sim_and_ret_out(out, wires, vals) + true_result_raw = [x + y for x, y in zip(vals_raw[0], vals_raw[1])] + true_result = [decimal_to_posit(i, nbits, es) for i in true_result_raw] + + for sim, expected in zip(out_vals, true_result): + self.assertLessEqual(abs(sim - expected), 1) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/rtllib/test_positmatmul.py b/tests/rtllib/test_positmatmul.py new file mode 100644 index 00000000..8e80392a --- /dev/null +++ b/tests/rtllib/test_positmatmul.py @@ -0,0 +1,123 @@ +import doctest +import random +import unittest + +import pyrtl +import pyrtl.rtllib.testingutils as utils +from pyrtl.rtllib.matrix import Matrix, matrix_wv_to_list +from pyrtl.rtllib.positmatmul import posit_matmul +from pyrtl.positutils import decimal_to_posit + + +class TestDocTests(unittest.TestCase): + """Test documentation examples.""" + + def test_doctests(self): + failures, tests = doctest.testmod(m=pyrtl.rtllib.positadder) + self.assertGreater(tests, 0) + self.assertEqual(failures, 0) + + +class PositMatrixTestBase(unittest.TestCase): + def check_against_expected(self, result, expected_output, rows, cols, nbits): + expected = Matrix(rows, cols, bits=nbits, value=expected_output) + + result_wv = pyrtl.Output(name='result') + expected_wv = pyrtl.Output(name='expected') + + result_wv <<= result.to_wirevector() + expected_wv <<= expected.to_wirevector() + + sim = pyrtl.Simulation() + sim.step({}) + + result_vals = matrix_wv_to_list( + sim.inspect('result'), rows, cols, nbits + ) + expected_vals = matrix_wv_to_list( + sim.inspect('expected'), rows, cols, nbits + ) + + for i in range(len(result_vals)): + for j in range(len(result_vals[0])): + self.assertLessEqual( + abs(result_vals[i][j] - expected_vals[i][j]), 2 + ) + + def generate_and_check(self, m, n, p, identity=False): + nbits_list = [8, 16] + es_list = [0, 1, 2] + nbits, es = random.choice(nbits_list), random.choice(es_list) + + useed = 2 ** (2 ** es) + maxpos = useed ** (nbits - 2) + + matrix_x_raw = [ + [random.randint(0, maxpos) for _ in range(n)] for _ in range(m) + ] + matrix_x = [ + [decimal_to_posit(val, nbits, es) for val in row] + for row in matrix_x_raw + ] + test_x = Matrix(m, n, bits=nbits, value=matrix_x) + + if identity: + matrix_y_raw = [ + [1 if i == j else 0 for j in range(n)] for i in range(n) + ] + p = n + else: + matrix_y_raw = [ + [random.randint(0, maxpos) for _ in range(p)] + for _ in range(n) + ] + + matrix_y = [ + [decimal_to_posit(val, nbits, es) for val in row] + for row in matrix_y_raw + ] + test_y = Matrix(n, p, bits=nbits, value=matrix_y) + + result = posit_matmul(test_x, test_y, nbits, es) + + expected_output_raw = [[0 for _ in range(p)] for _ in range(m)] + for i in range(m): + for j in range(p): + for k in range(n): + expected_output_raw[i][j] += ( + matrix_x_raw[i][k] * matrix_y_raw[k][j] + ) + + expected_output = [ + [decimal_to_posit(val, nbits, es) for val in row] + for row in expected_output_raw + ] + + self.check_against_expected(result, expected_output, m, p, nbits) + + +class TestPositMatmul(PositMatrixTestBase): + @classmethod + def setUpClass(cls): + random.seed(42) + + def setUp(self): + pyrtl.reset_working_block() + + def tearDown(self): + pyrtl.reset_working_block() + + def test_posit_matmul_identity(self): + m = random.randint(1, 5) + n = random.randint(1, 5) + self.generate_and_check(m, n, p=None, identity=True) + + def test_posit_matmul(self): + m = random.randint(1, 5) + n = random.randint(1, 5) + p = random.randint(1, 5) + self.generate_and_check(m, n, p, identity=False) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/rtllib/test_positmul.py b/tests/rtllib/test_positmul.py new file mode 100644 index 00000000..e41a0f53 --- /dev/null +++ b/tests/rtllib/test_positmul.py @@ -0,0 +1,54 @@ +import doctest +import random +import unittest + +import pyrtl +import pyrtl.rtllib.testingutils as utils +from pyrtl.rtllib.positmul import posit_mul +from pyrtl.positutils import decimal_to_posit + +class TestDocTests(unittest.TestCase): + """Test documentation examples.""" + + def test_doctests(self): + failures, tests = doctest.testmod(m=pyrtl.rtllib.positmul) + self.assertGreater(tests, 0) + self.assertEqual(failures, 0) + +class TestPositMultiplier(unittest.TestCase): + @classmethod + def setUpClass(cls): + random.seed(42) + + def setUp(self): + pyrtl.reset_working_block() + + def tearDown(self): + pyrtl.reset_working_block() + + def test_posit_multiplier(self): + nbits_list = [8, 16, 32] + es_list = [0, 1, 2, 3, 4] + nbits, es = random.choice(nbits_list), random.choice(es_list) + a = pyrtl.Input(bitwidth=nbits, name="a") + b = pyrtl.Input(bitwidth=nbits, name="b") + out = pyrtl.Output(name="out") + + out <<= posit_mul(a, b, nbits, es) + + useed = 2 ** (2 ** es) + maxpos = useed ** (nbits - 2) + + wires = [a, b] + vals_raw = [[random.randint(0, maxpos) for _ in range(7)] for _ in wires] + vals = [[decimal_to_posit(j, nbits, es) for j in i] for i in vals_raw] + + out_vals = utils.sim_and_ret_out(out, wires, vals) + true_result_raw = [x * y for x, y in zip(vals_raw[0], vals_raw[1])] + true_result = [decimal_to_posit(i, nbits, es) for i in true_result_raw] + + for sim, expected in zip(out_vals, true_result): + self.assertLessEqual(abs(sim - expected), 1) + +if __name__ == "__main__": + unittest.main()