1010
1111DecimalLike = Union [int , float , Decimal , str ]
1212
13+
1314def _parenthesise (s : str , capture : bool = False ) -> str :
1415 open = "(" if capture else "(?:"
1516 return f"{ open } { s } )"
@@ -22,6 +23,7 @@ def _to_decimal(value: DecimalLike) -> Decimal:
2223 return Decimal (str (value ))
2324 return Decimal (value )
2425
26+
2527class Node (ABC ):
2628 @abstractmethod
2729 def render (self , capturing : bool = False ) -> str :
@@ -45,7 +47,6 @@ def max_repeats(self) -> Optional[int]:
4547 return 1
4648
4749
48-
4950@dataclass (frozen = True )
5051class Empty (Node ):
5152 def render (self , capturing : bool = False ) -> str :
@@ -124,9 +125,13 @@ def render(self, capturing: bool = False) -> str:
124125
125126 def normalize (self ) -> "Node" :
126127 # Normalize repeated child
127- if (self .min_repeats == 0 and self .max_repeats == 0 ) or (isinstance (self .node , Sequence ) and not self .node .parts ) or (isinstance (self .node , Empty ) and self .min_repeats == 0 ):
128+ if (
129+ (self .min_repeats == 0 and self .max_repeats == 0 )
130+ or (isinstance (self .node , Sequence ) and not self .node .parts )
131+ or (isinstance (self .node , Empty ) and self .min_repeats == 0 )
132+ ):
128133 return _seq ()
129- if ( isinstance (self .node , Empty ) and self .min_repeats > 0 ) :
134+ if isinstance (self .node , Empty ) and self .min_repeats > 0 :
130135 return Empty ()
131136 normalized_child = self .node .normalize ()
132137 # (a+)? === a* and (a*)? === a*
@@ -144,14 +149,20 @@ def normalize(self) -> "Node":
144149 # In those cases a{n,}{m} === a{n*m,} and a{n}{m} === a{n*m}
145150 if (
146151 isinstance (normalized_child , FixedRepetition )
147- and self .min_repeats == self .max_repeats # implies self.max_repeats is not None
148- and (normalized_child .max_repeats is None or normalized_child .min_repeats == normalized_child .max_repeats )
152+ and self .min_repeats
153+ == self .max_repeats # implies self.max_repeats is not None
154+ and (
155+ normalized_child .max_repeats is None
156+ or normalized_child .min_repeats == normalized_child .max_repeats
157+ )
149158 ):
150159 multiplier = self .min_repeats
151160 return FixedRepetition (
152161 normalized_child .node ,
153162 normalized_child .min_repeats * multiplier ,
154- normalized_child .max_repeats * multiplier if normalized_child .max_repeats is not None else None ,
163+ normalized_child .max_repeats * multiplier
164+ if normalized_child .max_repeats is not None
165+ else None ,
155166 )
156167 return FixedRepetition (normalized_child , self .min_repeats , self .max_repeats )
157168
@@ -174,18 +185,15 @@ def normalize(self) -> "Node":
174185 for part in flattened_parts [1 :]:
175186 # Merge adjacent literals into one node to reduce AST noise.
176187 prev = merged_parts [- 1 ]
177- if (
178- isinstance (prev , Literal )
179- and isinstance (part , Literal )
180- ):
188+ if isinstance (prev , Literal ) and isinstance (part , Literal ):
181189 merged_parts [- 1 ] = Literal (prev .text + part .text )
182190 continue
183191 # Merge two adjacent fixed repetitions of the same node.
184192 # a{n,m}a{k,l} === a{n+k,m+l}
185- if (
186- ( prev . node if isinstance ( prev , FixedRepetition ) else prev ) == ( part .node if isinstance (part , FixedRepetition ) else part )
193+ if (prev . node if isinstance ( prev , FixedRepetition ) else prev ) == (
194+ part .node if isinstance (part , FixedRepetition ) else part
187195 ):
188- node = ( prev .node if isinstance (prev , FixedRepetition ) else prev )
196+ node = prev .node if isinstance (prev , FixedRepetition ) else prev
189197 max_repeats = (
190198 None
191199 if prev .max_repeats is None or part .max_repeats is None
@@ -211,7 +219,10 @@ class Either(Node):
211219 options : tuple [Node , ...]
212220
213221 def render (self , capturing : bool = False ) -> str :
214- return _parenthesise ('|' .join (option .render (capturing = capturing ) for option in self .options ), capture = capturing )
222+ return _parenthesise (
223+ "|" .join (option .render (capturing = capturing ) for option in self .options ),
224+ capture = capturing ,
225+ )
215226
216227 def normalize (self ) -> Node :
217228 # Normalize options first so each branch is internally simplified.
@@ -237,7 +248,9 @@ def normalize(self) -> Node:
237248 if base_node not in repetition_groups :
238249 repetition_groups [base_node ] = []
239250 repetition_group_order .append (base_node )
240- repetition_groups [base_node ].append ((option .min_repeats , option .max_repeats ))
251+ repetition_groups [base_node ].append (
252+ (option .min_repeats , option .max_repeats )
253+ )
241254
242255 merged_repetition_options : list [Node ] = []
243256 for base_node in repetition_group_order :
@@ -253,10 +266,7 @@ def normalize(self) -> Node:
253266 continue
254267
255268 prev_min , prev_max = merged [- 1 ]
256- can_merge = (
257- prev_max is None
258- or interval_min <= prev_max + 1
259- )
269+ can_merge = prev_max is None or interval_min <= prev_max + 1
260270 if can_merge :
261271 if prev_max is None or interval_max is None :
262272 merged [- 1 ] = (prev_min , None )
@@ -305,9 +315,11 @@ def _seq(*nodes: Node) -> Node:
305315 flat .extend (node .as_parts ())
306316 return Sequence (tuple (flat ))
307317
318+
308319def optional (node : Node ) -> Node :
309320 return FixedRepetition (node , 0 , 1 )
310321
322+
311323def __compute_numerical_range_ast (
312324 str_a : str , str_b : str , start_parts : Optional [list [Node ]] = None
313325) -> Node :
@@ -455,7 +467,9 @@ def _fractional_precision(value: Decimal) -> int:
455467def _fixed_width_int_range_ast (lower : int , upper : int , width : int ) -> Node :
456468 if upper < lower :
457469 return Empty ()
458- return __compute_numerical_range_ast (str (lower ).zfill (width ), str (upper ).zfill (width ))
470+ return __compute_numerical_range_ast (
471+ str (lower ).zfill (width ), str (upper ).zfill (width )
472+ )
459473
460474
461475def _fractional_interval_ast (
@@ -480,7 +494,7 @@ def _fractional_interval_ast(
480494 if p == 0 :
481495 p = 1
482496
483- scale = 10 ** p
497+ scale = 10 ** p
484498 scaled_lower = int (lower * scale )
485499 scaled_upper = (scale - 1 ) if upper_open_one else int (upper * scale )
486500 patterns : list [Node ] = []
@@ -503,7 +517,9 @@ def _fractional_interval_ast(
503517 patterns .append (_seq (full_prefixes , _any_digits (None )))
504518 else :
505519 if scaled_lower <= scaled_upper - 1 :
506- full_prefixes = _fixed_width_int_range_ast (scaled_lower , scaled_upper - 1 , p )
520+ full_prefixes = _fixed_width_int_range_ast (
521+ scaled_lower , scaled_upper - 1 , p
522+ )
507523 patterns .append (_seq (full_prefixes , _any_digits (None )))
508524 upper_text = str (scaled_upper ).zfill (p )
509525 patterns .append (
@@ -572,6 +588,7 @@ def _float_range_ast_within_one(a: Decimal, b: Decimal) -> Node:
572588 sign = [Literal ("-" )] if pre_decs < 0 else []
573589 return _seq (* sign , pre_decs_node , Literal ("." ), decimal_ast )
574590
591+
575592def _float_range_ast (a : DecimalLike , b : DecimalLike , strict = False ) -> Node :
576593 """
577594 Generate a regex AST that matches decimal numbers in the inclusive range [a, b].
@@ -625,6 +642,7 @@ def _zero() -> Literal:
625642def _integer_unbounded_ast () -> Node :
626643 return _one_of (_negative_unbounded_ast (), _positive_unbounded_ast (), _zero ())
627644
645+
628646def _positive_with_min_digits_ast (extra_digits : int ) -> Node :
629647 if extra_digits < 0 :
630648 raise ValueError ("extra_digits must be >= 0" )
@@ -634,6 +652,7 @@ def _positive_with_min_digits_ast(extra_digits: int) -> Node:
634652 FixedRepetition (DigitRange (0 , 9 ), 0 , None ),
635653 )
636654
655+
637656def _positive_unbounded_ast () -> Node :
638657 return _positive_with_min_digits_ast (0 )
639658
@@ -670,8 +689,11 @@ def _strict_decimal_unbounded_ast() -> Node:
670689 ),
671690 )
672691
692+
673693def _negative_strict_decimal_with_min_int_digits_ast (extra_digits : int ) -> Node :
674- return _seq (Literal ("-" ), _positive_strict_decimal_with_min_int_digits_ast (extra_digits ))
694+ return _seq (
695+ Literal ("-" ), _positive_strict_decimal_with_min_int_digits_ast (extra_digits )
696+ )
675697
676698
677699def _positive_strict_decimal_with_min_int_digits_ast (extra_digits : int ) -> Node :
@@ -763,14 +785,16 @@ def _float_range_from_bounds_ast(
763785 maximum_decimal = _to_decimal (maximum )
764786 if maximum_decimal >= 0 :
765787 bounded = _float_range_ast (Decimal ("0.0" ), maximum_decimal , strict = strict )
766- decimal_ast = _one_of (_negative_strict_decimal_with_min_int_digits_ast (0 ), bounded )
767- else :
768- int_digits = len (
769- str (int (floor (abs (maximum_decimal ))))
788+ decimal_ast = _one_of (
789+ _negative_strict_decimal_with_min_int_digits_ast (0 ), bounded
770790 )
791+ else :
792+ int_digits = len (str (int (floor (abs (maximum_decimal )))))
771793 lower = - (Decimal (10 ) ** int_digits )
772794 bounded = _float_range_ast (lower , maximum_decimal , strict )
773- decimal_ast = _one_of (bounded , _negative_strict_decimal_with_min_int_digits_ast (int_digits ))
795+ decimal_ast = _one_of (
796+ bounded , _negative_strict_decimal_with_min_int_digits_ast (int_digits )
797+ )
774798
775799 integer_upper = int (floor (maximum_decimal ))
776800 integer_ast = _range_from_bounds_ast (None , integer_upper )
@@ -783,12 +807,18 @@ def _float_range_from_bounds_ast(
783807 minimum_decimal = _to_decimal (minimum )
784808 if minimum_decimal <= 0 :
785809 bounded = _float_range_ast (minimum_decimal , Decimal ("0.0" ), strict = strict )
786- decimal_ast = _one_of (bounded , _positive_strict_decimal_with_min_int_digits_ast (0 ))
810+ decimal_ast = _one_of (
811+ bounded , _positive_strict_decimal_with_min_int_digits_ast (0 )
812+ )
787813 else :
788- int_digits = len (str (int (minimum_decimal .to_integral_value (rounding = ROUND_FLOOR ))))
814+ int_digits = len (
815+ str (int (minimum_decimal .to_integral_value (rounding = ROUND_FLOOR )))
816+ )
789817 upper = Decimal (10 ) ** int_digits
790818 bounded = _float_range_ast (minimum_decimal , upper , strict = strict )
791- decimal_ast = _one_of (bounded , _positive_strict_decimal_with_min_int_digits_ast (int_digits ))
819+ decimal_ast = _one_of (
820+ bounded , _positive_strict_decimal_with_min_int_digits_ast (int_digits )
821+ )
792822
793823 integer_lower = int (minimum_decimal .to_integral_value (rounding = ROUND_CEILING ))
794824 integer_ast = _range_from_bounds_ast (integer_lower , None )
@@ -823,7 +853,9 @@ def range_regex(
823853
824854 For floating-point ranges, use ``float_range_regex``.
825855 """
826- return _range_from_bounds_ast (minimum , maximum ).normalize ().render (capturing = capturing )
856+ return (
857+ _range_from_bounds_ast (minimum , maximum ).normalize ().render (capturing = capturing )
858+ )
827859
828860
829861def float_range_regex (
@@ -843,6 +875,8 @@ def float_range_regex(
843875 matched, as long as their numeric value is in range.
844876 - If ``capturing`` is ``True``, grouping uses ``(...)`` instead of ``(?:...)``.
845877 """
846- return _float_range_from_bounds_ast (minimum , maximum , strict ).normalize ().render (
847- capturing = capturing
878+ return (
879+ _float_range_from_bounds_ast (minimum , maximum , strict )
880+ .normalize ()
881+ .render (capturing = capturing )
848882 )
0 commit comments