diff --git a/Modules/_decimal/_decimal.c b/Modules/_decimal/_decimal.c index 5b053c73e20bc9..10acc28ae3d994 100644 --- a/Modules/_decimal/_decimal.c +++ b/Modules/_decimal/_decimal.c @@ -99,6 +99,17 @@ typedef struct { PyCFunction _py_float_as_integer_ratio; } decimal_state; +/* Like PyType_GetModule, but skips verification + * that type is a heap type (TODO: pycore_typeobject.h) */ +static inline PyObject * +_PyType_GetModule(PyTypeObject *type) +{ + assert(PyType_Check(type)); + assert(type->tp_flags & Py_TPFLAGS_HEAPTYPE); + PyHeapTypeObject *et = (PyHeapTypeObject *)type; + return et->ht_module; +} + static inline decimal_state * get_module_state(PyObject *mod) { @@ -112,21 +123,33 @@ static struct PyModuleDef _decimal_module; static inline decimal_state * get_module_state_by_def(PyTypeObject *tp) { - PyObject *mod = PyType_GetModuleByDef(tp, &_decimal_module); - assert(mod != NULL); + PyObject *mod = _PyType_GetModule(tp); + if (mod != NULL) { + assert(_PyModule_GetDef(mod) == &_decimal_module); + } + else { + mod = PyType_GetModuleByDef(tp, &_decimal_module); + assert(mod != NULL); + } return get_module_state(mod); } static inline decimal_state * find_state_left_or_right(PyObject *left, PyObject *right) { - PyObject *mod = PyType_GetModuleByDef(Py_TYPE(left), &_decimal_module); - if (mod == NULL) { + PyTypeObject *tp = Py_TYPE(right); + if (tp->tp_flags & Py_TPFLAGS_HEAPTYPE) { + PyObject *mod = _PyType_GetModule(tp); + if (mod && _PyModule_GetDef(mod) == &_decimal_module) { + return get_module_state(mod); + } + mod = PyType_GetModuleByDef(tp, &_decimal_module); + if (mod != NULL) { + return get_module_state(mod); + } PyErr_Clear(); - mod = PyType_GetModuleByDef(Py_TYPE(right), &_decimal_module); } - assert(mod != NULL); - return get_module_state(mod); + return get_module_state_by_def(Py_TYPE(left)); } @@ -178,6 +201,7 @@ typedef struct PyDecContextObject { PyObject *flags; int capitals; PyThreadState *tstate; + decimal_state *mstate; } PyDecContextObject; typedef struct { @@ -198,6 +222,14 @@ typedef struct { #define CTX(v) (&((PyDecContextObject *)v)->ctx) #define CtxCaps(v) (((PyDecContextObject *)v)->capitals) +static inline decimal_state * +ctx_get_module_state(PyObject *v) { + decimal_state *state = ((PyDecContextObject *)v)->mstate; + assert(state != NULL); + assert(PyDecContext_Check(state, v)); + return state; +} + Py_LOCAL_INLINE(PyObject *) incr_true(void) @@ -552,7 +584,7 @@ static int dec_addstatus(PyObject *context, uint32_t status) { mpd_context_t *ctx = CTX(context); - decimal_state *state = get_module_state_by_def(Py_TYPE(context)); + decimal_state *state = ctx_get_module_state(context); ctx->status |= status; if (status & (ctx->traps|MPD_Malloc_error)) { @@ -634,7 +666,7 @@ signaldict_iter(PyObject *self) if (SdFlagAddr(self) == NULL) { return value_error_ptr(INVALID_SIGNALDICT_ERROR_MSG); } - decimal_state *state = get_module_state_by_def(Py_TYPE(self)); + decimal_state *state = get_module_state_by_def(Py_TYPE(self)->tp_base); return PyTuple_Type.tp_iter(state->SignalTuple); } @@ -645,7 +677,7 @@ signaldict_getitem(PyObject *self, PyObject *key) if (SdFlagAddr(self) == NULL) { return value_error_ptr(INVALID_SIGNALDICT_ERROR_MSG); } - decimal_state *state = get_module_state_by_def(Py_TYPE(self)); + decimal_state *state = get_module_state_by_def(Py_TYPE(self)->tp_base); flag = exception_as_flag(state, key); if (flag & DEC_ERRORS) { @@ -669,7 +701,7 @@ signaldict_setitem(PyObject *self, PyObject *key, PyObject *value) return value_error_int("signal keys cannot be deleted"); } - decimal_state *state = get_module_state_by_def(Py_TYPE(self)); + decimal_state *state = get_module_state_by_def(Py_TYPE(self)->tp_base); flag = exception_as_flag(state, key); if (flag & DEC_ERRORS) { return -1; @@ -720,7 +752,7 @@ signaldict_repr(PyObject *self) assert(SIGNAL_MAP_LEN == 9); - decimal_state *state = get_module_state_by_def(Py_TYPE(self)); + decimal_state *state = get_module_state_by_def(Py_TYPE(self)->tp_base); for (cm=state->signal_map, i=0; cm->name != NULL; cm++, i++) { n[i] = cm->fqname; b[i] = SdFlags(self)&cm->flag ? "True" : "False"; @@ -739,7 +771,7 @@ signaldict_richcompare(PyObject *v, PyObject *w, int op) { PyObject *res = Py_NotImplemented; - decimal_state *state = find_state_left_or_right(v, w); + decimal_state *state = get_module_state_by_def(Py_TYPE(v)->tp_base); assert(PyDecSignalDict_Check(state, v)); if ((SdFlagAddr(v) == NULL) || (SdFlagAddr(w) == NULL)) { @@ -776,7 +808,7 @@ signaldict_copy(PyObject *self, PyObject *args UNUSED) if (SdFlagAddr(self) == NULL) { return value_error_ptr(INVALID_SIGNALDICT_ERROR_MSG); } - decimal_state *state = get_module_state_by_def(Py_TYPE(self)); + decimal_state *state = get_module_state_by_def(Py_TYPE(self)->tp_base); return flags_as_dict(state, SdFlags(self)); } @@ -846,7 +878,7 @@ static PyObject * context_getround(PyObject *self, void *closure UNUSED) { int i = mpd_getround(CTX(self)); - decimal_state *state = get_module_state_by_def(Py_TYPE(self)); + decimal_state *state = ctx_get_module_state(self); return Py_NewRef(state->round_map[i]); } @@ -1005,7 +1037,7 @@ context_setround(PyObject *self, PyObject *value, void *closure UNUSED) mpd_context_t *ctx; int x; - decimal_state *state = get_module_state_by_def(Py_TYPE(self)); + decimal_state *state = ctx_get_module_state(self); x = getround(state, value); if (x == -1) { return -1; @@ -1064,7 +1096,7 @@ context_settraps_list(PyObject *self, PyObject *value) { mpd_context_t *ctx; uint32_t flags; - decimal_state *state = get_module_state_by_def(Py_TYPE(self)); + decimal_state *state = ctx_get_module_state(self); flags = list_as_flags(state, value); if (flags & DEC_ERRORS) { return -1; @@ -1084,7 +1116,7 @@ context_settraps_dict(PyObject *self, PyObject *value) mpd_context_t *ctx; uint32_t flags; - decimal_state *state = get_module_state_by_def(Py_TYPE(self)); + decimal_state *state = ctx_get_module_state(self); if (PyDecSignalDict_Check(state, value)) { flags = SdFlags(value); } @@ -1129,7 +1161,7 @@ context_setstatus_list(PyObject *self, PyObject *value) { mpd_context_t *ctx; uint32_t flags; - decimal_state *state = get_module_state_by_def(Py_TYPE(self)); + decimal_state *state = ctx_get_module_state(self); flags = list_as_flags(state, value); if (flags & DEC_ERRORS) { @@ -1150,7 +1182,7 @@ context_setstatus_dict(PyObject *self, PyObject *value) mpd_context_t *ctx; uint32_t flags; - decimal_state *state = get_module_state_by_def(Py_TYPE(self)); + decimal_state *state = ctx_get_module_state(self); if (PyDecSignalDict_Check(state, value)) { flags = SdFlags(value); } @@ -1379,7 +1411,7 @@ context_new(PyTypeObject *type, PyObject *args UNUSED, PyObject *kwds UNUSED) CtxCaps(self) = 1; self->tstate = NULL; - + ((PyDecContextObject *)self)->mstate = state; return (PyObject *)self; } @@ -1406,7 +1438,7 @@ context_dealloc(PyDecContextObject *self) PyTypeObject *tp = Py_TYPE(self); PyObject_GC_UnTrack(self); #ifndef WITH_DECIMAL_CONTEXTVAR - decimal_state *state = get_module_state_by_def(Py_TYPE(self)); + decimal_state *state = ctx_get_module_state((PyObject *)self); if (self == state->cached_context) { state->cached_context = NULL; } @@ -1458,7 +1490,7 @@ context_repr(PyDecContextObject *self) int n, mem; #ifdef Py_DEBUG - decimal_state *state = get_module_state_by_def(Py_TYPE(self)); + decimal_state *state = ctx_get_module_state((PyObject *)self); assert(PyDecContext_Check(state, self)); #endif ctx = CTX(self); @@ -1509,7 +1541,7 @@ init_extended_context(PyObject *v) #ifdef EXTRA_FUNCTIONALITY /* Factory function for creating IEEE interchange format contexts */ static PyObject * -ieee_context(PyObject *dummy UNUSED, PyObject *v) +ieee_context(PyObject *module, PyObject *v) { PyObject *context; mpd_ssize_t bits; @@ -1526,7 +1558,7 @@ ieee_context(PyObject *dummy UNUSED, PyObject *v) goto error; } - decimal_state *state = get_module_state_by_def(Py_TYPE(v)); + decimal_state *state = get_module_state(module); context = PyObject_CallObject((PyObject *)state->PyDecContext_Type, NULL); if (context == NULL) { return NULL; @@ -1549,7 +1581,7 @@ context_copy(PyObject *self, PyObject *args UNUSED) { PyObject *copy; - decimal_state *state = get_module_state_by_def(Py_TYPE(self)); + decimal_state *state = ctx_get_module_state(self); copy = PyObject_CallObject((PyObject *)state->PyDecContext_Type, NULL); if (copy == NULL) { return NULL; @@ -1569,7 +1601,7 @@ context_reduce(PyObject *self, PyObject *args UNUSED) PyObject *traps; PyObject *ret; mpd_context_t *ctx; - decimal_state *state = get_module_state_by_def(Py_TYPE(self)); + decimal_state *state = ctx_get_module_state(self); ctx = CTX(self); @@ -2004,11 +2036,10 @@ static PyType_Spec ctxmanager_spec = { /******************************************************************************/ static PyObject * -PyDecType_New(PyTypeObject *type) +PyDecType_New(decimal_state *state, PyTypeObject *type) { PyDecObject *dec; - decimal_state *state = get_module_state_by_def(type); if (type == state->PyDec_Type) { dec = PyObject_GC_New(PyDecObject, state->PyDec_Type); } @@ -2030,7 +2061,7 @@ PyDecType_New(PyTypeObject *type) return (PyObject *)dec; } -#define dec_alloc(st) PyDecType_New((st)->PyDec_Type) +#define dec_alloc(st) PyDecType_New(st, (st)->PyDec_Type) static int dec_traverse(PyObject *dec, visitproc visit, void *arg) @@ -2133,7 +2164,8 @@ PyDecType_FromCString(PyTypeObject *type, const char *s, PyObject *dec; uint32_t status = 0; - dec = PyDecType_New(type); + decimal_state *state = ctx_get_module_state(context); + dec = PyDecType_New(state, type); if (dec == NULL) { return NULL; } @@ -2157,7 +2189,8 @@ PyDecType_FromCStringExact(PyTypeObject *type, const char *s, uint32_t status = 0; mpd_context_t maxctx; - dec = PyDecType_New(type); + decimal_state *state = ctx_get_module_state(context); + dec = PyDecType_New(state, type); if (dec == NULL) { return NULL; } @@ -2244,7 +2277,8 @@ PyDecType_FromSsize(PyTypeObject *type, mpd_ssize_t v, PyObject *context) PyObject *dec; uint32_t status = 0; - dec = PyDecType_New(type); + decimal_state *state = ctx_get_module_state(context); + dec = PyDecType_New(state, type); if (dec == NULL) { return NULL; } @@ -2265,7 +2299,8 @@ PyDecType_FromSsizeExact(PyTypeObject *type, mpd_ssize_t v, PyObject *context) uint32_t status = 0; mpd_context_t maxctx; - dec = PyDecType_New(type); + decimal_state *state = ctx_get_module_state(context); + dec = PyDecType_New(state, type); if (dec == NULL) { return NULL; } @@ -2283,13 +2318,13 @@ PyDecType_FromSsizeExact(PyTypeObject *type, mpd_ssize_t v, PyObject *context) /* Convert from a PyLongObject. The context is not modified; flags set during conversion are accumulated in the status parameter. */ static PyObject * -dec_from_long(PyTypeObject *type, PyObject *v, +dec_from_long(decimal_state *state, PyTypeObject *type, PyObject *v, const mpd_context_t *ctx, uint32_t *status) { PyObject *dec; PyLongObject *l = (PyLongObject *)v; - dec = PyDecType_New(type); + dec = PyDecType_New(state, type); if (dec == NULL) { return NULL; } @@ -2334,7 +2369,8 @@ PyDecType_FromLong(PyTypeObject *type, PyObject *v, PyObject *context) return NULL; } - dec = dec_from_long(type, v, CTX(context), &status); + decimal_state *state = ctx_get_module_state(context); + dec = dec_from_long(state, type, v, CTX(context), &status); if (dec == NULL) { return NULL; } @@ -2363,7 +2399,8 @@ PyDecType_FromLongExact(PyTypeObject *type, PyObject *v, } mpd_maxcontext(&maxctx); - dec = dec_from_long(type, v, &maxctx, &status); + decimal_state *state = ctx_get_module_state(context); + dec = dec_from_long(state, type, v, &maxctx, &status); if (dec == NULL) { return NULL; } @@ -2395,7 +2432,7 @@ PyDecType_FromFloatExact(PyTypeObject *type, PyObject *v, mpd_t *d1, *d2; uint32_t status = 0; mpd_context_t maxctx; - decimal_state *state = get_module_state_by_def(type); + decimal_state *state = ctx_get_module_state(context); #ifdef Py_DEBUG assert(PyType_IsSubtype(type, state->PyDec_Type)); @@ -2416,7 +2453,7 @@ PyDecType_FromFloatExact(PyTypeObject *type, PyObject *v, sign = (copysign(1.0, x) == 1.0) ? 0 : 1; if (Py_IS_NAN(x) || Py_IS_INFINITY(x)) { - dec = PyDecType_New(type); + dec = PyDecType_New(state, type); if (dec == NULL) { return NULL; } @@ -2533,12 +2570,12 @@ PyDecType_FromDecimalExact(PyTypeObject *type, PyObject *v, PyObject *context) PyObject *dec; uint32_t status = 0; - decimal_state *state = get_module_state_by_def(type); + decimal_state *state = ctx_get_module_state(context); if (type == state->PyDec_Type && PyDec_CheckExact(state, v)) { return Py_NewRef(v); } - dec = PyDecType_New(type); + dec = PyDecType_New(state, type); if (dec == NULL) { return NULL; } @@ -2822,7 +2859,7 @@ dec_from_float(PyObject *type, PyObject *pyfloat) static PyObject * ctx_from_float(PyObject *context, PyObject *v) { - decimal_state *state = get_module_state_by_def(Py_TYPE(context)); + decimal_state *state = ctx_get_module_state(context); return PyDec_FromFloat(state, v, context); } @@ -2833,7 +2870,7 @@ dec_apply(PyObject *v, PyObject *context) PyObject *result; uint32_t status = 0; - decimal_state *state = get_module_state_by_def(Py_TYPE(context)); + decimal_state *state = ctx_get_module_state(context); result = dec_alloc(state); if (result == NULL) { return NULL; @@ -2860,7 +2897,7 @@ dec_apply(PyObject *v, PyObject *context) static PyObject * PyDecType_FromObjectExact(PyTypeObject *type, PyObject *v, PyObject *context) { - decimal_state *state = get_module_state_by_def(type); + decimal_state *state = ctx_get_module_state(context); if (v == NULL) { return PyDecType_FromSsizeExact(type, 0, context); } @@ -2895,7 +2932,7 @@ PyDecType_FromObjectExact(PyTypeObject *type, PyObject *v, PyObject *context) static PyObject * PyDec_FromObject(PyObject *v, PyObject *context) { - decimal_state *state = get_module_state_by_def(Py_TYPE(context)); + decimal_state *state = ctx_get_module_state(context); if (v == NULL) { return PyDec_FromSsize(state, 0, context); } @@ -2982,7 +3019,7 @@ ctx_create_decimal(PyObject *context, PyObject *args) Py_LOCAL_INLINE(int) convert_op(int type_err, PyObject **conv, PyObject *v, PyObject *context) { - decimal_state *state = get_module_state_by_def(Py_TYPE(context)); + decimal_state *state = ctx_get_module_state(context); if (PyDec_Check(state, v)) { *conv = Py_NewRef(v); return 1; @@ -3085,7 +3122,7 @@ multiply_by_denominator(PyObject *v, PyObject *r, PyObject *context) if (tmp == NULL) { return NULL; } - decimal_state *state = get_module_state_by_def(Py_TYPE(context)); + decimal_state *state = ctx_get_module_state(context); denom = PyDec_FromLongExact(state, tmp, context); Py_DECREF(tmp); if (denom == NULL) { @@ -3140,7 +3177,7 @@ numerator_as_decimal(PyObject *r, PyObject *context) return NULL; } - decimal_state *state = get_module_state_by_def(Py_TYPE(context)); + decimal_state *state = ctx_get_module_state(context); num = PyDec_FromLongExact(state, tmp, context); Py_DECREF(tmp); return num; @@ -3159,7 +3196,7 @@ convert_op_cmp(PyObject **vcmp, PyObject **wcmp, PyObject *v, PyObject *w, *vcmp = v; - decimal_state *state = get_module_state_by_def(Py_TYPE(context)); + decimal_state *state = ctx_get_module_state(context); if (PyDec_Check(state, w)) { *wcmp = Py_NewRef(w); } @@ -4399,12 +4436,11 @@ dec_conjugate(PyObject *self, PyObject *dummy UNUSED) return Py_NewRef(self); } -static PyObject * -dec_mpd_radix(PyObject *self, PyObject *dummy UNUSED) +static inline PyObject * +_dec_mpd_radix(decimal_state *state) { PyObject *result; - decimal_state *state = get_module_state_by_def(Py_TYPE(self)); result = dec_alloc(state); if (result == NULL) { return NULL; @@ -4414,6 +4450,13 @@ dec_mpd_radix(PyObject *self, PyObject *dummy UNUSED) return result; } +static PyObject * +dec_mpd_radix(PyObject *self, PyObject *dummy UNUSED) +{ + decimal_state *state = get_module_state_by_def(Py_TYPE(self)); + return _dec_mpd_radix(state); +} + static PyObject * dec_mpd_qcopy_abs(PyObject *self, PyObject *dummy UNUSED) { @@ -4638,11 +4681,10 @@ dec_richcompare(PyObject *v, PyObject *w, int op) uint32_t status = 0; int a_issnan, b_issnan; int r; - decimal_state *state = find_state_left_or_right(v, w); -#ifdef Py_DEBUG + decimal_state *state = get_module_state_by_def(Py_TYPE(v)); assert(PyDec_Check(state, v)); -#endif + CURRENT_CONTEXT(state, context); CONVERT_BINOP_CMP(&a, &b, v, w, op, context); @@ -5114,8 +5156,7 @@ ctx_##MPDFUNC(PyObject *context, PyObject *v) \ uint32_t status = 0; \ \ CONVERT_OP_RAISE(&a, v, context); \ - decimal_state *state = \ - get_module_state_by_def(Py_TYPE(context)); \ + decimal_state *state = ctx_get_module_state(context); \ if ((result = dec_alloc(state)) == NULL) { \ Py_DECREF(a); \ return NULL; \ @@ -5146,8 +5187,7 @@ ctx_##MPDFUNC(PyObject *context, PyObject *args) \ } \ \ CONVERT_BINOP_RAISE(&a, &b, v, w, context); \ - decimal_state *state = \ - get_module_state_by_def(Py_TYPE(context)); \ + decimal_state *state = ctx_get_module_state(context); \ if ((result = dec_alloc(state)) == NULL) { \ Py_DECREF(a); \ Py_DECREF(b); \ @@ -5182,8 +5222,7 @@ ctx_##MPDFUNC(PyObject *context, PyObject *args) \ } \ \ CONVERT_BINOP_RAISE(&a, &b, v, w, context); \ - decimal_state *state = \ - get_module_state_by_def(Py_TYPE(context)); \ + decimal_state *state = ctx_get_module_state(context); \ if ((result = dec_alloc(state)) == NULL) { \ Py_DECREF(a); \ Py_DECREF(b); \ @@ -5212,7 +5251,7 @@ ctx_##MPDFUNC(PyObject *context, PyObject *args) \ } \ \ CONVERT_TERNOP_RAISE(&a, &b, &c, v, w, x, context); \ - decimal_state *state = get_module_state_by_def(Py_TYPE(context)); \ + decimal_state *state = ctx_get_module_state(context); \ if ((result = dec_alloc(state)) == NULL) { \ Py_DECREF(a); \ Py_DECREF(b); \ @@ -5278,7 +5317,7 @@ ctx_mpd_qdivmod(PyObject *context, PyObject *args) } CONVERT_BINOP_RAISE(&a, &b, v, w, context); - decimal_state *state = get_module_state_by_def(Py_TYPE(context)); + decimal_state *state = ctx_get_module_state(context); q = dec_alloc(state); if (q == NULL) { Py_DECREF(a); @@ -5333,7 +5372,7 @@ ctx_mpd_qpow(PyObject *context, PyObject *args, PyObject *kwds) } } - decimal_state *state = get_module_state_by_def(Py_TYPE(context)); + decimal_state *state = ctx_get_module_state(context); result = dec_alloc(state); if (result == NULL) { Py_DECREF(a); @@ -5368,7 +5407,8 @@ DecCtx_TernaryFunc(mpd_qfma) static PyObject * ctx_mpd_radix(PyObject *context, PyObject *dummy) { - return dec_mpd_radix(context, dummy); + decimal_state *state = ctx_get_module_state(context); + return _dec_mpd_radix(state); } /* Boolean functions: single decimal argument */ @@ -5385,7 +5425,7 @@ DecCtx_BoolFunc_NO_CTX(mpd_iszero) static PyObject * ctx_iscanonical(PyObject *context, PyObject *v) { - decimal_state *state = get_module_state_by_def(Py_TYPE(context)); + decimal_state *state = ctx_get_module_state(context); if (!PyDec_Check(state, v)) { PyErr_SetString(PyExc_TypeError, "argument must be a Decimal"); @@ -5411,7 +5451,7 @@ PyDecContext_Apply(PyObject *context, PyObject *v) static PyObject * ctx_canonical(PyObject *context, PyObject *v) { - decimal_state *state = get_module_state_by_def(Py_TYPE(context)); + decimal_state *state = ctx_get_module_state(context); if (!PyDec_Check(state, v)) { PyErr_SetString(PyExc_TypeError, "argument must be a Decimal"); @@ -5428,7 +5468,7 @@ ctx_mpd_qcopy_abs(PyObject *context, PyObject *v) uint32_t status = 0; CONVERT_OP_RAISE(&a, v, context); - decimal_state *state = get_module_state_by_def(Py_TYPE(context)); + decimal_state *state = ctx_get_module_state(context); result = dec_alloc(state); if (result == NULL) { Py_DECREF(a); @@ -5461,7 +5501,7 @@ ctx_mpd_qcopy_negate(PyObject *context, PyObject *v) uint32_t status = 0; CONVERT_OP_RAISE(&a, v, context); - decimal_state *state = get_module_state_by_def(Py_TYPE(context)); + decimal_state *state = ctx_get_module_state(context); result = dec_alloc(state); if (result == NULL) { Py_DECREF(a); @@ -5558,7 +5598,7 @@ ctx_mpd_qcopy_sign(PyObject *context, PyObject *args) } CONVERT_BINOP_RAISE(&a, &b, v, w, context); - decimal_state *state = get_module_state_by_def(Py_TYPE(context)); + decimal_state *state = ctx_get_module_state(context); result = dec_alloc(state); if (result == NULL) { Py_DECREF(a);