@@ -11,6 +11,9 @@ from flint.utils.typecheck cimport typecheck
1111cimport libc.stdlib
1212
1313from typing import Optional
14+ from flint.utils.flint_exceptions import IncompatibleContextError
15+
16+ from flint.types.fmpz cimport fmpz, any_as_fmpz
1417
1518
1619FLINT_BITS = _FLINT_BITS
@@ -38,7 +41,7 @@ cdef class flint_scalar(flint_elem):
3841 def is_zero (self ):
3942 return False
4043
41- def _any_as_self (self ):
44+ def _any_as_self (self , other ):
4245 return NotImplemented
4346
4447 def _neg_ (self ):
@@ -329,14 +332,13 @@ cdef class flint_mpoly_context(flint_elem):
329332 return nametup
330333
331334 @classmethod
332- def get_context (cls , slong nvars = 1 , ordering = Ordering.lex, names: Optional[str] = "x", nametup: Optional[tuple] = None ):
335+ def create_context_key (cls , slong nvars = 1 , ordering = Ordering.lex, names: Optional[str] = "x", nametup: Optional[tuple] = None ):
333336 """
334- Retrieve a context via the number of variables, `nvars`, the ordering, `ordering`, and either a variable
335- name string, `names`, or a tuple of variable names, `nametup` .
337+ Create a key for the context cache via the number of variables, the ordering, and
338+ either a variable name string, or a tuple of variable names.
336339 """
337-
338340 # A type hint of `ordering: Ordering` results in the error "TypeError: an integer is required" if a Ordering
339- # object is not provided. This is pretty obtuse so we check it's type ourselves
341+ # object is not provided. This is pretty obtuse so we check its type ourselves
340342 if not isinstance (ordering, Ordering):
341343 raise TypeError (f" `ordering` ('{ordering}') is not an instance of flint.Ordering" )
342344
@@ -346,6 +348,15 @@ cdef class flint_mpoly_context(flint_elem):
346348 key = nvars, ordering, cls .create_variable_names(nvars, names)
347349 else :
348350 raise ValueError (" must provide either `names` or `nametup`" )
351+ return key
352+
353+ @classmethod
354+ def get_context (cls , *args , **kwargs ):
355+ """
356+ Retrieve a context via the number of variables, `nvars`, the ordering, `ordering`, and either a variable
357+ name string, `names`, or a tuple of variable names, `nametup`.
358+ """
359+ key = cls .create_context_key(* args, ** kwargs)
349360
350361 ctx = cls ._ctx_cache.get(key)
351362 if ctx is None :
@@ -361,6 +372,18 @@ cdef class flint_mpoly_context(flint_elem):
361372 nametup = ctx.names()
362373 )
363374
375+ def any_as_scalar (self , other ):
376+ raise NotImplementedError (" abstract method" )
377+
378+ def scalar_as_mpoly (self , other ):
379+ raise NotImplementedError (" abstract method" )
380+
381+ def compatible_context_check (self , other ):
382+ if not typecheck(other, type (self )):
383+ raise TypeError (f" type {type(other)} is not {type(self)}" )
384+ elif other is not self :
385+ raise IncompatibleContextError(f" {other} is not {self}" )
386+
364387
365388cdef class flint_mpoly(flint_elem):
366389 """
@@ -373,6 +396,281 @@ cdef class flint_mpoly(flint_elem):
373396 def to_dict (self ):
374397 return {self .monomial(i): self .coefficient(i) for i in range (len (self ))}
375398
399+ def _division_check (self , other ):
400+ if not other:
401+ raise ZeroDivisionError (" nmod_mpoly division by zero" )
402+
403+ def _add_scalar_ (self , other ):
404+ return NotImplemented
405+
406+ def _add_mpoly_ (self , other ):
407+ return NotImplemented
408+
409+ def _iadd_scalar_ (self , other ):
410+ return NotImplemented
411+
412+ def _iadd_mpoly_ (self , other ):
413+ return NotImplemented
414+
415+ def _sub_scalar_ (self , other ):
416+ return NotImplemented
417+
418+ def _sub_mpoly_ (self , other ):
419+ return NotImplemented
420+
421+ def _isub_scalar_ (self , other ):
422+ return NotImplemented
423+
424+ def _isub_mpoly_ (self , other ):
425+ return NotImplemented
426+
427+ def _mul_scalar_ (self , other ):
428+ return NotImplemented
429+
430+ def _imul_mpoly_ (self , other ):
431+ return NotImplemented
432+
433+ def _imul_scalar_ (self , other ):
434+ return NotImplemented
435+
436+ def _mul_mpoly_ (self , other ):
437+ return NotImplemented
438+
439+ def _pow_ (self , other ):
440+ return NotImplemented
441+
442+ def _divmod_mpoly_ (self , other ):
443+ return NotImplemented
444+
445+ def _floordiv_mpoly_ (self , other ):
446+ return NotImplemented
447+
448+ def _truediv_mpoly_ (self , other ):
449+ return NotImplemented
450+
451+ def __add__ (self , other ):
452+ if typecheck(other, type (self )):
453+ self .context().compatible_context_check(other.context())
454+ return self ._add_mpoly_(other)
455+
456+ other = self .context().any_as_scalar(other)
457+ if other is NotImplemented :
458+ return NotImplemented
459+
460+ return self ._add_scalar_(other)
461+
462+ def __radd__ (self , other ):
463+ return self .__add__ (other)
464+
465+ def iadd (self , other ):
466+ """
467+ In-place addition, mutates self.
468+
469+ >>> from flint import Ordering, fmpz_mpoly_ctx
470+ >>> ctx = fmpz_mpoly_ctx.get_context(2, Ordering.lex, 'x')
471+ >>> f = ctx.from_dict({(1, 0): 2, (0, 1): 3, (1, 1): 4})
472+ >>> f
473+ 4*x0*x1 + 2*x0 + 3*x1
474+ >>> f.iadd(5)
475+ >>> f
476+ 4*x0*x1 + 2*x0 + 3*x1 + 5
477+
478+ """
479+ if typecheck(other, type (self )):
480+ self .context().compatible_context_check(other.context())
481+ self ._iadd_mpoly_(other)
482+ return
483+
484+ other_scalar = self .context().any_as_scalar(other)
485+ if other_scalar is NotImplemented :
486+ raise NotImplementedError (f" cannot add {type(self)} and {type(other)}" )
487+
488+ self ._iadd_scalar_(other_scalar)
489+
490+ def __sub__ (self , other ):
491+ if typecheck(other, type (self )):
492+ self .context().compatible_context_check(other.context())
493+ return self ._sub_mpoly_(other)
494+
495+ other = self .context().any_as_scalar(other)
496+ if other is NotImplemented :
497+ return NotImplemented
498+
499+ return self ._sub_scalar_(other)
500+
501+ def __rsub__ (self , other ):
502+ return - self .__sub__ (other)
503+
504+ def isub (self , other ):
505+ """
506+ In-place subtraction, mutates self.
507+
508+ >>> from flint import Ordering, fmpz_mpoly_ctx
509+ >>> ctx = fmpz_mpoly_ctx.get_context(2, Ordering.lex, 'x')
510+ >>> f = ctx.from_dict({(1, 0): 2, (0, 1): 3, (1, 1): 4})
511+ >>> f
512+ 4*x0*x1 + 2*x0 + 3*x1
513+ >>> f.isub(5)
514+ >>> f
515+ 4*x0*x1 + 2*x0 + 3*x1 - 5
516+
517+ """
518+ if typecheck(other, type (self )):
519+ self .context().compatible_context_check(other.context())
520+ self ._isub_mpoly_(other)
521+ return
522+
523+ other_scalar = self .context().any_as_scalar(other)
524+ if other_scalar is NotImplemented :
525+ raise NotImplementedError (f" cannot subtract {type(self)} and {type(other)}" )
526+
527+ self ._isub_scalar_(other_scalar)
528+
529+ def __mul__ (self , other ):
530+ if typecheck(other, type (self )):
531+ self .context().compatible_context_check(other.context())
532+ return self ._mul_mpoly_(other)
533+
534+ other = self .context().any_as_scalar(other)
535+ if other is NotImplemented :
536+ return NotImplemented
537+
538+ return self ._mul_scalar_(other)
539+
540+ def __rmul__ (self , other ):
541+ return self .__mul__ (other)
542+
543+ def imul (self , other ):
544+ """
545+ In-place multiplication, mutates self.
546+
547+ >>> from flint import Ordering, fmpz_mpoly_ctx
548+ >>> ctx = fmpz_mpoly_ctx.get_context(2, Ordering.lex, 'x')
549+ >>> f = ctx.from_dict({(1, 0): 2, (0, 1): 3, (1, 1): 4})
550+ >>> f
551+ 4*x0*x1 + 2*x0 + 3*x1
552+ >>> f.imul(2)
553+ >>> f
554+ 8*x0*x1 + 4*x0 + 6*x1
555+
556+ """
557+ if typecheck(other, type (self )):
558+ self .context().compatible_context_check(other.context())
559+ self ._imul_mpoly_(other)
560+ return
561+
562+ other_scalar = self .context().any_as_scalar(other)
563+ if other_scalar is NotImplemented :
564+ raise NotImplementedError (f" cannot multiply {type(self)} and {type(other)}" )
565+
566+ self ._imul_scalar_(other_scalar)
567+
568+ def __pow__ (self , other , modulus ):
569+ if modulus is not None :
570+ raise NotImplementedError (" cannot specify modulus outside of the context" )
571+ elif typecheck(other, fmpz):
572+ return self ._pow_(other)
573+
574+ other = any_as_fmpz(other)
575+ if other is NotImplemented :
576+ return NotImplemented
577+ elif other < 0 :
578+ raise ValueError (" cannot raise to a negative power" )
579+
580+ return self ._pow_(other)
581+
582+ def __divmod__ (self , other ):
583+ if typecheck(other, type (self )):
584+ self .context().compatible_context_check(other.context())
585+ self ._division_check(other)
586+ return self ._divmod_mpoly_(other)
587+
588+ other = self .context().any_as_scalar(other)
589+ if other is NotImplemented :
590+ return NotImplemented
591+
592+ other = self .context().scalar_as_mpoly(other)
593+ self ._division_check(other)
594+ return self ._divmod_mpoly_(other)
595+
596+ def __rdivmod__ (self , other ):
597+ other = self .context().any_as_scalar(other)
598+ if other is NotImplemented :
599+ return NotImplemented
600+
601+ other = self .context().scalar_as_mpoly(other)
602+ other._division_check(self )
603+ return other._divmod_mpoly_(self )
604+
605+ def __truediv__ (self , other ):
606+ if typecheck(other, type (self )):
607+ self .context().compatible_context_check(other.context())
608+ self ._division_check(other)
609+ return self ._truediv_mpoly_(other)
610+
611+ other = self .context().any_as_scalar(other)
612+ if other is NotImplemented :
613+ return NotImplemented
614+
615+ other = self .context().scalar_as_mpoly(other)
616+ self ._division_check(other)
617+ return self ._truediv_mpoly_(other)
618+
619+ def __rtruediv__ (self , other ):
620+ other = self .context().any_as_scalar(other)
621+ if other is NotImplemented :
622+ return NotImplemented
623+
624+ other = self .context().scalar_as_mpoly(other)
625+ other._division_check(self )
626+ return other._truediv_mpoly_(self )
627+
628+ def __floordiv__ (self , other ):
629+ if typecheck(other, type (self )):
630+ self .context().compatible_context_check(other.context())
631+ self ._division_check(other)
632+ return self ._floordiv_mpoly_(other)
633+
634+ other = self .context().any_as_scalar(other)
635+ if other is NotImplemented :
636+ return NotImplemented
637+
638+ other = self .context().scalar_as_mpoly(other)
639+ self ._division_check(other)
640+ return self ._floordiv_mpoly_(other)
641+
642+ def __rfloordiv__ (self , other ):
643+ other = self .context().any_as_scalar(other)
644+ if other is NotImplemented :
645+ return NotImplemented
646+
647+ other = self .context().scalar_as_mpoly(other)
648+ other._division_check(self )
649+ return other._floordiv_mpoly_(self )
650+
651+ def __mod__ (self , other ):
652+ if typecheck(other, type (self )):
653+ self .context().compatible_context_check(other.context())
654+ self ._division_check(other)
655+ return self ._mod_mpoly_(other)
656+
657+ other = self .context().any_as_scalar(other)
658+ if other is NotImplemented :
659+ return NotImplemented
660+
661+ other = self .context().scalar_as_mpoly(other)
662+ self ._division_check(other)
663+ return self ._mod_mpoly_(other)
664+
665+ def __rmod__ (self , other ):
666+ other = self .context().any_as_scalar(other)
667+ if other is NotImplemented :
668+ return NotImplemented
669+
670+ other = self .context().scalar_as_mpoly(other)
671+ other._division_check(self )
672+ return other._mod_mpoly_(self )
673+
376674 def __contains__ (self , x ):
377675 """
378676 Returns True if `self` contains a term with exponent vector `x` and a non-zero coefficient.
0 commit comments