Skip to content

Commit d43a3d3

Browse files
authored
[ty] Avoid unnecessary argument type expansion (#19999)
## Summary Part of: astral-sh/ty#868 This PR adds a heuristic to avoid argument type expansion if it's going to eventually lead to no matching overload. This is done by checking whether the non-expandable argument types are assignable to the corresponding annotated parameter type. If one of them is not assignable to all of the remaining overloads, then argument type expansion isn't going to help. ## Test Plan Add mdtest that would otherwise take a long time because of the number of arguments that it would need to expand (30).
1 parent 9911196 commit d43a3d3

File tree

3 files changed

+277
-1
lines changed

3 files changed

+277
-1
lines changed

crates/ty_python_semantic/resources/mdtest/call/overloads.md

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,217 @@ def _(ab: A | B, ac: A | C, cd: C | D):
620620
reveal_type(f(*(cd,))) # revealed: Unknown
621621
```
622622

623+
### Optimization: Avoid argument type expansion
624+
625+
Argument type expansion could lead to exponential growth of the number of argument lists that needs
626+
to be evaluated, so ty deploys some heuristics to prevent this from happening.
627+
628+
Heuristic: If an argument type that cannot be expanded and cannot be assighned to any of the
629+
remaining overloads before argument type expansion, then even with argument type expansion, it won't
630+
lead to a successful evaluation of the call.
631+
632+
`overloaded.pyi`:
633+
634+
```pyi
635+
from typing import overload
636+
637+
class A: ...
638+
class B: ...
639+
class C: ...
640+
641+
@overload
642+
def f() -> None: ...
643+
@overload
644+
def f(**kwargs: int) -> C: ...
645+
@overload
646+
def f(x: A, /, **kwargs: int) -> A: ...
647+
@overload
648+
def f(x: B, /, **kwargs: int) -> B: ...
649+
650+
class Foo:
651+
@overload
652+
def f(self) -> None: ...
653+
@overload
654+
def f(self, **kwargs: int) -> C: ...
655+
@overload
656+
def f(self, x: A, /, **kwargs: int) -> A: ...
657+
@overload
658+
def f(self, x: B, /, **kwargs: int) -> B: ...
659+
```
660+
661+
```py
662+
from overloaded import A, B, C, Foo, f
663+
from typing_extensions import reveal_type
664+
665+
def _(ab: A | B, a=1):
666+
reveal_type(f(a1=a, a2=a, a3=a)) # revealed: C
667+
reveal_type(f(A(), a1=a, a2=a, a3=a)) # revealed: A
668+
reveal_type(f(B(), a1=a, a2=a, a3=a)) # revealed: B
669+
670+
# Here, the arity check filters out the first and second overload, type checking fails on the
671+
# remaining overloads, so ty moves on to argument type expansion. But, the first argument (`C`)
672+
# isn't assignable to any of the remaining overloads (3 and 4), so there's no point in expanding
673+
# the other 30 arguments of type `Unknown | Literal[1]` which would result in allocating a
674+
# vector containing 2**30 argument lists after expanding all of the arguments.
675+
reveal_type(
676+
# error: [no-matching-overload]
677+
# revealed: Unknown
678+
f(
679+
C(),
680+
a1=a,
681+
a2=a,
682+
a3=a,
683+
a4=a,
684+
a5=a,
685+
a6=a,
686+
a7=a,
687+
a8=a,
688+
a9=a,
689+
a10=a,
690+
a11=a,
691+
a12=a,
692+
a13=a,
693+
a14=a,
694+
a15=a,
695+
a16=a,
696+
a17=a,
697+
a18=a,
698+
a19=a,
699+
a20=a,
700+
a21=a,
701+
a22=a,
702+
a23=a,
703+
a24=a,
704+
a25=a,
705+
a26=a,
706+
a27=a,
707+
a28=a,
708+
a29=a,
709+
a30=a,
710+
)
711+
)
712+
713+
# Here, the heuristics won't come into play because all arguments can be expanded but expanding
714+
# the first argument resutls in a successful evaluation of the call, so there's no exponential
715+
# growth of the number of argument lists.
716+
reveal_type(
717+
# revealed: A | B
718+
f(
719+
ab,
720+
a1=a,
721+
a2=a,
722+
a3=a,
723+
a4=a,
724+
a5=a,
725+
a6=a,
726+
a7=a,
727+
a8=a,
728+
a9=a,
729+
a10=a,
730+
a11=a,
731+
a12=a,
732+
a13=a,
733+
a14=a,
734+
a15=a,
735+
a16=a,
736+
a17=a,
737+
a18=a,
738+
a19=a,
739+
a20=a,
740+
a21=a,
741+
a22=a,
742+
a23=a,
743+
a24=a,
744+
a25=a,
745+
a26=a,
746+
a27=a,
747+
a28=a,
748+
a29=a,
749+
a30=a,
750+
)
751+
)
752+
753+
def _(foo: Foo, ab: A | B, a=1):
754+
reveal_type(foo.f(a1=a, a2=a, a3=a)) # revealed: C
755+
reveal_type(foo.f(A(), a1=a, a2=a, a3=a)) # revealed: A
756+
reveal_type(foo.f(B(), a1=a, a2=a, a3=a)) # revealed: B
757+
758+
reveal_type(
759+
# error: [no-matching-overload]
760+
# revealed: Unknown
761+
foo.f(
762+
C(),
763+
a1=a,
764+
a2=a,
765+
a3=a,
766+
a4=a,
767+
a5=a,
768+
a6=a,
769+
a7=a,
770+
a8=a,
771+
a9=a,
772+
a10=a,
773+
a11=a,
774+
a12=a,
775+
a13=a,
776+
a14=a,
777+
a15=a,
778+
a16=a,
779+
a17=a,
780+
a18=a,
781+
a19=a,
782+
a20=a,
783+
a21=a,
784+
a22=a,
785+
a23=a,
786+
a24=a,
787+
a25=a,
788+
a26=a,
789+
a27=a,
790+
a28=a,
791+
a29=a,
792+
a30=a,
793+
)
794+
)
795+
796+
reveal_type(
797+
# revealed: A | B
798+
foo.f(
799+
ab,
800+
a1=a,
801+
a2=a,
802+
a3=a,
803+
a4=a,
804+
a5=a,
805+
a6=a,
806+
a7=a,
807+
a8=a,
808+
a9=a,
809+
a10=a,
810+
a11=a,
811+
a12=a,
812+
a13=a,
813+
a14=a,
814+
a15=a,
815+
a16=a,
816+
a17=a,
817+
a18=a,
818+
a19=a,
819+
a20=a,
820+
a21=a,
821+
a22=a,
822+
a23=a,
823+
a24=a,
824+
a25=a,
825+
a26=a,
826+
a27=a,
827+
a28=a,
828+
a29=a,
829+
a30=a,
830+
)
831+
)
832+
```
833+
623834
## Filtering based on `Any` / `Unknown`
624835

625836
This is the step 5 of the overload call evaluation algorithm which specifies that:

crates/ty_python_semantic/src/types/call/arguments.rs

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use ruff_python_ast as ast;
55

66
use crate::Db;
77
use crate::types::KnownClass;
8-
use crate::types::enums::enum_member_literals;
8+
use crate::types::enums::{enum_member_literals, enum_metadata};
99
use crate::types::tuple::{Tuple, TupleLength, TupleType};
1010

1111
use super::Type;
@@ -208,10 +208,32 @@ impl<'a, 'db> FromIterator<(Argument<'a>, Option<Type<'db>>)> for CallArguments<
208208
}
209209
}
210210

211+
/// Returns `true` if the type can be expanded into its subtypes.
212+
///
213+
/// In other words, it returns `true` if [`expand_type`] returns [`Some`] for the given type.
214+
pub(crate) fn is_expandable_type<'db>(db: &'db dyn Db, ty: Type<'db>) -> bool {
215+
match ty {
216+
Type::NominalInstance(instance) => {
217+
let class = instance.class(db);
218+
class.is_known(db, KnownClass::Bool)
219+
|| instance.tuple_spec(db).is_some_and(|spec| match &*spec {
220+
Tuple::Fixed(fixed_length_tuple) => fixed_length_tuple
221+
.all_elements()
222+
.any(|element| is_expandable_type(db, *element)),
223+
Tuple::Variable(_) => false,
224+
})
225+
|| enum_metadata(db, class.class_literal(db).0).is_some()
226+
}
227+
Type::Union(_) => true,
228+
_ => false,
229+
}
230+
}
231+
211232
/// Expands a type into its possible subtypes, if applicable.
212233
///
213234
/// Returns [`None`] if the type cannot be expanded.
214235
fn expand_type<'db>(db: &'db dyn Db, ty: Type<'db>) -> Option<Vec<Type<'db>>> {
236+
// NOTE: Update `is_expandable_type` if this logic changes accordingly.
215237
match ty {
216238
Type::NominalInstance(instance) => {
217239
let class = instance.class(db);

crates/ty_python_semantic/src/types/call/bind.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use crate::Program;
1616
use crate::db::Db;
1717
use crate::dunder_all::dunder_all_names;
1818
use crate::place::{Boundness, Place};
19+
use crate::types::call::arguments::is_expandable_type;
1920
use crate::types::diagnostic::{
2021
CALL_NON_CALLABLE, CONFLICTING_ARGUMENT_FORMS, INVALID_ARGUMENT_TYPE, MISSING_ARGUMENT,
2122
NO_MATCHING_OVERLOAD, PARAMETER_ALREADY_ASSIGNED, TOO_MANY_POSITIONAL_ARGUMENTS,
@@ -1337,6 +1338,48 @@ impl<'db> CallableBinding<'db> {
13371338
// for evaluating the expanded argument lists.
13381339
snapshotter.restore(self, pre_evaluation_snapshot);
13391340

1341+
// At this point, there's at least one argument that can be expanded.
1342+
//
1343+
// This heuristic tries to detect if there's any need to perform argument type expansion or
1344+
// not by checking whether there are any non-expandable argument type that cannot be
1345+
// assigned to any of the remaining overloads.
1346+
//
1347+
// This heuristic needs to be applied after restoring the bindings state to the one before
1348+
// type checking as argument type expansion would evaluate it from that point on.
1349+
for (argument_index, (argument, argument_type)) in argument_types.iter().enumerate() {
1350+
// TODO: Remove `Keywords` once `**kwargs` support is added
1351+
if matches!(argument, Argument::Synthetic | Argument::Keywords) {
1352+
continue;
1353+
}
1354+
let Some(argument_type) = argument_type else {
1355+
continue;
1356+
};
1357+
if is_expandable_type(db, argument_type) {
1358+
continue;
1359+
}
1360+
let mut is_argument_assignable_to_any_overload = false;
1361+
'overload: for (_, overload) in self.matching_overloads() {
1362+
for parameter_index in &overload.argument_matches[argument_index].parameters {
1363+
let parameter_type = overload.signature.parameters()[*parameter_index]
1364+
.annotated_type()
1365+
.unwrap_or(Type::unknown());
1366+
if argument_type.is_assignable_to(db, parameter_type) {
1367+
is_argument_assignable_to_any_overload = true;
1368+
break 'overload;
1369+
}
1370+
}
1371+
}
1372+
if !is_argument_assignable_to_any_overload {
1373+
tracing::debug!(
1374+
"Argument at {argument_index} (`{}`) is not assignable to any of the \
1375+
remaining overloads, skipping argument type expansion",
1376+
argument_type.display(db)
1377+
);
1378+
snapshotter.restore(self, post_evaluation_snapshot);
1379+
return;
1380+
}
1381+
}
1382+
13401383
for expanded_argument_lists in expansions {
13411384
// This is the merged state of the bindings after evaluating all of the expanded
13421385
// argument lists. This will be the final state to restore the bindings to if all of

0 commit comments

Comments
 (0)