Skip to content

Commit ee79280

Browse files
authored
fix: Make Interval subclasses properly inherit bounds (#266)
1 parent a01ff42 commit ee79280

File tree

2 files changed

+59
-10
lines changed

2 files changed

+59
-10
lines changed

src/phantom/interval.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,24 @@ def _get_scalar_float_bounds(
110110
return low, high
111111

112112

113+
def _resolve_bound(
114+
cls: type,
115+
name: str,
116+
argument: Comparable | None,
117+
default: Comparable,
118+
) -> None:
119+
inherited = getattr(cls, name, None)
120+
121+
if argument is not None:
122+
resolved = argument
123+
elif inherited is not None:
124+
resolved = inherited
125+
else:
126+
resolved = default
127+
128+
setattr(cls, name, resolved)
129+
130+
113131
class Interval(Phantom[Comparable], bound=Comparable, abstract=True):
114132
"""
115133
Base class for all interval types, providing the following class arguments:
@@ -128,12 +146,12 @@ class Interval(Phantom[Comparable], bound=Comparable, abstract=True):
128146
def __init_subclass__(
129147
cls,
130148
check: IntervalCheck | None = None,
131-
low: Comparable = neg_inf,
132-
high: Comparable = inf,
149+
low: Comparable | None = None,
150+
high: Comparable | None = None,
133151
**kwargs: Any,
134152
) -> None:
135-
resolve_class_attr(cls, "__low__", low)
136-
resolve_class_attr(cls, "__high__", high)
153+
_resolve_bound(cls, "__low__", low, neg_inf)
154+
_resolve_bound(cls, "__high__", high, inf)
137155
resolve_class_attr(cls, "__check__", check)
138156
if getattr(cls, "__check__", None) is None:
139157
raise TypeError(f"{cls.__qualname__} must define an interval check")

tests/test_interval.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@
2323

2424
class TestInterval:
2525
def test_subclassing_without_check_raises(self):
26-
with pytest.raises(TypeError, match="I must define an interval check$"):
26+
with pytest.raises(TypeError, match="A must define an interval check$"):
2727

28-
class I(Interval, abstract=False): # noqa: E742
28+
class A(Interval, abstract=False):
2929
...
3030

3131
def test_parse_coerces_str(self):
@@ -35,7 +35,7 @@ class Great(int, Inclusive, low=10):
3535
assert Great.parse("10") == 10
3636

3737
def test_allows_decimal_bound(self):
38-
class I( # noqa: E742
38+
class A(
3939
Decimal,
4040
Interval,
4141
check=interval.exclusive,
@@ -44,9 +44,40 @@ class I( # noqa: E742
4444
):
4545
...
4646

47-
assert not isinstance(2, I)
48-
assert not isinstance(1.98, I)
49-
assert isinstance(Decimal("1.98"), I)
47+
assert not isinstance(2, A)
48+
assert not isinstance(1.98, A)
49+
assert isinstance(Decimal("1.98"), A)
50+
51+
def test_subclass_inherits_bounds(self):
52+
class A(int, Inclusive, low=-10, high=10):
53+
...
54+
55+
class B(A):
56+
...
57+
58+
assert B.__check__ is A.__check__
59+
assert isinstance(-10, B)
60+
assert isinstance(10, B)
61+
assert not isinstance(-11, B)
62+
assert not isinstance(11, B)
63+
64+
class C(A, low=0):
65+
...
66+
67+
assert C.__check__ is A.__check__
68+
assert isinstance(0, C)
69+
assert isinstance(10, C)
70+
assert not isinstance(-1, C)
71+
assert not isinstance(11, C)
72+
73+
class D(A, high=0):
74+
...
75+
76+
assert D.__check__ is A.__check__
77+
assert isinstance(-10, D)
78+
assert isinstance(0, D)
79+
assert not isinstance(-11, D)
80+
assert not isinstance(1, D)
5081

5182

5283
parametrize_negative_ints = pytest.mark.parametrize("i", (-10, -1, -0, +0))

0 commit comments

Comments
 (0)