|
| 1 | +import scala.compiletime.ops.int.{`*`, +} |
| 2 | + |
| 3 | +// HList |
| 4 | +sealed trait Shape |
| 5 | +final case class #:[H <: Int & Singleton, T <: Shape](head: H, tail: T) extends Shape |
| 6 | +case object Ø extends Shape |
| 7 | +type Ø = Ø.type |
| 8 | + |
| 9 | +// Reduce |
| 10 | +def reduce[T, S <: Shape, A <: Shape](shape: S, axes: A): Reduce[S, A, 0] = ??? |
| 11 | +type Reduce[S, Axes <: Shape, I <: Int] <: Shape = S match { |
| 12 | + case head #: tail => Contains[Axes, I] match { |
| 13 | + case true => Reduce[tail, Remove[Axes, I], I + 1] |
| 14 | + case false => head #: Reduce[tail, Axes, I + 1] |
| 15 | + } |
| 16 | + case Ø => Axes match { |
| 17 | + case Ø => Ø |
| 18 | + // otherwise, do not reduce further |
| 19 | + } |
| 20 | +} |
| 21 | +type Contains[Haystack <: Shape, Needle <: Int] <: Boolean = Haystack match { |
| 22 | + case Ø => false |
| 23 | + case head #: tail => head match { |
| 24 | + case Needle => true |
| 25 | + case _ => Contains[tail, Needle] |
| 26 | + } |
| 27 | +} |
| 28 | +type Remove[From <: Shape, Value <: Int] <: Shape = From match { |
| 29 | + case Ø => Ø |
| 30 | + case head #: tail => head match { |
| 31 | + case Value => Remove[tail, Value] |
| 32 | + case _ => head #: Remove[tail, Value] |
| 33 | + } |
| 34 | +} |
| 35 | + |
| 36 | +// Reshape |
| 37 | +def reshape[From <: Shape, To <: Shape](from: From, to: To) |
| 38 | + (using ev: NumElements[From] =:= NumElements[To]): To = ??? |
| 39 | +type NumElements[X <: Shape] <: Int = X match { |
| 40 | + case Ø => 1 |
| 41 | + case head #: tail => head * NumElements[tail] |
| 42 | +} |
| 43 | + |
| 44 | +// Test cases |
| 45 | +val input = #:(25, #:(256, #:(256, #:(3, Ø)))) |
| 46 | +val reduced = reduce(input, #:(3, #:(1, #:(2, Ø)))) |
| 47 | +val reshaped: 5 #: 5 #: Ø = reshape(reduced, #:(5, #:(5, Ø))) |
0 commit comments