from math import sqrt, floor, ceil, gcd import fractions import random def xgcd(a, b): """ Given two integers (a, b), return (g, u, v) where g is the GCD of a and b, and (u,v) are the coefficients of the Bezout relation a*u + b*v == g. """ (x0, x1, y0, y1) = (1, 0, 0, 1) while b != 0: (q, a, b) = (a // b, b, a % b) (x0, x1) = (x1, x0 - q * x1) (y0, y1) = (y1, y0 - q * y1) return (b, x0, y0) def keygen(l, k): """ Generate an instance of the problem. Returns (public, private) """ a = 2 n = 2 while gcd(a, n) != 1: n = random.getrandbits(l) a = random.randrange(n) x = 2 y = 2 while gcd(x, y) != 1: x = random.getrandbits(k) y = random.getrandbits(k) _, t, z = xgcd(x, y) z = -z assert x*t - z*y == 1 aprime = a*x + n*y nprime = a*z + n*t return (aprime, nprime), (a, n, x, y, z, t) ################################################################ def dot(u, v): """ Return the dot product of the two vectors u and v """ r = 0 for (x, y) in zip(u, v): r += x * y return r def sqnorm(u): """ Return the square of the norm of the vector u """ return dot(u, u) def aupbv(a, u, b, v): """ Given two vectors (u, v) and two scalar (a, b), return the vector a*u + b*v """ z = [] for (x, y) in zip(u, v): z.append(a * x + b * y) return z def norm(u): return sqrt(dot(u, u)) def scale(s, v): """ Return 1/s * v, where s is a (potentially large) scalar and v is a vector. This yields a vector of rationals """ f = fractions.Fraction(1, s) return [f * x for x in v] def lagrange_reduction(u, v): """ Given a basis (u, v) of a two-dimensional lattice, return the two shortest vectors. This algorithm operates only on integers. """ if sqnorm(u) < sqnorm(v): tmp = u u = v v = tmp while True: f = fractions.Fraction(dot(u, v), dot(v, v)) q = round(f) r = aupbv(1, u, -q, v) u = v v = r if sqnorm(u) <= sqnorm(v): return (u, v) def enumerate(r, s, B, N): """Given a lattice L spanned by the two vectors (r, s), assumed to be the shortest, return the list of all vectors of L with norm less than B. It is assumed that r and s have norm about N, and that B is close to N. """ short_vectors = [] # Compute volume Vol2 = sqnorm(r) * sqnorm(s) - dot(r, s)**2 sVol = sqrt(fractions.Fraction(Vol2, N**4)) # Gram-Schmidt orthogonalization mu = fractions.Fraction(dot(r, s), dot(r, r)) snorm_rstar = norm(scale(N, r)) snorm_sstar = sVol / snorm_rstar # enumeration sB = fractions.Fraction(B, N) x2_max = floor(sB / snorm_sstar) for x2 in range(-x2_max, x2_max + 1): x1_max = floor(sB / snorm_rstar - mu*x2) x1_min = ceil(-sB / snorm_rstar - mu*x2) for x1 in range(x1_min, x1_max + 1): w = aupbv(x1, r, x2, s) if sqnorm(w) <= B * B: short_vectors.append(w) return short_vectors def linear_program(aprime, nprime, l, k): """ Return the set of all (gamma, delta) such that |gamma * aprime + delta * nprime| <= 2**l |gamma| <= 2**k |delta| <= 2**k """ u = [aprime, 2**(l - k), 0] v = [nprime, 0, 2**(l - k)] r, s = lagrange_reduction(u, v) B = 2**(l+1) solutions = [] short_vectors = enumerate(r, s, B, 2**l) for (a, g, d) in short_vectors: g = g // 2**(l - k) d = d // 2**(l - k) if abs(a) <= 2**l and abs(g) <= 2**k and \ abs(d) <= 2**k: solutions.append([g, d]) return solutions def break_protocol(aprime, nprime, l, k): """ Given the server input (a', n'), produce the secrets of the client (a, n). """ Omega = linear_program(aprime, nprime, l, k) secrets = [] for (t, mz) in Omega: for (my, x) in Omega: if abs(t * x - my * mz) != 1: continue a = t * aprime + mz * nprime n = my * aprime + x * nprime if 0 < a < n: secrets.append([a, n]) return secrets ######################################################"" l = 2048 k = 128 pk, sk = keygen(l, k) aprime, nprime = pk candidates = break_protocol(aprime, nprime, l, k) print("{} candidates found".format(len(candidates))) a, n, *_ = sk assert [a, n] in candidates print("correct solution found")