Skip to content

Commit a65482d

Browse files
committed
fix: respect both mulitple_of and minimum/maximum constraints
Previously, `generate_constrained_number()` would potentially generate invalid numbers when `mulitple_of` is not None and exactly one of either `minimum` or `maximum` is not None, since it would just return `mulitple_of` without respecting the upper or lower bound. This significantly changes the implementation of `generate_constrained_number()` in an attempt to handle this case. We now first check for the presence of `mulitple_of`, and if it is None, then we return early by delegating to the `method` parameter. Otherwise, we first generate a random number with `method`, and then we attempt to find the nearest number that is a proper multiple of `mulitple_of`. Most of the newly added complexity of the function is meant to handle floating-point or Decimal precision issues, and in worst-case scenarios, it may still not be capable of finding a legal value.
1 parent c4e3d91 commit a65482d

File tree

3 files changed

+128
-24
lines changed

3 files changed

+128
-24
lines changed

polyfactory/value_generators/constrained_numbers.py

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from __future__ import annotations
22

3+
import decimal
34
from decimal import Decimal
5+
from math import ceil, floor, isinf
46
from sys import float_info
57
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast
68

@@ -227,16 +229,62 @@ def generate_constrained_number(
227229
228230
:returns: A value of type T.
229231
"""
230-
if minimum is None or maximum is None:
231-
return multiple_of if multiple_of is not None else method(random=random)
232232
if multiple_of is None:
233233
return method(random=random, minimum=minimum, maximum=maximum)
234-
if multiple_of >= minimum:
234+
235+
def passes_all_constraints(value: T) -> bool:
236+
return (
237+
(minimum is None or value >= minimum)
238+
and (maximum is None or value <= maximum)
239+
and (multiple_of is None or passes_pydantic_multiple_validator(value, multiple_of))
240+
)
241+
242+
# If the arguments are Decimals, they might have precision that is greater than the current decimal context. If
243+
# so, recreate them under the current context to ensure they have the appropriate precision. This is important
244+
# because otherwise, x * 1 == x may not always hold, which can cause the algorithm below to fail in unintuitive
245+
# ways.
246+
if isinstance(minimum, Decimal):
247+
minimum = decimal.getcontext().create_decimal(minimum)
248+
if isinstance(maximum, Decimal):
249+
maximum = decimal.getcontext().create_decimal(maximum)
250+
if isinstance(multiple_of, Decimal):
251+
multiple_of = decimal.getcontext().create_decimal(multiple_of)
252+
253+
max_attempts = 10
254+
for _ in range(max_attempts):
255+
# We attempt to generate a random number and find the nearest valid multiple, but a naive approach of rounding
256+
# to the nearest multiple may push the number out of range. To handle edge cases, we find both the nearest
257+
# multiple in both the negative and positive directions (floor and ceil), and we pick one that fits within
258+
# range. We should be guaranteed to find a number other than in the case where the range (minimum, maximum) is
259+
# narrow and does not contain any multiple of multiple_of.
260+
random_value = method(random=random, minimum=minimum, maximum=maximum)
261+
quotient = random_value / multiple_of
262+
if isinf(quotient):
263+
continue
264+
lower = floor(quotient) * multiple_of
265+
upper = ceil(quotient) * multiple_of
266+
267+
# If both the lower and upper candidates are out of bounds, then there are no valid multiples that fit within
268+
# the specified range.
269+
if minimum is not None and maximum is not None and lower < minimum and upper > maximum:
270+
msg = f"no multiple of {multiple_of} exists between {minimum} and {maximum}"
271+
raise ParameterException(msg)
272+
273+
for candidate in [lower, upper]:
274+
if not passes_all_constraints(candidate):
275+
continue
276+
return candidate
277+
278+
# Try last-ditch attempt at using the multiple_of, 0, or -multiple_of as the value
279+
if passes_all_constraints(multiple_of):
235280
return multiple_of
236-
result = minimum
237-
while not passes_pydantic_multiple_validator(result, multiple_of):
238-
result = round(method(random=random, minimum=minimum, maximum=maximum) / multiple_of) * multiple_of
239-
return result
281+
if passes_all_constraints(-multiple_of):
282+
return -multiple_of
283+
if passes_all_constraints(multiple_of * 0):
284+
return multiple_of * 0
285+
286+
msg = f"could not find solution in {max_attempts} attempts"
287+
raise ValueError(msg)
240288

241289

242290
def handle_constrained_int(

tests/constraints/test_decimal_constraints.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Optional, cast
44

55
import pytest
6-
from hypothesis import given
6+
from hypothesis import assume, given
77
from hypothesis.strategies import decimals, integers
88

99
from pydantic import BaseModel, condecimal
@@ -239,19 +239,23 @@ def test_handle_constrained_decimal_handles_multiple_of_with_le(val1: Decimal, v
239239
decimals(
240240
allow_nan=False,
241241
allow_infinity=False,
242-
min_value=-1000000000,
243-
max_value=1000000000,
242+
min_value=-100000000,
243+
max_value=100000000,
244244
),
245245
decimals(
246246
allow_nan=False,
247247
allow_infinity=False,
248-
min_value=-1000000000,
249-
max_value=1000000000,
248+
min_value=-100000000,
249+
max_value=100000000,
250250
),
251251
)
252252
def test_handle_constrained_decimal_handles_multiple_of_with_ge(val1: Decimal, val2: Decimal) -> None:
253253
min_value, multiple_of = sorted([val1, val2])
254254
if multiple_of != Decimal("0"):
255+
# When multiple_of is too many orders of magnitude smaller than min_value, then floating-point precision issues
256+
# prevent us from constructing a number that can pass passes_pydantic_multiple_validator(). This scenario is
257+
# very unlikely to occur in practice, so we tell Hypothesis to not generate these cases.
258+
assume(abs(min_value / multiple_of) < Decimal("1e8"))
255259
result = handle_constrained_decimal(
256260
random=Random(),
257261
multiple_of=multiple_of,
@@ -267,23 +271,37 @@ def test_handle_constrained_decimal_handles_multiple_of_with_ge(val1: Decimal, v
267271
)
268272

269273

274+
# Note: The magnitudes of the min and max values have been specifically chosen to avoid issues with floating-point
275+
# rounding errors. Despite these tests using Decimal numbers, the function under test will convert them to floats when
276+
# calling `passes_pydantic_multiple_validator()`. Because `passes_pydantic_multiple_validator()` uses the modulus
277+
# operator (%) with a fixed modulo of 1.0, we actually have to care about the absolute rounding error, not the relative
278+
# error. IEEE 754 double-precision floating-point numbers are guaranteed to have at least 15 decimal digits of
279+
# significand and up to 17 decimal digits of significant. `passes_pydantic_multiple_validator()` requires that the
280+
# remainder modulo 1.0 be within 1e-8 of 0.0 or 1.0. Therefore, we can support a maximum value of approximately 10**(15
281+
# - 8) = 10**7. We have some probabilistic buffer, so can set a maximum value of 10**8 and expect the tests to pass with
282+
# reasonable confidence.
270283
@given(
271284
decimals(
272285
allow_nan=False,
273286
allow_infinity=False,
274-
min_value=-1000000000,
275-
max_value=1000000000,
287+
min_value=-100000000,
288+
max_value=100000000,
276289
),
277290
decimals(
278291
allow_nan=False,
279292
allow_infinity=False,
280-
min_value=-1000000000,
281-
max_value=1000000000,
293+
min_value=-100000000,
294+
max_value=100000000,
282295
),
283296
)
284297
def test_handle_constrained_decimal_handles_multiple_of_with_gt(val1: Decimal, val2: Decimal) -> None:
285298
min_value, multiple_of = sorted([val1, val2])
286299
if multiple_of != Decimal("0"):
300+
# Despite the note above about choosing a max_value to avoid _absolute_ rounding errors, we also have to worry
301+
# about _relative_ rounding errors between min_value and multiple_of. Once again,
302+
# `passes_pydantic_multiple_validator()` requires that the remainder be no greater than 1e-8, so we tell
303+
# Hypothesis not to generate cases where the min_value and multiple_of have a ratio greater than that.
304+
assume(abs(min_value / multiple_of) < Decimal("1e8"))
287305
result = handle_constrained_decimal(
288306
random=Random(),
289307
multiple_of=multiple_of,

tests/test_number_generation.py

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from decimal import Decimal, localcontext
12
from random import Random
23

34
import pytest
@@ -6,24 +7,61 @@
67
generate_constrained_number,
78
passes_pydantic_multiple_validator,
89
)
9-
from polyfactory.value_generators.primitives import create_random_float
10+
from polyfactory.value_generators.primitives import create_random_decimal, create_random_float
1011

1112

1213
@pytest.mark.parametrize(
1314
("maximum", "minimum", "multiple_of"),
14-
((100, 2, 8), (-100, -187, -10), (7.55, 0.13, 0.0123)),
15+
(
16+
(100, 2, 8),
17+
(-100, -187, -10),
18+
(7.55, 0.13, 0.0123),
19+
(None, 10, 3),
20+
(None, -10, 3),
21+
(13, 2, None),
22+
(50, None, 7),
23+
(-50, None, 7),
24+
(None, None, 4),
25+
(900, None, 1000),
26+
),
1527
)
16-
def test_generate_constrained_number(maximum: float, minimum: float, multiple_of: float) -> None:
17-
assert passes_pydantic_multiple_validator(
28+
def test_generate_constrained_number(maximum: float | None, minimum: float | None, multiple_of: float | None) -> None:
29+
value = generate_constrained_number(
30+
random=Random(),
31+
minimum=minimum,
32+
maximum=maximum,
1833
multiple_of=multiple_of,
19-
value=generate_constrained_number(
34+
method=create_random_float,
35+
)
36+
if maximum is not None:
37+
assert value <= maximum
38+
if minimum is not None:
39+
assert value >= minimum
40+
if multiple_of is not None:
41+
assert passes_pydantic_multiple_validator(multiple_of=multiple_of, value=value)
42+
43+
44+
def test_generate_constrained_number_with_overprecise_decimals() -> None:
45+
minimum = Decimal("1.0005")
46+
maximum = Decimal("2")
47+
multiple_of = Decimal("1.0005")
48+
49+
with localcontext() as ctx:
50+
ctx.prec = 3
51+
52+
value = generate_constrained_number(
2053
random=Random(),
2154
minimum=minimum,
2255
maximum=maximum,
2356
multiple_of=multiple_of,
24-
method=create_random_float,
25-
),
26-
)
57+
method=create_random_decimal,
58+
)
59+
if maximum is not None:
60+
assert value <= ctx.create_decimal(maximum)
61+
if minimum is not None:
62+
assert value >= ctx.create_decimal(minimum)
63+
if multiple_of is not None:
64+
assert passes_pydantic_multiple_validator(multiple_of=ctx.create_decimal(multiple_of), value=value)
2765

2866

2967
def test_passes_pydantic_multiple_validator_handles_zero_multiplier() -> None:

0 commit comments

Comments
 (0)