diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index 86c58118372e..fd09af8a9ee5 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -818,6 +818,25 @@ object desugar { } } + /** Transforms + * + * type $T >: Low <: Hi + * + * to + * + * @patternBindHole type $T >: Low <: Hi + * + * if the type is a type splice. + */ + def quotedPatternTypeDef(tree: TypeDef)(implicit ctx: Context): TypeDef = { + assert(ctx.mode.is(Mode.QuotedPattern)) + if (tree.name.startsWith("$") && !tree.isBackquoted) { + val patternBindHoleAnnot = New(ref(defn.InternalQuoted_patternBindHoleAnnot.typeRef)).withSpan(tree.span) + val mods = tree.mods.withAddedAnnotation(patternBindHoleAnnot) + tree.withMods(mods) + } else tree + } + /** The normalized name of `mdef`. This means * 1. Check that the name does not redefine a Scala core class. * If it does redefine, issue an error and return a mangled name instead of the original one. @@ -1031,7 +1050,9 @@ object desugar { checkModifiers(tree) match { case tree: ValDef => valDef(tree) case tree: TypeDef => - if (tree.isClassDef) classDef(tree) else tree + if (tree.isClassDef) classDef(tree) + else if (ctx.mode.is(Mode.QuotedPattern)) quotedPatternTypeDef(tree) + else tree case tree: DefDef => if (tree.name.isConstructorName) tree // was already handled by enclosing classDef else defDef(tree) diff --git a/compiler/src/dotty/tools/dotc/ast/TreeMapWithImplicits.scala b/compiler/src/dotty/tools/dotc/ast/TreeMapWithImplicits.scala index 714845f1fdfd..06db968a7a3d 100644 --- a/compiler/src/dotty/tools/dotc/ast/TreeMapWithImplicits.scala +++ b/compiler/src/dotty/tools/dotc/ast/TreeMapWithImplicits.scala @@ -68,6 +68,20 @@ class TreeMapWithImplicits extends tpd.TreeMap { nestedCtx } + private def patternScopeCtx(pattern: Tree)(implicit ctx: Context): Context = { + val nestedCtx = ctx.fresh.setNewScope + new TreeTraverser { + def traverse(tree: Tree)(implicit ctx: Context): Unit = { + tree match { + case d: DefTree => nestedCtx.enter(d.symbol) + case _ => + } + traverseChildren(tree) + } + }.traverse(pattern) + nestedCtx + } + override def transform(tree: Tree)(implicit ctx: Context): Tree = { def localCtx = if (tree.hasType && tree.symbol.exists) ctx.withOwner(tree.symbol) else ctx @@ -93,6 +107,13 @@ class TreeMapWithImplicits extends tpd.TreeMap { Nil, transformSelf(self), transformStats(impl.body, tree.symbol)) + case tree: CaseDef => + val patCtx = patternScopeCtx(tree.pat)(ctx) + cpy.CaseDef(tree)( + transform(tree.pat), + transform(tree.guard)(patCtx), + transform(tree.body)(patCtx) + ) case _ => super.transform(tree) } diff --git a/compiler/src/dotty/tools/dotc/ast/Trees.scala b/compiler/src/dotty/tools/dotc/ast/Trees.scala index 3ed62c7d24ac..76c72db824b3 100644 --- a/compiler/src/dotty/tools/dotc/ast/Trees.scala +++ b/compiler/src/dotty/tools/dotc/ast/Trees.scala @@ -768,6 +768,8 @@ object Trees { /** Is this a definition of a class? */ def isClassDef: Boolean = rhs.isInstanceOf[Template[_]] + + def isBackquoted: Boolean = hasAttachment(Backquoted) } /** extends parents { self => body } diff --git a/compiler/src/dotty/tools/dotc/ast/tpd.scala b/compiler/src/dotty/tools/dotc/ast/tpd.scala index 277f1dda299d..22d3dc989566 100644 --- a/compiler/src/dotty/tools/dotc/ast/tpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/tpd.scala @@ -1208,7 +1208,7 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { /** An extractor for typed splices */ object Splice { def apply(tree: Tree)(implicit ctx: Context): Tree = { - val baseType = tree.tpe.baseType(defn.QuotedExprClass) + val baseType = tree.tpe.baseType(defn.QuotedExprClass).orElse(tree.tpe.baseType(defn.QuotedTypeClass)) val argType = if (baseType != NoType) baseType.argTypesHi.head else defn.NothingType @@ -1342,6 +1342,17 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { } } + /** Creates the tuple type tree repesentation of the type trees in `ts` */ + def tupleTypeTree(elems: List[Tree])(implicit ctx: Context): Tree = { + val arity = elems.length + if (arity <= Definitions.MaxTupleArity && defn.TupleType(arity) != null) AppliedTypeTree(TypeTree(defn.TupleType(arity)), elems) + else nestedPairsTypeTree(elems) + } + + /** Creates the nested pairs type tree repesentation of the type trees in `ts` */ + def nestedPairsTypeTree(ts: List[Tree])(implicit ctx: Context): Tree = + ts.foldRight[Tree](TypeTree(defn.UnitType))((x, acc) => AppliedTypeTree(TypeTree(defn.PairType), x :: acc :: Nil)) + /** Replaces all positions in `tree` with zero-extent positions */ private def focusPositions(tree: Tree)(implicit ctx: Context): Tree = { val transformer = new tpd.TreeMap { diff --git a/compiler/src/dotty/tools/dotc/core/Symbols.scala b/compiler/src/dotty/tools/dotc/core/Symbols.scala index 0669c4c74cbd..496eacbdf546 100644 --- a/compiler/src/dotty/tools/dotc/core/Symbols.scala +++ b/compiler/src/dotty/tools/dotc/core/Symbols.scala @@ -210,8 +210,8 @@ trait Symbols { this: Context => Nil, decls) /** Define a new symbol associated with a Bind or pattern wildcard and, by default, make it gadt narrowable. */ - def newPatternBoundSymbol(name: Name, info: Type, span: Span, addToGadt: Boolean = true): Symbol = { - val sym = newSymbol(owner, name, Case, info, coord = span) + def newPatternBoundSymbol(name: Name, info: Type, span: Span, addToGadt: Boolean = true, flags: FlagSet = EmptyFlags): Symbol = { + val sym = newSymbol(owner, name, Case | flags, info, coord = span) if (addToGadt && name.isTypeName) gadt.addToConstraint(sym) sym } diff --git a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala index 3f4b8c8f19ef..40a7107913c4 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala @@ -2644,10 +2644,15 @@ object Parsers { def typeDefOrDcl(start: Offset, mods: Modifiers): Tree = { newLinesOpt() atSpan(start, nameStart) { - val name = ident().toTypeName + val nameIdent = typeIdent() val tparams = typeParamClauseOpt(ParamOwner.Type) - def makeTypeDef(rhs: Tree): Tree = - finalizeDef(TypeDef(name, lambdaAbstract(tparams, rhs)), mods, start) + def makeTypeDef(rhs: Tree): Tree = { + val rhs1 = lambdaAbstract(tparams, rhs) + val tdef = TypeDef(nameIdent.name.toTypeName, rhs1) + if (nameIdent.isBackquoted) + tdef.pushAttachment(Backquoted, ()) + finalizeDef(tdef, mods, start) + } in.token match { case EQUALS => in.nextToken() diff --git a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala index 8efd21fe35ac..1d02dd716b11 100644 --- a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala @@ -479,6 +479,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { if (lo eq hi) optText(lo)(" = " ~ _) else optText(lo)(" >: " ~ _) ~ optText(hi)(" <: " ~ _) case Bind(name, body) => + ("given ": Text).provided(tree.symbol.is(Implicit) && !homogenizedView) ~ // Used for scala.quoted.Type in quote patterns (not pickled) changePrec(InfixPrec) { toText(name) ~ " @ " ~ toText(body) } case Alternative(trees) => changePrec(OrPrec) { toText(trees, " | ") } @@ -610,6 +611,8 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { else keywordStr("'{") ~ toTextGlobal(dropBlock(tree)) ~ keywordStr("}") case Splice(tree) => keywordStr("${") ~ toTextGlobal(dropBlock(tree)) ~ keywordStr("}") + case TypSplice(tree) => + keywordStr("${") ~ toTextGlobal(dropBlock(tree)) ~ keywordStr("}") case tree: Applications.IntegratedTypeArgs => toText(tree.app) ~ Str("(with integrated type args)").provided(printDebug) case Thicket(trees) => diff --git a/compiler/src/dotty/tools/dotc/tastyreflect/KernelImpl.scala b/compiler/src/dotty/tools/dotc/tastyreflect/KernelImpl.scala index afe9622d25d8..0e2dddf69555 100644 --- a/compiler/src/dotty/tools/dotc/tastyreflect/KernelImpl.scala +++ b/compiler/src/dotty/tools/dotc/tastyreflect/KernelImpl.scala @@ -35,6 +35,15 @@ class KernelImpl(val rootContext: core.Contexts.Context, val rootPosition: util. def Context_source(self: Context): java.nio.file.Path = self.compilationUnit.source.file.jpath + def Context_GADT_setFreshGADTBounds(self: Context): Context = + self.fresh.setFreshGADTBounds.addMode(Mode.GadtConstraintInference) + + def Context_GADT_addToConstraint(self: Context)(syms: List[Symbol]): Boolean = + self.gadt.addToConstraint(syms) + + def Context_GADT_approximation(self: Context)(sym: Symbol, fromBelow: Boolean): Type = + self.gadt.approximation(sym, fromBelow) + // // REPORTING // diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index faff91716c4a..3621d093c897 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1967,42 +1967,166 @@ class Typer extends Namer typedTypeApply(untpd.TypeApply(untpd.ref(defn.InternalQuoted_typeQuoteR), quoted :: Nil), pt)(quoteContext).withSpan(tree.span) case quoted => ctx.compilationUnit.needsStaging = true - if (ctx.mode.is(Mode.Pattern) && level == 0) { - val exprPt = pt.baseType(defn.QuotedExprClass) - val quotedPt = if (exprPt.exists) exprPt.argTypesHi.head else defn.AnyType - val quoted1 = typedExpr(quoted, quotedPt)(quoteContext.addMode(Mode.QuotedPattern)) - val (shape, splices) = splitQuotePattern(quoted1) - val patType = defn.tupleType(splices.tpes.map(_.widen)) - val splicePat = typed(untpd.Tuple(splices.map(untpd.TypedSplice(_))).withSpan(quoted.span), patType) - UnApply( - fun = ref(defn.InternalQuotedMatcher_unapplyR).appliedToType(patType), - implicits = - ref(defn.InternalQuoted_exprQuoteR).appliedToType(shape.tpe).appliedTo(shape) :: - implicitArgTree(defn.QuoteContextType, tree.span) :: Nil, - patterns = splicePat :: Nil, - proto = pt) - } + if (ctx.mode.is(Mode.Pattern) && level == 0) + typedQuotePattern(quoted, pt, tree.span) else typedApply(untpd.Apply(untpd.ref(defn.InternalQuoted_exprQuoteR), quoted), pt)(quoteContext).withSpan(tree.span) } } - def splitQuotePattern(quoted: Tree)(implicit ctx: Context): (Tree, List[Tree]) = { + /** Type a quote pattern `case '{ } =>` qiven the a current prototype. Typing the pattern + * will also transform it into a call to `scala.internal.quoted.Matcher.unapply`. + * + * Code directly inside the quote is typed as an expression using Mode.QuotedPattern. Splices + * within the quotes become patterns again and typed acordingly. + * + * ``` + * case '{ ($ls: List[$t]) } => + * // `t` is of type `Type[T$1]` for some unknown T$1 + * // `t` is implicitly available + * // `l` is of type `Expr[List[T$1]]` + * '{ val h: $t = $ls.head } + * ``` + * + * For each type splice we will create a new type binding in the pattern match ($t @ _ in this case) + * and a corresponding type in the quoted pattern as a hole (@patternBindHole type $t in this case). + * All these generated types are inserted at the start of the quoted code. + * + * After typing the tree will resemble + * + * ``` + * case '{ type ${given t: Type[$t @ _]}; ${ls: Expr[List[$t]]} } => ... + * ``` + * + * Then the pattern is _split_ into the expression containd in the pattern replacing the splices by holes, + * and the patterns in the splices. All these are recombined into a call to `Matcher.unapply`. + * + * ``` + * case scala.internal.quoted.Matcher.unapply[ + * Tuple1[$t @ _], // Type binging definition + * Tuple2[Type[$t], Expr[List[$t]]] // Typing the result of the pattern match + * ]( + * Tuple2.unapply + * [Type[$t], Expr[List[$t]]] //Propagated from the tuple above + * (implict t @ _, ls @ _: Expr[List[$t]]) // from the spliced patterns + * )( + * '{ // Runtime quote Matcher.unapply uses to mach against. Expression directly inside the quoted pattern without the splices + * @scala.internal.Quoted.patternBindHole type $t + * scala.internal.Quoted.patternHole[List[$t]] + * }, + * true, // If there is at least one type splice. Used to instantiate the context with or without GADT constraints + * x$2 // tasty.Reflection instance + * ) => ... + * ``` + */ + private def typedQuotePattern(quoted: untpd.Tree, pt: Type, quoteSpan: Span)(implicit ctx: Context): Tree = { + val exprPt = pt.baseType(defn.QuotedExprClass) + val quotedPt = if (exprPt.exists) exprPt.argTypesHi.head else defn.AnyType + val quoted1 = typedExpr(quoted, quotedPt)(quoteContext.addMode(Mode.QuotedPattern)) + + val (typeBindings, shape, splices) = splitQuotePattern(quoted1) + + class ReplaceBindings extends TypeMap() { + override def apply(tp: Type): Type = tp match { + case tp: TypeRef => + val tp1 = if (tp.typeSymbol == defn.QuotedType_splice) tp.dealias else tp + typeBindings.get(tp1.typeSymbol).fold(tp)(_.symbol.typeRef) + case tp => mapOver(tp) + } + } + val replaceBindings = new ReplaceBindings + val patType = defn.tupleType(splices.tpes.map(tpe => replaceBindings(tpe.widen))) + + val typeBindingsTuple = tpd.tupleTypeTree(typeBindings.values.toList) + + val replaceBindingsInTree = new TreeMap { + private[this] var bindMap = Map.empty[Symbol, Symbol] + override def transform(tree: tpd.Tree)(implicit ctx: Context): tpd.Tree = { + tree match { + case tree: Bind => + val sym = tree.symbol + val newInfo = replaceBindings(sym.info) + val newSym = ctx.newSymbol(sym.owner, sym.name, sym.flags, newInfo, sym.privateWithin, sym.coord) + bindMap += sym -> newSym + Bind(newSym, transform(tree.body)).withSpan(sym.span) + case _ => + super.transform(tree).withType(replaceBindingsInType(tree.tpe)) + } + } + private[this] val replaceBindingsInType = new ReplaceBindings { + override def apply(tp: Type): Type = tp match { + case tp: TermRef => bindMap.get(tp.termSymbol).fold[Type](tp)(_.typeRef) + case tp => super.apply(tp) + } + } + } + + val splicePat = typed(untpd.Tuple(splices.map(x => untpd.TypedSplice(replaceBindingsInTree.transform(x)))).withSpan(quoted.span), patType) + + UnApply( + fun = ref(defn.InternalQuotedMatcher_unapplyR).appliedToTypeTrees(typeBindingsTuple :: TypeTree(patType) :: Nil), + implicits = + ref(defn.InternalQuoted_exprQuoteR).appliedToType(shape.tpe).appliedTo(shape) :: + Literal(Constant(typeBindings.nonEmpty)) :: + implicitArgTree(defn.QuoteContextType, quoteSpan) :: Nil, + patterns = splicePat :: Nil, + proto = pt) + } + + /** Split a typed quoted pattern is split into its type bindings, pattern expression and inner patterns. + * Type definitions with `@patternBindHole` will be inserted in the pattern expression for each type binding. + * + * A quote pattern + * ``` + * case '{ type ${given t: Type[$t @ _]}; ${ls: Expr[List[$t]]} } => ... + * ``` + * will return + * ``` + * ( + * Map(<$t>: Symbol -> <$t @ _>: Bind), + * <'{ + * @scala.internal.Quoted.patternBindHole type $t + * scala.internal.Quoted.patternHole[List[$t]] + * }>: Tree, + * List(: Tree) + * ) + * ``` + */ + private def splitQuotePattern(quoted: Tree)(implicit ctx: Context): (Map[Symbol, Bind], Tree, List[Tree]) = { val ctx0 = ctx + + val typeBindings: collection.mutable.Map[Symbol, Bind] = collection.mutable.Map.empty + def getBinding(sym: Symbol): Bind = + typeBindings.getOrElseUpdate(sym, { + val bindingBounds = sym.info + val bsym = ctx.newPatternBoundSymbol(sym.name.toTypeName, bindingBounds, quoted.span) + Bind(bsym, untpd.Ident(nme.WILDCARD).withType(bindingBounds)).withSpan(quoted.span) + }) + object splitter extends tpd.TreeMap { val patBuf = new mutable.ListBuffer[Tree] + val freshTypePatBuf = new mutable.ListBuffer[Tree] + val freshTypeBindingsBuff = new mutable.ListBuffer[Tree] + val typePatBuf = new mutable.ListBuffer[Tree] override def transform(tree: Tree)(implicit ctx: Context) = tree match { case Typed(Splice(pat), tpt) if !tpt.tpe.derivesFrom(defn.RepeatedParamClass) => - val exprTpt = AppliedTypeTree(TypeTree(defn.QuotedExprType), tpt :: Nil) + val tpt1 = transform(tpt) // Transform type bindings + val exprTpt = AppliedTypeTree(TypeTree(defn.QuotedExprType), tpt1 :: Nil) transform(Splice(Typed(pat, exprTpt))) case Splice(pat) => - try patternHole(tree) + try ref(defn.InternalQuoted_patternHoleR).appliedToType(tree.tpe).withSpan(tree.span) finally { val patType = pat.tpe.widen val patType1 = patType.underlyingIfRepeated(isJava = false) val pat1 = if (patType eq patType1) pat else pat.withType(patType1) patBuf += pat1 } + case Select(pat, _) if tree.symbol == defn.QuotedType_splice => + val sym = tree.tpe.dealias.typeSymbol.asType + val tdef = TypeDef(sym).withSpan(sym.span) + freshTypeBindingsBuff += transformTypeBindingTypeDef(tdef, freshTypePatBuf) + TypeTree(tree.tpe.dealias).withSpan(tree.span) + case ddef: ValOrDefDef => if (ddef.symbol.hasAnnotation(defn.InternalQuoted_patternBindHoleAnnot)) { val bindingType = ddef.symbol.info match { @@ -2021,17 +2145,49 @@ class Typer extends Namer patBuf += Bind(sym, untpd.Ident(nme.WILDCARD).withType(bindingExprTpe)).withSpan(ddef.span) } super.transform(tree) + case tdef: TypeDef if tdef.symbol.hasAnnotation(defn.InternalQuoted_patternBindHoleAnnot) => + transformTypeBindingTypeDef(tdef, typePatBuf) case _ => super.transform(tree) } + + def transformTypeBindingTypeDef(tdef: TypeDef, buff: mutable.Builder[Tree, List[Tree]]): Tree = { + val bindingType = getBinding(tdef.symbol).symbol.typeRef + val bindingTypeTpe = AppliedType(defn.QuotedTypeType, bindingType :: Nil) + assert(tdef.name.startsWith("$")) + val bindName = tdef.name.toString.stripPrefix("$").toTermName + val sym = ctx0.newPatternBoundSymbol(bindName, bindingTypeTpe, tdef.span, flags = ImplicitTerm) + buff += Bind(sym, untpd.Ident(nme.WILDCARD).withType(bindingTypeTpe)).withSpan(tdef.span) + super.transform(tdef) + } } - val result = splitter.transform(quoted) - (result, splitter.patBuf.toList) - } + val shape0 = splitter.transform(quoted) + val patterns = (splitter.freshTypePatBuf.iterator ++ splitter.typePatBuf.iterator ++ splitter.patBuf.iterator).toList + val freshTypeBindings = splitter.freshTypeBindingsBuff.result() - /** A hole the shape pattern of a quoted.Matcher.unapply, representing a splice */ - def patternHole(splice: Tree)(implicit ctx: Context): Tree = - ref(defn.InternalQuoted_patternHoleR).appliedToType(splice.tpe).withSpan(splice.span) + val shape1 = seq( + freshTypeBindings, + shape0 + ) + val shape2 = { + if (freshTypeBindings.isEmpty) shape1 + else { + val isFreshTypeBindings = freshTypeBindings.map(_.symbol).toSet + val typeMap = new TypeMap() { + def apply(tp: Type): Type = tp match { + case tp: TypeRef if tp.typeSymbol == defn.QuotedType_splice => + val tp1 = tp.dealias + if (isFreshTypeBindings(tp1.typeSymbol)) tp1 + else tp + case tp => mapOver(tp) + } + } + new TreeTypeMap(typeMap = typeMap).transform(shape1) + } + } + + (typeBindings.toMap, shape2, patterns) + } /** Translate `${ t: Expr[T] }` into expression `t.splice` while tracking the quotation level in the context */ def typedSplice(tree: untpd.Splice, pt: Type)(implicit ctx: Context): Tree = track("typedSplice") { @@ -2077,7 +2233,28 @@ class Typer extends Namer ctx.warning("Canceled quote directly inside a splice. ${ '[ XYZ ] } is equivalent to XYZ.", tree.sourcePos) typed(innerType, pt) case expr => - typedSelect(untpd.Select(tree.expr, tpnme.splice), pt)(spliceContext).withSpan(tree.span) + if (ctx.mode.is(Mode.QuotedPattern) && level == 1) { + if (isFullyDefined(pt, ForceDegree.all)) { + ctx.error(i"Spliced type pattern must not be fully defined. Consider using $pt directly", tree.expr.sourcePos) + tree.withType(UnspecifiedErrorType) + } else { + def spliceOwner(ctx: Context): Symbol = + if (ctx.mode.is(Mode.QuotedPattern)) spliceOwner(ctx.outer) else ctx.owner + val name = expr match { + case Ident(name) => ("$" + name).toTypeName + case Typed(Ident(name), _) => ("$" + name).toTypeName + case Bind(name, _) => ("$" + name).toTypeName + case _ => NameKinds.UniqueName.fresh("$".toTypeName) + } + val typeSym = ctx.newSymbol(spliceOwner(ctx), name, EmptyFlags, TypeBounds.empty, NoSymbol, expr.span) + typeSym.addAnnotation(Annotation(New(ref(defn.InternalQuoted_patternBindHoleAnnot.typeRef)).withSpan(expr.span))) + val pat = typedPattern(expr, defn.QuotedTypeType.appliedTo(typeSym.typeRef))( + spliceContext.retractMode(Mode.QuotedPattern).withOwner(spliceOwner(ctx))) + pat.select(tpnme.splice) + } + } else { + typedSelect(untpd.Select(tree.expr, tpnme.splice), pt)(spliceContext).withSpan(tree.span) + } } } diff --git a/library/src-3.x/scala/internal/Quoted.scala b/library/src-3.x/scala/internal/Quoted.scala index 2e940dd85794..c9a5a2727320 100644 --- a/library/src-3.x/scala/internal/Quoted.scala +++ b/library/src-3.x/scala/internal/Quoted.scala @@ -25,6 +25,9 @@ object Quoted { @compileTimeOnly("Illegal reference to `scala.internal.Quoted.patternBindHole`") class patternBindHole extends Annotation + /** A splice of a name in a quoted pattern is that marks the definition of a type splice */ + class patternType extends Annotation + /** Artifact of pickled type splices * * During quote reification a quote `'{ ... F[$t] ... }` will be transformed into diff --git a/library/src-3.x/scala/internal/quoted/Matcher.scala b/library/src-bootstrapped/scala/internal/quoted/Matcher.scala similarity index 78% rename from library/src-3.x/scala/internal/quoted/Matcher.scala rename to library/src-bootstrapped/scala/internal/quoted/Matcher.scala index 383779742151..bb3d0c6b4e62 100644 --- a/library/src-3.x/scala/internal/quoted/Matcher.scala +++ b/library/src-bootstrapped/scala/internal/quoted/Matcher.scala @@ -26,17 +26,35 @@ object Matcher { * * @param scrutineeExpr `Expr[_]` on which we are pattern matching * @param patternExpr `Expr[_]` containing the pattern tree + * @param hasTypeSplices `Boolean` notify if the pattern has type splices (if so we use a GADT context) * @param qctx the current QuoteContext * @return None if it did not match, `Some(tup)` if it matched where `tup` contains `Expr[Ti]`` */ - def unapply[Tup <: Tuple](scrutineeExpr: Expr[_])(implicit patternExpr: Expr[_], qctx: QuoteContext): Option[Tup] = { + def unapply[TypeBindings <: Tuple, Tup <: Tuple](scrutineeExpr: Expr[_])(implicit patternExpr: Expr[_], + hasTypeSplices: Boolean, qctx: QuoteContext): Option[Tup] = { + + // TODO improve performance import qctx.tasty.{Bind => BindPattern, _} import Matching._ type Env = Set[(Symbol, Symbol)] + class SymBinding(val sym: Symbol) + inline def withEnv[T](env: Env)(body: => given Env => T): T = body given env + def hasBindTypeAnnotation(tpt: TypeTree): Boolean = tpt match { + case Annotated(tpt2, annot) => isBindAnnotation(annot) || hasBindTypeAnnotation(tpt2) + case _ => false + } + + def hasBindAnnotation(sym: Symbol) = sym.annots.exists(isBindAnnotation) + + def isBindAnnotation(tree: Tree): Boolean = tree match { + case New(tpt) => tpt.symbol == kernel.Definitions_InternalQuoted_patternBindHoleAnnot + case annot => annot.symbol.owner == kernel.Definitions_InternalQuoted_patternBindHoleAnnot + } + /** Check that all trees match with `mtch` and concatenate the results with && */ def matchLists[T](l1: List[T], l2: List[T])(mtch: (T, T) => Matching): Matching = (l1, l2) match { case (x :: xs, y :: ys) => mtch(x, y) && matchLists(xs, ys)(mtch) @@ -45,7 +63,7 @@ object Matcher { } /** Check that all trees match with =#= and concatenate the results with && */ - def (scrutinees: List[Tree]) =##= (patterns: List[Tree]) given Env: Matching = + def (scrutinees: List[Tree]) =##= (patterns: List[Tree]) given Context, Env: Matching = matchLists(scrutinees, patterns)(_ =#= _) /** Check that the trees match and return the contents from the pattern holes. @@ -56,7 +74,18 @@ object Matcher { * @param `the[Env]` Set of tuples containing pairs of symbols (s, p) where s defines a symbol in `scrutinee` which corresponds to symbol p in `pattern`. * @return `None` if it did not match or `Some(tup: Tuple)` if it matched where `tup` contains the contents of the holes. */ - def (scrutinee: Tree) =#= (pattern: Tree) given Env: Matching = { + def (scrutinee0: Tree) =#= (pattern0: Tree) given Context, Env: Matching = { + + /** Normalieze the tree */ + def normalize(tree: Tree): Tree = tree match { + case Block(Nil, expr) => normalize(expr) + case Block(stats1, Block(stats2, expr)) => normalize(Block(stats1 ::: stats2, expr)) + case Inlined(_, Nil, expr) => normalize(expr) + case _ => tree + } + + val scrutinee = normalize(scrutinee0) + val pattern = normalize(pattern0) /** Check that both are `val` or both are `lazy val` or both are `var` **/ def checkValFlags(): Boolean = { @@ -69,23 +98,7 @@ object Matcher { def bindingMatch(sym: Symbol) = matched(new Bind(sym.name, sym)) - def hasBindTypeAnnotation(tpt: TypeTree): Boolean = tpt match { - case Annotated(tpt2, Apply(Select(New(TypeIdent("patternBindHole")), ""), Nil)) => true - case Annotated(tpt2, _) => hasBindTypeAnnotation(tpt2) - case _ => false - } - - def hasBindAnnotation(sym: Symbol) = - sym.annots.exists { case Apply(Select(New(TypeIdent("patternBindHole")),""),List()) => true; case _ => true } - - /** Normalieze the tree */ - def normalize(tree: Tree): Tree = tree match { - case Block(Nil, expr) => normalize(expr) - case Inlined(_, Nil, expr) => normalize(expr) - case _ => tree - } - - (normalize(scrutinee), normalize(pattern)) match { + (scrutinee, pattern) match { // Match a scala.internal.Quoted.patternHole typed as a repeated argument and return the scrutinee tree case (IsTerm(scrutinee @ Typed(s, tpt1)), Typed(TypeApply(patternHole, tpt :: Nil), tpt2)) @@ -110,6 +123,9 @@ object Matcher { case (Typed(expr1, tpt1), Typed(expr2, tpt2)) => expr1 =#= expr2 && tpt1 =#= tpt2 + case (scrutinee, Typed(expr2, _)) => + scrutinee =#= expr2 + case (Ident(_), Ident(_)) if scrutinee.symbol == pattern.symbol || the[Env].apply((scrutinee.symbol, pattern.symbol)) => matched @@ -125,11 +141,20 @@ object Matcher { case (TypeApply(fn1, args1), TypeApply(fn2, args2)) if fn1.symbol == fn2.symbol => fn1 =#= fn2 && args1 =##= args2 - case (Block(stats1, expr1), Block(stats2, expr2)) => - withEnv(the[Env] ++ stats1.map(_.symbol).zip(stats2.map(_.symbol))) { - stats1 =##= stats2 && expr1 =#= expr2 + case (Block(stats1, expr1), Block(binding :: stats2, expr2)) if isTypeBinding(binding) => + qctx.tasty.kernel.Context_GADT_addToConstraint(the[Context])(binding.symbol :: Nil) + matched(new SymBinding(binding.symbol)) && Block(stats1, expr1) =#= Block(stats2, expr2) + + case (Block(stat1 :: stats1, expr1), Block(stat2 :: stats2, expr2)) => + withEnv(the[Env] + (stat1.symbol -> stat2.symbol)) { + stat1 =#= stat2 && Block(stats1, expr1) =#= Block(stats2, expr2) } + case (scrutinee, Block(typeBindings, expr2)) if typeBindings.forall(isTypeBinding) => + val bindingSymbols = typeBindings.map(_.symbol) + qctx.tasty.kernel.Context_GADT_addToConstraint(the[Context])(bindingSymbols) + bindingSymbols.foldRight(scrutinee =#= expr2)((x, acc) => matched(new SymBinding(x)) && acc) + case (If(cond1, thenp1, elsep1), If(cond2, thenp2, elsep2)) => cond1 =#= cond2 && thenp1 =#= thenp2 && elsep1 =#= elsep2 @@ -142,9 +167,6 @@ object Matcher { case (While(cond1, body1), While(cond2, body2)) => cond1 =#= cond2 && body1 =#= body2 - case (NamedArg(name1, expr1), NamedArg(name2, expr2)) if name1 == name2 => - expr1 =#= expr2 - case (New(tpt1), New(tpt2)) => tpt1 =#= tpt2 @@ -157,10 +179,7 @@ object Matcher { case (Repeated(elems1, _), Repeated(elems2, _)) if elems1.size == elems2.size => elems1 =##= elems2 - case (IsTypeTree(scrutinee @ TypeIdent(_)), IsTypeTree(pattern @ TypeIdent(_))) if scrutinee.symbol == pattern.symbol => - matched - - case (IsInferred(scrutinee), IsInferred(pattern)) if scrutinee.tpe <:< pattern.tpe => + case (IsTypeTree(scrutinee), IsTypeTree(pattern)) if scrutinee.tpe <:< pattern.tpe => matched case (Applied(tycon1, args1), Applied(tycon2, args2)) => @@ -171,7 +190,7 @@ object Matcher { if (hasBindAnnotation(pattern.symbol) || hasBindTypeAnnotation(tpt2)) bindingMatch(scrutinee.symbol) else matched def rhsEnv = the[Env] + (scrutinee.symbol -> pattern.symbol) - bindMatch && tpt1 =#= tpt2 && (treeOptMatches(rhs1, rhs2) given rhsEnv) + bindMatch && tpt1 =#= tpt2 && (treeOptMatches(rhs1, rhs2) given (the[Context], rhsEnv)) case (DefDef(_, typeParams1, paramss1, tpt1, Some(rhs1)), DefDef(_, typeParams2, paramss2, tpt2, Some(rhs2))) => val bindMatch = @@ -227,7 +246,7 @@ object Matcher { } } - def treeOptMatches(scrutinee: Option[Tree], pattern: Option[Tree]) given Env: Matching = { + def treeOptMatches(scrutinee: Option[Tree], pattern: Option[Tree]) given Context, Env: Matching = { (scrutinee, pattern) match { case (Some(x), Some(y)) => x =#= y case (None, None) => matched @@ -235,7 +254,7 @@ object Matcher { } } - def caseMatches(scrutinee: CaseDef, pattern: CaseDef) given Env: Matching = { + def caseMatches(scrutinee: CaseDef, pattern: CaseDef) given Context, Env: Matching = { val (caseEnv, patternMatch) = scrutinee.pattern =%= pattern.pattern withEnv(caseEnv) { patternMatch && @@ -254,7 +273,7 @@ object Matcher { * @return The new environment containing the bindings defined in this pattern tuppled with * `None` if it did not match or `Some(tup: Tuple)` if it matched where `tup` contains the contents of the holes. */ - def (scrutinee: Pattern) =%= (pattern: Pattern) given Env: (Env, Matching) = (scrutinee, pattern) match { + def (scrutinee: Pattern) =%= (pattern: Pattern) given Context, Env: (Env, Matching) = (scrutinee, pattern) match { case (Pattern.Value(v1), Pattern.Unapply(TypeApply(Select(patternHole @ Ident("patternHole"), "unapply"), List(tpt)), Nil, Nil)) if patternHole.symbol.owner.fullName == "scala.runtime.quoted.Matcher$" => (the[Env], matched(v1.seal)) @@ -264,7 +283,7 @@ object Matcher { case (Pattern.Bind(name1, body1), Pattern.Bind(name2, body2)) => val bindEnv = the[Env] + (scrutinee.symbol -> pattern.symbol) - (body1 =%= body2) given bindEnv + (body1 =%= body2) given (the[Context], bindEnv) case (Pattern.Unapply(fun1, implicits1, patterns1), Pattern.Unapply(fun2, implicits2, patterns2)) => val (patEnv, patternsMatch) = foldPatterns(patterns1, patterns2) @@ -300,16 +319,39 @@ object Matcher { (the[Env], notMatched) } - def foldPatterns(patterns1: List[Pattern], patterns2: List[Pattern]) given Env: (Env, Matching) = { + def foldPatterns(patterns1: List[Pattern], patterns2: List[Pattern]) given Context, Env: (Env, Matching) = { if (patterns1.size != patterns2.size) (the[Env], notMatched) else patterns1.zip(patterns2).foldLeft((the[Env], matched)) { (acc, x) => - val (env, res) = (x._1 =%= x._2) given acc._1 + val (env, res) = (x._1 =%= x._2) given (the[Context], acc._1) (env, acc._2 && res) } } + def isTypeBinding(tree: Tree): Boolean = tree match { + case IsTypeDef(tree) => hasBindAnnotation(tree.symbol) + case _ => false + } + implicit val env: Env = Set.empty - (scrutineeExpr.unseal =#= patternExpr.unseal).asOptionOfTuple.asInstanceOf[Option[Tup]] + + val res = { + if (hasTypeSplices) { + implicit val ctx: Context = qctx.tasty.kernel.Context_GADT_setFreshGADTBounds(rootContext) + val matchings = scrutineeExpr.unseal.underlyingArgument =#= patternExpr.unseal.underlyingArgument + // After matching and doing all subtype check, we have to aproximate all the type bindings + // that we have found and seal them in a quoted.Type + matchings.asOptionOfTuple.map { tup => + Tuple.fromArray(tup.toArray.map { // TODO improve performace + case x: SymBinding => kernel.Context_GADT_approximation(the[Context])(x.sym, true).seal + case x => x + }) + } + } + else { + scrutineeExpr.unseal.underlyingArgument =#= patternExpr.unseal.underlyingArgument + } + } + res.asInstanceOf[Option[Tup]] } /** Result of matching a part of an expression */ diff --git a/library/src-non-bootstrapped/scala/internal/quoted/Matcher.scala b/library/src-non-bootstrapped/scala/internal/quoted/Matcher.scala new file mode 100644 index 000000000000..d00c5f6e8cdb --- /dev/null +++ b/library/src-non-bootstrapped/scala/internal/quoted/Matcher.scala @@ -0,0 +1,11 @@ +package scala.internal.quoted + +import scala.quoted._ +import scala.tasty._ + +object Matcher { + + def unapply[Tup <: Tuple](scrutineeExpr: Expr[_])(implicit patternExpr: Expr[_], reflection: Reflection): Option[Tup] = + throw new Exception("running on non bootstrapped library") + +} diff --git a/library/src/scala/tasty/reflect/Kernel.scala b/library/src/scala/tasty/reflect/Kernel.scala index e8ee85bd983e..cd9df6e059a4 100644 --- a/library/src/scala/tasty/reflect/Kernel.scala +++ b/library/src/scala/tasty/reflect/Kernel.scala @@ -140,6 +140,10 @@ trait Kernel { /** Returns the source file being compiled. The path is relative to the current working directory. */ def Context_source(self: Context): java.nio.file.Path + def Context_GADT_setFreshGADTBounds(self: Context): Context + def Context_GADT_addToConstraint(self: Context)(syms: List[Symbol]): Boolean + def Context_GADT_approximation(self: Context)(sym: Symbol, fromBelow: Boolean): Type + // // REPORTING // diff --git a/tests/neg/quotedPatterns-5.scala b/tests/neg/quotedPatterns-5.scala new file mode 100644 index 000000000000..4ef688259573 --- /dev/null +++ b/tests/neg/quotedPatterns-5.scala @@ -0,0 +1,12 @@ +object Test { + def test(x: quoted.Expr[Int]) given tasty.Reflection = x match { + case '{ type $t; poly[$t]($x); 4 } => ??? // error: duplicate pattern variable: $t + case '{ type `$t`; poly[`$t`]($x); 4 } => + val tt: quoted.Type[_] = t // error + ??? + case _ => + } + + def poly[T](x: T): Unit = () + +} diff --git a/tests/pos/quote-matching-implicit-types.scala b/tests/pos/quote-matching-implicit-types.scala new file mode 100644 index 000000000000..634447c31902 --- /dev/null +++ b/tests/pos/quote-matching-implicit-types.scala @@ -0,0 +1,16 @@ +import scala.quoted._ + +object Foo { + + def f(e: Expr[Any]) given tasty.Reflection : Unit = e match { + case '{ foo[$t]($x) } => bar(x) + case '{ foo[$t]($x) } if bar(x) => () + case '{ foo[$t]($x) } => '{ foo($x) } + case '{ foo[$t]($x) } if bar[Any]('{ foo($x) }) => () + } + + def foo[T](t: T): Unit = () + + def bar[T: Type](t: Expr[T]): Boolean = true + +} diff --git a/tests/pos/quotedPatterns.scala b/tests/pos/quotedPatterns.scala index 238a173bd9d1..1ac1c7487aa1 100644 --- a/tests/pos/quotedPatterns.scala +++ b/tests/pos/quotedPatterns.scala @@ -30,6 +30,18 @@ object Test { case '{ def $ff[T](i: T): Int = $z; 2 } => val a: quoted.matching.Bind[[T] =>> T => Int] = ff z + case '{ poly[$t]($x); 4 } => ??? + case '{ poly[${Foo(t)}]($x); 4 } => ??? + case '{ type $X; poly[`$X`]($x); 4 } => ??? + case '{ type $t; poly[${Foo(x: quoted.Type[`$t`])}]($x); 4 } => ??? + case '{ type $T; val x: `$T` = $a; val y: `$T` = x; 1 } => ??? + case '{ type $t <: AnyRef; val x: `$t` = $a; val y: `$t` = x; 1 } => ??? case _ => '{1} } + + def poly[T](x: T): Unit = () + + object Foo { + def unapply[T](arg: quoted.Type[T]): Option[quoted.Type[T]] = Some(arg) + } } \ No newline at end of file diff --git a/tests/run-macros/quote-matcher-runtime.check b/tests/run-macros/quote-matcher-runtime.check index a630c9d5285a..e772ae7edb1c 100644 --- a/tests/run-macros/quote-matcher-runtime.check +++ b/tests/run-macros/quote-matcher-runtime.check @@ -16,7 +16,7 @@ Result: Some(List()) Scrutinee: 1 Pattern: (1: scala.Int) -Result: None +Result: Some(List()) Scrutinee: 3 Pattern: scala.internal.Quoted.patternHole[scala.Int] @@ -714,3 +714,118 @@ Pattern: try scala.internal.Quoted.patternHole[scala.Int] finally { } Result: Some(List(Expr(1), Expr(2))) +Scrutinee: scala.List.apply[scala.Int]((1, 2, 3: scala.[scala.Int])).foreach[scala.Unit](((x: scala.Int) => scala.Predef.println(x))) +Pattern: { + @scala.internal.Quoted.patternBindHole type T + scala.internal.Quoted.patternHole[scala.List[scala.Int]].foreach[T](scala.internal.Quoted.patternHole[scala.Function1[scala.Int, T]]) +} +Result: Some(List(Type(scala.Unit), Expr(scala.List.apply[scala.Int]((1, 2, 3: scala.[scala.Int]))), Expr(((x: scala.Int) => scala.Predef.println(x))))) + +Scrutinee: scala.List.apply[scala.Int]((1, 2, 3: scala.[scala.Int])).foreach[scala.Unit](((x: scala.Int) => scala.Predef.println(x))) +Pattern: { + @scala.internal.Quoted.patternBindHole type T = scala.Unit + scala.internal.Quoted.patternHole[scala.List[scala.Int]].foreach[T](scala.internal.Quoted.patternHole[scala.Function1[scala.Int, T]]) +} +Result: Some(List(Type(scala.Unit), Expr(scala.List.apply[scala.Int]((1, 2, 3: scala.[scala.Int]))), Expr(((x: scala.Int) => scala.Predef.println(x))))) + +Scrutinee: scala.List.apply[scala.Int]((1, 2, 3: scala.[scala.Int])).foreach[scala.Unit](((x: scala.Int) => scala.Predef.println(x))) +Pattern: { + @scala.internal.Quoted.patternBindHole type T <: scala.Predef.String + scala.internal.Quoted.patternHole[scala.List[scala.Int]].foreach[T](scala.internal.Quoted.patternHole[scala.Function1[scala.Int, T]]) +} +Result: None + +Scrutinee: { + val a: scala.Int = 4 + val b: scala.Int = 4 + () +} +Pattern: { + @scala.internal.Quoted.patternBindHole type T + val a: T = scala.internal.Quoted.patternHole[T] + val b: T = scala.internal.Quoted.patternHole[T] + () +} +Result: Some(List(Type(scala.Int), Expr(4), Expr(4))) + +Scrutinee: { + val a: scala.Int = 4 + val b: scala.Int = 5 + () +} +Pattern: { + @scala.internal.Quoted.patternBindHole type T + val a: T = scala.internal.Quoted.patternHole[T] + val b: T = scala.internal.Quoted.patternHole[T] + () +} +Result: Some(List(Type(scala.Int), Expr(4), Expr(5))) + +Scrutinee: { + val a: scala.Int = 4 + val b: scala.Predef.String = "x" + () +} +Pattern: { + @scala.internal.Quoted.patternBindHole type T + val a: T = scala.internal.Quoted.patternHole[T] + val b: T = scala.internal.Quoted.patternHole[T] + () +} +Result: Some(List(Type(scala.Int | java.lang.String), Expr(4), Expr("x"))) + +Scrutinee: { + val a: scala.Int = 4 + val b: scala.Predef.String = "x" + () +} +Pattern: { + @scala.internal.Quoted.patternBindHole type T <: scala.Int + val a: T = scala.internal.Quoted.patternHole[T] + val b: T = scala.internal.Quoted.patternHole[T] + () +} +Result: None + +Scrutinee: scala.List.apply[scala.Int]((1, 2, 3: scala.[scala.Int])).map[scala.Double, scala.collection.immutable.List[scala.Double]](((x: scala.Int) => x.toDouble./(2)))(scala.collection.immutable.List.canBuildFrom[scala.Double]).map[java.lang.String, scala.collection.immutable.List[java.lang.String]](((y: scala.Double) => y.toString()))(scala.collection.immutable.List.canBuildFrom[java.lang.String]) +Pattern: { + @scala.internal.Quoted.patternBindHole type T + @scala.internal.Quoted.patternBindHole type U + @scala.internal.Quoted.patternBindHole type V + + (scala.internal.Quoted.patternHole[scala.List[T]].map[U, scala.collection.immutable.List[U]](scala.internal.Quoted.patternHole[scala.Function1[T, U]])(scala.collection.immutable.List.canBuildFrom[U]).map[V, scala.collection.immutable.List[V]](scala.internal.Quoted.patternHole[scala.Function1[U, V]])(scala.collection.immutable.List.canBuildFrom[V]): scala.collection.immutable.List[scala.Any]) +} +Result: Some(List(Type(scala.Int), Type(scala.Double), Type(java.lang.String), Expr(scala.List.apply[scala.Int]((1, 2, 3: scala.[scala.Int]))), Expr(((x: scala.Int) => x.toDouble./(2))), Expr(((y: scala.Double) => y.toString())))) + +Scrutinee: ((x: scala.Int) => x) +Pattern: { + @scala.internal.Quoted.patternBindHole type T + + (scala.internal.Quoted.patternHole[scala.Function1[T, T]]: scala.Function1[scala.Nothing, scala.Any]) +} +Result: Some(List(Type(scala.Int), Expr(((x: scala.Int) => x)))) + +Scrutinee: ((x: scala.Int) => x.toString()) +Pattern: { + @scala.internal.Quoted.patternBindHole type T + + (scala.internal.Quoted.patternHole[scala.Function1[T, T]]: scala.Function1[scala.Nothing, scala.Any]) +} +Result: None + +Scrutinee: ((x: scala.Any) => scala.Predef.???) +Pattern: { + @scala.internal.Quoted.patternBindHole type T + + (scala.internal.Quoted.patternHole[scala.Function1[T, T]]: scala.Function1[scala.Nothing, scala.Any]) +} +Result: Some(List(Type(scala.Nothing), Expr(((x: scala.Any) => scala.Predef.???)))) + +Scrutinee: ((x: scala.Nothing) => (1: scala.Any)) +Pattern: { + @scala.internal.Quoted.patternBindHole type T + + (scala.internal.Quoted.patternHole[scala.Function1[T, T]]: scala.Function1[scala.Nothing, scala.Any]) +} +Result: None + diff --git a/tests/run-macros/quote-matcher-runtime/quoted_1.scala b/tests/run-macros/quote-matcher-runtime/quoted_1.scala index e23939bcdecb..1f1c324cc5ec 100644 --- a/tests/run-macros/quote-matcher-runtime/quoted_1.scala +++ b/tests/run-macros/quote-matcher-runtime/quoted_1.scala @@ -10,7 +10,7 @@ object Macros { private def impl[A, B](a: Expr[A], b: Expr[B]) given (qctx: QuoteContext): Expr[Unit] = { import qctx.tasty.{Bind => _, _} - val res = scala.internal.quoted.Matcher.unapply[Tuple](a)(b, qctx).map { tup => + val res = scala.internal.quoted.Matcher.unapply[Tuple, Tuple](a)(b, true, qctx).map { tup => tup.toArray.toList.map { case r: Expr[_] => s"Expr(${r.unseal.show})" diff --git a/tests/run-macros/quote-matcher-runtime/quoted_2.scala b/tests/run-macros/quote-matcher-runtime/quoted_2.scala index a88342a79657..800819c3ad01 100644 --- a/tests/run-macros/quote-matcher-runtime/quoted_2.scala +++ b/tests/run-macros/quote-matcher-runtime/quoted_2.scala @@ -3,7 +3,7 @@ import Macros._ import scala.internal.quoted.Matcher._ -import scala.internal.Quoted.{patternHole, patternBindHole} +import scala.internal.Quoted._ object Test { @@ -134,6 +134,18 @@ object Test { matches(try 1 finally 2, try 1 finally 2) matches(try 1 catch { case _ => 2 }, try patternHole[Int] catch { case _ => patternHole[Int] }) matches(try 1 finally 2, try patternHole[Int] finally patternHole[Int]) + matches(List(1, 2, 3).foreach(x => println(x)), { @patternBindHole type T; patternHole[List[Int]].foreach[T](patternHole[Int => T]) }) + matches(List(1, 2, 3).foreach(x => println(x)), { @patternBindHole type T = Unit; patternHole[List[Int]].foreach[T](patternHole[Int => T]) }) + matches(List(1, 2, 3).foreach(x => println(x)), { @patternBindHole type T <: String; patternHole[List[Int]].foreach[T](patternHole[Int => T]) }) + matches({ val a: Int = 4; val b: Int = 4 }, { @patternBindHole type T; { val a: T = patternHole[T]; val b: T = patternHole[T] } }) + matches({ val a: Int = 4; val b: Int = 5 }, { @patternBindHole type T; { val a: T = patternHole[T]; val b: T = patternHole[T] } }) + matches({ val a: Int = 4; val b: String = "x" }, { @patternBindHole type T; { val a: T = patternHole[T]; val b: T = patternHole[T] } }) + matches({ val a: Int = 4; val b: String = "x" }, { @patternBindHole type T <: Int; { val a: T = patternHole[T]; val b: T = patternHole[T] } }) + matches(List(1, 2, 3).map(x => x.toDouble / 2).map(y => y.toString), { @patternBindHole type T; @patternBindHole type U; @patternBindHole type V; patternHole[List[T]].map(patternHole[T => U]).map(patternHole[U => V]) }) + matches((x: Int) => x, { @patternBindHole type T; patternHole[T => T] }) + matches((x: Int) => x.toString, { @patternBindHole type T; patternHole[T => T] }) + matches((x: Any) => ???, { @patternBindHole type T; patternHole[T => T] }) + matches((x: Nothing) => (1 : Any), { @patternBindHole type T; patternHole[T => T] }) } } diff --git a/tests/run-macros/quote-matching-optimize-1.check b/tests/run-macros/quote-matching-optimize-1.check new file mode 100644 index 000000000000..5bf24a93d543 --- /dev/null +++ b/tests/run-macros/quote-matching-optimize-1.check @@ -0,0 +1,32 @@ +Original: ls.filter(((x: scala.Int) => x.<(3))).filter(((x: scala.Int) => x.>(1))) +Optimized: ls.filter(((x: scala.Int) => x.<(3).&&(x.>(1)))) +Result: List(2) + +Original: ls2.filter(((x: scala.Char) => x.<('c'))).filter(((x: scala.Char) => x.>('a'))) +Optimized: ls2.filter(((x: scala.Char) => x.<('c').&&(x.>('a')))) +Result: List(b) + +Original: ls.filter(((x: scala.Int) => x.<(3))).filter(((x: scala.Int) => x.>(1))).filter(((x: scala.Int) => x.==(2))) +Optimized: ls.filter(((x: scala.Int) => x.<(3).&&(x.>(1).&&(x.==(2))))) +Result: List(2) + +1 +2 +Original: ls.filter(((x: scala.Int) => x.<(3))).foreach[scala.Unit](((x: scala.Int) => scala.Predef.println(x))) +Optimized: ls.foreach[scala.Unit](((x: scala.Int) => if (x.<(3)) scala.Predef.println(x) else ())) +Result: () + +Original: scala.List.apply[scala.Int]((1, 2, 3: scala.[scala.Int])).map[scala.Int, scala.collection.immutable.List[scala.Int]](((a: scala.Int) => a.*(2)))(scala.collection.immutable.List.canBuildFrom[scala.Int]).map[java.lang.String, scala.collection.immutable.List[java.lang.String]](((b: scala.Int) => b.toString()))(scala.collection.immutable.List.canBuildFrom[java.lang.String]) +Optimized: scala.List.apply[scala.Int]((1, 2, 3: scala.[scala.Int])).map[java.lang.String, scala.collection.immutable.List[java.lang.String]](((x: scala.Int) => { + val x$5: scala.Int = x.*(2) + x$5.toString() +}))(scala.collection.immutable.List.canBuildFrom[java.lang.String]) +Result: List(2, 4, 6) + +Original: scala.List.apply[scala.Int]((55, 67, 87: scala.[scala.Int])).map[scala.Char, scala.collection.immutable.List[scala.Char]](((a: scala.Int) => a.toChar))(scala.collection.immutable.List.canBuildFrom[scala.Char]).map[java.lang.String, scala.collection.immutable.List[java.lang.String]](((b: scala.Char) => b.toString()))(scala.collection.immutable.List.canBuildFrom[java.lang.String]) +Optimized: scala.List.apply[scala.Int]((55, 67, 87: scala.[scala.Int])).map[java.lang.String, scala.collection.immutable.List[java.lang.String]](((x: scala.Int) => { + val x$10: scala.Char = x.toChar + x$10.toString() +}))(scala.collection.immutable.List.canBuildFrom[java.lang.String]) +Result: List(7, C, W) + diff --git a/tests/run-macros/quote-matching-optimize-1/Macro_1.scala b/tests/run-macros/quote-matching-optimize-1/Macro_1.scala new file mode 100644 index 000000000000..0c3d8529da29 --- /dev/null +++ b/tests/run-macros/quote-matching-optimize-1/Macro_1.scala @@ -0,0 +1,40 @@ +import scala.quoted._ +import scala.quoted.autolift._ + +import scala.tasty.Reflection + +object Macro { + + inline def optimize[T](x: => T): Any = ${ Macro.impl('x) } + + def impl[T: Type](x: Expr[T]) given Reflection: Expr[Any] = { + val reflect = the[Reflection] + import reflect._ // TODO remove + + def optimize(x: Expr[Any]): Expr[Any] = x match { + case '{ type $t; ($ls: List[`$t`]).filter($f).filter($g) } => + optimize('{ $ls.filter(x => ${f('x)} && ${g('x)}) }) + + case '{ type $t; type $u; type $v; ($ls: List[`$t`]).map[`$u`, List[`$u`]]($f).map[`$v`, List[`$v`]]($g) } => + optimize('{ $ls.map(x => ${g(f('x))}) }) + + case '{ type $t; ($ls: List[`$t`]).filter($f).foreach[Unit]($g) } => + optimize('{ $ls.foreach(x => if (${f('x)}) ${g('x)} else ()) }) + + case _ => x + } + + val res = optimize(x) + + '{ + val result = $res + val originalCode = ${x.show} + val optimizeCode = ${res.show} + println("Original: " + originalCode) + println("Optimized: " + optimizeCode) + println("Result: " + result) + println() + } + } + +} diff --git a/tests/run-macros/quote-matching-optimize-1/Test_2.scala b/tests/run-macros/quote-matching-optimize-1/Test_2.scala new file mode 100644 index 000000000000..de052ddc8445 --- /dev/null +++ b/tests/run-macros/quote-matching-optimize-1/Test_2.scala @@ -0,0 +1,15 @@ +object Test { + import Macro._ + + def main(args: Array[String]): Unit = { + val ls = List(1, 2, 3) + val ls2 = List('a', 'b', 'c') + optimize(ls.filter(x => x < 3).filter(x => x > 1)) + optimize(ls2.filter(x => x < 'c').filter(x => x > 'a')) + optimize(ls.filter(x => x < 3).filter(x => x > 1).filter(x => x == 2)) + optimize(ls.filter(x => x < 3).foreach(x => println(x))) + optimize(List(1, 2, 3).map(a => a * 2).map(b => b.toString)) + optimize(List(55, 67, 87).map(a => a.toChar).map(b => b.toString)) + } + +} diff --git a/tests/run-macros/quote-matching-optimize-2.check b/tests/run-macros/quote-matching-optimize-2.check new file mode 100644 index 000000000000..759f7db1bfaf --- /dev/null +++ b/tests/run-macros/quote-matching-optimize-2.check @@ -0,0 +1,32 @@ +Original: ls.filter(((x: scala.Int) => x.<(3))).filter(((x: scala.Int) => x.>(1))) +Optimized: ls.filter(((x: scala.Int) => x.<(3).&&(x.>(1)))) +Result: List(2) + +Original: ls2.filter(((x: scala.Char) => x.<('c'))).filter(((x: scala.Char) => x.>('a'))) +Optimized: ls2.filter(((x: scala.Char) => x.<('c').&&(x.>('a')))) +Result: List(b) + +Original: ls.filter(((x: scala.Int) => x.<(3))).filter(((x: scala.Int) => x.>(1))).filter(((x: scala.Int) => x.==(2))) +Optimized: ls.filter(((x: scala.Int) => x.<(3).&&(x.>(1).&&(x.==(2))))) +Result: List(2) + +1 +2 +Original: ls.filter(((x: scala.Int) => x.<(3))).foreach[scala.Unit](((x: scala.Int) => scala.Predef.println(x))) +Optimized: ls.foreach[scala.Any](((x: scala.Int) => if (x.<(3)) scala.Predef.println(x) else ())) +Result: () + +Original: ls.map[scala.Int, scala.collection.immutable.List[scala.Int]](((a: scala.Int) => a.*(2)))(scala.collection.immutable.List.canBuildFrom[scala.Int]).map[java.lang.String, scala.collection.immutable.List[java.lang.String]](((b: scala.Int) => b.toString()))(scala.collection.immutable.List.canBuildFrom[java.lang.String]) +Optimized: ls.map[java.lang.String, scala.collection.immutable.List[java.lang.String]](((x: scala.Int) => { + val x$5: scala.Int = x.*(2) + x$5.toString() +}))(scala.collection.immutable.List.canBuildFrom[java.lang.String]) +Result: List(2, 4, 6) + +Original: ls.map[scala.Char, scala.collection.immutable.List[scala.Char]](((a: scala.Int) => a.toChar))(scala.collection.immutable.List.canBuildFrom[scala.Char]).map[java.lang.String, scala.collection.immutable.List[java.lang.String]](((b: scala.Char) => b.toString()))(scala.collection.immutable.List.canBuildFrom[java.lang.String]) +Optimized: ls.map[java.lang.String, scala.collection.immutable.List[java.lang.String]](((x: scala.Int) => { + val x$10: scala.Char = x.toChar + x$10.toString() +}))(scala.collection.immutable.List.canBuildFrom[java.lang.String]) +Result: List(, , ) + diff --git a/tests/run-macros/quote-matching-optimize-2/Macro_1.scala b/tests/run-macros/quote-matching-optimize-2/Macro_1.scala new file mode 100644 index 000000000000..c6a31501ad50 --- /dev/null +++ b/tests/run-macros/quote-matching-optimize-2/Macro_1.scala @@ -0,0 +1,40 @@ +import scala.quoted._ +import scala.quoted.autolift._ + +import scala.tasty.Reflection + +object Macro { + + inline def optimize[T](x: => T): Any = ${ Macro.impl('x) } + + def impl[T: Type](x: Expr[T]) given Reflection: Expr[Any] = { + val reflect = the[Reflection] + import reflect._ // TODO remove + + def optimize(x: Expr[Any]): Expr[Any] = x match { + case '{ ($ls: List[$t]).filter($f).filter($g) } => + optimize('{ $ls.filter(x => ${f('x)} && ${g('x)}) }) + + case '{ type $u; type $v; ($ls: List[$t]).map[`$u`, List[`$u`]]($f).map[`$v`, List[`$v`]]($g) } => + optimize('{ $ls.map(x => ${g(f('x))}) }) + + case '{ ($ls: List[$t]).filter($f).foreach[$u]($g) } => + optimize('{ $ls.foreach[Any](x => if (${f('x)}) ${g('x)} else ()) }) + + case _ => x + } + + val res = optimize(x) + + '{ + val result = $res + val originalCode = ${x.show} + val optimizeCode = ${res.show} + println("Original: " + originalCode) + println("Optimized: " + optimizeCode) + println("Result: " + result) + println() + } + } + +} diff --git a/tests/run-macros/quote-matching-optimize-2/Test_2.scala b/tests/run-macros/quote-matching-optimize-2/Test_2.scala new file mode 100644 index 000000000000..bab33d5ba202 --- /dev/null +++ b/tests/run-macros/quote-matching-optimize-2/Test_2.scala @@ -0,0 +1,15 @@ +object Test { + import Macro._ + + def main(args: Array[String]): Unit = { + val ls = List(1, 2, 3) + val ls2 = List('a', 'b', 'c') + optimize(ls.filter(x => x < 3).filter(x => x > 1)) + optimize(ls2.filter(x => x < 'c').filter(x => x > 'a')) + optimize(ls.filter(x => x < 3).filter(x => x > 1).filter(x => x == 2)) + optimize(ls.filter(x => x < 3).foreach(x => println(x))) + optimize(ls.map(a => a * 2).map(b => b.toString)) + optimize(ls.map(a => a.toChar).map(b => b.toString)) + } + +} diff --git a/tests/run-macros/quote-matching-optimize-3.check b/tests/run-macros/quote-matching-optimize-3.check new file mode 100644 index 000000000000..257667305b70 --- /dev/null +++ b/tests/run-macros/quote-matching-optimize-3.check @@ -0,0 +1,26 @@ +Original: ls.filter(((x: scala.Int) => x.<(3))).filter(((x: scala.Int) => x.>(1))) +Optimized: ls.filter(((x: scala.Int) => ((x: scala.Int) => x.<(3)).apply(x).&&(((x: scala.Int) => x.>(1)).apply(x)))) +Result: List(2) + +Original: ls2.filter(((x: scala.Char) => x.<('c'))).filter(((x: scala.Char) => x.>('a'))) +Optimized: ls2.filter(((x: scala.Char) => ((x: scala.Char) => x.<('c')).apply(x).&&(((x: scala.Char) => x.>('a')).apply(x)))) +Result: List(b) + +Original: ls.filter(((x: scala.Int) => x.<(3))).filter(((x: scala.Int) => x.>(1))).filter(((x: scala.Int) => x.==(2))) +Optimized: ls.filter(((x: scala.Int) => ((x: scala.Int) => x.<(3)).apply(x).&&(((x: scala.Int) => ((x: scala.Int) => x.>(1)).apply(x).&&(((x: scala.Int) => x.==(2)).apply(x))).apply(x)))) +Result: List(2) + +1 +2 +Original: ls.filter(((x: scala.Int) => x.<(3))).foreach[scala.Unit](((x: scala.Int) => scala.Predef.println(x))) +Optimized: ls.foreach[scala.Any](((x: scala.Int) => if (((x: scala.Int) => x.<(3)).apply(x)) ((x: scala.Int) => scala.Predef.println(x)).apply(x) else ())) +Result: () + +Original: ls.map[scala.Long, scala.collection.immutable.List[scala.Long]](((a: scala.Int) => a.toLong))(scala.collection.immutable.List.canBuildFrom[scala.Long]).map[java.lang.String, scala.collection.immutable.List[java.lang.String]](((b: scala.Long) => b.toString()))(scala.collection.immutable.List.canBuildFrom[java.lang.String]) +Optimized: ls.map[java.lang.String, scala.collection.immutable.List[java.lang.String]](((x: scala.Int) => ((b: scala.Long) => b.toString()).apply(((a: scala.Int) => a.toLong).apply(x))))(scala.collection.immutable.List.canBuildFrom[java.lang.String]) +Result: List(1, 2, 3) + +Original: ls.map[scala.Char, scala.collection.immutable.List[scala.Char]](((a: scala.Int) => a.toChar))(scala.collection.immutable.List.canBuildFrom[scala.Char]).map[java.lang.String, scala.collection.immutable.List[java.lang.String]](((b: scala.Char) => b.toString()))(scala.collection.immutable.List.canBuildFrom[java.lang.String]) +Optimized: ls.map[java.lang.String, scala.collection.immutable.List[java.lang.String]](((x: scala.Int) => ((b: scala.Char) => b.toString()).apply(((a: scala.Int) => a.toChar).apply(x))))(scala.collection.immutable.List.canBuildFrom[java.lang.String]) +Result: List(, , ) + diff --git a/tests/run-macros/quote-matching-optimize-3/Macro_1.scala b/tests/run-macros/quote-matching-optimize-3/Macro_1.scala new file mode 100644 index 000000000000..7b25ff5a8270 --- /dev/null +++ b/tests/run-macros/quote-matching-optimize-3/Macro_1.scala @@ -0,0 +1,40 @@ +import scala.quoted._ +import scala.quoted.autolift._ + +import scala.tasty.Reflection + +object Macro { + + inline def optimize[T](x: => T): Any = ${ Macro.impl('x) } + + def impl[T: Type](x: Expr[T]) given Reflection: Expr[Any] = { + val reflect = the[Reflection] + import reflect._ // TODO remove + + def optimize(x: Expr[Any]): Expr[Any] = x match { + case '{ ($ls: List[$t]).filter($f).filter($g) } => + optimize('{ $ls.filter(x => $f(x) && $g(x)) }) + + case '{ type $uu; type $vv; ($ls: List[$tt]).map[`$uu`, List[`$uu`]]($f).map[String, List[String]]($g) } => + optimize('{ $ls.map(x => $g($f(x))) }) + + case '{ ($ls: List[$t]).filter($f).foreach[$u]($g) } => + optimize('{ $ls.foreach[Any](x => if ($f(x)) $g(x) else ()) }) + + case _ => x + } + + val res = optimize(x) + + '{ + val result = $res + val originalCode = ${x.show} + val optimizeCode = ${res.show} + println("Original: " + originalCode) + println("Optimized: " + optimizeCode) + println("Result: " + result) + println() + } + } + +} diff --git a/tests/run-macros/quote-matching-optimize-3/Test_2.scala b/tests/run-macros/quote-matching-optimize-3/Test_2.scala new file mode 100644 index 000000000000..fc6aa6c12a95 --- /dev/null +++ b/tests/run-macros/quote-matching-optimize-3/Test_2.scala @@ -0,0 +1,15 @@ +object Test { + import Macro._ + + def main(args: Array[String]): Unit = { + val ls = List(1, 2, 3) + val ls2 = List('a', 'b', 'c') + optimize(ls.filter(x => x < 3).filter(x => x > 1)) + optimize(ls2.filter(x => x < 'c').filter(x => x > 'a')) + optimize(ls.filter(x => x < 3).filter(x => x > 1).filter(x => x == 2)) + optimize(ls.filter(x => x < 3).foreach(x => println(x))) + optimize(ls.map(a => a.toLong).map(b => b.toString)) + optimize(ls.map(a => a.toChar).map(b => b.toString)) + } + +} diff --git a/tests/run-with-compiler/quote-matcher-type-bind.check b/tests/run-with-compiler/quote-matcher-type-bind.check new file mode 100644 index 000000000000..26bc59686c2f --- /dev/null +++ b/tests/run-with-compiler/quote-matcher-type-bind.check @@ -0,0 +1,2 @@ +g: 5 +f: abc diff --git a/tests/run-with-compiler/quote-matcher-type-bind/Macro_1.scala b/tests/run-with-compiler/quote-matcher-type-bind/Macro_1.scala new file mode 100644 index 000000000000..229ff01f71d7 --- /dev/null +++ b/tests/run-with-compiler/quote-matcher-type-bind/Macro_1.scala @@ -0,0 +1,30 @@ +import scala.quoted._ +import scala.quoted.matching._ + +import scala.tasty.Reflection + +import scala.internal.quoted.Matcher._ +import scala.internal.Quoted._ + +object Macros { + + inline def swapFandG(x: => Unit): Unit = ${impl('x)} + + private def impl(x: Expr[Unit])(implicit reflect: Reflection): Expr[Unit] = { + x match { + case '{ DSL.f[$t]($x) } => '{ DSL.g[$t]($x) } + case '{ DSL.g[$t]($x) } => '{ DSL.f[$t]($x) } + case _ => x + } + } + +} + +// +// DSL in which the user write the code +// + +object DSL { + def f[T](x: T): Unit = println("f: " + x.toString) + def g[T](x: T): Unit = println("g: " + x.toString) +} diff --git a/tests/run-with-compiler/quote-matcher-type-bind/Test_2.scala b/tests/run-with-compiler/quote-matcher-type-bind/Test_2.scala new file mode 100644 index 000000000000..4c388ce16d47 --- /dev/null +++ b/tests/run-with-compiler/quote-matcher-type-bind/Test_2.scala @@ -0,0 +1,11 @@ +import Macros._ + + +object Test { + + def main(args: Array[String]): Unit = { + swapFandG(DSL.f(5)) + swapFandG(DSL.g("abc")) + } + +}