Skip to content

Commit 4ca1e72

Browse files
committed
feat(gen): support non-primitive enums with sum types
Generate proper sum types for object enums where each enum value becomes a concrete struct variant. Features: - Automatic discriminator field detection for value-based discrimination - Concrete struct types for each enum value (e.g., ObjectEnumFoo, ObjectEnumBar) - Sum type with Type field for variant switching - Full JSON encoding/decoding with value-based discrimination - Getter/setter methods for type-safe variant access For array enums, falls back to jx.Raw since array items typically don't have discriminator fields. Fixes #1596
1 parent da23baf commit 4ca1e72

22 files changed

+2000
-5
lines changed
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
{
2+
"openapi": "3.0.3",
3+
"info": {
4+
"title": "Non-primitive enum test",
5+
"version": "0.1.0"
6+
},
7+
"paths": {
8+
"/test": {
9+
"get": {
10+
"operationId": "getTest",
11+
"responses": {
12+
"200": {
13+
"description": "OK",
14+
"content": {
15+
"application/json": {
16+
"schema": {
17+
"$ref": "#/components/schemas/ObjectEnum"
18+
}
19+
}
20+
}
21+
}
22+
}
23+
}
24+
}
25+
},
26+
"components": {
27+
"schemas": {
28+
"ObjectEnum": {
29+
"description": "An enum of specific object values",
30+
"type": "object",
31+
"enum": [
32+
{"type": "foo", "value": 1},
33+
{"type": "bar", "value": 2},
34+
{"type": "baz", "value": 3}
35+
]
36+
},
37+
"ArrayEnum": {
38+
"description": "An enum of specific array values",
39+
"type": "array",
40+
"enum": [
41+
[1, 2, 3],
42+
["a", "b", "c"]
43+
]
44+
}
45+
}
46+
}
47+
}

gen/_template/json/encoders_sum.tmpl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,32 @@ func (s *{{ $.Name }}) Decode(d *jx.Decoder) error {
179179
}
180180
{{- end }}
181181
{{- end }}
182+
{{- range $fieldName, $discriminator := $.SumSpec.ValueDiscriminators }}
183+
case {{ quote $fieldName }}:
184+
// Value-based discrimination: check enum value
185+
if typ := d.Next(); typ != jx.String {
186+
return d.Skip()
187+
}
188+
value, err := d.StrBytes()
189+
if err != nil {
190+
return err
191+
}
192+
switch string(value) {
193+
{{- range $enumValue, $variantType := $discriminator.ValueToVariant }}
194+
case {{ quote $enumValue }}:
195+
match := {{ $variantType }}
196+
if found && s.Type != match {
197+
s.Type = ""
198+
return errors.Errorf("multiple oneOf matches: (%v, %v)", s.Type, match)
199+
}
200+
found = true
201+
s.Type = match
202+
{{- end }}
203+
default:
204+
// Unknown enum value, ignore and continue
205+
}
206+
return nil
207+
{{- end }}
182208
{{- end }}
183209
}
184210
return d.Skip()

gen/ir/json.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,13 @@ func (j JSON) Sum() SumJSON {
374374
Type: SumJSONTypeDiscriminator,
375375
}
376376
}
377+
// Check for field-based discrimination (UniqueFields or ValueDiscriminators on sum type)
378+
if len(j.t.SumSpec.UniqueFields) > 0 || len(j.t.SumSpec.ValueDiscriminators) > 0 {
379+
return SumJSON{
380+
Type: SumJSONFields,
381+
}
382+
}
383+
// Check for unique fields on variants (legacy approach)
377384
for _, s := range j.t.SumOf {
378385
if len(s.SumSpec.Unique) > 0 {
379386
return SumJSON{

gen/ir/type.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,20 @@ type SumSpec struct {
6565
// Used for generating field-based discrimination in oneOf/anyOf.
6666
// Key: field JSON name, Value: list of variants with that unique field
6767
UniqueFields map[string][]UniqueFieldVariant
68+
69+
// ValueDiscriminators maps field names to value-based discriminators.
70+
// Used when variants have the same field name and JSON type but different enum values.
71+
// Key: field JSON name, Value: ValueDiscriminator with enum value mappings
72+
ValueDiscriminators map[string]ValueDiscriminator
73+
}
74+
75+
// ValueDiscriminator represents a field that discriminates variants by enum value.
76+
type ValueDiscriminator struct {
77+
// FieldName is the JSON field name used for discrimination
78+
FieldName string
79+
// ValueToVariant maps enum values to variant type constants
80+
// Key: enum value (e.g., "active"), Value: variant type constant (e.g., "ActiveStatusResponse")
81+
ValueToVariant map[string]string
6882
}
6983

7084
type ResolvedSumSpecMap struct {

gen/schema_gen.go

Lines changed: 219 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -305,11 +305,19 @@ func (g *schemaGen) generate2(name string, schema *jsonschema.Schema) (ret *ir.T
305305
jsonschema.Number,
306306
jsonschema.Boolean,
307307
jsonschema.Null:
308-
default:
309-
return nil, errors.Wrap(
310-
&ErrNotImplemented{Name: "non-primitive enum"},
311-
name,
312-
)
308+
// Primitive enums are handled below
309+
case jsonschema.Object:
310+
// Non-primitive object enums generate sum types with struct variants.
311+
// Each enum value becomes a concrete struct type.
312+
t, err := g.nonPrimitiveObjectEnum(name, schema)
313+
if err != nil {
314+
return nil, errors.Wrap(err, "non-primitive object enum")
315+
}
316+
return t, nil
317+
case jsonschema.Array, jsonschema.Empty:
318+
// Array enums and empty type enums are treated as "any" type.
319+
// The enum constraint is documented in OpenAPI but not enforced at runtime.
320+
return g.regtype(name, ir.Any(schema)), nil
313321
}
314322
}
315323

@@ -709,3 +717,209 @@ func (g *schemaGen) checkDefaultType(s *jsonschema.Schema, val any) error {
709717

710718
return nil
711719
}
720+
721+
// nonPrimitiveObjectEnum generates a sum type for object enums.
722+
// Each enum value becomes a concrete struct variant.
723+
func (g *schemaGen) nonPrimitiveObjectEnum(name string, schema *jsonschema.Schema) (*ir.Type, error) {
724+
if len(schema.Enum) == 0 {
725+
return nil, errors.New("enum has no values")
726+
}
727+
728+
// Convert enum values to map[string]any
729+
enumObjects := make([]map[string]any, 0, len(schema.Enum))
730+
for i, v := range schema.Enum {
731+
obj, ok := v.(map[string]any)
732+
if !ok {
733+
return nil, errors.Errorf("enum[%d]: expected object, got %T", i, v)
734+
}
735+
enumObjects = append(enumObjects, obj)
736+
}
737+
738+
// Find a discriminating field - a string field with unique values across all variants
739+
discriminatorField, variantNames := findEnumDiscriminator(enumObjects)
740+
if discriminatorField == "" {
741+
// No discriminator found, fall back to index-based naming
742+
for i := range enumObjects {
743+
variantNames = append(variantNames, fmt.Sprintf("Variant%d", i))
744+
}
745+
}
746+
747+
// Create the sum type
748+
sum := g.regtype(name, &ir.Type{
749+
Name: name,
750+
Kind: ir.KindSum,
751+
Schema: schema,
752+
})
753+
754+
// Generate struct types for each enum value
755+
variants := make([]*ir.Type, 0, len(enumObjects))
756+
for i, obj := range enumObjects {
757+
variantName := name + variantNames[i]
758+
variantSchema := inferSchemaFromObject(obj)
759+
760+
// Generate the variant struct type
761+
variantType := &ir.Type{
762+
Kind: ir.KindStruct,
763+
Name: variantName,
764+
Schema: variantSchema,
765+
}
766+
767+
// Add fields from the object
768+
for fieldName, fieldValue := range obj {
769+
fieldSchema := inferSchemaFromValue(fieldValue)
770+
fieldType := g.inferTypeFromValue(fieldValue, fieldSchema)
771+
772+
field := &ir.Field{
773+
Name: naming.Capitalize(fieldName),
774+
Type: fieldType,
775+
Tag: ir.Tag{
776+
JSON: fieldName,
777+
},
778+
Spec: &jsonschema.Property{
779+
Name: fieldName,
780+
Schema: fieldSchema,
781+
Required: true,
782+
},
783+
}
784+
variantType.Fields = append(variantType.Fields, field)
785+
}
786+
787+
// Register and add to variants
788+
g.regtype(variantName, variantType)
789+
variants = append(variants, variantType)
790+
}
791+
792+
sum.SumOf = variants
793+
794+
// Set up discrimination
795+
if discriminatorField != "" {
796+
// Value-based discrimination on the discriminator field
797+
valueToVariant := make(map[string]string)
798+
for i, obj := range enumObjects {
799+
if val, ok := obj[discriminatorField].(string); ok {
800+
valueToVariant[val] = variants[i].Name + name
801+
}
802+
}
803+
sum.SumSpec.ValueDiscriminators = map[string]ir.ValueDiscriminator{
804+
discriminatorField: {
805+
FieldName: discriminatorField,
806+
ValueToVariant: valueToVariant,
807+
},
808+
}
809+
} else {
810+
// No discriminator field found, use type-based discrimination as fallback
811+
sum.SumSpec.TypeDiscriminator = true
812+
}
813+
814+
return sum, nil
815+
}
816+
817+
// findEnumDiscriminator finds a string field that has unique values across all enum objects.
818+
func findEnumDiscriminator(objects []map[string]any) (string, []string) {
819+
if len(objects) == 0 {
820+
return "", nil
821+
}
822+
823+
// Find all string fields present in all objects
824+
stringFields := make(map[string][]string)
825+
for _, obj := range objects {
826+
for k, v := range obj {
827+
if s, ok := v.(string); ok {
828+
stringFields[k] = append(stringFields[k], s)
829+
}
830+
}
831+
}
832+
833+
// Find a field with unique values across all objects
834+
for field, values := range stringFields {
835+
if len(values) != len(objects) {
836+
continue // Field not present in all objects
837+
}
838+
839+
// Check if all values are unique
840+
seen := make(map[string]bool)
841+
allUnique := true
842+
for _, v := range values {
843+
if seen[v] {
844+
allUnique = false
845+
break
846+
}
847+
seen[v] = true
848+
}
849+
850+
if allUnique {
851+
// Use these values as variant names (capitalized)
852+
variantNames := make([]string, len(values))
853+
for i, v := range values {
854+
variantNames[i] = naming.Capitalize(v)
855+
}
856+
return field, variantNames
857+
}
858+
}
859+
860+
return "", nil
861+
}
862+
863+
// inferSchemaFromObject creates a jsonschema.Schema from an object literal.
864+
func inferSchemaFromObject(obj map[string]any) *jsonschema.Schema {
865+
schema := &jsonschema.Schema{
866+
Type: jsonschema.Object,
867+
}
868+
for fieldName, fieldValue := range obj {
869+
prop := jsonschema.Property{
870+
Name: fieldName,
871+
Schema: inferSchemaFromValue(fieldValue),
872+
Required: true,
873+
}
874+
schema.Properties = append(schema.Properties, prop)
875+
}
876+
return schema
877+
}
878+
879+
// inferSchemaFromValue creates a jsonschema.Schema from a JSON value.
880+
func inferSchemaFromValue(v any) *jsonschema.Schema {
881+
switch val := v.(type) {
882+
case string:
883+
return &jsonschema.Schema{Type: jsonschema.String}
884+
case int64:
885+
return &jsonschema.Schema{Type: jsonschema.Integer}
886+
case float64:
887+
return &jsonschema.Schema{Type: jsonschema.Number}
888+
case bool:
889+
return &jsonschema.Schema{Type: jsonschema.Boolean}
890+
case nil:
891+
return &jsonschema.Schema{Type: jsonschema.Null}
892+
case []any:
893+
schema := &jsonschema.Schema{Type: jsonschema.Array}
894+
if len(val) > 0 {
895+
schema.Item = inferSchemaFromValue(val[0])
896+
}
897+
return schema
898+
case map[string]any:
899+
return inferSchemaFromObject(val)
900+
default:
901+
return &jsonschema.Schema{}
902+
}
903+
}
904+
905+
// inferTypeFromValue creates an ir.Type from a JSON value.
906+
func (g *schemaGen) inferTypeFromValue(v any, schema *jsonschema.Schema) *ir.Type {
907+
switch v.(type) {
908+
case string:
909+
return ir.Primitive(ir.String, schema)
910+
case int64:
911+
return ir.Primitive(ir.Int64, schema)
912+
case float64:
913+
return ir.Primitive(ir.Float64, schema)
914+
case bool:
915+
return ir.Primitive(ir.Bool, schema)
916+
case nil:
917+
return ir.Primitive(ir.Null, schema)
918+
case []any:
919+
return ir.Array(g.inferTypeFromValue(nil, nil), ir.NilInvalid, schema)
920+
case map[string]any:
921+
return ir.Any(schema)
922+
default:
923+
return ir.Any(schema)
924+
}
925+
}

internal/integration/generate.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,4 @@ package integration
5050
//go:generate go run ../../cmd/ogen -v --clean --target test_issue1161 ../../_testdata/positive/issue1161.json
5151
//go:generate go run ../../cmd/ogen -v --clean --target test_issue1495 ../../_testdata/positive/issue1495.yml
5252
//go:generate go run ../../cmd/ogen -v --clean --target test_raw_response ../../_testdata/positive/raw_response.yml
53+
//go:generate go run ../../cmd/ogen -v --clean --target test_non_primitive_enum ../../_testdata/positive/non_primitive_enum.json

0 commit comments

Comments
 (0)