Skip to content

Commit be28902

Browse files
committed
Synthesize implicits for product and sum mirrors
1 parent 797f5f6 commit be28902

File tree

8 files changed

+137
-10
lines changed

8 files changed

+137
-10
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

+4-2
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,9 @@ class Definitions {
688688
lazy val ModuleSerializationProxyConstructor: TermSymbol =
689689
ModuleSerializationProxyClass.requiredMethod(nme.CONSTRUCTOR, List(ClassType(TypeBounds.empty)))
690690

691-
//lazy val MirrorType: TypeRef = ctx.requiredClassRef("scala.deriving.Mirror")
691+
lazy val MirrorType: TypeRef = ctx.requiredClassRef("scala.deriving.Mirror")
692+
def MirrorClass(implicit ctx: Context): ClassSymbol = MirrorType.symbol.asClass
693+
692694
lazy val Mirror_ProductType: TypeRef = ctx.requiredClassRef("scala.deriving.Mirror.Product")
693695
def Mirror_ProductClass(implicit ctx: Context): ClassSymbol = Mirror_ProductType.symbol.asClass
694696

@@ -709,7 +711,7 @@ class Definitions {
709711
def ShapeCaseClass(implicit ctx: Context): ClassSymbol = ShapeCaseType.symbol.asClass
710712
lazy val ShapeCasesType: TypeRef = ctx.requiredClassRef("scala.compiletime.Shape.Cases")
711713
def ShapeCasesClass(implicit ctx: Context): ClassSymbol = ShapeCasesType.symbol.asClass
712-
lazy val MirrorType: TypeRef = ctx.requiredClassRef("scala.reflect.Mirror")
714+
lazy val ReflectMirrorType: TypeRef = ctx.requiredClassRef("scala.reflect.Mirror")
713715
lazy val GenericClassType: TypeRef = ctx.requiredClassRef("scala.reflect.GenericClass")
714716

715717
lazy val LanguageModuleRef: TermSymbol = ctx.requiredModule("scala.language")

compiler/src/dotty/tools/dotc/core/StdNames.scala

+3
Original file line numberDiff line numberDiff line change
@@ -326,10 +326,13 @@ object StdNames {
326326
val AnnotatedType: N = "AnnotatedType"
327327
val AppliedTypeTree: N = "AppliedTypeTree"
328328
val ArrayAnnotArg: N = "ArrayAnnotArg"
329+
val CaseLabel: N = "CaseLabel"
329330
val CAP: N = "CAP"
330331
val Constant: N = "Constant"
331332
val ConstantType: N = "ConstantType"
332333
val doubleHash: N = "doubleHash"
334+
val ElemLabels: N = "ElemLabels"
335+
val ElemTypes: N = "ElemTypes"
333336
val ExistentialTypeTree: N = "ExistentialTypeTree"
334337
val Flag : N = "Flag"
335338
val floatHash: N = "floatHash"

compiler/src/dotty/tools/dotc/transform/TypeUtils.scala

+3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import TypeErasure.ErasedValueType
77
import Types._
88
import Contexts._
99
import Symbols._
10+
import Names.Name
1011

1112
object TypeUtils {
1213
/** A decorator that provides methods on types
@@ -63,5 +64,7 @@ object TypeUtils {
6364
}
6465
extractAlias(lo)
6566
}
67+
68+
def refinedWith(name: Name, info: Type)(implicit ctx: Context) = RefinedType(self, name, info)
6669
}
6770
}

compiler/src/dotty/tools/dotc/typer/Deriving.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ trait Deriving { this: Typer =>
372372
TypeDef(shapeAlias)
373373
}
374374
val reflectMethod: DefDef = {
375-
val meth = newMethod(nme.reflect, MethodType(clsArg :: Nil, defn.MirrorType)).entered
375+
val meth = newMethod(nme.reflect, MethodType(clsArg :: Nil, defn.ReflectMirrorType)).entered
376376
def rhs(paramRef: Tree)(implicit ctx: Context): Tree = {
377377
def reflectCase(scrut: Tree, idx: Int, elems: List[Type]): Tree = {
378378
val ordinal = Literal(Constant(idx))
@@ -401,7 +401,7 @@ trait Deriving { this: Typer =>
401401
}
402402

403403
val reifyMethod: DefDef = {
404-
val meth = newMethod(nme.reify, MethodType(defn.MirrorType :: Nil, clsArg)).entered
404+
val meth = newMethod(nme.reify, MethodType(defn.ReflectMirrorType :: Nil, clsArg)).entered
405405
def rhs(paramRef: Tree)(implicit ctx: Context): Tree = {
406406
def reifyCase(caseType: Type, elems: List[Type]): Tree = caseType match {
407407
case caseType: TermRef =>

compiler/src/dotty/tools/dotc/typer/Implicits.scala

+101-6
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ import ErrorReporting._
2727
import reporting.diagnostic.Message
2828
import Inferencing.fullyDefinedType
2929
import Trees._
30+
import transform.SymUtils._
31+
import transform.TypeUtils._
3032
import Hashable._
3133
import util.{Property, SourceFile, NoSource}
3234
import config.Config
@@ -812,17 +814,103 @@ trait Implicits { self: Typer =>
812814
EmptyTree
813815
}
814816

817+
lazy val synthesizedProductMirror: SpecialHandler =
818+
(formal: Type, span: Span) => implicit (ctx: Context) => {
819+
formal.member(tpnme.MonoType).info match {
820+
case monoAlias @ TypeAlias(monoType) =>
821+
if (monoType.termSymbol.is(CaseVal)) {
822+
val modul = monoType.termSymbol
823+
val caseLabel = ConstantType(Constant(modul.name.toString))
824+
val mirrorType = defn.Mirror_SingletonType
825+
.refinedWith(tpnme.MonoType, monoAlias)
826+
.refinedWith(tpnme.CaseLabel, TypeAlias(caseLabel))
827+
ref(modul).withSpan(span).cast(mirrorType)
828+
}
829+
else if (monoType.classSymbol.isGenericProduct) {
830+
val cls = monoType.classSymbol
831+
val accessors = cls.caseAccessors.filterNot(_.is(PrivateLocal))
832+
val elemTypes = accessors.map(monoType.memberInfo(_))
833+
val caseLabel = ConstantType(Constant(cls.name.toString))
834+
val elemLabels = accessors.map(acc => ConstantType(Constant(acc.name.toString)))
835+
val mirrorType =
836+
defn.Mirror_ProductType
837+
.refinedWith(tpnme.MonoType, monoAlias)
838+
.refinedWith(tpnme.ElemTypes, TypeAlias(TypeOps.nestedPairs(elemTypes)))
839+
.refinedWith(tpnme.CaseLabel, TypeAlias(caseLabel))
840+
.refinedWith(tpnme.ElemLabels, TypeAlias(TypeOps.nestedPairs(elemLabels)))
841+
val modul = cls.linkedClass.sourceModule
842+
assert(modul.is(Module))
843+
ref(modul).withSpan(span).cast(mirrorType)
844+
}
845+
else EmptyTree
846+
case _ => EmptyTree
847+
}
848+
}
849+
850+
lazy val synthesizedSumMirror: SpecialHandler =
851+
(formal: Type, span: Span) => implicit (ctx: Context) =>
852+
formal.member(tpnme.MonoType).info match {
853+
case monoAlias @ TypeAlias(monoType) if monoType.classSymbol.isGenericSum =>
854+
val cls = monoType.classSymbol
855+
val elemTypes = cls.children.map {
856+
case caseClass: ClassSymbol =>
857+
assert(caseClass.is(Case))
858+
if (caseClass.is(Module))
859+
caseClass.sourceModule.termRef
860+
else caseClass.primaryConstructor.info match {
861+
case info: PolyType =>
862+
def instantiate(implicit ctx: Context) = {
863+
val poly = constrained(info, untpd.EmptyTree)._1
864+
val mono @ MethodType(_) = poly.resultType
865+
val resType = mono.finalResultType
866+
resType <:< cls.appliedRef
867+
val tparams = poly.paramRefs
868+
val variances = caseClass.typeParams.map(_.paramVariance)
869+
val instanceTypes = (tparams, variances).zipped.map((tparam, variance) =>
870+
ctx.typeComparer.instanceType(tparam, fromBelow = variance < 0))
871+
resType.substParams(poly, instanceTypes)
872+
}
873+
instantiate(ctx.fresh.setExploreTyperState().setOwner(caseClass))
874+
case _ =>
875+
caseClass.typeRef
876+
}
877+
case child => child.termRef
878+
}
879+
val mirrorType =
880+
defn.Mirror_SumType
881+
.refinedWith(tpnme.MonoType, monoAlias)
882+
.refinedWith(tpnme.ElemTypes, TypeAlias(TypeOps.nestedPairs(elemTypes)))
883+
var modul = cls.linkedClass.sourceModule
884+
if (!modul.exists) ???
885+
ref(modul).withSpan(span).cast(mirrorType)
886+
case _ =>
887+
EmptyTree
888+
}
889+
890+
lazy val synthesizedMirror: SpecialHandler =
891+
(formal: Type, span: Span) => implicit (ctx: Context) =>
892+
formal.member(tpnme.MonoType).info match {
893+
case monoAlias @ TypeAlias(monoType) =>
894+
if (monoType.termSymbol.is(CaseVal) || monoType.classSymbol.isGenericProduct)
895+
synthesizedProductMirror(formal, span)(ctx)
896+
else
897+
synthesizedSumMirror(formal, span)(ctx)
898+
}
899+
815900
private var mySpecialHandlers: SpecialHandlers = null
816901

817902
private def specialHandlers(implicit ctx: Context) = {
818903
if (mySpecialHandlers == null)
819904
mySpecialHandlers = List(
820-
defn.ClassTagClass -> synthesizedClassTag,
821-
defn.QuotedTypeClass -> synthesizedTypeTag,
822-
defn.GenericClass -> synthesizedGeneric,
905+
defn.ClassTagClass -> synthesizedClassTag,
906+
defn.QuotedTypeClass -> synthesizedTypeTag,
907+
defn.GenericClass -> synthesizedGeneric,
823908
defn.TastyReflectionClass -> synthesizedTastyContext,
824-
defn.EqlClass -> synthesizedEq,
825-
defn.ValueOfClass -> synthesizedValueOf
909+
defn.EqlClass -> synthesizedEq,
910+
defn.ValueOfClass -> synthesizedValueOf,
911+
defn.Mirror_ProductClass -> synthesizedProductMirror,
912+
defn.Mirror_SumClass -> synthesizedSumMirror,
913+
defn.MirrorClass -> synthesizedMirror
826914
)
827915
mySpecialHandlers
828916
}
@@ -836,7 +924,14 @@ trait Implicits { self: Typer =>
836924
case fail @ SearchFailure(failed) =>
837925
def trySpecialCases(handlers: SpecialHandlers): Tree = handlers match {
838926
case (cls, handler) :: rest =>
839-
val base = formal.baseType(cls)
927+
def baseWithRefinements(tp: Type): Type = tp.dealias match {
928+
case tp @ RefinedType(parent, rname, rinfo) =>
929+
tp.derivedRefinedType(baseWithRefinements(parent), rname, rinfo)
930+
case _ =>
931+
tp.baseType(cls)
932+
}
933+
val base = baseWithRefinements(formal)
934+
println(i"try special $formal/$cls/$base")
840935
val result =
841936
if (base <:< formal) {
842937
// With the subtype test we enforce that the searched type `formal` is of the right form

library/src/scala/deriving.scala

+3
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ object deriving {
3939

4040
trait Singleton extends Product {
4141
type MonoType = this.type
42+
type ElemTypes = Unit
43+
type ElemLabels = Unit
44+
4245
def fromProduct(p: scala.Product) = this
4346

4447
def productElement(n: Int): Any = throw new IndexOutOfBoundsException(n.toString)

tests/run/deriving.check

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
A(1,2)
2+
A(1,2)
3+
B
4+
1

tests/run/deriving.scala

+17
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,23 @@ case class A(x: Int, y: Int) extends T
55
case object B extends T
66

77
object Test extends App {
8+
import deriving.{Mirror, EmptyProduct}
89

910
case class AA[X >: Null <: AnyRef](x: X, y: X, z: String)
11+
12+
println(the[Mirror.ProductOf[A]].fromProduct(A(1, 2)))
13+
assert(the[Mirror.SumOf[T]].ordinal(A(1, 2)) == 0)
14+
assert(the[Mirror.Sum { type MonoType = T }].ordinal(B) == 1)
15+
the[Mirror.Of[A]] match {
16+
case m: Mirror.Product =>
17+
println(m.fromProduct(A(1, 2)))
18+
}
19+
the[Mirror.Of[B.type]] match {
20+
case m: Mirror.Product =>
21+
println(m.fromProduct(EmptyProduct))
22+
}
23+
the[Mirror.Of[T]] match {
24+
case m: Mirror.SumOf[T] =>
25+
println(m.ordinal(B))
26+
}
1027
}

0 commit comments

Comments
 (0)