@@ -5,49 +5,83 @@ from flint.types.fmpz cimport fmpz, any_as_fmpz
5
5
from flint.types.fmpz_poly cimport any_as_fmpz_poly
6
6
from flint.types.fmpz_poly cimport fmpz_poly
7
7
from flint.types.nmod cimport any_as_nmod_ctx
8
- from flint.types.nmod cimport nmod
8
+ from flint.types.nmod cimport nmod, nmod_ctx
9
9
10
10
from flint.flintlib.nmod_vec cimport *
11
11
from flint.flintlib.nmod_poly cimport *
12
12
from flint.flintlib.nmod_poly_factor cimport *
13
13
from flint.flintlib.fmpz_poly cimport fmpz_poly_get_nmod_poly
14
- from flint.flintlib.ulong_extras cimport n_gcdinv
14
+ from flint.flintlib.ulong_extras cimport n_gcdinv, n_is_prime
15
15
16
16
from flint.utils.flint_exceptions import DomainError
17
17
18
18
19
- cdef any_as_nmod_poly(obj, nmod_t mod):
20
- cdef nmod_poly r
21
- cdef mp_limb_t v
22
- # XXX: should check that modulus is the same here, and not all over the place
23
- if typecheck(obj, nmod_poly):
19
+ _nmod_poly_ctx_cache = {}
20
+
21
+
22
+ cdef nmod_ctx any_as_nmod_poly_ctx(obj):
23
+ """ Convert an int to an nmod_ctx."""
24
+ if typecheck(obj, nmod_poly_ctx):
24
25
return obj
25
- if any_as_nmod(& v, obj, mod):
26
- r = nmod_poly.__new__ (nmod_poly)
27
- nmod_poly_init(r.val, mod.n)
28
- nmod_poly_set_coeff_ui(r.val, 0 , v)
29
- return r
30
- x = any_as_fmpz_poly(obj)
31
- if x is not NotImplemented :
32
- r = nmod_poly.__new__ (nmod_poly)
33
- nmod_poly_init(r.val, mod.n) # XXX: create flint _nmod_poly_set_modulus for this?
34
- fmpz_poly_get_nmod_poly(r.val, (< fmpz_poly> x).val)
35
- return r
26
+ if typecheck(obj, int ):
27
+ ctx = _nmod_poly_ctx_cache.get(obj)
28
+ if ctx is None :
29
+ ctx = nmod_poly_ctx(obj)
30
+ _nmod_poly_ctx_cache[obj] = ctx
31
+ return ctx
36
32
return NotImplemented
37
33
38
- cdef nmod_poly_set_list(nmod_poly_t poly, list val):
39
- cdef long i, n
40
- cdef nmod_t mod
41
- cdef mp_limb_t v
42
- nmod_init(& mod, nmod_poly_modulus(poly)) # XXX
43
- n = PyList_GET_SIZE(val)
44
- nmod_poly_fit_length(poly, n)
45
- for i from 0 <= i < n:
46
- c = val[i]
47
- if any_as_nmod(& v, val[i], mod):
48
- nmod_poly_set_coeff_ui(poly, i, v)
49
- else :
50
- raise TypeError (" unsupported coefficient in list" )
34
+
35
+ cdef class nmod_poly_ctx:
36
+ """
37
+ Context object for creating :class:`~.nmod_poly` initalised
38
+ with modulus :math:`N`.
39
+
40
+ >>> nmod_ctx(17)
41
+ nmod_ctx(17)
42
+
43
+ """
44
+ def __init__ (self , mod ):
45
+ cdef mp_limb_t m
46
+ m = mod
47
+ nmod_init(& self .mod, m)
48
+ self .ctx = nmod_ctx(mod)
49
+ self ._is_prime = n_is_prime(m)
50
+
51
+ cdef int any_as_nmod(self , mp_limb_t * val, obj) except - 1 :
52
+ return self .ctx.any_as_nmod(val, obj)
53
+
54
+ cdef any_as_nmod_poly(self , obj):
55
+ cdef nmod_poly r
56
+ cdef mp_limb_t v
57
+ # XXX: should check that modulus is the same here, and not all over the place
58
+ if typecheck(obj, nmod_poly):
59
+ return obj
60
+ if self .ctx.any_as_nmod(& v, obj):
61
+ r = nmod_poly.__new__ (nmod_poly)
62
+ nmod_poly_init(r.val, self .mod.n)
63
+ nmod_poly_set_coeff_ui(r.val, 0 , v)
64
+ return r
65
+ x = any_as_fmpz_poly(obj)
66
+ if x is not NotImplemented :
67
+ r = nmod_poly.__new__ (nmod_poly)
68
+ nmod_poly_init(r.val, self .mod.n) # XXX: create flint _nmod_poly_set_modulus for this?
69
+ fmpz_poly_get_nmod_poly(r.val, (< fmpz_poly> x).val)
70
+ return r
71
+ return NotImplemented
72
+
73
+ cdef nmod_poly_set_list(self , nmod_poly_t poly, list val):
74
+ cdef long i, n
75
+ cdef mp_limb_t v
76
+ n = PyList_GET_SIZE(val)
77
+ nmod_poly_fit_length(poly, n)
78
+ for i from 0 <= i < n:
79
+ c = val[i]
80
+ if self .any_as_nmod(& v, val[i]):
81
+ nmod_poly_set_coeff_ui(poly, i, v)
82
+ else :
83
+ raise TypeError (" unsupported coefficient in list" )
84
+
51
85
52
86
cdef class nmod_poly(flint_poly):
53
87
"""
@@ -79,24 +113,32 @@ cdef class nmod_poly(flint_poly):
79
113
def __dealloc__ (self ):
80
114
nmod_poly_clear(self .val)
81
115
82
- def __init__ (self , val = None , ulong mod = 0 ):
116
+ def __init__ (self , val = None , mod = 0 ):
83
117
cdef ulong m2
84
118
cdef mp_limb_t v
119
+ cdef nmod_poly_ctx ctx
120
+
85
121
if typecheck(val, nmod_poly):
86
122
m2 = nmod_poly_modulus((< nmod_poly> val).val)
87
123
if m2 != mod:
88
124
raise ValueError (" different moduli!" )
89
125
nmod_poly_init(self .val, m2)
90
126
nmod_poly_set(self .val, (< nmod_poly> val).val)
127
+ self .ctx = (< nmod_poly> val).ctx
91
128
else :
92
129
if mod == 0 :
93
130
raise ValueError (" a nonzero modulus is required" )
94
- nmod_poly_init(self .val, mod)
131
+ ctx = any_as_nmod_poly_ctx(mod)
132
+ if ctx is NotImplemented :
133
+ raise TypeError (" cannot create nmod_poly_ctx from input of type %s " , type (mod))
134
+
135
+ self .ctx = ctx
136
+ nmod_poly_init(self .val, ctx.mod.n)
95
137
if typecheck(val, fmpz_poly):
96
138
fmpz_poly_get_nmod_poly(self .val, (< fmpz_poly> val).val)
97
139
elif typecheck(val, list ):
98
- nmod_poly_set_list(self .val, val)
99
- elif any_as_nmod(& v, val, self .val.mod ):
140
+ ctx. nmod_poly_set_list(self .val, val)
141
+ elif ctx. any_as_nmod(& v, val):
100
142
nmod_poly_fit_length(self .val, 1 )
101
143
nmod_poly_set_coeff_ui(self .val, 0 , v)
102
144
else :
@@ -178,7 +220,7 @@ cdef class nmod_poly(flint_poly):
178
220
cdef mp_limb_t v
179
221
if i < 0 :
180
222
raise ValueError (" cannot assign to index < 0 of polynomial" )
181
- if any_as_nmod(& v, x, self .val.mod ):
223
+ if self .ctx. any_as_nmod(& v, x):
182
224
nmod_poly_set_coeff_ui(self .val, i, v)
183
225
else :
184
226
raise TypeError (" cannot set element of type %s " % type (x))
@@ -291,7 +333,7 @@ cdef class nmod_poly(flint_poly):
291
333
9*x^4 + 12*x^3 + 10*x^2 + 4*x + 1
292
334
"""
293
335
cdef nmod_poly res
294
- other = any_as_nmod_poly(other, ( < nmod_poly > self ).val.mod )
336
+ other = self .ctx.any_as_nmod_poly(other )
295
337
if other is NotImplemented :
296
338
raise TypeError (" cannot convert input to nmod_poly" )
297
339
res = nmod_poly.__new__ (nmod_poly)
@@ -316,11 +358,11 @@ cdef class nmod_poly(flint_poly):
316
358
147*x^3 + 159*x^2 + 4*x + 7
317
359
"""
318
360
cdef nmod_poly res
319
- g = any_as_nmod_poly(other, self .val.mod )
361
+ g = self .ctx.any_as_nmod_poly(other )
320
362
if g is NotImplemented :
321
363
raise TypeError (f" cannot convert {other = } to nmod_poly" )
322
364
323
- h = any_as_nmod_poly(modulus, self .val.mod )
365
+ h = self . any_as_nmod_poly(modulus)
324
366
if h is NotImplemented :
325
367
raise TypeError (f" cannot convert {modulus = } to nmod_poly" )
326
368
@@ -334,11 +376,11 @@ cdef class nmod_poly(flint_poly):
334
376
335
377
def __call__ (self , other ):
336
378
cdef mp_limb_t c
337
- if any_as_nmod(& c, other, self .val.mod ):
379
+ if self .ctx. any_as_nmod(& c, other):
338
380
v = nmod(0 , self .modulus())
339
381
(< nmod> v).val = nmod_poly_evaluate_nmod(self .val, c)
340
382
return v
341
- t = any_as_nmod_poly(other, self .val.mod )
383
+ t = self .ctx.any_as_nmod_poly(other )
342
384
if t is not NotImplemented :
343
385
r = nmod_poly.__new__ (nmod_poly)
344
386
nmod_poly_init_preinv((< nmod_poly> r).val, self .val.mod.n, self .val.mod.ninv)
@@ -369,7 +411,7 @@ cdef class nmod_poly(flint_poly):
369
411
370
412
def _add_ (s , t ):
371
413
cdef nmod_poly r
372
- t = any_as_nmod_poly(t, ( < nmod_poly > s).val.mod )
414
+ t = s.ctx.any_as_nmod_poly(t )
373
415
if t is NotImplemented :
374
416
return t
375
417
if (< nmod_poly> s).val.mod.n != (< nmod_poly> t).val.mod.n:
@@ -395,20 +437,20 @@ cdef class nmod_poly(flint_poly):
395
437
return r
396
438
397
439
def __sub__ (s , t ):
398
- t = any_as_nmod_poly(t, ( < nmod_poly > s).val.mod )
440
+ t = s.ctx.any_as_nmod_poly(t )
399
441
if t is NotImplemented :
400
442
return t
401
443
return s._sub_(t)
402
444
403
445
def __rsub__ (s , t ):
404
- t = any_as_nmod_poly(t, ( < nmod_poly > s).val.mod )
446
+ t = s. any_as_nmod_poly(t)
405
447
if t is NotImplemented :
406
448
return t
407
449
return t._sub_(s)
408
450
409
451
def _mul_ (s , t ):
410
452
cdef nmod_poly r
411
- t = any_as_nmod_poly(t, ( < nmod_poly > s).val.mod )
453
+ t = s. any_as_nmod_poly(t)
412
454
if t is NotImplemented :
413
455
return t
414
456
if (< nmod_poly> s).val.mod.n != (< nmod_poly> t).val.mod.n:
@@ -425,7 +467,7 @@ cdef class nmod_poly(flint_poly):
425
467
return s._mul_(t)
426
468
427
469
def __truediv__ (s , t ):
428
- t = any_as_nmod_poly(t, ( < nmod_poly > s).val.mod )
470
+ t = s. any_as_nmod_poly(t)
429
471
if t is NotImplemented :
430
472
return t
431
473
res, r = s._divmod_(t)
@@ -434,7 +476,7 @@ cdef class nmod_poly(flint_poly):
434
476
return res
435
477
436
478
def __rtruediv__ (s , t ):
437
- t = any_as_nmod_poly(t, ( < nmod_poly > s).val.mod )
479
+ t = s. any_as_nmod_poly(t)
438
480
if t is NotImplemented :
439
481
return t
440
482
res, r = t._divmod_(s)
@@ -454,13 +496,13 @@ cdef class nmod_poly(flint_poly):
454
496
return r
455
497
456
498
def __floordiv__ (s , t ):
457
- t = any_as_nmod_poly(t, ( < nmod_poly > s).val.mod )
499
+ t = s. any_as_nmod_poly(t)
458
500
if t is NotImplemented :
459
501
return t
460
502
return s._floordiv_(t)
461
503
462
504
def __rfloordiv__ (s , t ):
463
- t = any_as_nmod_poly(t, ( < nmod_poly > s).val.mod )
505
+ t = s. any_as_nmod_poly(t)
464
506
if t is NotImplemented :
465
507
return t
466
508
return t._floordiv_(s)
@@ -479,13 +521,13 @@ cdef class nmod_poly(flint_poly):
479
521
return P, Q
480
522
481
523
def __divmod__ (s , t ):
482
- t = any_as_nmod_poly(t, ( < nmod_poly > s).val.mod )
524
+ t = s. any_as_nmod_poly(t)
483
525
if t is NotImplemented :
484
526
return t
485
527
return s._divmod_(t)
486
528
487
529
def __rdivmod__ (s , t ):
488
- t = any_as_nmod_poly(t, ( < nmod_poly > s).val.mod )
530
+ t = s. any_as_nmod_poly(t)
489
531
if t is NotImplemented :
490
532
return t
491
533
return t._divmod_(s)
@@ -534,7 +576,7 @@ cdef class nmod_poly(flint_poly):
534
576
if e < 0 :
535
577
raise ValueError (" Exponent must be non-negative" )
536
578
537
- modulus = any_as_nmod_poly(modulus, ( < nmod_poly > self ).val.mod )
579
+ modulus = self .ctx.any_as_nmod_poly(modulus )
538
580
if modulus is NotImplemented :
539
581
raise TypeError (" cannot convert input to nmod_poly" )
540
582
@@ -556,7 +598,7 @@ cdef class nmod_poly(flint_poly):
556
598
557
599
# To optimise powering, we precompute the inverse of the reverse of the modulus
558
600
if mod_rev_inv is not None :
559
- mod_rev_inv = any_as_nmod_poly(mod_rev_inv, ( < nmod_poly > self ).val.mod )
601
+ mod_rev_inv = self . any_as_nmod_poly(mod_rev_inv)
560
602
if mod_rev_inv is NotImplemented :
561
603
raise TypeError (f" Cannot interpret {mod_rev_inv} as a polynomial" )
562
604
else :
@@ -585,7 +627,7 @@ cdef class nmod_poly(flint_poly):
585
627
586
628
"""
587
629
cdef nmod_poly res
588
- other = any_as_nmod_poly(other, ( < nmod_poly > self ).val.mod )
630
+ other = self . any_as_nmod_poly(other)
589
631
if other is NotImplemented :
590
632
raise TypeError (" cannot convert input to nmod_poly" )
591
633
if self .val.mod.n != (< nmod_poly> other).val.mod.n:
@@ -597,7 +639,7 @@ cdef class nmod_poly(flint_poly):
597
639
598
640
def xgcd (self , other ):
599
641
cdef nmod_poly res1, res2, res3
600
- other = any_as_nmod_poly(other, ( < nmod_poly > self ).val.mod )
642
+ other = self . any_as_nmod_poly(other)
601
643
if other is NotImplemented :
602
644
raise TypeError (" cannot convert input to fmpq_poly" )
603
645
res1 = nmod_poly.__new__ (nmod_poly)
0 commit comments