Skip to content

Commit 34a3b49

Browse files
committed
Add nmod_poly_ctx
1 parent c2b1f67 commit 34a3b49

File tree

2 files changed

+115
-57
lines changed

2 files changed

+115
-57
lines changed

src/flint/types/nmod_poly.pxd

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,26 @@
1-
from flint.flint_base.flint_base cimport flint_poly
2-
1+
from flint.flintlib.nmod cimport nmod_t
32
from flint.flintlib.nmod_poly cimport nmod_poly_t
43
from flint.flintlib.flint cimport mp_limb_t
54

5+
from flint.flint_base.flint_base cimport flint_poly
6+
7+
from flint.types.nmod cimport nmod_ctx
8+
9+
10+
cdef class nmod_poly_ctx:
11+
cdef nmod_ctx ctx
12+
cdef nmod_t mod
13+
cdef bint _is_prime
14+
15+
cdef nmod_poly_set_list(self, nmod_poly_t poly, list val)
16+
cdef int any_as_nmod(self, mp_limb_t * val, obj) except -1
17+
cdef any_as_nmod_poly(self, obj)
18+
19+
620
cdef class nmod_poly(flint_poly):
721
cdef nmod_poly_t val
22+
cdef nmod_poly_ctx ctx
23+
824
cpdef long length(self)
925
cpdef long degree(self)
1026
cpdef mp_limb_t modulus(self)

src/flint/types/nmod_poly.pyx

Lines changed: 97 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -5,49 +5,83 @@ from flint.types.fmpz cimport fmpz, any_as_fmpz
55
from flint.types.fmpz_poly cimport any_as_fmpz_poly
66
from flint.types.fmpz_poly cimport fmpz_poly
77
from flint.types.nmod cimport any_as_nmod_ctx
8-
from flint.types.nmod cimport nmod
8+
from flint.types.nmod cimport nmod, nmod_ctx
99

1010
from flint.flintlib.nmod_vec cimport *
1111
from flint.flintlib.nmod_poly cimport *
1212
from flint.flintlib.nmod_poly_factor cimport *
1313
from flint.flintlib.fmpz_poly cimport fmpz_poly_get_nmod_poly
14-
from flint.flintlib.ulong_extras cimport n_gcdinv
14+
from flint.flintlib.ulong_extras cimport n_gcdinv, n_is_prime
1515

1616
from flint.utils.flint_exceptions import DomainError
1717

1818

19-
cdef any_as_nmod_poly(obj, nmod_t mod):
20-
cdef nmod_poly r
21-
cdef mp_limb_t v
22-
# XXX: should check that modulus is the same here, and not all over the place
23-
if typecheck(obj, nmod_poly):
19+
_nmod_poly_ctx_cache = {}
20+
21+
22+
cdef nmod_ctx any_as_nmod_poly_ctx(obj):
23+
"""Convert an int to an nmod_ctx."""
24+
if typecheck(obj, nmod_poly_ctx):
2425
return obj
25-
if any_as_nmod(&v, obj, mod):
26-
r = nmod_poly.__new__(nmod_poly)
27-
nmod_poly_init(r.val, mod.n)
28-
nmod_poly_set_coeff_ui(r.val, 0, v)
29-
return r
30-
x = any_as_fmpz_poly(obj)
31-
if x is not NotImplemented:
32-
r = nmod_poly.__new__(nmod_poly)
33-
nmod_poly_init(r.val, mod.n) # XXX: create flint _nmod_poly_set_modulus for this?
34-
fmpz_poly_get_nmod_poly(r.val, (<fmpz_poly>x).val)
35-
return r
26+
if typecheck(obj, int):
27+
ctx = _nmod_poly_ctx_cache.get(obj)
28+
if ctx is None:
29+
ctx = nmod_poly_ctx(obj)
30+
_nmod_poly_ctx_cache[obj] = ctx
31+
return ctx
3632
return NotImplemented
3733

38-
cdef nmod_poly_set_list(nmod_poly_t poly, list val):
39-
cdef long i, n
40-
cdef nmod_t mod
41-
cdef mp_limb_t v
42-
nmod_init(&mod, nmod_poly_modulus(poly)) # XXX
43-
n = PyList_GET_SIZE(val)
44-
nmod_poly_fit_length(poly, n)
45-
for i from 0 <= i < n:
46-
c = val[i]
47-
if any_as_nmod(&v, val[i], mod):
48-
nmod_poly_set_coeff_ui(poly, i, v)
49-
else:
50-
raise TypeError("unsupported coefficient in list")
34+
35+
cdef class nmod_poly_ctx:
36+
"""
37+
Context object for creating :class:`~.nmod_poly` initalised
38+
with modulus :math:`N`.
39+
40+
>>> nmod_ctx(17)
41+
nmod_ctx(17)
42+
43+
"""
44+
def __init__(self, mod):
45+
cdef mp_limb_t m
46+
m = mod
47+
nmod_init(&self.mod, m)
48+
self.ctx = nmod_ctx(mod)
49+
self._is_prime = n_is_prime(m)
50+
51+
cdef int any_as_nmod(self, mp_limb_t * val, obj) except -1:
52+
return self.ctx.any_as_nmod(val, obj)
53+
54+
cdef any_as_nmod_poly(self, obj):
55+
cdef nmod_poly r
56+
cdef mp_limb_t v
57+
# XXX: should check that modulus is the same here, and not all over the place
58+
if typecheck(obj, nmod_poly):
59+
return obj
60+
if self.ctx.any_as_nmod(&v, obj):
61+
r = nmod_poly.__new__(nmod_poly)
62+
nmod_poly_init(r.val, self.mod.n)
63+
nmod_poly_set_coeff_ui(r.val, 0, v)
64+
return r
65+
x = any_as_fmpz_poly(obj)
66+
if x is not NotImplemented:
67+
r = nmod_poly.__new__(nmod_poly)
68+
nmod_poly_init(r.val, self.mod.n) # XXX: create flint _nmod_poly_set_modulus for this?
69+
fmpz_poly_get_nmod_poly(r.val, (<fmpz_poly>x).val)
70+
return r
71+
return NotImplemented
72+
73+
cdef nmod_poly_set_list(self, nmod_poly_t poly, list val):
74+
cdef long i, n
75+
cdef mp_limb_t v
76+
n = PyList_GET_SIZE(val)
77+
nmod_poly_fit_length(poly, n)
78+
for i from 0 <= i < n:
79+
c = val[i]
80+
if self.any_as_nmod(&v, val[i]):
81+
nmod_poly_set_coeff_ui(poly, i, v)
82+
else:
83+
raise TypeError("unsupported coefficient in list")
84+
5185

5286
cdef class nmod_poly(flint_poly):
5387
"""
@@ -79,24 +113,32 @@ cdef class nmod_poly(flint_poly):
79113
def __dealloc__(self):
80114
nmod_poly_clear(self.val)
81115

82-
def __init__(self, val=None, ulong mod=0):
116+
def __init__(self, val=None, mod=0):
83117
cdef ulong m2
84118
cdef mp_limb_t v
119+
cdef nmod_poly_ctx ctx
120+
85121
if typecheck(val, nmod_poly):
86122
m2 = nmod_poly_modulus((<nmod_poly>val).val)
87123
if m2 != mod:
88124
raise ValueError("different moduli!")
89125
nmod_poly_init(self.val, m2)
90126
nmod_poly_set(self.val, (<nmod_poly>val).val)
127+
self.ctx = (<nmod_poly>val).ctx
91128
else:
92129
if mod == 0:
93130
raise ValueError("a nonzero modulus is required")
94-
nmod_poly_init(self.val, mod)
131+
ctx = any_as_nmod_poly_ctx(mod)
132+
if ctx is NotImplemented:
133+
raise TypeError("cannot create nmod_poly_ctx from input of type %s", type(mod))
134+
135+
self.ctx = ctx
136+
nmod_poly_init(self.val, ctx.mod.n)
95137
if typecheck(val, fmpz_poly):
96138
fmpz_poly_get_nmod_poly(self.val, (<fmpz_poly>val).val)
97139
elif typecheck(val, list):
98-
nmod_poly_set_list(self.val, val)
99-
elif any_as_nmod(&v, val, self.val.mod):
140+
ctx.nmod_poly_set_list(self.val, val)
141+
elif ctx.any_as_nmod(&v, val):
100142
nmod_poly_fit_length(self.val, 1)
101143
nmod_poly_set_coeff_ui(self.val, 0, v)
102144
else:
@@ -178,7 +220,7 @@ cdef class nmod_poly(flint_poly):
178220
cdef mp_limb_t v
179221
if i < 0:
180222
raise ValueError("cannot assign to index < 0 of polynomial")
181-
if any_as_nmod(&v, x, self.val.mod):
223+
if self.ctx.any_as_nmod(&v, x):
182224
nmod_poly_set_coeff_ui(self.val, i, v)
183225
else:
184226
raise TypeError("cannot set element of type %s" % type(x))
@@ -291,7 +333,7 @@ cdef class nmod_poly(flint_poly):
291333
9*x^4 + 12*x^3 + 10*x^2 + 4*x + 1
292334
"""
293335
cdef nmod_poly res
294-
other = any_as_nmod_poly(other, (<nmod_poly>self).val.mod)
336+
other = self.ctx.any_as_nmod_poly(other)
295337
if other is NotImplemented:
296338
raise TypeError("cannot convert input to nmod_poly")
297339
res = nmod_poly.__new__(nmod_poly)
@@ -316,11 +358,11 @@ cdef class nmod_poly(flint_poly):
316358
147*x^3 + 159*x^2 + 4*x + 7
317359
"""
318360
cdef nmod_poly res
319-
g = any_as_nmod_poly(other, self.val.mod)
361+
g = self.ctx.any_as_nmod_poly(other)
320362
if g is NotImplemented:
321363
raise TypeError(f"cannot convert {other = } to nmod_poly")
322364

323-
h = any_as_nmod_poly(modulus, self.val.mod)
365+
h = self.any_as_nmod_poly(modulus)
324366
if h is NotImplemented:
325367
raise TypeError(f"cannot convert {modulus = } to nmod_poly")
326368

@@ -334,11 +376,11 @@ cdef class nmod_poly(flint_poly):
334376

335377
def __call__(self, other):
336378
cdef mp_limb_t c
337-
if any_as_nmod(&c, other, self.val.mod):
379+
if self.ctx.any_as_nmod(&c, other):
338380
v = nmod(0, self.modulus())
339381
(<nmod>v).val = nmod_poly_evaluate_nmod(self.val, c)
340382
return v
341-
t = any_as_nmod_poly(other, self.val.mod)
383+
t = self.ctx.any_as_nmod_poly(other)
342384
if t is not NotImplemented:
343385
r = nmod_poly.__new__(nmod_poly)
344386
nmod_poly_init_preinv((<nmod_poly>r).val, self.val.mod.n, self.val.mod.ninv)
@@ -369,7 +411,7 @@ cdef class nmod_poly(flint_poly):
369411

370412
def _add_(s, t):
371413
cdef nmod_poly r
372-
t = any_as_nmod_poly(t, (<nmod_poly>s).val.mod)
414+
t = s.ctx.any_as_nmod_poly(t)
373415
if t is NotImplemented:
374416
return t
375417
if (<nmod_poly>s).val.mod.n != (<nmod_poly>t).val.mod.n:
@@ -395,20 +437,20 @@ cdef class nmod_poly(flint_poly):
395437
return r
396438

397439
def __sub__(s, t):
398-
t = any_as_nmod_poly(t, (<nmod_poly>s).val.mod)
440+
t = s.ctx.any_as_nmod_poly(t)
399441
if t is NotImplemented:
400442
return t
401443
return s._sub_(t)
402444

403445
def __rsub__(s, t):
404-
t = any_as_nmod_poly(t, (<nmod_poly>s).val.mod)
446+
t = s.any_as_nmod_poly(t)
405447
if t is NotImplemented:
406448
return t
407449
return t._sub_(s)
408450

409451
def _mul_(s, t):
410452
cdef nmod_poly r
411-
t = any_as_nmod_poly(t, (<nmod_poly>s).val.mod)
453+
t = s.any_as_nmod_poly(t)
412454
if t is NotImplemented:
413455
return t
414456
if (<nmod_poly>s).val.mod.n != (<nmod_poly>t).val.mod.n:
@@ -425,7 +467,7 @@ cdef class nmod_poly(flint_poly):
425467
return s._mul_(t)
426468

427469
def __truediv__(s, t):
428-
t = any_as_nmod_poly(t, (<nmod_poly>s).val.mod)
470+
t = s.any_as_nmod_poly(t)
429471
if t is NotImplemented:
430472
return t
431473
res, r = s._divmod_(t)
@@ -434,7 +476,7 @@ cdef class nmod_poly(flint_poly):
434476
return res
435477

436478
def __rtruediv__(s, t):
437-
t = any_as_nmod_poly(t, (<nmod_poly>s).val.mod)
479+
t = s.any_as_nmod_poly(t)
438480
if t is NotImplemented:
439481
return t
440482
res, r = t._divmod_(s)
@@ -454,13 +496,13 @@ cdef class nmod_poly(flint_poly):
454496
return r
455497

456498
def __floordiv__(s, t):
457-
t = any_as_nmod_poly(t, (<nmod_poly>s).val.mod)
499+
t = s.any_as_nmod_poly(t)
458500
if t is NotImplemented:
459501
return t
460502
return s._floordiv_(t)
461503

462504
def __rfloordiv__(s, t):
463-
t = any_as_nmod_poly(t, (<nmod_poly>s).val.mod)
505+
t = s.any_as_nmod_poly(t)
464506
if t is NotImplemented:
465507
return t
466508
return t._floordiv_(s)
@@ -479,13 +521,13 @@ cdef class nmod_poly(flint_poly):
479521
return P, Q
480522

481523
def __divmod__(s, t):
482-
t = any_as_nmod_poly(t, (<nmod_poly>s).val.mod)
524+
t = s.any_as_nmod_poly(t)
483525
if t is NotImplemented:
484526
return t
485527
return s._divmod_(t)
486528

487529
def __rdivmod__(s, t):
488-
t = any_as_nmod_poly(t, (<nmod_poly>s).val.mod)
530+
t = s.any_as_nmod_poly(t)
489531
if t is NotImplemented:
490532
return t
491533
return t._divmod_(s)
@@ -534,7 +576,7 @@ cdef class nmod_poly(flint_poly):
534576
if e < 0:
535577
raise ValueError("Exponent must be non-negative")
536578

537-
modulus = any_as_nmod_poly(modulus, (<nmod_poly>self).val.mod)
579+
modulus = self.ctx.any_as_nmod_poly(modulus)
538580
if modulus is NotImplemented:
539581
raise TypeError("cannot convert input to nmod_poly")
540582

@@ -556,7 +598,7 @@ cdef class nmod_poly(flint_poly):
556598

557599
# To optimise powering, we precompute the inverse of the reverse of the modulus
558600
if mod_rev_inv is not None:
559-
mod_rev_inv = any_as_nmod_poly(mod_rev_inv, (<nmod_poly>self).val.mod)
601+
mod_rev_inv = self.any_as_nmod_poly(mod_rev_inv)
560602
if mod_rev_inv is NotImplemented:
561603
raise TypeError(f"Cannot interpret {mod_rev_inv} as a polynomial")
562604
else:
@@ -585,7 +627,7 @@ cdef class nmod_poly(flint_poly):
585627
586628
"""
587629
cdef nmod_poly res
588-
other = any_as_nmod_poly(other, (<nmod_poly>self).val.mod)
630+
other = self.any_as_nmod_poly(other)
589631
if other is NotImplemented:
590632
raise TypeError("cannot convert input to nmod_poly")
591633
if self.val.mod.n != (<nmod_poly>other).val.mod.n:
@@ -597,7 +639,7 @@ cdef class nmod_poly(flint_poly):
597639

598640
def xgcd(self, other):
599641
cdef nmod_poly res1, res2, res3
600-
other = any_as_nmod_poly(other, (<nmod_poly>self).val.mod)
642+
other = self.any_as_nmod_poly(other)
601643
if other is NotImplemented:
602644
raise TypeError("cannot convert input to fmpq_poly")
603645
res1 = nmod_poly.__new__(nmod_poly)

0 commit comments

Comments
 (0)