Skip to content

Commit 1d5b70f

Browse files
committed
Proper union type optimization
1 parent f8aeb9f commit 1d5b70f

File tree

4 files changed

+139
-39
lines changed

4 files changed

+139
-39
lines changed

lib/Typer.fs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1370,7 +1370,7 @@ type ResolvedUnion = {
13701370
caseUndefined: bool
13711371
typeofableTypes: Set<Typeofable>
13721372
caseArray: Set<Type> option
1373-
caseEnum: Set<Choice<Enum * EnumCase, Literal>>
1373+
caseEnum: Set<Choice<Enum * EnumCase * Type, Literal>>
13741374
discriminatedUnions: Map<string, Map<Literal, Type>>
13751375
otherTypes: Set<Type>
13761376
}
@@ -1393,8 +1393,8 @@ module ResolvedUnion =
13931393
ru.caseEnum
13941394
|> Set.toSeq
13951395
|> Seq.map (function
1396-
| Choice1Of2 ({ name = ty }, { name = name; value = Some value }) -> sprintf "%s.%s=%s" ty name (Literal.toString value)
1397-
| Choice1Of2 ({ name = ty }, { name = name; value = None }) -> sprintf "%s.%s=?" ty name
1396+
| Choice1Of2 ({ name = ty }, { name = name; value = Some value }, _) -> sprintf "%s.%s=%s" ty name (Literal.toString value)
1397+
| Choice1Of2 ({ name = ty }, { name = name; value = None }, _) -> sprintf "%s.%s=?" ty name
13981398
| Choice2Of2 l -> Literal.toString l)
13991399
yield sprintf "enum<%s>" (cases |> String.concat " | ")
14001400
for k, m in ru.discriminatedUnions |> Map.toSeq do
@@ -1411,7 +1411,7 @@ module ResolvedUnion =
14111411
let hasUndefined = nullOrUndefined |> List.contains (Prim Undefined)
14121412
{| hasNull = hasNull; hasUndefined = hasUndefined; rest = rest |}
14131413

1414-
let rec private getEnumFromUnion ctx (u: UnionType) : Set<Choice<Enum * EnumCase, Literal>> * UnionType =
1414+
let rec private getEnumFromUnion ctx (u: UnionType) : Set<Choice<Enum * EnumCase * Type, Literal>> * UnionType =
14151415
let (|Dummy|) _ = []
14161416

14171417
let rec go t =
@@ -1429,9 +1429,9 @@ module ResolvedUnion =
14291429
let bindings = Type.createBindings i.name loc a.typeParams tyargs
14301430
go (a.target |> Type.substTypeVar bindings ())
14311431
| Definition.Enum e ->
1432-
e.cases |> Seq.map (fun c -> Choice1Of2 (Choice1Of2 (e, c)))
1432+
e.cases |> Seq.map (fun c -> Choice1Of2 (Choice1Of2 (e, c, t)))
14331433
| Definition.EnumCase (c, e) ->
1434-
Seq.singleton (Choice1Of2 (Choice1Of2 (e, c)))
1434+
Seq.singleton (Choice1Of2 (Choice1Of2 (e, c, t)))
14351435
| _ -> Seq.empty
14361436
let result =
14371437
i |> Ident.getDefinitions ctx

src/Targets/JsOfOCaml/Writer.fs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,19 +161,19 @@ module OverrideFunc =
161161
| Some text -> Some text
162162
| None -> f1 _flags _emitType _ctx ty
163163

164-
let emitEnum (flags: EmitTypeFlags) ctx (cases: Set<Choice<Enum * EnumCase, Literal>>) =
164+
let emitEnum (flags: EmitTypeFlags) ctx (cases: Set<Choice<Enum * EnumCase * _, Literal>>) =
165165
let forceSkipAttr text = if flags.forceSkipAttributes then empty else text
166166
let usedValues =
167167
cases
168-
|> Seq.choose (function Choice1Of2 (_, { value = v }) -> v | _ -> None)
168+
|> Seq.choose (function Choice1Of2 (_, { value = v }, _) -> v | _ -> None)
169169
|> Set.ofSeq
170170
let cases =
171171
cases
172172
// Remove literal cases (e.g. `42`) when it is a duplicate of some enum case (e.g. `Case = 42`).
173173
|> Set.filter (function Choice2Of2 l when usedValues |> Set.contains l -> false | _ -> true)
174174
// Convert to identifiers while merging duplicate enum cases
175175
|> Set.map (function
176-
| Choice1Of2 (e, c) -> enumCaseToIdentifier e c |> str, c.value
176+
| Choice1Of2 (e, c, _) -> enumCaseToIdentifier e c |> str, c.value
177177
| Choice2Of2 l -> "L_" @+ literalToIdentifier ctx l, Some l)
178178
between "[" "]" (concat (str " | ") [
179179
for name, value in Set.toSeq cases do

src/Targets/ReScript/ReScriptHelper.fs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -310,9 +310,11 @@ module Type =
310310
let void_ = str "unit"
311311
let string = str "string"
312312
let boolean = str "bool"
313+
let int = str "int"
314+
let float = str "float"
313315
let number (opt: Options) =
314-
if opt.numberAsInt then str "int"
315-
else str "float"
316+
if opt.numberAsInt then int
317+
else float
316318
let array = str "array"
317319
let readonlyArray = str "array"
318320
let option t = app (str "option") [t]
@@ -518,4 +520,4 @@ module Statement =
518520
modules
519521
|> List.filter (fun x -> sccSet |> Set.contains x.origName |> not)
520522
|> emitNonRec
521-
sccModules @ otherModules
523+
sccModules @ otherModules

src/Targets/ReScript/Writer.fs

Lines changed: 125 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,6 @@ module State =
3232
type Context = TyperContext<Options, State>
3333
module Context = TyperContext
3434

35-
type Variance = Covariant | Contravariant | Invariant with
36-
static member (~-) (v: Variance) =
37-
match v with
38-
| Covariant -> Contravariant
39-
| Contravariant -> Covariant
40-
| Invariant -> Invariant
41-
4235
type Label =
4336
| Case of text * text list
4437
| TagType of text * text list
@@ -52,9 +45,7 @@ type [<RequireQualifiedAccess>] External =
5245
type EmitTypeFlags = {
5346
resolveUnion: bool
5447
needParen: bool
55-
variance: Variance
5648
external: External
57-
simplifyContravariantUnion: bool
5849
avoidTheseArgumentNames: Set<string>
5950
}
6051

@@ -63,17 +54,14 @@ module EmitTypeFlags =
6354
{
6455
resolveUnion = true
6556
needParen = false
66-
variance = Covariant
6757
external = External.None
68-
simplifyContravariantUnion = false
6958
avoidTheseArgumentNames = Set.empty
7059
}
7160

7261
let noExternal flags =
7362
{ flags with external = External.None }
7463
let ofFuncArg isVariadic flags =
7564
{ flags with
76-
variance = -flags.variance
7765
external =
7866
match flags.external with
7967
| External.Root _ -> External.Argument isVariadic
@@ -347,18 +335,130 @@ and emitFuncType (flags: EmitTypeFlags) (overrideFunc: OverrideFunc) (ctx: Conte
347335
| _ -> Type.curriedArrow (args ()) (retTy flags) |> paren
348336

349337
and emitUnion (flags: EmitTypeFlags) (overrideFunc: OverrideFunc) (ctx: Context) (u: UnionType) : text =
350-
// TODO: more classification
351-
let u = ResolvedUnion.checkNullOrUndefined u
352-
let rest =
353-
let rest = u.rest |> List.map (emitTypeImpl (EmitTypeFlags.noExternal flags) overrideFunc ctx)
354-
if List.isEmpty rest then Type.never
355-
else Type.union rest
356-
match u.hasNull, u.hasUndefined with
357-
| true, _ | _, true when flags.external = External.Return true -> Type.option rest
358-
| true, true -> Type.null_or_undefined_or rest
359-
| true, false -> Type.null_or rest
360-
| false, true -> Type.undefined_or rest
361-
| false, false -> rest
338+
if flags.resolveUnion = false then
339+
u.types
340+
|> List.distinct
341+
|> List.map (emitTypeImpl (EmitTypeFlags.noExternal flags) overrideFunc ctx)
342+
|> Type.union
343+
else if flags.external = External.Return true then
344+
let u = ResolvedUnion.checkNullOrUndefined u
345+
let rest =
346+
if List.isEmpty u.rest then Type.never
347+
else
348+
let t = Union { types = u.rest }
349+
emitTypeImpl (EmitTypeFlags.noExternal flags) overrideFunc ctx t
350+
match u.hasNull, u.hasUndefined with
351+
| true, _ | _, true -> Type.option rest
352+
| false, false -> rest
353+
else
354+
let u = ResolvedUnion.resolve ctx u
355+
356+
let treatEnum (cases: Set<Choice<Enum * EnumCase * Type, Literal>>) =
357+
let handleLiteral l attr ty =
358+
match l with
359+
| LString s -> Choice1Of2 {| name = Choice1Of2 s; value = None; attr = attr |}
360+
| LInt i -> Choice1Of2 {| name = Choice2Of2 i; value = None; attr = attr |}
361+
| LFloat _ -> Choice2Of2 (ty |? Type.float)
362+
| LBool _ -> Choice2Of2 (ty |? Type.boolean)
363+
let cases = [
364+
for c in cases do
365+
match c with
366+
| Choice1Of2 (_, _, ty) ->
367+
let ty = emitTypeImpl (EmitTypeFlags.noExternal flags) overrideFunc ctx ty
368+
yield Choice2Of2 ty
369+
| Choice2Of2 l -> yield handleLiteral l None None
370+
]
371+
let cases, rest = List.splitChoice2 cases
372+
[
373+
if List.isEmpty cases |> not then
374+
yield Type.polyVariant cases
375+
yield! rest
376+
]
377+
378+
let treatArray (ts: Set<Type>) =
379+
// TODO: think how to map multiple array cases properly
380+
let elemT =
381+
let elemT =
382+
match Set.toList ts with
383+
| [t] -> t
384+
| ts -> Union { types = ts }
385+
emitTypeImpl (EmitTypeFlags.noExternal flags) overrideFunc ctx elemT
386+
Type.app Type.array [elemT]
387+
388+
let treatDUMany du =
389+
// TODO: anonymous DU?
390+
let types =
391+
du
392+
|> Map.toList
393+
|> List.collect (fun (_, cases) -> Map.toList cases)
394+
|> List.map (fun (_, t) -> t)
395+
types
396+
|> List.map (emitTypeImpl (EmitTypeFlags.noExternal { flags with resolveUnion = false }) overrideFunc ctx)
397+
|> List.distinct
398+
399+
let baseTypes = [
400+
if not (Set.isEmpty u.caseEnum) then
401+
yield! treatEnum u.caseEnum
402+
if not (Map.isEmpty u.discriminatedUnions) then
403+
yield! treatDUMany u.discriminatedUnions
404+
match u.caseArray with
405+
| Some ts -> yield treatArray ts
406+
| None -> ()
407+
for t in u.otherTypes do
408+
yield emitTypeImpl (EmitTypeFlags.noExternal { flags with resolveUnion = false }) overrideFunc ctx t
409+
]
410+
411+
let case name value = {| name = Choice1Of2 name; value = value; attr = None |}
412+
let genPoly unwrap =
413+
let cases = [
414+
for t in u.typeofableTypes do
415+
match t with
416+
| Typeofable.String -> yield case "String" (Some Type.string)
417+
| Typeofable.Number -> yield case "Number" (Some (Type.number ctx.options))
418+
| Typeofable.Boolean -> yield case "Boolean" (Some Type.boolean)
419+
| Typeofable.Symbol -> yield case "Symbol" (Some Type.symbol)
420+
| Typeofable.BigInt -> yield case "BigInt" (Some Type.bigint)
421+
422+
if u.caseNull then
423+
yield case "Null" (if unwrap then Some Type.null_ else None)
424+
if u.caseUndefined then
425+
yield case "Undefined" (if unwrap then Some Type.undefined else None)
426+
427+
match List.distinct baseTypes with
428+
| [] -> ()
429+
| ts ->
430+
if unwrap then
431+
for i, t in ts |> List.indexed do
432+
yield case (sprintf "U%d" (i+1)) (Some t)
433+
else
434+
yield case "Other" (Some (Type.union ts))
435+
]
436+
Type.polyVariant cases
437+
438+
let createNullable isNull isUndefined t =
439+
match isNull, isUndefined with
440+
| false, false -> t
441+
| true, false -> Type.null_or t
442+
| false, true -> Type.undefined_or t
443+
| true, true -> Type.null_or_undefined_or t
444+
445+
let emitTypeofableType t = emitTypeImpl flags overrideFunc ctx (TypeofableType.toType t)
446+
447+
let isExternalArg = match flags.external with External.Argument _ -> true | _ -> false
448+
449+
match baseTypes, Set.toList u.typeofableTypes, u.caseNull, u.caseUndefined with
450+
| [], [], false, false -> impossible "emitUnion_empty_union"
451+
| [], [], true, false -> Type.null_
452+
| [], [], false, true -> Type.undefined
453+
| [], [], true, true -> Type.null_or_undefined_or Type.never
454+
| [t], [], isNull, isUndefined -> createNullable isNull isUndefined t
455+
| ts, [], isNull, isUndefined when not isExternalArg ->
456+
createNullable isNull isUndefined (Type.union ts)
457+
| [], [t], isNull, isUndefined -> createNullable isNull isUndefined (emitTypeofableType t)
458+
| _, _, _, _ ->
459+
match flags.external with
460+
| External.Argument _ -> Attr.PolyVariant.unwrap +@ " " + genPoly true
461+
| _ -> Type.app (str "Primitive.t") [genPoly false]
362462

363463
/// `[ #A | #B | ... ]`
364464
and emitLabels (ctx: Context) labels =
@@ -621,7 +721,6 @@ let extValue flags overrideFunc ctx (t: Type) =
621721
ty, attr
622722

623723
let rec emitMembers flags overrideFunc ctx (selfTy: Type) (isExportDefaultClass: bool) (ma: MemberAttribute) m =
624-
let flags = { flags with simplifyContravariantUnion = true }
625724
let emitType_ = emitTypeImpl flags overrideFunc
626725

627726
let comments = emitComments ma.comments
@@ -1052,7 +1151,7 @@ let rec emitClass flags overrideFunc (ctx: Context) (current: StructuredText) (c
10521151

10531152
let builder =
10541153
let emitType_ ctx ty =
1055-
emitTypeImpl { flags with needParen = true; variance = Contravariant } overrideFunc ctx ty
1154+
emitTypeImpl { flags with needParen = true } overrideFunc ctx ty
10561155
if not c.isPOJO then []
10571156
else
10581157
let field (fl: FieldLike) =
@@ -1362,7 +1461,6 @@ let createStructuredText (rootCtx: Context) (stmts: Statement list) : Structured
13621461

13631462
/// convert interface members to appropriate statements
13641463
let intfToStmts (moduleIntf: Class<_>) ctx flags overrideFunc =
1365-
let flags = { flags with simplifyContravariantUnion = true }
13661464
let inline extFunc ft = extFunc flags overrideFunc ctx ft
13671465
let inline func ft = func flags overrideFunc ctx ft
13681466
let inline newableFunc ft = newableFunc flags overrideFunc ctx ft

0 commit comments

Comments
 (0)