3
\$\begingroup\$

I wrote code for the multiple polynomial quadratic sieve (MPQS) here:

import numpy as np import random import logging import time from math import sqrt, ceil, floor, exp, log2, log, isqrt from sympy import nextprime logging.basicConfig( format='[%(levelname)s] %(asctime)s - %(message)s', level=logging.INFO ) logger = logging.getLogger(__name__) _known_primes = [2, 3] def init_known_primes(limit=1000): global _known_primes _known_primes += [x for x in range(5, limit, 2) if is_prime(x)] logger.info("Initialized _known_primes up to %d. Total known primes: %d", limit, len(_known_primes)) def _try_composite(a, d, n, s): if pow(a, d, n) == 1: return False for i in range(s): if pow(a, 2**i * d, n) == n - 1: return False return True def is_prime(n, _precision_for_huge_n=16): if n in _known_primes: return True if any((n % p) == 0 for p in _known_primes) or n in (0, 1): return n in _known_primes d, s = n - 1, 0 while d % 2 == 0: d >>= 1 s += 1 if n < 1373653: return not any(_try_composite(a, d, n, s) for a in (2, 3)) if n < 25326001: return not any(_try_composite(a, d, n, s) for a in (2, 3, 5)) if n < 118670087467: if n == 3215031751: return False return not any(_try_composite(a, d, n, s) for a in (2, 3, 5, 7)) if n < 2152302898747: return not any(_try_composite(a, d, n, s) for a in (2, 3, 5, 7, 11)) if n < 3474749660383: return not any(_try_composite(a, d, n, s) for a in (2, 3, 5, 7, 11, 13)) if n < 341550071728321: return not any(_try_composite(a, d, n, s) for a in (2, 3, 5, 7, 11, 13, 17)) return not any(_try_composite(a, d, n, s) for a in _known_primes[:_precision_for_huge_n]) class QuadraticSieve: def __init__( self, M: int = 200000, B: int = 10000, T: int = 1 ): self.logger = logging.getLogger(__name__) self.prime_log_map = {} self.root_map = {} # Store hyperparameters self.M = M self.B = B self.T = T @staticmethod def gcd(a, b): """Compute GCD of two integers using Euclid's Algorithm.""" a, b = abs(a), abs(b) while a: a, b = b % a, a return b @staticmethod def legendre(n, p): """Compute the Legendre symbol (n/p).""" val = pow(n, (p - 1) // 2, p) return val - p if val > 1 else val @staticmethod def jacobi(a, m): """Compute the Jacobi symbol (a/m).""" a = a % m t = 1 while a != 0: while a % 2 == 0: a //= 2 if m % 8 in [3, 5]: t = -t a, m = m, a if a % 4 == 3 and m % 4 == 3: t = -t a %= m return t if m == 1 else 0 @staticmethod def modinv(n, p): return pow(n, -1, p) @staticmethod def hensel(r, p, n): x = (n - r * r) / p # f(b) = b ^ 2 - n z = QuadraticSieve.modinv(2 * r, p) % p # f'(b) = 2b y = (-x * z) % p return r + y * p def factorise_fast(self, value, factor_base): """Factor a number over the given factor base.""" factors = [] if value < 0: factors.append(-1) value = -value for factor in factor_base[1:]: while value % factor == 0: factors.append(factor) value //= factor return sorted(factors), value @staticmethod def tonelli_shanks(a, p): """Solve x^2 ≡ a (mod p) for x using the Tonelli-Shanks algorithm.""" a %= p if p % 8 in [3, 7]: x = pow(a, (p + 1) // 4, p) return x, p - x if p % 8 == 5: x = pow(a, (p + 3) // 8, p) if pow(x, 2, p) != a % p: x = (x * pow(2, (p - 1) // 4, p)) % p return x, p - x d = 2 symb = 0 while symb != -1: symb = QuadraticSieve.jacobi(d, p) d += 1 d -= 1 n = p - 1 s = 0 while n % 2 == 0: n //= 2 s += 1 t = n A = pow(a, t, p) D = pow(d, t, p) m = 0 for i in range(s): i1 = pow(2, s - 1 - i) i2 = (A * pow(D, m, p)) % p i3 = pow(i2, i1, p) if i3 == p - 1: m += pow(2, i) x = (pow(a, (t + 1) // 2, p) * pow(D, m // 2, p)) % p return x, p - x @staticmethod def gauss_elim(x): """Perform Gaussian elimination on a binary matrix over GF(2)""" x = x.astype(bool, copy=False) n, m = x.shape marks = [] for i in range(n): row = x[i] ones = np.flatnonzero(row) if ones.size == 0: continue pivot = ones[0] marks.append(pivot) mask = x[:, pivot].copy() mask[i] = False x[mask] ^= row return x.astype(np.int8, copy=False), sorted(marks) @staticmethod def find_null_space_GF2(reduced_matrix, pivot_rows): """Find null space vectors of the reduced binary matrix over GF(2)""" n, m = reduced_matrix.shape nulls = [] free_rows = [row for row in range(n) if row not in pivot_rows] k = 0 for row in free_rows: ones = np.where(reduced_matrix[row] == 1)[0] null = np.zeros(n) null[row] = 1 mask = np.isin(np.arange(n), pivot_rows) relevant_cols = reduced_matrix[:, ones] matching_rows = np.any(relevant_cols == 1, axis=1) null[mask & matching_rows] = 1 nulls.append(null) k += 1 if k == 4: # no need to find entire null space, just a few values will do break return np.asarray(nulls, dtype=np.int8) @staticmethod def prime_sieve(n): """Return list of primes up to n using Sieve of Eratosthenes.""" sieve_array = np.ones((n + 1,), dtype=bool) sieve_array[0], sieve_array[1] = False, False for i in range(2, int(n ** 0.5) + 1): if sieve_array[i]: sieve_array[i * 2::i] = False return np.where(sieve_array)[0].tolist() def find_b(self, N): """Determine the factor base bound B""" x = ceil(exp(sqrt(0.5 * log(N) * log(log(N))))) + 1 return x def get_smooth_b(self, N, B): """Build the factor base of primes p ≤ B where (N/p) = 1.""" primes = self.prime_sieve(B) factor_base = [-1, 2] self.prime_log_map[2] = 1 for p in primes[1:]: if self.legendre(N, p) == 1: factor_base.append(p) self.prime_log_map[p] = round(log2(p)) self.root_map[p] = self.tonelli_shanks(N, p) return factor_base def decide_bound(self, N, B=None): """Decide on bound B using heuristic if none provided.""" if B is None: B = self.find_b(N) self.logger.info("Using B = %d", B) return B def build_factor_base(self, N, B): """Build factor base for Quadratic Sieve.""" fb = self.get_smooth_b(N, B) self.logger.info("Factor base size: %d", len(fb)) return fb def sieve_interval(self, N, a, b, c, factor_base, M): sieve_values = [0] * (2 * M + 1) for p in factor_base: if p < 20: continue ainv = self.modinv(a, p) r1, r2 = self.root_map[p] r1 = int((ainv * (r1 - b)) % p) r2 = int((ainv * (r2 - b)) % p) r1 = (r1 + M) % p r2 = (r2 + M) % p for r in [r1, r2]: for i in range(r, 2 * M +1, p): sieve_values[i] += self.prime_log_map[p] return sieve_values def sieve(self, N, B, factor_base, M, partial=True): fb_len = len(factor_base) zero_row = [0] * fb_len error = 30 # large prime stuff large_prime_bound = B * 128 partials = {} lp_found = 0 num_relations = 0 target_relations = fb_len + self.T d = int(sqrt(sqrt(2 * N) / M)) interval = 2 * M threshold = int(log2(M * sqrt(N)) - error) num_poly = 0 matrix = [] relations = [] roots = [] while len(relations) < target_relations: print(str(round(len(relations) / target_relations * 100, 2)) + "% done") num_poly += 1 d = nextprime(d) while self.legendre(N, d) != 1: d = nextprime(d) a = d * d b = self.tonelli_shanks(N, d)[0] b = (b + (N - b * b) * QuadraticSieve.modinv(b + b, d)) % a c = (b * b - N) // a sieve_values = self.sieve_interval(N, a, b, c, factor_base, M) i = 0 - 1 x = -M - 1 while i < interval: i += 1 x += 1 val = sieve_values[i] if val > threshold: relation = a * x + b mark = False poly_val = int(a * x * x + 2 * b * x + c) local_factors, value = self.factorise_fast(poly_val, factor_base) if not partial and value != 1: continue if partial and value != 1 and value < large_prime_bound: if value not in partials: partials[value] = (relation, local_factors, poly_val * a) continue else: lp_found += 1 relation = relation * partials[value][0] local_factors += partials[value][1] poly_val = poly_val * partials[value][2] mark = True elif partial and value != 1: continue row = zero_row.copy() counts = {} for fac in local_factors: counts[fac] = counts.get(fac, 0) + 1 for idx, prime in enumerate(factor_base): row[idx] = counts.get(prime, 0) % 2 matrix.append(row) relations.append(relation) roots.append(poly_val * a) return matrix, relations, roots def solve_dependencies(self, matrix): """Solve for dependencies in GF(2).""" self.logger.info("Solving linear system in GF(2).") matrix = np.array(matrix).T reduced_matrix, marks = self.gauss_elim(matrix) null_basis = self.find_null_space_GF2(reduced_matrix.T, marks) return null_basis def extract_factors(self, N, relations, roots, dep_vectors): """Extract factors using dependency vectors.""" for r in dep_vectors: prod_left = 1 prod_right = 1 for idx, bit in enumerate(r): if bit == 1: prod_left *= relations[idx] prod_right *= roots[idx] sqrt_right = isqrt(prod_right) prod_left = prod_left % N sqrt_right = sqrt_right % N factor_candidate = self.gcd(N, prod_left - sqrt_right) if factor_candidate not in (1, N): other_factor = N // factor_candidate self.logger.info("Found factors: %d, %d", factor_candidate, other_factor) return factor_candidate, other_factor return 0, 0 def factor(self, N, B=None): """Main factorization method using the Quadratic Sieve algorithm.""" overall_start = time.time() self.logger.info("========== Quadratic Sieve V4 Start ==========") self.logger.info("Factoring N = %d", N) # Step 1: Decide Bound step_start = time.time() B = self.decide_bound(N, self.B) step_end = time.time() self.logger.info("Step 1 (Decide Bound) took %.3f seconds", step_end - step_start) # Step 2: Build Factor Base step_start = time.time() factor_base = self.build_factor_base(N, B) step_end = time.time() self.logger.info("Step 2 (Build Factor Base) took %.3f seconds", step_end - step_start) # Step 3: Sieve Phase step_start = time.time() matrix, relations, roots = self.sieve(N, B, factor_base, self.M) step_end = time.time() self.logger.info("Step 3 (Sieve Interval) took %.3f seconds", step_end - step_start) if len(matrix) < len(factor_base) + 1: self.logger.warning("Not enough smooth relations found. Try increasing the sieve interval.") return 0, 0 # Step 4: Solve for Dependencies step_start = time.time() dep_vectors = self.solve_dependencies(matrix) step_end = time.time() self.logger.info("Step 5 (Solve Dependencies) took %.3f seconds", step_end - step_start) # Step 5: Extract Factors step_start = time.time() f1, f2 = self.extract_factors(N, relations, roots, dep_vectors) step_end = time.time() self.logger.info("Step 6 (Extract Factors) took %.3f seconds", step_end - step_start) if f1 and f2: self.logger.info("Quadratic Sieve successful: %d * %d = %d", f1, f2, N) else: self.logger.warning("No non-trivial factors found with the current settings.") overall_end = time.time() self.logger.info("Total time for Quadratic Sieve: %.3f seconds", overall_end - overall_start) self.logger.info("========== Quadratic Sieve End ==========") return f1, f2 if __name__ == '__main__': N = 97245170828229363259 * 49966345331749027373 qs = QuadraticSieve(B=11812, M=59060, T=1) factor1, factor2 = qs.factor(N) if factor1 and factor2: print(f"Factors of N: {factor1} * {factor2} = {N}") else: print("Failed to factorize N with the current settings.") 

The main thing I've been working on making faster is the sieving portion of the code(the sieve_interval and sieve methods):

 def sieve_interval(self, N, a, b, c, factor_base, M): sieve_values = [0] * (2 * M + 1) for p in factor_base: if p < 20: continue ainv = self.modinv(a, p) r1, r2 = self.root_map[p] r1 = int((ainv * (r1 - b)) % p) r2 = int((ainv * (r2 - b)) % p) r1 = (r1 + M) % p r2 = (r2 + M) % p for r in [r1, r2]: for i in range(r, 2 * M +1, p): sieve_values[i] += self.prime_log_map[p] return sieve_values def sieve(self, N, B, factor_base, M, partial=True): fb_len = len(factor_base) zero_row = [0] * fb_len error = 30 # large prime stuff large_prime_bound = B * 128 partials = {} lp_found = 0 num_relations = 0 target_relations = fb_len + self.T d = int(sqrt(sqrt(2 * N) / M)) interval = 2 * M threshold = int(log2(M * sqrt(N)) - error) num_poly = 0 matrix = [] relations = [] roots = [] while len(relations) < target_relations: print(str(round(len(relations) / target_relations * 100, 2)) + "% done") num_poly += 1 d = nextprime(d) while self.legendre(N, d) != 1: d = nextprime(d) a = d * d b = self.tonelli_shanks(N, d)[0] b = (b + (N - b * b) * QuadraticSieve.modinv(b + b, d)) % a c = (b * b - N) // a sieve_values = self.sieve_interval(N, a, b, c, factor_base, M) i = 0 - 1 x = -M - 1 while i < interval: i += 1 x += 1 val = sieve_values[i] if val > threshold: relation = a * x + b mark = False poly_val = int(a * x * x + 2 * b * x + c) local_factors, value = self.factorise_fast(poly_val, factor_base) if not partial and value != 1: continue if partial and value != 1 and value < large_prime_bound: if value not in partials: partials[value] = (relation, local_factors, poly_val * a) continue else: lp_found += 1 relation = relation * partials[value][0] local_factors += partials[value][1] poly_val = poly_val * partials[value][2] mark = True elif partial and value != 1: continue row = zero_row.copy() counts = {} for fac in local_factors: counts[fac] = counts.get(fac, 0) + 1 for idx, prime in enumerate(factor_base): row[idx] = counts.get(prime, 0) % 2 matrix.append(row) relations.append(relation) roots.append(poly_val * a) return matrix, relations, roots 

It works relatively well in base python and as I was mainly comparing my timings to the code located here. It was faster during the sieving phase using Python however when I tried compiling both using PyPy, although my code got a speedup, it was still 10-20x slower than PyPy compiled code from the link above that I was basing my timings against(for a random 60 digit semiprime).

I tried various things like using a Pure Python nextprime instead of sympy, changing this section to be more pure python looped(which I thought the JIT would compile better):

row = zero_row.copy() counts = {} for fac in local_factors: counts[fac] = counts.get(fac, 0) + 1 for idx, prime in enumerate(factor_base): row[idx] = counts.get(prime, 0) % 2 

and a few other minor tweaks which I thought might be the problem however nothing worked.

What exactly do I need to change in the sieving portion of my code compiled by PyPy would be as fast as the code in the question linked above? This is the main thing but beyond this, are there any other parts of the code that could be improved for performance?

\$\endgroup\$
3
  • 2
    \$\begingroup\$I have rolled back Rev 3 → 2. Please see What should I do when someone answers my question?.\$\endgroup\$CommentedDec 30, 2024 at 20:37
  • \$\begingroup\$The answer didn't really answer the question I was asking. The main performance bottleneck is in the sieving code. I optimized that a bit myself not using anything from the answer.\$\endgroup\$
    – J. Doe
    CommentedDec 30, 2024 at 21:49
  • 2
    \$\begingroup\$Even though the answer didn't really answer the question asked: "Including revised versions of the code makes the question confusing, especially if someone later reviews the newer code."\$\endgroup\$CommentedJan 1 at 16:21

1 Answer 1

2
\$\begingroup\$

The OP question is about performance. Feel free to edit the question to append profiling measurements.

appropriate datastructure

_known_primes = [2, 3] 

Mostly we ask if a number is in the known primes. So this should definitely be a set(), letting us answer in \$O(1)\$ constant rather than linear time. And then we can ditch that global keyword.

type annotation

Annotation is optional, but this code really deserves to be annotated. Then you could tack on numpy@jit decorators and measure the amount of speedup.

list element pointers

Here M is "large".

 sieve_values = [0] * (2 * M + 1) 

That isn't terrible. But you might prefer a numpy ndarray or a python array, to avoid storing and chasing all those object pointers. However, at least one example semi-prime you're interested in has factors of 67- and 66-bits, which won't fit within np.int64, so you might need python's BigInts after all.

magic numbers

Bounds like \$1,373,653\$ or \$341,550,071,728,321\$, B=11812, M=59060, deserve a # comment on how you came up with them, or a paper citation so the interested reader can better understand the approach.

Also, we use a tuple for a fixed number of items where position dictates meaning, and a list for a variable number of similar items. So instead of for a in (2, 3, 5), prefer for a in [2, 3, 5].

Insert _ characters for readability: 341_550_071_728_321

\$\endgroup\$
1
  • \$\begingroup\$Because I am using PyPy, c-based extensions are highly detrimental to overall performance because of the JIT so sieving with numpy(which I already tried) just slows down the code. Everything else you mentioned just doesn't affect the actual running time of the code beyond maybe a couple hundred microseconds because the main time spent in the code is sieving which is what I want to make better written for PyPy compilation.\$\endgroup\$
    – J. Doe
    CommentedDec 30, 2024 at 21:59

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.