Skip to content

[Prototype] Better type inference for lambdas (e.g., as used in folds) #9076

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,12 @@ object desugar {
if (restrictedAccess) mods.withPrivateWithin(constr1.mods.privateWithin)
else mods
}
val appParamss =
// FIXME: This now infers `List[List[DefTree]]`, the issue
// is that `withMods` is defined in `DefTree` so that becomes the
// upper bound of the type variable (see logic in `constrainSelectionQualifier`),
// but the result type of `withMods` is a type member which is
// refined in `ValDef`.
val appParamss: List[List[ValDef]] =
derivedVparamss.nestedZipWithConserve(constrVparamss)((ap, cp) =>
ap.withMods(ap.mods | (cp.mods.flags & HasDefault)))
val app = DefDef(nme.apply, derivedTparams, appParamss, applyResultTpt, widenedCreatorExpr)
Expand Down
24 changes: 19 additions & 5 deletions compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,22 @@ trait ConstraintHandling[AbstractContext] {

protected def addOneBound(param: TypeParamRef, bound: Type, isUpper: Boolean)(using AbstractContext): Boolean =
if !constraint.contains(param) then true
else if !isUpper && param.occursIn(bound)
// We don't allow recursive lower bounds when defining a type,
// so we shouldn't allow them as constraints either.
else if
bound.existsPart {
case `param` =>
// We don't allow recursive lower bounds when defining a type,
// so we shouldn't allow them as constraints either.
!isUpper
case AppliedType(tycon: TypeRef, args) if tycon.info.isInstanceOf[MatchAlias] =>
// FIXME: this is incomplete, see tests/pos/type-match-fbounds.scala
args.exists {
case `param` => true
case tp: TypeVar => tp.origin eq param
case _ => false
}
case _ => false
}
then
false
else
val oldBounds @ TypeBounds(lo, hi) = constraint.nonParamBounds(param)
Expand Down Expand Up @@ -432,8 +445,9 @@ trait ConstraintHandling[AbstractContext] {
if lower.nonEmpty && !bounds.lo.isRef(defn.NothingClass)
|| upper.nonEmpty && !bounds.hi.isAny
then constr.println(i"INIT*** $tl")
lower.forall(addOneBound(_, bounds.hi, isUpper = true)) &&
upper.forall(addOneBound(_, bounds.lo, isUpper = false))

lower.forall(loParam => addOneBound(loParam, bounds.hi, isUpper = true) && addOneBound(param, constraint.nonParamBounds(loParam).lo, isUpper = false))
&& upper.forall(upParam => addOneBound(upParam, bounds.lo, isUpper = false) && addOneBound(param, constraint.nonParamBounds(upParam).hi, isUpper = true))
case _ =>
// Happens if param was already solved while processing earlier params of the same TypeLambda.
// See #4720.
Expand Down
15 changes: 13 additions & 2 deletions compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
* @param isUpper If true, `bound` is an upper bound, else a lower bound.
*/
private def stripParams(tp: Type, paramBuf: mutable.ListBuffer[TypeParamRef],
isUpper: Boolean)(implicit ctx: Context): Type = tp match {
isUpper: Boolean)(implicit ctx: Context): Type = tp.stripTypeVar match {
case param: TypeParamRef if contains(param) =>
if (!paramBuf.contains(param)) paramBuf += param
NoType
Expand Down Expand Up @@ -310,6 +310,9 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
private def ensureNonCyclic(param: TypeParamRef, inst: Type)(using Context): Type =

def recur(tp: Type, fromBelow: Boolean): Type = tp match
case tp: NamedType =>
val underlying1 = recur(tp.underlying, fromBelow)
if underlying1 ne tp.underlying then underlying1 else tp
case tp: AndOrType =>
val r1 = recur(tp.tp1, fromBelow)
val r2 = recur(tp.tp2, fromBelow)
Expand Down Expand Up @@ -365,7 +368,7 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
* Q <: tp implies Q <: P and isUpper = true, or
* tp <: Q implies P <: Q and isUpper = false
*/
private def dependentParams(tp: Type, isUpper: Boolean): List[TypeParamRef] = tp match
private def dependentParams(tp: Type, isUpper: Boolean)(using Context): List[TypeParamRef] = tp.stripTypeVar match
case param: TypeParamRef if contains(param) =>
param :: (if (isUpper) upper(param) else lower(param))
case tp: AndType if isUpper =>
Expand Down Expand Up @@ -479,6 +482,12 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
case (e1: TypeBounds, _) if e1 contains e2 => e2
case (_, e2: TypeBounds) if e2 contains e1 => e1
case (tv1: TypeVar, tv2: TypeVar) if tv1 eq tv2 => e1

// Should this be based on the merged entries instead of
// using this.entry/other.entry ?
case (e1: TypeParamRef, e2) if this.entry(e1).bounds.contains(e2) => e2
case (e1, e2: TypeParamRef) if other.entry(e2).bounds.contains(e1) => e1

case _ =>
if (otherHasErrors)
e1
Expand Down Expand Up @@ -607,6 +616,8 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
def occursAtToplevel(param: TypeParamRef, inst: Type)(implicit ctx: Context): Boolean =

def occurs(tp: Type)(using Context): Boolean = tp match
case tp: NamedType =>
occurs(tp.underlying)
case tp: AndOrType =>
occurs(tp.tp1) || occurs(tp.tp2)
case tp: TypeParamRef =>
Expand Down
6 changes: 2 additions & 4 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1983,8 +1983,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
/** Merge `t1` into `tp2` if t1 is a subtype of some &-summand of tp2.
*/
private def mergeIfSub(tp1: Type, tp2: Type): Type =
if (isSubTypeWhenFrozen(tp1, tp2))
if (isSubTypeWhenFrozen(tp2, tp1)) tp2 else tp1 // keep existing type if possible
if (isSubTypeWhenFrozen(tp1, tp2)) tp1
else tp2 match {
case tp2 @ AndType(tp21, tp22) =>
val lower1 = mergeIfSub(tp1, tp21)
Expand All @@ -2004,8 +2003,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
* @param canConstrain If true, new constraints might be added to make the merge possible.
*/
private def mergeIfSuper(tp1: Type, tp2: Type, canConstrain: Boolean): Type =
if (isSubType(tp2, tp1, whenFrozen = !canConstrain))
if (isSubType(tp1, tp2, whenFrozen = !canConstrain)) tp2 else tp1 // keep existing type if possible
if (isSubType(tp2, tp1, whenFrozen = !canConstrain)) tp1
else tp2 match {
case tp2 @ OrType(tp21, tp22) =>
val higher1 = mergeIfSuper(tp1, tp21, canConstrain)
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3927,7 +3927,7 @@ object Types {
NoType
}

def tyconTypeParams(implicit ctx: Context): List[ParamInfo] = {
def tyconTypeParams(implicit ctx: Context): List[TypeApplications.TypeParamInfo] = {
val tparams = tycon.typeParams
if (tparams.isEmpty) HKTypeLambda.any(args.length).typeParams else tparams
}
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ class PlainPrinter(_ctx: Context) extends Printer {
case tp @ AppliedType(tycon, args) =>
if (defn.isCompiletimeAppliedType(tycon.typeSymbol)) tp.tryCompiletimeConstantFold
else tycon.dealias.appliedTo(args)
// Workaround for https://github.com/lampepfl/dotty/issues/8988
case tp @ AnnotatedType(underlying @ AnnotatedType(_, annot2), annot1)
if (annot1.symbol eq defn.UncheckedVarianceAnnot) && (annot1.symbol eq annot2.symbol) =>
homogenize(underlying)
case _ =>
tp
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Inferencing.scala
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ object Inferencing {
*
* we want to instantiate U to x.type right away. No need to wait further.
*/
private def variances(tp: Type)(using Context): VarianceMap = {
def variances(tp: Type)(using Context): VarianceMap = {
Stats.record("variances")
val constraint = ctx.typerState.constraint

Expand Down
27 changes: 21 additions & 6 deletions compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,12 @@ object ProtoTypes {
* or as an upper bound of a prefix or underlying type.
*/
private def hasUnknownMembers(tp: Type)(using Context): Boolean = tp match {
case tp: TypeVar => !tp.isInstantiated
case tp: TypeVar =>
// FIXME: This used to be `!tp.isInstantiated` but that prevents
// extension methods from being selected with the changes in this PR.
// This change doesn't break any testcase, can we construct a testcase
// where this matters?
false
case tp: WildcardType => true
case NoType => true
case tp: TypeRef =>
Expand All @@ -152,20 +157,30 @@ object ProtoTypes {
case _ => false
}

override def isMatchedBy(tp1: Type, keepConstraint: Boolean)(using Context): Boolean =
name == nme.WILDCARD || hasUnknownMembers(tp1) ||
{
val mbr = if (privateOK) tp1.member(name) else tp1.nonPrivateMember(name)
override def isMatchedBy(tp1: Type, keepConstraint: Boolean)(using Context): Boolean = {
if name == nme.WILDCARD || hasUnknownMembers(tp1) then
return true

def go(pre: Type): Boolean = {
val mbr = if (privateOK) pre.member(name) else pre.nonPrivateMember(name)
def qualifies(m: SingleDenotation) =
memberProto.isRef(defn.UnitClass) ||
tp1.isValueType && compat.normalizedCompatible(NamedType(tp1, name, m), memberProto, keepConstraint)
pre.isValueType && compat.normalizedCompatible(NamedType(pre, name, m), memberProto, keepConstraint)
// Note: can't use `m.info` here because if `m` is a method, `m.info`
// loses knowledge about `m`'s default arguments.
mbr match { // hasAltWith inlined for performance
case mbr: SingleDenotation => mbr.exists && qualifies(mbr)
case _ => mbr hasAltWith qualifies
}
}
tp1.widenDealias.stripTypeVar match {
case tp: TypeParamRef =>
val bounds = ctx.typeComparer.bounds(tp)
go(bounds.hi) || go(bounds.lo)
case _ =>
go(tp1)
}
}

def underlying(using Context): Type = WildcardType

Expand Down
Loading