Skip to content

Commit

Permalink
remove float from compress/decompress
Browse files Browse the repository at this point in the history
  • Loading branch information
GiacomoPope committed Jul 19, 2024
1 parent 5eb35bf commit 49feac4
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 26 deletions.
6 changes: 3 additions & 3 deletions benchmark_kyber.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ def profile_kyber(Kyber):
gvars = {}
lvars = {"Kyber": Kyber, "c": c, "pk": pk, "sk": sk}

cProfile.runctx("Kyber.keygen()", globals=gvars, locals=lvars, sort=1)
cProfile.runctx("Kyber.enc(pk)", globals=gvars, locals=lvars, sort=1)
cProfile.runctx("Kyber.dec(c, sk)", globals=gvars, locals=lvars, sort=1)
cProfile.runctx("[Kyber.keygen() for _ in range(100)]", globals=gvars, locals=lvars, sort=1)
cProfile.runctx("[Kyber.enc(pk) for _ in range(100)]", globals=gvars, locals=lvars, sort=1)
cProfile.runctx("[Kyber.dec(c, sk) for _ in range(100)]", globals=gvars, locals=lvars, sort=1)

def benchmark_kyber(Kyber, name, count):
keygen_times = []
Expand Down
33 changes: 15 additions & 18 deletions polynomials.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import random
from utils import *
from utils import bytes_to_bits, bitstring_to_bytes, compress, decompress

class PolynomialRing:
"""
Expand Down Expand Up @@ -139,28 +139,25 @@ def encode(self, l=None):

def compress(self, d):
"""
Compress the polynomial by compressing each coefficent
Compress the polynomial by compressing each coefficient
NOTE: This is lossy compression
"""
compress_mod = 2**d
compress_float = compress_mod / self.parent.q
self.coeffs = [round_up(compress_float * c) % compress_mod for c in self.coeffs]
self.coeffs = [compress(c, d, self.parent.q) for c in self.coeffs]
return self

def decompress(self, d):
"""
Decompress the polynomial by decompressing each coefficent
Decompress the polynomial by decompressing each coefficient
NOTE: This as compression is lossy, we have
x' = decompress(compress(x)), which x' != x, but is
close in magnitude.
"""
decompress_float = self.parent.q / 2**d
self.coeffs = [round_up(decompress_float * c) for c in self.coeffs ]
self.coeffs = [decompress(c, d, self.parent.q) for c in self.coeffs ]
return self

def add_mod_q(self, x, y):
"""
add two coefficents modulo q
add two coefficients modulo q
"""
tmp = x + y
if tmp >= self.parent.q:
Expand All @@ -169,7 +166,7 @@ def add_mod_q(self, x, y):

def sub_mod_q(self, x, y):
"""
sub two coefficents modulo q
sub two coefficients modulo q
"""
tmp = x - y
if tmp < 0:
Expand Down Expand Up @@ -231,13 +228,13 @@ def __neg__(self):
def __add__(self, other):
if isinstance(other, PolynomialRing.Polynomial):
if self.is_ntt ^ other.is_ntt:
raise ValueError(f"Both or neither polynomials must be in NTT form before multiplication")
raise ValueError("Both or neither polynomials must be in NTT form before multiplication")
new_coeffs = [self.add_mod_q(x,y) for x,y in zip(self.coeffs, other.coeffs)]
elif isinstance(other, int):
new_coeffs = self.coeffs.copy()
new_coeffs[0] = self.add_mod_q(new_coeffs[0], other)
else:
raise NotImplementedError(f"Polynomials can only be added to each other")
raise NotImplementedError("Polynomials can only be added to each other")
return self.parent(new_coeffs, is_ntt=self.is_ntt)

def __radd__(self, other):
Expand All @@ -250,13 +247,13 @@ def __iadd__(self, other):
def __sub__(self, other):
if isinstance(other, PolynomialRing.Polynomial):
if self.is_ntt ^ other.is_ntt:
raise ValueError(f"Both or neither polynomials must be in NTT form before multiplication")
raise ValueError("Both or neither polynomials must be in NTT form before multiplication")
new_coeffs = [self.sub_mod_q(x,y) for x,y in zip(self.coeffs, other.coeffs)]
elif isinstance(other, int):
new_coeffs = self.coeffs.copy()
new_coeffs[0] = self.sub_mod_q(new_coeffs[0], other)
else:
raise NotImplementedError(f"Polynomials can only be subracted from each other")
raise NotImplementedError("Polynomials can only be subracted from each other")
return self.parent(new_coeffs, is_ntt=self.is_ntt)

def __rsub__(self, other):
Expand All @@ -271,13 +268,13 @@ def __mul__(self, other):
if self.is_ntt and other.is_ntt:
return self.ntt_multiplication(other)
elif self.is_ntt ^ other.is_ntt:
raise ValueError(f"Both or neither polynomials must be in NTT form before multiplication")
raise ValueError("Both or neither polynomials must be in NTT form before multiplication")
else:
new_coeffs = self.schoolbook_multiplication(other)
elif isinstance(other, int):
new_coeffs = [(c * other) % self.parent.q for c in self.coeffs]
else:
raise NotImplementedError(f"Polynomials can only be multiplied by each other, or scaled by integers")
raise NotImplementedError("Polynomials can only be multiplied by each other, or scaled by integers")
return self.parent(new_coeffs, is_ntt=self.is_ntt)

def __rmul__(self, other):
Expand All @@ -289,11 +286,11 @@ def __imul__(self, other):

def __pow__(self, n):
if not isinstance(n, int):
raise TypeError(f"Exponentiation of a polynomial must be done using an integer.")
raise TypeError("Exponentiation of a polynomial must be done using an integer.")

# Deal with negative scalar multiplication
if n < 0:
raise ValueError(f"Negative powers are not supported for elements of a Polynomial Ring")
raise ValueError("Negative powers are not supported for elements of a Polynomial Ring")
f = self
g = self.parent(1, is_ntt=self.is_ntt)
while n > 0:
Expand Down
21 changes: 16 additions & 5 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,27 @@ def bitstring_to_bytes(s):
Convert a string of bits to bytes with bytes stored little endian
"""
return bytes([int(s[i:i+8][::-1], 2) for i in range(0, len(s), 8)])

def round_up(x):

def compress(x, d, q):
"""
Compute round((2^d / q) * x) % 2^d
"""
t = 1 << d
q_over_2 = q // 2
y = (t * x + q_over_2) // q
return y % t

def decompress(x, d, q):
"""
Round x.5 up always
Compute round((q / 2^d) * x)
"""
return round(x + 0.000001)
t = 1 << (d - 1)
y = (q * x + t) >> d
return y

def xor_bytes(a, b):
"""
XOR two byte arrays, assume that they are
of the same length
"""
return bytes(a^b for a,b in zip(a,b))
return bytes(a^b for a,b in zip(a,b))

0 comments on commit 49feac4

Please sign in to comment.