Skip to content

[5.9][ConstraintSystem] Use fallback type constraint to default pack expansion #66709

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions include/swift/Sema/Constraint.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,14 +186,12 @@ enum class ConstraintKind : char {
/// constraint.
OneWayBindParam,
/// If there is no contextual info e.g. `_ = { 42 }` default first type
/// to a second type (inferred closure type). This is effectively a
/// `Defaultable` constraint which a couple of differences:
/// to a second type. This is effectively a `Defaultable` constraint
/// which one significant difference:
///
/// - References inferred closure type and all of the outer parameters
/// referenced by closure body.
/// - Handled specially by binding inference, specifically contributes
/// to the bindings only if there are no contextual types available.
DefaultClosureType,
FallbackType,
/// The first type represents a result of an unresolved member chain,
/// and the second type is its base type. This constraint acts almost
/// like `Equal` but also enforces following semantics:
Expand Down Expand Up @@ -701,7 +699,7 @@ class Constraint final : public llvm::ilist_node<Constraint>,
case ConstraintKind::OptionalObject:
case ConstraintKind::OneWayEqual:
case ConstraintKind::OneWayBindParam:
case ConstraintKind::DefaultClosureType:
case ConstraintKind::FallbackType:
case ConstraintKind::UnresolvedMemberChainBase:
case ConstraintKind::PackElementOf:
case ConstraintKind::SameShape:
Expand Down
11 changes: 6 additions & 5 deletions include/swift/Sema/ConstraintSystem.h
Original file line number Diff line number Diff line change
Expand Up @@ -4847,11 +4847,12 @@ class ConstraintSystem {
TypeMatchOptions flags,
ConstraintLocatorBuilder locator);

/// Attempt to simplify the given defaultable closure type constraint.
SolutionKind simplifyDefaultClosureTypeConstraint(
Type closureType, Type inferredType,
ArrayRef<TypeVariableType *> referencedOuterParameters,
TypeMatchOptions flags, ConstraintLocatorBuilder locator);
/// Attempt to simplify the given fallback type constraint.
SolutionKind
simplifyFallbackTypeConstraint(Type defaultableType, Type fallbackType,
ArrayRef<TypeVariableType *> referencedVars,
TypeMatchOptions flags,
ConstraintLocatorBuilder locator);

/// Attempt to simplify a property wrapper constraint.
SolutionKind simplifyPropertyWrapperConstraint(Type wrapperType, Type wrappedValueType,
Expand Down
15 changes: 7 additions & 8 deletions lib/Sema/CSBindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ void BindingSet::inferTransitiveBindings(

// Infer transitive defaults.
for (const auto &def : bindings.Defaults) {
if (def.getSecond()->getKind() == ConstraintKind::DefaultClosureType)
if (def.getSecond()->getKind() == ConstraintKind::FallbackType)
continue;

addDefault(def.second);
Expand Down Expand Up @@ -1510,7 +1510,7 @@ void PotentialBindings::infer(Constraint *constraint) {
}

case ConstraintKind::Defaultable:
case ConstraintKind::DefaultClosureType:
case ConstraintKind::FallbackType:
// Do these in a separate pass.
if (CS.getFixedTypeRecursive(constraint->getFirstType(), true)
->getAs<TypeVariableType>() == TypeVar) {
Expand Down Expand Up @@ -1634,7 +1634,7 @@ void PotentialBindings::retract(Constraint *constraint) {
break;

case ConstraintKind::Defaultable:
case ConstraintKind::DefaultClosureType: {
case ConstraintKind::FallbackType: {
Defaults.erase(constraint);
break;
}
Expand Down Expand Up @@ -2075,11 +2075,10 @@ bool TypeVarBindingProducer::computeNext() {
if (NumTries == 0) {
// Add defaultable constraints (if any).
for (auto *constraint : DelayedDefaults) {
if (constraint->getKind() == ConstraintKind::DefaultClosureType) {
// If there are no other possible bindings for this closure
// let's default it to the type inferred from its parameters/body,
// otherwise we should only attempt contextual types as a
// top-level closure type.
if (constraint->getKind() == ConstraintKind::FallbackType) {
// If there are no other possible bindings for this variable
// let's default it to the fallback type, otherwise we should
// only attempt contextual types.
if (!ExploredTypes.empty())
continue;
}
Expand Down
6 changes: 3 additions & 3 deletions lib/Sema/CSGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2921,9 +2921,9 @@ namespace {
SmallVector<TypeVariableType *, 4> referencedVars{
collectVarRefs.varRefs.begin(), collectVarRefs.varRefs.end()};

CS.addUnsolvedConstraint(Constraint::create(
CS, ConstraintKind::DefaultClosureType, closureType, inferredType,
locator, referencedVars));
CS.addUnsolvedConstraint(
Constraint::create(CS, ConstraintKind::FallbackType, closureType,
inferredType, locator, referencedVars));

CS.setClosureType(closure, inferredType);
return closureType;
Expand Down
40 changes: 20 additions & 20 deletions lib/Sema/CSSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2284,7 +2284,7 @@ ConstraintSystem::matchTupleTypes(TupleType *tuple1, TupleType *tuple2,
case ConstraintKind::BridgingConversion:
case ConstraintKind::OneWayEqual:
case ConstraintKind::OneWayBindParam:
case ConstraintKind::DefaultClosureType:
case ConstraintKind::FallbackType:
case ConstraintKind::UnresolvedMemberChainBase:
case ConstraintKind::PropertyWrapper:
case ConstraintKind::SyntacticElement:
Expand Down Expand Up @@ -2643,7 +2643,7 @@ static bool matchFunctionRepresentations(FunctionType::ExtInfo einfo1,
case ConstraintKind::ValueWitness:
case ConstraintKind::OneWayEqual:
case ConstraintKind::OneWayBindParam:
case ConstraintKind::DefaultClosureType:
case ConstraintKind::FallbackType:
case ConstraintKind::UnresolvedMemberChainBase:
case ConstraintKind::PropertyWrapper:
case ConstraintKind::SyntacticElement:
Expand Down Expand Up @@ -3161,7 +3161,7 @@ ConstraintSystem::matchFunctionTypes(FunctionType *func1, FunctionType *func2,
case ConstraintKind::BridgingConversion:
case ConstraintKind::OneWayEqual:
case ConstraintKind::OneWayBindParam:
case ConstraintKind::DefaultClosureType:
case ConstraintKind::FallbackType:
case ConstraintKind::UnresolvedMemberChainBase:
case ConstraintKind::PropertyWrapper:
case ConstraintKind::SyntacticElement:
Expand Down Expand Up @@ -6814,7 +6814,7 @@ ConstraintSystem::matchTypes(Type type1, Type type2, ConstraintKind kind,
case ConstraintKind::ValueWitness:
case ConstraintKind::OneWayEqual:
case ConstraintKind::OneWayBindParam:
case ConstraintKind::DefaultClosureType:
case ConstraintKind::FallbackType:
case ConstraintKind::UnresolvedMemberChainBase:
case ConstraintKind::PropertyWrapper:
case ConstraintKind::SyntacticElement:
Expand Down Expand Up @@ -10950,18 +10950,18 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyDefaultableConstraint(
return SolutionKind::Solved;
}

ConstraintSystem::SolutionKind
ConstraintSystem::simplifyDefaultClosureTypeConstraint(
Type closureType, Type inferredType,
ArrayRef<TypeVariableType *> referencedOuterParameters,
TypeMatchOptions flags, ConstraintLocatorBuilder locator) {
closureType = getFixedTypeRecursive(closureType, flags, /*wantRValue=*/true);
ConstraintSystem::SolutionKind ConstraintSystem::simplifyFallbackTypeConstraint(
Type defaultableType, Type fallbackType,
ArrayRef<TypeVariableType *> referencedVars, TypeMatchOptions flags,
ConstraintLocatorBuilder locator) {
defaultableType =
getFixedTypeRecursive(defaultableType, flags, /*wantRValue=*/true);

if (closureType->isTypeVariableOrMember()) {
if (defaultableType->isTypeVariableOrMember()) {
if (flags.contains(TMF_GenerateConstraints)) {
addUnsolvedConstraint(Constraint::create(
*this, ConstraintKind::DefaultClosureType, closureType, inferredType,
getConstraintLocator(locator), referencedOuterParameters));
*this, ConstraintKind::FallbackType, defaultableType, fallbackType,
getConstraintLocator(locator), referencedVars));
return SolutionKind::Solved;
}

Expand Down Expand Up @@ -15014,7 +15014,7 @@ ConstraintSystem::addConstraintImpl(ConstraintKind kind, Type first,
case ConstraintKind::Conjunction:
case ConstraintKind::KeyPath:
case ConstraintKind::KeyPathApplication:
case ConstraintKind::DefaultClosureType:
case ConstraintKind::FallbackType:
case ConstraintKind::SyntacticElement:
llvm_unreachable("Use the correct addConstraint()");
}
Expand Down Expand Up @@ -15546,12 +15546,12 @@ ConstraintSystem::simplifyConstraint(const Constraint &constraint) {
/*flags*/ None,
constraint.getLocator());

case ConstraintKind::DefaultClosureType:
return simplifyDefaultClosureTypeConstraint(constraint.getFirstType(),
constraint.getSecondType(),
constraint.getTypeVariables(),
/*flags*/ None,
constraint.getLocator());
case ConstraintKind::FallbackType:
return simplifyFallbackTypeConstraint(constraint.getFirstType(),
constraint.getSecondType(),
constraint.getTypeVariables(),
/*flags*/ None,
constraint.getLocator());

case ConstraintKind::PropertyWrapper:
return simplifyPropertyWrapperConstraint(constraint.getFirstType(),
Expand Down
12 changes: 6 additions & 6 deletions lib/Sema/Constraint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ Constraint::Constraint(ConstraintKind Kind, Type First, Type Second,
llvm_unreachable("Wrong constructor for member constraint");

case ConstraintKind::Defaultable:
case ConstraintKind::DefaultClosureType:
case ConstraintKind::FallbackType:
assert(!First.isNull());
assert(!Second.isNull());
break;
Expand Down Expand Up @@ -164,7 +164,7 @@ Constraint::Constraint(ConstraintKind Kind, Type First, Type Second, Type Third,
case ConstraintKind::Conjunction:
case ConstraintKind::OneWayEqual:
case ConstraintKind::OneWayBindParam:
case ConstraintKind::DefaultClosureType:
case ConstraintKind::FallbackType:
case ConstraintKind::UnresolvedMemberChainBase:
case ConstraintKind::PropertyWrapper:
case ConstraintKind::SyntacticElement:
Expand Down Expand Up @@ -314,7 +314,7 @@ Constraint *Constraint::clone(ConstraintSystem &cs) const {
case ConstraintKind::Defaultable:
case ConstraintKind::OneWayEqual:
case ConstraintKind::OneWayBindParam:
case ConstraintKind::DefaultClosureType:
case ConstraintKind::FallbackType:
case ConstraintKind::UnresolvedMemberChainBase:
case ConstraintKind::PropertyWrapper:
case ConstraintKind::BindTupleOfFunctionParams:
Expand Down Expand Up @@ -470,8 +470,8 @@ void Constraint::print(llvm::raw_ostream &Out, SourceManager *sm,
case ConstraintKind::OpenedExistentialOf: Out << " opened archetype of "; break;
case ConstraintKind::OneWayEqual: Out << " one-way bind to "; break;
case ConstraintKind::OneWayBindParam: Out << " one-way bind param to "; break;
case ConstraintKind::DefaultClosureType:
Out << " closure can default to ";
case ConstraintKind::FallbackType:
Out << " can fallback to ";
break;
case ConstraintKind::UnresolvedMemberChainBase:
Out << " unresolved member chain base ";
Expand Down Expand Up @@ -742,7 +742,7 @@ gatherReferencedTypeVars(Constraint *constraint,
case ConstraintKind::SelfObjectOfProtocol:
case ConstraintKind::OneWayEqual:
case ConstraintKind::OneWayBindParam:
case ConstraintKind::DefaultClosureType:
case ConstraintKind::FallbackType:
case ConstraintKind::UnresolvedMemberChainBase:
case ConstraintKind::PropertyWrapper:
case ConstraintKind::BindTupleOfFunctionParams:
Expand Down
4 changes: 2 additions & 2 deletions lib/Sema/ConstraintSystem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1031,7 +1031,7 @@ Type ConstraintSystem::openPackExpansionType(PackExpansionType *expansion,
// This constraint is important to make sure that pack expansion always
// has a binding and connect pack expansion var to any type variables
// that appear in pattern and shape types.
addUnsolvedConstraint(Constraint::create(*this, ConstraintKind::Defaultable,
addUnsolvedConstraint(Constraint::create(*this, ConstraintKind::FallbackType,
expansionVar, openedPackExpansion,
expansionLoc));

Expand Down Expand Up @@ -7391,7 +7391,7 @@ bool TypeVarBindingProducer::requiresOptionalAdjustment(
PotentialBinding
TypeVarBindingProducer::getDefaultBinding(Constraint *constraint) const {
assert(constraint->getKind() == ConstraintKind::Defaultable ||
constraint->getKind() == ConstraintKind::DefaultClosureType);
constraint->getKind() == ConstraintKind::FallbackType);

auto type = constraint->getSecondType();
Binding binding{type, BindingKind::Exact, constraint};
Expand Down
42 changes: 42 additions & 0 deletions test/Constraints/pack-expansion-expressions.swift
Original file line number Diff line number Diff line change
Expand Up @@ -560,3 +560,45 @@ func configure<T, each Element>(
repeat item[keyPath: (each configuration).0] = (each configuration).1
return item
}

// rdar://110819621 - generic parameter is bound before pack expansion type which result in inference failures
func test_that_expansions_are_bound_early() {
struct Data {
let prop: Int?
}

struct Value<each T> {
init(_ body: (repeat each T) -> Bool) {}
}

func compute<Root, Value>(
root: Root,
keyPath: KeyPath<Root, Value>,
other: Value) -> Bool { true }

func test_keypath(v: Int) {
let _: Value<Data> = Value({
compute(
root: $0,
keyPath: \.prop,
other: v
)
}) // Ok

let _: Value = Value<Data>({
compute(
root: $0,
keyPath: \.prop,
other: v
)
}) // Ok
}

func equal<Value>(_: Value, _: Value) -> Bool {}

func test_equality(i: Int) {
let _: Value<Data> = Value({
equal($0.prop, i) // Ok
})
}
}
2 changes: 1 addition & 1 deletion unittests/Sema/ConstraintSimplificationTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ TEST_F(SemaTest, TestClosureInferenceFromOptionalContext) {
auto *closureTy = cs.createTypeVariable(closureLoc, /*options=*/0);

cs.addUnsolvedConstraint(Constraint::create(
cs, ConstraintKind::DefaultClosureType, closureTy, defaultTy,
cs, ConstraintKind::FallbackType, closureTy, defaultTy,
cs.getConstraintLocator(closure), /*referencedVars=*/{}));

auto contextualTy =
Expand Down