Skip to content

Commit 09c98e4

Browse files
belm0Stefan Krah
and
Stefan Krah
authored
[3.12] gh-114563: C decimal falls back to pydecimal for unsupported format strings (GH-114879) (GH-115353)
Immediate merits: * eliminate complex workarounds for 'z' format support (NOTE: mpdecimal recently added 'z' support, so this becomes efficient in the long term.) * fix 'z' format memory leak * fix 'z' format applied to 'F' * fix missing '#' format support Suggested and prototyped by Stefan Krah. Fixes gh-114563, gh-91060 (cherry picked from commit 72340d1) Co-authored-by: John Belmonte <[email protected]> Co-authored-by: Stefan Krah <[email protected]>
1 parent 2ed47d8 commit 09c98e4

File tree

4 files changed

+88
-122
lines changed

4 files changed

+88
-122
lines changed

Lib/test/test_decimal.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,6 +1121,13 @@ def test_formatting(self):
11211121
('z>z6.1f', '-0.', 'zzz0.0'),
11221122
('x>z6.1f', '-0.', 'xxx0.0'),
11231123
('🖤>z6.1f', '-0.', '🖤🖤🖤0.0'), # multi-byte fill char
1124+
('\x00>z6.1f', '-0.', '\x00\x00\x000.0'), # null fill char
1125+
1126+
# issue 114563 ('z' format on F type in cdecimal)
1127+
('z3,.10F', '-6.24E-323', '0.0000000000'),
1128+
1129+
# issue 91060 ('#' format in cdecimal)
1130+
('#', '0', '0.'),
11241131

11251132
# issue 6850
11261133
('a=-7.0', '0.12345', 'aaaa0.1'),
@@ -5712,6 +5719,21 @@ def test_c_signaldict_segfault(self):
57125719
with self.assertRaisesRegex(ValueError, err_msg):
57135720
sd.copy()
57145721

5722+
def test_format_fallback_capitals(self):
5723+
# Fallback to _pydecimal formatting (triggered by `#` format which
5724+
# is unsupported by mpdecimal) should honor the current context.
5725+
x = C.Decimal('6.09e+23')
5726+
self.assertEqual(format(x, '#'), '6.09E+23')
5727+
with C.localcontext(capitals=0):
5728+
self.assertEqual(format(x, '#'), '6.09e+23')
5729+
5730+
def test_format_fallback_rounding(self):
5731+
y = C.Decimal('6.09')
5732+
self.assertEqual(format(y, '#.1f'), '6.1')
5733+
with C.localcontext(rounding=C.ROUND_DOWN):
5734+
self.assertEqual(format(y, '#.1f'), '6.0')
5735+
5736+
57155737
@requires_docstrings
57165738
@requires_cdecimal
57175739
class SignatureTest(unittest.TestCase):
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
Fix several :func:`format()` bugs when using the C implementation of :class:`~decimal.Decimal`:
2+
* memory leak in some rare cases when using the ``z`` format option (coerce negative 0)
3+
* incorrect output when applying the ``z`` format option to type ``F`` (fixed-point with capital ``NAN`` / ``INF``)
4+
* incorrect output when applying the ``#`` format option (alternate form)

Modules/_decimal/_decimal.c

Lines changed: 61 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ static PyObject *default_context_template = NULL;
143143
static PyObject *basic_context_template = NULL;
144144
static PyObject *extended_context_template = NULL;
145145

146+
/* Invariant: NULL or pointer to _pydecimal.Decimal */
147+
static PyObject *PyDecimal = NULL;
146148

147149
/* Error codes for functions that return signals or conditions */
148150
#define DEC_INVALID_SIGNALS (MPD_Max_status+1U)
@@ -3219,56 +3221,6 @@ dotsep_as_utf8(const char *s)
32193221
return utf8;
32203222
}
32213223

3222-
/* copy of libmpdec _mpd_round() */
3223-
static void
3224-
_mpd_round(mpd_t *result, const mpd_t *a, mpd_ssize_t prec,
3225-
const mpd_context_t *ctx, uint32_t *status)
3226-
{
3227-
mpd_ssize_t exp = a->exp + a->digits - prec;
3228-
3229-
if (prec <= 0) {
3230-
mpd_seterror(result, MPD_Invalid_operation, status);
3231-
return;
3232-
}
3233-
if (mpd_isspecial(a) || mpd_iszero(a)) {
3234-
mpd_qcopy(result, a, status);
3235-
return;
3236-
}
3237-
3238-
mpd_qrescale_fmt(result, a, exp, ctx, status);
3239-
if (result->digits > prec) {
3240-
mpd_qrescale_fmt(result, result, exp+1, ctx, status);
3241-
}
3242-
}
3243-
3244-
/* Locate negative zero "z" option within a UTF-8 format spec string.
3245-
* Returns pointer to "z", else NULL.
3246-
* The portion of the spec we're working with is [[fill]align][sign][z] */
3247-
static const char *
3248-
format_spec_z_search(char const *fmt, Py_ssize_t size) {
3249-
char const *pos = fmt;
3250-
char const *fmt_end = fmt + size;
3251-
/* skip over [[fill]align] (fill may be multi-byte character) */
3252-
pos += 1;
3253-
while (pos < fmt_end && *pos & 0x80) {
3254-
pos += 1;
3255-
}
3256-
if (pos < fmt_end && strchr("<>=^", *pos) != NULL) {
3257-
pos += 1;
3258-
} else {
3259-
/* fill not present-- skip over [align] */
3260-
pos = fmt;
3261-
if (pos < fmt_end && strchr("<>=^", *pos) != NULL) {
3262-
pos += 1;
3263-
}
3264-
}
3265-
/* skip over [sign] */
3266-
if (pos < fmt_end && strchr("+- ", *pos) != NULL) {
3267-
pos += 1;
3268-
}
3269-
return pos < fmt_end && *pos == 'z' ? pos : NULL;
3270-
}
3271-
32723224
static int
32733225
dict_get_item_string(PyObject *dict, const char *key, PyObject **valueobj, const char **valuestr)
32743226
{
@@ -3294,6 +3246,48 @@ dict_get_item_string(PyObject *dict, const char *key, PyObject **valueobj, const
32943246
return 0;
32953247
}
32963248

3249+
/*
3250+
* Fallback _pydecimal formatting for new format specifiers that mpdecimal does
3251+
* not yet support. As documented, libmpdec follows the PEP-3101 format language:
3252+
* https://www.bytereef.org/mpdecimal/doc/libmpdec/assign-convert.html#to-string
3253+
*/
3254+
static PyObject *
3255+
pydec_format(PyObject *dec, PyObject *context, PyObject *fmt)
3256+
{
3257+
PyObject *result;
3258+
PyObject *pydec;
3259+
PyObject *u;
3260+
3261+
if (PyDecimal == NULL) {
3262+
PyDecimal = _PyImport_GetModuleAttrString("_pydecimal", "Decimal");
3263+
if (PyDecimal == NULL) {
3264+
return NULL;
3265+
}
3266+
}
3267+
3268+
u = dec_str(dec);
3269+
if (u == NULL) {
3270+
return NULL;
3271+
}
3272+
3273+
pydec = PyObject_CallOneArg(PyDecimal, u);
3274+
Py_DECREF(u);
3275+
if (pydec == NULL) {
3276+
return NULL;
3277+
}
3278+
3279+
result = PyObject_CallMethod(pydec, "__format__", "(OO)", fmt, context);
3280+
Py_DECREF(pydec);
3281+
3282+
if (result == NULL && PyErr_ExceptionMatches(PyExc_ValueError)) {
3283+
/* Do not confuse users with the _pydecimal exception */
3284+
PyErr_Clear();
3285+
PyErr_SetString(PyExc_ValueError, "invalid format string");
3286+
}
3287+
3288+
return result;
3289+
}
3290+
32973291
/* Formatted representation of a PyDecObject. */
32983292
static PyObject *
32993293
dec_format(PyObject *dec, PyObject *args)
@@ -3306,16 +3300,11 @@ dec_format(PyObject *dec, PyObject *args)
33063300
PyObject *fmtarg;
33073301
PyObject *context;
33083302
mpd_spec_t spec;
3309-
char const *fmt;
3310-
char *fmt_copy = NULL;
3303+
char *fmt;
33113304
char *decstring = NULL;
33123305
uint32_t status = 0;
33133306
int replace_fillchar = 0;
3314-
int no_neg_0 = 0;
33153307
Py_ssize_t size;
3316-
mpd_t *mpd = MPD(dec);
3317-
mpd_uint_t dt[MPD_MINALLOC_MAX];
3318-
mpd_t tmp = {MPD_STATIC|MPD_STATIC_DATA,0,0,0,MPD_MINALLOC_MAX,dt};
33193308

33203309

33213310
CURRENT_CONTEXT(context);
@@ -3324,39 +3313,20 @@ dec_format(PyObject *dec, PyObject *args)
33243313
}
33253314

33263315
if (PyUnicode_Check(fmtarg)) {
3327-
fmt = PyUnicode_AsUTF8AndSize(fmtarg, &size);
3316+
fmt = (char *)PyUnicode_AsUTF8AndSize(fmtarg, &size);
33283317
if (fmt == NULL) {
33293318
return NULL;
33303319
}
3331-
/* NOTE: If https://github.com/python/cpython/pull/29438 lands, the
3332-
* format string manipulation below can be eliminated by enhancing
3333-
* the forked mpd_parse_fmt_str(). */
3320+
33343321
if (size > 0 && fmt[0] == '\0') {
33353322
/* NUL fill character: must be replaced with a valid UTF-8 char
33363323
before calling mpd_parse_fmt_str(). */
33373324
replace_fillchar = 1;
3338-
fmt = fmt_copy = dec_strdup(fmt, size);
3339-
if (fmt_copy == NULL) {
3325+
fmt = dec_strdup(fmt, size);
3326+
if (fmt == NULL) {
33403327
return NULL;
33413328
}
3342-
fmt_copy[0] = '_';
3343-
}
3344-
/* Strip 'z' option, which isn't understood by mpd_parse_fmt_str().
3345-
* NOTE: fmt is always null terminated by PyUnicode_AsUTF8AndSize() */
3346-
char const *z_position = format_spec_z_search(fmt, size);
3347-
if (z_position != NULL) {
3348-
no_neg_0 = 1;
3349-
size_t z_index = z_position - fmt;
3350-
if (fmt_copy == NULL) {
3351-
fmt = fmt_copy = dec_strdup(fmt, size);
3352-
if (fmt_copy == NULL) {
3353-
return NULL;
3354-
}
3355-
}
3356-
/* Shift characters (including null terminator) left,
3357-
overwriting the 'z' option. */
3358-
memmove(fmt_copy + z_index, fmt_copy + z_index + 1, size - z_index);
3359-
size -= 1;
3329+
fmt[0] = '_';
33603330
}
33613331
}
33623332
else {
@@ -3366,10 +3336,13 @@ dec_format(PyObject *dec, PyObject *args)
33663336
}
33673337

33683338
if (!mpd_parse_fmt_str(&spec, fmt, CtxCaps(context))) {
3369-
PyErr_SetString(PyExc_ValueError,
3370-
"invalid format string");
3371-
goto finish;
3339+
if (replace_fillchar) {
3340+
PyMem_Free(fmt);
3341+
}
3342+
3343+
return pydec_format(dec, context, fmtarg);
33723344
}
3345+
33733346
if (replace_fillchar) {
33743347
/* In order to avoid clobbering parts of UTF-8 thousands separators or
33753348
decimal points when the substitution is reversed later, the actual
@@ -3422,45 +3395,8 @@ dec_format(PyObject *dec, PyObject *args)
34223395
}
34233396
}
34243397

3425-
if (no_neg_0 && mpd_isnegative(mpd) && !mpd_isspecial(mpd)) {
3426-
/* Round into a temporary (carefully mirroring the rounding
3427-
of mpd_qformat_spec()), and check if the result is negative zero.
3428-
If so, clear the sign and format the resulting positive zero. */
3429-
mpd_ssize_t prec;
3430-
mpd_qcopy(&tmp, mpd, &status);
3431-
if (spec.prec >= 0) {
3432-
switch (spec.type) {
3433-
case 'f':
3434-
mpd_qrescale(&tmp, &tmp, -spec.prec, CTX(context), &status);
3435-
break;
3436-
case '%':
3437-
tmp.exp += 2;
3438-
mpd_qrescale(&tmp, &tmp, -spec.prec, CTX(context), &status);
3439-
break;
3440-
case 'g':
3441-
prec = (spec.prec == 0) ? 1 : spec.prec;
3442-
if (tmp.digits > prec) {
3443-
_mpd_round(&tmp, &tmp, prec, CTX(context), &status);
3444-
}
3445-
break;
3446-
case 'e':
3447-
if (!mpd_iszero(&tmp)) {
3448-
_mpd_round(&tmp, &tmp, spec.prec+1, CTX(context), &status);
3449-
}
3450-
break;
3451-
}
3452-
}
3453-
if (status & MPD_Errors) {
3454-
PyErr_SetString(PyExc_ValueError, "unexpected error when rounding");
3455-
goto finish;
3456-
}
3457-
if (mpd_iszero(&tmp)) {
3458-
mpd_set_positive(&tmp);
3459-
mpd = &tmp;
3460-
}
3461-
}
34623398

3463-
decstring = mpd_qformat_spec(mpd, &spec, CTX(context), &status);
3399+
decstring = mpd_qformat_spec(MPD(dec), &spec, CTX(context), &status);
34643400
if (decstring == NULL) {
34653401
if (status & MPD_Malloc_error) {
34663402
PyErr_NoMemory();
@@ -3483,7 +3419,7 @@ dec_format(PyObject *dec, PyObject *args)
34833419
Py_XDECREF(grouping);
34843420
Py_XDECREF(sep);
34853421
Py_XDECREF(dot);
3486-
if (fmt_copy) PyMem_Free(fmt_copy);
3422+
if (replace_fillchar) PyMem_Free(fmt);
34873423
if (decstring) mpd_free(decstring);
34883424
return result;
34893425
}
@@ -5893,6 +5829,9 @@ PyInit__decimal(void)
58935829
/* Create the module */
58945830
ASSIGN_PTR(m, PyModule_Create(&_decimal_module));
58955831

5832+
/* For format specifiers not yet supported by libmpdec */
5833+
PyDecimal = NULL;
5834+
58965835
/* Add types to the module */
58975836
CHECK_INT(PyModule_AddObjectRef(m, "Decimal", (PyObject *)&PyDec_Type));
58985837
CHECK_INT(PyModule_AddObjectRef(m, "Context", (PyObject *)&PyDecContext_Type));

Tools/c-analyzer/cpython/globals-to-fix.tsv

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,7 @@ Modules/_decimal/_decimal.c - basic_context_template -
436436
Modules/_decimal/_decimal.c - current_context_var -
437437
Modules/_decimal/_decimal.c - default_context_template -
438438
Modules/_decimal/_decimal.c - extended_context_template -
439+
Modules/_decimal/_decimal.c - PyDecimal -
439440
Modules/_decimal/_decimal.c - round_map -
440441
Modules/_decimal/_decimal.c - Rational -
441442
Modules/_decimal/_decimal.c - SignalTuple -

0 commit comments

Comments
 (0)