Skip to content

Commit 37c4880

Browse files
committed
Formatting
1 parent f3e03e4 commit 37c4880

File tree

3 files changed

+83
-39
lines changed

3 files changed

+83
-39
lines changed

range_ex/range_regex.py

Lines changed: 68 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
DecimalLike = Union[int, float, Decimal, str]
1212

13+
1314
def _parenthesise(s: str, capture: bool = False) -> str:
1415
open = "(" if capture else "(?:"
1516
return f"{open}{s})"
@@ -22,6 +23,7 @@ def _to_decimal(value: DecimalLike) -> Decimal:
2223
return Decimal(str(value))
2324
return Decimal(value)
2425

26+
2527
class Node(ABC):
2628
@abstractmethod
2729
def render(self, capturing: bool = False) -> str:
@@ -45,7 +47,6 @@ def max_repeats(self) -> Optional[int]:
4547
return 1
4648

4749

48-
4950
@dataclass(frozen=True)
5051
class Empty(Node):
5152
def render(self, capturing: bool = False) -> str:
@@ -124,9 +125,13 @@ def render(self, capturing: bool = False) -> str:
124125

125126
def normalize(self) -> "Node":
126127
# Normalize repeated child
127-
if (self.min_repeats == 0 and self.max_repeats == 0) or (isinstance(self.node, Sequence) and not self.node.parts) or (isinstance(self.node, Empty) and self.min_repeats == 0):
128+
if (
129+
(self.min_repeats == 0 and self.max_repeats == 0)
130+
or (isinstance(self.node, Sequence) and not self.node.parts)
131+
or (isinstance(self.node, Empty) and self.min_repeats == 0)
132+
):
128133
return _seq()
129-
if (isinstance(self.node, Empty) and self.min_repeats > 0):
134+
if isinstance(self.node, Empty) and self.min_repeats > 0:
130135
return Empty()
131136
normalized_child = self.node.normalize()
132137
# (a+)? === a* and (a*)? === a*
@@ -144,14 +149,20 @@ def normalize(self) -> "Node":
144149
# In those cases a{n,}{m} === a{n*m,} and a{n}{m} === a{n*m}
145150
if (
146151
isinstance(normalized_child, FixedRepetition)
147-
and self.min_repeats == self.max_repeats # implies self.max_repeats is not None
148-
and (normalized_child.max_repeats is None or normalized_child.min_repeats == normalized_child.max_repeats)
152+
and self.min_repeats
153+
== self.max_repeats # implies self.max_repeats is not None
154+
and (
155+
normalized_child.max_repeats is None
156+
or normalized_child.min_repeats == normalized_child.max_repeats
157+
)
149158
):
150159
multiplier = self.min_repeats
151160
return FixedRepetition(
152161
normalized_child.node,
153162
normalized_child.min_repeats * multiplier,
154-
normalized_child.max_repeats * multiplier if normalized_child.max_repeats is not None else None,
163+
normalized_child.max_repeats * multiplier
164+
if normalized_child.max_repeats is not None
165+
else None,
155166
)
156167
return FixedRepetition(normalized_child, self.min_repeats, self.max_repeats)
157168

@@ -174,18 +185,15 @@ def normalize(self) -> "Node":
174185
for part in flattened_parts[1:]:
175186
# Merge adjacent literals into one node to reduce AST noise.
176187
prev = merged_parts[-1]
177-
if (
178-
isinstance(prev, Literal)
179-
and isinstance(part, Literal)
180-
):
188+
if isinstance(prev, Literal) and isinstance(part, Literal):
181189
merged_parts[-1] = Literal(prev.text + part.text)
182190
continue
183191
# Merge two adjacent fixed repetitions of the same node.
184192
# a{n,m}a{k,l} === a{n+k,m+l}
185-
if (
186-
(prev.node if isinstance(prev, FixedRepetition) else prev) == (part.node if isinstance(part, FixedRepetition) else part)
193+
if (prev.node if isinstance(prev, FixedRepetition) else prev) == (
194+
part.node if isinstance(part, FixedRepetition) else part
187195
):
188-
node = (prev.node if isinstance(prev, FixedRepetition) else prev)
196+
node = prev.node if isinstance(prev, FixedRepetition) else prev
189197
max_repeats = (
190198
None
191199
if prev.max_repeats is None or part.max_repeats is None
@@ -211,7 +219,10 @@ class Either(Node):
211219
options: tuple[Node, ...]
212220

213221
def render(self, capturing: bool = False) -> str:
214-
return _parenthesise('|'.join(option.render(capturing=capturing) for option in self.options), capture=capturing)
222+
return _parenthesise(
223+
"|".join(option.render(capturing=capturing) for option in self.options),
224+
capture=capturing,
225+
)
215226

216227
def normalize(self) -> Node:
217228
# Normalize options first so each branch is internally simplified.
@@ -237,7 +248,9 @@ def normalize(self) -> Node:
237248
if base_node not in repetition_groups:
238249
repetition_groups[base_node] = []
239250
repetition_group_order.append(base_node)
240-
repetition_groups[base_node].append((option.min_repeats, option.max_repeats))
251+
repetition_groups[base_node].append(
252+
(option.min_repeats, option.max_repeats)
253+
)
241254

242255
merged_repetition_options: list[Node] = []
243256
for base_node in repetition_group_order:
@@ -253,10 +266,7 @@ def normalize(self) -> Node:
253266
continue
254267

255268
prev_min, prev_max = merged[-1]
256-
can_merge = (
257-
prev_max is None
258-
or interval_min <= prev_max + 1
259-
)
269+
can_merge = prev_max is None or interval_min <= prev_max + 1
260270
if can_merge:
261271
if prev_max is None or interval_max is None:
262272
merged[-1] = (prev_min, None)
@@ -305,9 +315,11 @@ def _seq(*nodes: Node) -> Node:
305315
flat.extend(node.as_parts())
306316
return Sequence(tuple(flat))
307317

318+
308319
def optional(node: Node) -> Node:
309320
return FixedRepetition(node, 0, 1)
310321

322+
311323
def __compute_numerical_range_ast(
312324
str_a: str, str_b: str, start_parts: Optional[list[Node]] = None
313325
) -> Node:
@@ -455,7 +467,9 @@ def _fractional_precision(value: Decimal) -> int:
455467
def _fixed_width_int_range_ast(lower: int, upper: int, width: int) -> Node:
456468
if upper < lower:
457469
return Empty()
458-
return __compute_numerical_range_ast(str(lower).zfill(width), str(upper).zfill(width))
470+
return __compute_numerical_range_ast(
471+
str(lower).zfill(width), str(upper).zfill(width)
472+
)
459473

460474

461475
def _fractional_interval_ast(
@@ -480,7 +494,7 @@ def _fractional_interval_ast(
480494
if p == 0:
481495
p = 1
482496

483-
scale = 10 ** p
497+
scale = 10**p
484498
scaled_lower = int(lower * scale)
485499
scaled_upper = (scale - 1) if upper_open_one else int(upper * scale)
486500
patterns: list[Node] = []
@@ -503,7 +517,9 @@ def _fractional_interval_ast(
503517
patterns.append(_seq(full_prefixes, _any_digits(None)))
504518
else:
505519
if scaled_lower <= scaled_upper - 1:
506-
full_prefixes = _fixed_width_int_range_ast(scaled_lower, scaled_upper - 1, p)
520+
full_prefixes = _fixed_width_int_range_ast(
521+
scaled_lower, scaled_upper - 1, p
522+
)
507523
patterns.append(_seq(full_prefixes, _any_digits(None)))
508524
upper_text = str(scaled_upper).zfill(p)
509525
patterns.append(
@@ -572,6 +588,7 @@ def _float_range_ast_within_one(a: Decimal, b: Decimal) -> Node:
572588
sign = [Literal("-")] if pre_decs < 0 else []
573589
return _seq(*sign, pre_decs_node, Literal("."), decimal_ast)
574590

591+
575592
def _float_range_ast(a: DecimalLike, b: DecimalLike, strict=False) -> Node:
576593
"""
577594
Generate a regex AST that matches decimal numbers in the inclusive range [a, b].
@@ -625,6 +642,7 @@ def _zero() -> Literal:
625642
def _integer_unbounded_ast() -> Node:
626643
return _one_of(_negative_unbounded_ast(), _positive_unbounded_ast(), _zero())
627644

645+
628646
def _positive_with_min_digits_ast(extra_digits: int) -> Node:
629647
if extra_digits < 0:
630648
raise ValueError("extra_digits must be >= 0")
@@ -634,6 +652,7 @@ def _positive_with_min_digits_ast(extra_digits: int) -> Node:
634652
FixedRepetition(DigitRange(0, 9), 0, None),
635653
)
636654

655+
637656
def _positive_unbounded_ast() -> Node:
638657
return _positive_with_min_digits_ast(0)
639658

@@ -670,8 +689,11 @@ def _strict_decimal_unbounded_ast() -> Node:
670689
),
671690
)
672691

692+
673693
def _negative_strict_decimal_with_min_int_digits_ast(extra_digits: int) -> Node:
674-
return _seq(Literal("-"), _positive_strict_decimal_with_min_int_digits_ast(extra_digits))
694+
return _seq(
695+
Literal("-"), _positive_strict_decimal_with_min_int_digits_ast(extra_digits)
696+
)
675697

676698

677699
def _positive_strict_decimal_with_min_int_digits_ast(extra_digits: int) -> Node:
@@ -763,14 +785,16 @@ def _float_range_from_bounds_ast(
763785
maximum_decimal = _to_decimal(maximum)
764786
if maximum_decimal >= 0:
765787
bounded = _float_range_ast(Decimal("0.0"), maximum_decimal, strict=strict)
766-
decimal_ast = _one_of(_negative_strict_decimal_with_min_int_digits_ast(0), bounded)
767-
else:
768-
int_digits = len(
769-
str(int(floor(abs(maximum_decimal))))
788+
decimal_ast = _one_of(
789+
_negative_strict_decimal_with_min_int_digits_ast(0), bounded
770790
)
791+
else:
792+
int_digits = len(str(int(floor(abs(maximum_decimal)))))
771793
lower = -(Decimal(10) ** int_digits)
772794
bounded = _float_range_ast(lower, maximum_decimal, strict)
773-
decimal_ast = _one_of(bounded, _negative_strict_decimal_with_min_int_digits_ast(int_digits))
795+
decimal_ast = _one_of(
796+
bounded, _negative_strict_decimal_with_min_int_digits_ast(int_digits)
797+
)
774798

775799
integer_upper = int(floor(maximum_decimal))
776800
integer_ast = _range_from_bounds_ast(None, integer_upper)
@@ -783,12 +807,18 @@ def _float_range_from_bounds_ast(
783807
minimum_decimal = _to_decimal(minimum)
784808
if minimum_decimal <= 0:
785809
bounded = _float_range_ast(minimum_decimal, Decimal("0.0"), strict=strict)
786-
decimal_ast = _one_of(bounded, _positive_strict_decimal_with_min_int_digits_ast(0))
810+
decimal_ast = _one_of(
811+
bounded, _positive_strict_decimal_with_min_int_digits_ast(0)
812+
)
787813
else:
788-
int_digits = len(str(int(minimum_decimal.to_integral_value(rounding=ROUND_FLOOR))))
814+
int_digits = len(
815+
str(int(minimum_decimal.to_integral_value(rounding=ROUND_FLOOR)))
816+
)
789817
upper = Decimal(10) ** int_digits
790818
bounded = _float_range_ast(minimum_decimal, upper, strict=strict)
791-
decimal_ast = _one_of(bounded, _positive_strict_decimal_with_min_int_digits_ast(int_digits))
819+
decimal_ast = _one_of(
820+
bounded, _positive_strict_decimal_with_min_int_digits_ast(int_digits)
821+
)
792822

793823
integer_lower = int(minimum_decimal.to_integral_value(rounding=ROUND_CEILING))
794824
integer_ast = _range_from_bounds_ast(integer_lower, None)
@@ -823,7 +853,9 @@ def range_regex(
823853
824854
For floating-point ranges, use ``float_range_regex``.
825855
"""
826-
return _range_from_bounds_ast(minimum, maximum).normalize().render(capturing=capturing)
856+
return (
857+
_range_from_bounds_ast(minimum, maximum).normalize().render(capturing=capturing)
858+
)
827859

828860

829861
def float_range_regex(
@@ -843,6 +875,8 @@ def float_range_regex(
843875
matched, as long as their numeric value is in range.
844876
- If ``capturing`` is ``True``, grouping uses ``(...)`` instead of ``(?:...)``.
845877
"""
846-
return _float_range_from_bounds_ast(minimum, maximum, strict).normalize().render(
847-
capturing=capturing
878+
return (
879+
_float_range_from_bounds_ast(minimum, maximum, strict)
880+
.normalize()
881+
.render(capturing=capturing)
848882
)

scripts/print_examples.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def build_examples() -> list[tuple[str, str]]:
2828
),
2929
(
3030
"float_range_regex(-.1, '1.5', strict=True)",
31-
float_range_regex(-.1, "1.5", strict=True),
31+
float_range_regex(-0.1, "1.5", strict=True),
3232
),
3333
(
3434
"float_range_regex(minimum='1.5', strict=False)",

tests/test_range_regex.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,24 @@ def float_ranges_and_values(draw):
4949

5050
@st.composite
5151
def optional_float_bounds_and_integer(draw):
52-
minimum_tenths = draw(st.one_of(st.none(), st.integers(min_value=-1000, max_value=1000)))
53-
maximum_tenths = draw(st.one_of(st.none(), st.integers(min_value=-1000, max_value=1000)))
52+
minimum_tenths = draw(
53+
st.one_of(st.none(), st.integers(min_value=-1000, max_value=1000))
54+
)
55+
maximum_tenths = draw(
56+
st.one_of(st.none(), st.integers(min_value=-1000, max_value=1000))
57+
)
5458
integer_value = draw(st.integers(min_value=-2000, max_value=2000))
5559
return minimum_tenths, maximum_tenths, integer_value
5660

5761

5862
@st.composite
5963
def optional_float_bounds_and_dot_leading_fraction(draw):
60-
minimum_tenths = draw(st.one_of(st.none(), st.integers(min_value=-1000, max_value=1000)))
61-
maximum_tenths = draw(st.one_of(st.none(), st.integers(min_value=-1000, max_value=1000)))
64+
minimum_tenths = draw(
65+
st.one_of(st.none(), st.integers(min_value=-1000, max_value=1000))
66+
)
67+
maximum_tenths = draw(
68+
st.one_of(st.none(), st.integers(min_value=-1000, max_value=1000))
69+
)
6270
digit = draw(st.integers(min_value=1, max_value=9))
6371
negative = draw(st.booleans())
6472
text = f"-.{digit}" if negative else f".{digit}"
@@ -290,10 +298,12 @@ def test_float_range_rejects_values_above_upper_bound_with_extra_digits_from_int
290298
compiled = re.compile(float_range_regex(1, 1.5))
291299
assert compiled.fullmatch("1.51") is None
292300

301+
293302
def test_pos_float_range_accepts_dot():
294303
compiled = re.compile(float_range_regex(0.1, 1.5))
295304
assert compiled.fullmatch(".51") is not None
296305

306+
297307
def test_float_range_rejects_non_parseable_string_bounds():
298308
with pytest.raises(InvalidOperation):
299309
float_range_regex("foo", "1.5")

0 commit comments

Comments
 (0)