Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions pyrefly/lib/alt/class/class_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2217,7 +2217,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
// Otherwise, analyze the value to determine the type
let (inherited_ty, inherited_annotation) =
self.get_inherited_type_and_annotation(class, name);
let is_inherited = if inherited_ty.is_none() {
let mut is_inherited = if inherited_ty.is_none() {
IsInherited::No
} else {
IsInherited::Maybe
Expand Down Expand Up @@ -2246,15 +2246,11 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
&errors2,
);
if errors2.is_empty() {
// The new type is compatible with the inherited one; use the inherited type to
// avoid spurious errors about changing the type of a read-write attribute.
// However, we need to clear the is_abstract_method flag since assigning
// a concrete implementation makes this field non-abstract.
let mut ty = inherited_ty;
ty.transform_toplevel_func_metadata(|meta| {
meta.flags.is_abstract_method = false;
});
ty
// The new type is compatible with the inherited one; use the child's
// inferred type to preserve type precision. Skip the override check
// since we've already validated compatibility above.
is_inherited = IsInherited::No;
self.attribute_expr_infer(e, None, name, errors)
} else {
// The hint was no good; infer the type without it.
self.attribute_expr_infer(e, None, name, errors)
Expand Down
34 changes: 29 additions & 5 deletions pyrefly/lib/test/class_overrides.rs
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,7 @@ class Foo:
class Bar(Foo):
x = 1

assert_type(Bar.x, Any | None)
assert_type(Bar.x, int)
assert_type(Foo.x, Any | None)
def test(x: type[Foo]):
assert_type(x.x, Any | None)
Expand All @@ -832,7 +832,7 @@ class Child2(Parent):
testcase!(
test_unannotated_empty_tuple_attribute_override,
r#"
from typing import Any, assert_type
from typing import Any, Literal, assert_type

class Foo:
x = ()
Expand All @@ -841,7 +841,7 @@ class Bar(Foo):
x = ("feature_x",)

assert_type(Foo.x, tuple[Any, ...])
assert_type(Bar.x, tuple[Any, ...])
assert_type(Bar.x, tuple[Literal['feature_x']])
"#,
);

Expand Down Expand Up @@ -966,7 +966,7 @@ class A:
class B(A):
x = 1
def f(b: B):
assert_type(b.x, int | None)
assert_type(b.x, int)
"#,
);

Expand All @@ -993,7 +993,31 @@ class D(C):
x = [B()]

def f(d: D):
assert_type(d.x, list[A])
assert_type(d.x, list[B])
"#,
);

testcase!(
test_class_variable_override_with_subclass,
r#"
class Request:
@classmethod
def make(cls) -> "Request":
return cls()

class FormRequest(Request):
@classmethod
def from_response(cls) -> "FormRequest":
return cls()

class BaseTest:
request_class = Request

class MyTest(BaseTest):
request_class = FormRequest

def test(self) -> None:
self.request_class.from_response()
"#,
);

Expand Down