### Lenstra's algorithm for divisors in residue classes

Posted:

**Thu Jan 11, 2018 6:17 pm UTC**I'm attempting to implement Lenstra's algorithm for finding divisors in residue classes as described as Algorithm 9.1.29 in Cohen's A Course in Computational Algebraic Number Theory (also available here in PDF form).

My initial translation of the description into Python-with-goto is as follows:

Translating this into proper Python yields

I tested this by evaluating list(moddiv(19, 101, 40320)), which should return [120], but it either produces a ZeroDivisionError (if run as shown) or enters an infinite loop (if run with the suggested omission). Clearly I did something wrong, but I can't figure out what. Can anybody help?

My initial translation of the description into Python-with-goto is as follows:

Code: Select all

`def isqrt(n): # Shamelessly stolen from https://codegolf.stackexchange.com/a/9088.`

if n < 0: return int(n)

c = n*4//3

d = c.bit_length()

a = d>>1

if d&1:

x = 1 << a

y = (x + (n >> a)) >> 1

else:

x = (3 << a) >> 2

y = (x + (c >> a)) >> 1

if x != y:

x, y = y, (y + n//y) >> 1

while y < x: x, y = y, (y + n//y) >> 1

return x

def xgcd_goto(a, b): # Algorithm 1.3.6

assert a >= 0 <= b

# 1: Initialize

u = 1

d = a

if b == 0: v = 0; return (u, v, d)

v1 = 0

v3 = b

# 2: Finished?

if v3 == 0: v = (d - a*u) // b; return (u, v, d)

# 3: Euclidean step

q, t3 = divmod(d, v3)

t1 = u - q * v1

u = v1

d = v3

v1 = t1

v3 = t3

goto 2

def moddiv_goto(r, s, n): # Algorithm 9.1.29

assert 0 <= r < s < n and gcd(r,s) == 1 and s**3 > n

# 1: Initialization

u, v, _ = xgcd(r, s)

rprime = (u * n) % s

assert 0 <= rprime < s

a0 = s

b0 = 0

c0 = 0

a1 = (u * rprime) % s

b1 = 1

c1s = u * (n - r * rprime)

assert c1s % s == 0

c1 = (c1s // s) % s

j = 1

if a1 == 0: a1 = s

assert 0 < a1 <= s

alist, blist, clist = [a0, a1], [b0, b1], [c0, c1]

# 2: Compute c

aj, bj, cj = alist[j], blist[j], clist[j]

if j % 2 == 0: c = cj

else: c = cj + s * ( (n + s**2 * (aj*bj-cj) ) // s**3 )

if c < 2 * aj * bj: goto 6 # TODO: This may belong under the "else" in the previous line.

# 3: Solve quadratic equation

A = c*s + aj*r + bj*rprime

B = aj * bj * n

D = A*A - 4*B

d = isqrt(D)

if d**2 != D: goto 5

t12, t22 = A + d, A - d

assert t12 % 2 == t22 % 2 == 0

t1, t2 = t12//2, t22//2

assert t1**2 - A * t1 + B == t2**2 - A * t2 + B == 0

# 4: Divisor found?

if t1 % aj == 0 and t2 % bj == 0 and (t1 // aj - r) % s == 0 and (t2 // bj - rprime) % s == 0: yield t1 // aj

# 5: Other value of c

if j % 2 == 0 and c > 0: c = c - s; goto 3

# 6: Next j

if aj == 0: return

j = j + 1

if j % 2 == 0: qj = alist[j-2] // alist[j-1]

else: qj = (alist[j-2] - 1) // alist[j-1]

aj = alist[j-2] - qj * alist[j-1]

bj = blist[j-2] - qj * blist[j-1]

cj = clist[j-2] - qj * clist[j-1]

alist.append(aj)

blist.append(bj)

clist.append(cj)

goto 2

Translating this into proper Python yields

Code: Select all

`def isqrt(n): # Shamelessly stolen from https://codegolf.stackexchange.com/a/9088.`

if n < 0: return int(n)

c = n*4//3

d = c.bit_length()

a = d>>1

if d&1:

x = 1 << a

y = (x + (n >> a)) >> 1

else:

x = (3 << a) >> 2

y = (x + (c >> a)) >> 1

if x != y:

x, y = y, (y + n//y) >> 1

while y < x: x, y = y, (y + n//y) >> 1

return x

def xgcd(a, b): # Algorithm 1.3.6

assert a >= 0 <= b

if b == 0: return (1, 0, d)

u, d, v, w = 1, a, 0, b

while w != 0:

q, r = divmod(d, w)

u, d, v, w = v, w, u - q * v, r

return (u, (d - a*u) // b, d)

def moddiv(r, s, n): # Algorithm 9.1.29

assert 0 <= r < s < n < s**3 and xgcd(r,s)[2] == 1

# 1: Initialization

u = xgcd(r,s)[0]

rprime = (u * n) % s

j, a1, c1 = 1, (u * rprime) % s, u * (n - r * rprime)

assert c1 % s == 0

if a1 == 0: a1 = s

alist, blist, clist = [s, a1], [0, 1], [0, (c1 // s) % s]

while True:

# 2: Compute c

aj, bj, cj, flag = alist[j], blist[j], clist[j], True

c = cj if j % 2 == 0 else cj + s * ( (n + s**2 * (aj*bj-cj) ) // s**3 )

if c < 2 * aj * bj and (j == 1): flag = False # TODO: The second condition should perhaps be omitted.

while flag:

# 3: Solve quadratic equation

A, B = c*s + aj*r + bj*rprime, aj * bj * n

D = A*A - 4*B

d = isqrt(D)

if d**2 == D: #goto 5

t1, t2 = (A + d) // 2, (A - d) // 2

assert (A - t1) * t1 == (A - t2) * t2 == B

# 4: Divisor found?

if t1 % aj == t2 % bj == (t1//aj - r) % s == (t2//bj - rprime) % s == 0: yield t1//aj

# 5: Other value of c

if j % 2 == 0 < c: c -= s

else: flag = False

# 6: Next j

if aj == 0: return

j += 1

qj = (alist[j-2] - (j % 2)) // alist[j-1]

aj = alist[j-2] - qj * alist[j-1]

bj = blist[j-2] - qj * blist[j-1]

cj = clist[j-2] - qj * clist[j-1]

alist.append(aj)

blist.append(bj)

clist.append(cj)

#goto 2

I tested this by evaluating list(moddiv(19, 101, 40320)), which should return [120], but it either produces a ZeroDivisionError (if run as shown) or enters an infinite loop (if run with the suggested omission). Clearly I did something wrong, but I can't figure out what. Can anybody help?