Skip to content

Commit 3ce91c4

Browse files
committed
Merge pull request #5942 from Microsoft/fixUnionToUnionTypeInference
Fix union/union or intersection/intersection type inference
2 parents 4735c00 + 6901a98 commit 3ce91c4

File tree

5 files changed

+98
-31
lines changed

5 files changed

+98
-31
lines changed

src/compiler/checker.ts

+30-31
Original file line numberDiff line numberDiff line change
@@ -6148,14 +6148,25 @@ namespace ts {
61486148
function inferFromTypes(source: Type, target: Type) {
61496149
if (source.flags & TypeFlags.Union && target.flags & TypeFlags.Union ||
61506150
source.flags & TypeFlags.Intersection && target.flags & TypeFlags.Intersection) {
6151-
// Source and target are both unions or both intersections. To improve the quality of
6152-
// inferences we first reduce the types by removing constituents that are identically
6153-
// matched by a constituent in the other type. For example, when inferring from
6154-
// 'string | string[]' to 'string | T', we reduce the types to 'string[]' and 'T'.
6155-
const reducedSource = reduceUnionOrIntersectionType(<UnionOrIntersectionType>source, <UnionOrIntersectionType>target);
6156-
const reducedTarget = reduceUnionOrIntersectionType(<UnionOrIntersectionType>target, <UnionOrIntersectionType>source);
6157-
source = reducedSource;
6158-
target = reducedTarget;
6151+
// Source and target are both unions or both intersections. First, find each
6152+
// target constituent type that has an identically matching source constituent
6153+
// type, and for each such target constituent type infer from the type to itself.
6154+
// When inferring from a type to itself we effectively find all type parameter
6155+
// occurrences within that type and infer themselves as their type arguments.
6156+
let matchingTypes: Type[];
6157+
for (const t of (<UnionOrIntersectionType>target).types) {
6158+
if (typeIdenticalToSomeType(t, (<UnionOrIntersectionType>source).types)) {
6159+
(matchingTypes || (matchingTypes = [])).push(t);
6160+
inferFromTypes(t, t);
6161+
}
6162+
}
6163+
// Next, to improve the quality of inferences, reduce the source and target types by
6164+
// removing the identically matched constituents. For example, when inferring from
6165+
// 'string | string[]' to 'string | T' we reduce the types to 'string[]' and 'T'.
6166+
if (matchingTypes) {
6167+
source = removeTypesFromUnionOrIntersection(<UnionOrIntersectionType>source, matchingTypes);
6168+
target = removeTypesFromUnionOrIntersection(<UnionOrIntersectionType>target, matchingTypes);
6169+
}
61596170
}
61606171
if (target.flags & TypeFlags.TypeParameter) {
61616172
// If target is a type parameter, make an inference, unless the source type contains
@@ -6317,39 +6328,27 @@ namespace ts {
63176328
}
63186329
}
63196330

6320-
function typeIdenticalToSomeType(source: Type, target: UnionOrIntersectionType): boolean {
6321-
for (const t of target.types) {
6322-
if (isTypeIdenticalTo(source, t)) {
6331+
function typeIdenticalToSomeType(type: Type, types: Type[]): boolean {
6332+
for (const t of types) {
6333+
if (isTypeIdenticalTo(t, type)) {
63236334
return true;
63246335
}
63256336
}
63266337
return false;
63276338
}
63286339

63296340
/**
6330-
* Return the reduced form of the source type. This type is computed by by removing all source
6331-
* constituents that have an identical match in the target type.
6341+
* Return a new union or intersection type computed by removing a given set of types
6342+
* from a given union or intersection type.
63326343
*/
6333-
function reduceUnionOrIntersectionType(source: UnionOrIntersectionType, target: UnionOrIntersectionType) {
6334-
let sourceTypes = source.types;
6335-
let sourceIndex = 0;
6336-
let modified = false;
6337-
while (sourceIndex < sourceTypes.length) {
6338-
if (typeIdenticalToSomeType(sourceTypes[sourceIndex], target)) {
6339-
if (!modified) {
6340-
sourceTypes = sourceTypes.slice(0);
6341-
modified = true;
6342-
}
6343-
sourceTypes.splice(sourceIndex, 1);
6344-
}
6345-
else {
6346-
sourceIndex++;
6344+
function removeTypesFromUnionOrIntersection(type: UnionOrIntersectionType, typesToRemove: Type[]) {
6345+
const reducedTypes: Type[] = [];
6346+
for (const t of type.types) {
6347+
if (!typeIdenticalToSomeType(t, typesToRemove)) {
6348+
reducedTypes.push(t);
63476349
}
63486350
}
6349-
if (modified) {
6350-
return source.flags & TypeFlags.Union ? getUnionType(sourceTypes, /*noSubtypeReduction*/ true) : getIntersectionType(sourceTypes);
6351-
}
6352-
return source;
6351+
return type.flags & TypeFlags.Union ? getUnionType(reducedTypes, /*noSubtypeReduction*/ true) : getIntersectionType(reducedTypes);
63536352
}
63546353

63556354
function getInferenceCandidates(context: InferenceContext, index: number): Type[] {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
//// [recursiveUnionTypeInference.ts]
2+
interface Foo<T> {
3+
x: T;
4+
}
5+
6+
function bar<T>(x: Foo<T> | string): T {
7+
return bar(x);
8+
}
9+
10+
11+
//// [recursiveUnionTypeInference.js]
12+
function bar(x) {
13+
return bar(x);
14+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
=== tests/cases/compiler/recursiveUnionTypeInference.ts ===
2+
interface Foo<T> {
3+
>Foo : Symbol(Foo, Decl(recursiveUnionTypeInference.ts, 0, 0))
4+
>T : Symbol(T, Decl(recursiveUnionTypeInference.ts, 0, 14))
5+
6+
x: T;
7+
>x : Symbol(x, Decl(recursiveUnionTypeInference.ts, 0, 18))
8+
>T : Symbol(T, Decl(recursiveUnionTypeInference.ts, 0, 14))
9+
}
10+
11+
function bar<T>(x: Foo<T> | string): T {
12+
>bar : Symbol(bar, Decl(recursiveUnionTypeInference.ts, 2, 1))
13+
>T : Symbol(T, Decl(recursiveUnionTypeInference.ts, 4, 13))
14+
>x : Symbol(x, Decl(recursiveUnionTypeInference.ts, 4, 16))
15+
>Foo : Symbol(Foo, Decl(recursiveUnionTypeInference.ts, 0, 0))
16+
>T : Symbol(T, Decl(recursiveUnionTypeInference.ts, 4, 13))
17+
>T : Symbol(T, Decl(recursiveUnionTypeInference.ts, 4, 13))
18+
19+
return bar(x);
20+
>bar : Symbol(bar, Decl(recursiveUnionTypeInference.ts, 2, 1))
21+
>x : Symbol(x, Decl(recursiveUnionTypeInference.ts, 4, 16))
22+
}
23+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
=== tests/cases/compiler/recursiveUnionTypeInference.ts ===
2+
interface Foo<T> {
3+
>Foo : Foo<T>
4+
>T : T
5+
6+
x: T;
7+
>x : T
8+
>T : T
9+
}
10+
11+
function bar<T>(x: Foo<T> | string): T {
12+
>bar : <T>(x: Foo<T> | string) => T
13+
>T : T
14+
>x : Foo<T> | string
15+
>Foo : Foo<T>
16+
>T : T
17+
>T : T
18+
19+
return bar(x);
20+
>bar(x) : T
21+
>bar : <T>(x: Foo<T> | string) => T
22+
>x : Foo<T> | string
23+
}
24+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
interface Foo<T> {
2+
x: T;
3+
}
4+
5+
function bar<T>(x: Foo<T> | string): T {
6+
return bar(x);
7+
}

0 commit comments

Comments
 (0)