Skip to content

Commit 5682e4f

Browse files
committed
WIP: Handle complex enums with interface types.
Part of #340.
1 parent 1866144 commit 5682e4f

File tree

4 files changed

+775
-22
lines changed

4 files changed

+775
-22
lines changed

internal/generate/responses.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ func populateResponseType(name string, r *openapi3.Response) ([]TypeTemplate, []
6262
continue
6363
}
6464

65-
tt, et := populateTypeTemplates(respName, s.Value, "")
65+
tt, et := populateTypeTemplates(respName, s.Value, "", "")
6666
types = append(types, tt...)
6767
enumTypes = append(enumTypes, et...)
6868

internal/generate/test_utils/types_output

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,16 @@ type DiskCreate struct {
1111
DiskSource DiskSource `json:"disk_source,omitempty" yaml:"disk_source,omitempty"`
1212
}
1313

14+
func (DiskCreate) isDiskCreate() {}
15+
1416
// DiskIdentifier is parameters for the [`Disk`](omicron_common::api::external::Disk) to be attached or
1517
// detached to an instance
1618
type DiskIdentifier struct {
1719
Name Name `json:"name,omitempty" yaml:"name,omitempty"`
1820
}
1921

22+
func (DiskIdentifier) isDiskIdentifier() {}
23+
2024
// DiskSourceType is the type definition for a DiskSourceType.
2125
type DiskSourceType string
2226
// DiskSourceSnapshot is create a disk from a disk snapshot
@@ -25,12 +29,16 @@ type DiskSourceSnapshot struct {
2529
Type DiskSourceType `json:"type,omitempty" yaml:"type,omitempty"`
2630
}
2731

32+
func (DiskSourceSnapshot) isDiskSourceSnapshot() {}
33+
2834
// DiskSourceImage is create a disk from a project image
2935
type DiskSourceImage struct {
3036
ImageId string `json:"image_id,omitempty" yaml:"image_id,omitempty"`
3137
Type DiskSourceType `json:"type,omitempty" yaml:"type,omitempty"`
3238
}
3339

40+
func (DiskSourceImage) isDiskSourceImage() {}
41+
3442
// DiskSource is the type definition for a DiskSource.
3543
type DiskSource struct {
3644
// SnapshotId is the type definition for a SnapshotId.

internal/generate/types.go

Lines changed: 118 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
package main
66

77
import (
8+
"encoding/json"
89
"fmt"
10+
"maps"
911
"os"
1012
"slices"
1113
"sort"
@@ -34,6 +36,17 @@ type TypeTemplate struct {
3436
Type string
3537
// Fields holds the information for the field
3638
Fields []TypeFields
39+
40+
// InterfaceType
41+
InterfaceType string
42+
// DiscriminatorKey
43+
DiscriminatorKey string
44+
// DiscriminatorType
45+
DiscriminatorType string
46+
// DiscriminatorToType
47+
DiscriminatorToType [][]string
48+
49+
GenericFieldName string
3750
}
3851

3952
// TypeFields holds the information for each type field
@@ -256,7 +269,7 @@ func constructTypes(schemas openapi3.Schemas) ([]TypeTemplate, []EnumTemplate) {
256269

257270
// Set name as a valid Go type name
258271
name = strcase.ToCamel(name)
259-
typeTpl, enumTpl := populateTypeTemplates(name, s.Value, "")
272+
typeTpl, enumTpl := populateTypeTemplates(name, s.Value, "", "")
260273
typeCollection = append(typeCollection, typeTpl...)
261274
enumCollection = append(enumCollection, enumTpl...)
262275
}
@@ -326,7 +339,12 @@ func writeTypes(f *os.File, typeCollection []TypeTemplate, typeValidationCollect
326339

327340
fmt.Fprintf(f, "%s\n", splitDocString(tt.Description))
328341
fmt.Fprintf(f, "type %s %s", tt.Name, tt.Type)
329-
if tt.Fields != nil {
342+
343+
if tt.Type == "interface" {
344+
fmt.Fprintf(f, " {\n")
345+
fmt.Fprintf(f, "\tis%s()\n", tt.Name)
346+
fmt.Fprint(f, "}\n")
347+
} else if tt.Fields != nil {
330348
fmt.Fprint(f, " {\n")
331349
for _, ft := range tt.Fields {
332350
if ft.Description != "" {
@@ -335,6 +353,34 @@ func writeTypes(f *os.File, typeCollection []TypeTemplate, typeValidationCollect
335353
fmt.Fprintf(f, "\t%s %s %s\n", ft.Name, ft.Type, ft.SerializationInfo)
336354
}
337355
fmt.Fprint(f, "}\n")
356+
357+
if tt.DiscriminatorKey != "" && len(tt.DiscriminatorToType) > 0 && tt.GenericFieldName != "" {
358+
fmt.Fprintf(f, "func (v *%s) UnmarshalJSON(data []byte) error {\n", tt.Name)
359+
fmt.Fprintf(f, "\tvar peek struct {\n")
360+
fmt.Fprintf(f, "\t\tdiscriminator %s`json:\"%s\"`\n", tt.DiscriminatorType, tt.DiscriminatorKey)
361+
fmt.Fprintf(f, "\t}\n")
362+
fmt.Fprintf(f, "\tif err := json.Unmarshal(data, &peek); err != nil {\n")
363+
fmt.Fprintf(f, "\t\treturn err\n")
364+
fmt.Fprintf(f, "\t}\n")
365+
fmt.Fprintf(f, "\tswitch peek.discriminator {\n")
366+
367+
for _, dtt := range tt.DiscriminatorToType {
368+
fmt.Fprintf(f, "\tcase %s:\n", dtt[0])
369+
fmt.Fprintf(f, "\t\tvar val %s\n", dtt[1])
370+
fmt.Fprintf(f, "\t\tif err := json.Unmarshal(data, &val); err != nil {\n")
371+
fmt.Fprintf(f, "\t\t\treturn err\n")
372+
fmt.Fprintf(f, "\t\t}\n")
373+
fmt.Fprintf(f, "\tv.%s = val\n", tt.GenericFieldName)
374+
}
375+
376+
fmt.Fprintf(f, "\t}\n")
377+
fmt.Fprintf(f, "\treturn nil\n")
378+
fmt.Fprint(f, "}\n")
379+
}
380+
}
381+
if tt.InterfaceType != "" {
382+
fmt.Fprintf(f, "\n")
383+
fmt.Fprintf(f, "func (%s) is%s() {}\n", tt.Name, tt.InterfaceType)
338384
}
339385
fmt.Fprint(f, "\n")
340386
}
@@ -381,7 +427,7 @@ func writeTypes(f *os.File, typeCollection []TypeTemplate, typeValidationCollect
381427
// populateTypeTemplates populates the template of a type definition for the given schema.
382428
// The additional parameter is only used as a suffix for the type name.
383429
// This is mostly for oneOf types.
384-
func populateTypeTemplates(name string, s *openapi3.Schema, enumFieldName string) ([]TypeTemplate, []EnumTemplate) {
430+
func populateTypeTemplates(name string, s *openapi3.Schema, enumFieldName string, interfaceName string) ([]TypeTemplate, []EnumTemplate) {
385431
typeName := name
386432

387433
// Type name will change for each enum type
@@ -418,21 +464,24 @@ func populateTypeTemplates(name string, s *openapi3.Schema, enumFieldName string
418464
typeTpl.Type = fmt.Sprintf("[]%s", s.Items.Value.Type)
419465
typeTpl.Name = typeName
420466
case "object":
467+
b, _ := json.Marshal(s)
468+
fmt.Printf("DEBUG IN OBJECT %s %+v %+v\n", typeName, string(b), enumFieldName)
421469
typeTpl = createTypeObject(s, name, typeName, formatTypeDescription(typeName, s))
470+
typeTpl.InterfaceType = interfaceName
422471

423472
// Iterate over the properties and append the types, if we need to.
424473
properties := sortedKeys(s.Properties)
425474
for _, k := range properties {
426475
v := s.Properties[k]
427476
if isLocalEnum(v) {
428-
tt, et := populateTypeTemplates(fmt.Sprintf("%s%s", name, strcase.ToCamel(k)), v.Value, "")
477+
tt, et := populateTypeTemplates(fmt.Sprintf("%s%s", name, strcase.ToCamel(k)), v.Value, "", "")
429478
types = append(types, tt...)
430479
enumTypes = append(enumTypes, et...)
431480
}
432481

433482
// TODO: So far this code is never hit with the current openapi spec
434483
if isLocalObject(v) {
435-
tt, et := populateTypeTemplates(fmt.Sprintf("%s%s", name, strcase.ToCamel(k)), v.Value, "")
484+
tt, et := populateTypeTemplates(fmt.Sprintf("%s%s", name, strcase.ToCamel(k)), v.Value, "", "")
436485
types = append(types, tt...)
437486
enumTypes = append(enumTypes, et...)
438487
}
@@ -649,6 +698,7 @@ func createAllOf(s *openapi3.Schema, stringEnums map[string][]string, name, type
649698
}
650699

651700
func createOneOf(s *openapi3.Schema, name, typeName string) ([]TypeTemplate, []EnumTemplate) {
701+
fmt.Printf("DEBUG IN CREATEONEOF %v\n", typeName)
652702
var parsedProperties []string
653703
var properties []string
654704
var genericTypes []string
@@ -698,6 +748,13 @@ func createOneOf(s *openapi3.Schema, name, typeName string) ([]TypeTemplate, []E
698748
}
699749
}
700750

751+
fmt.Printf("DEBUG GENERIC TYPES %+v\n", genericTypes)
752+
typeInterfaces := map[string]string{}
753+
maybeDiscriminators := map[string]struct{}{}
754+
discriminatorToType := [][]string{}
755+
discriminatorToDiscriminatorType := map[string]string{}
756+
interfaceName := ""
757+
701758
for _, v := range s.OneOf {
702759
// We want to iterate over the properties of the embedded object
703760
// and find the type that is a string.
@@ -707,17 +764,27 @@ func createOneOf(s *openapi3.Schema, name, typeName string) ([]TypeTemplate, []E
707764
keys := sortedKeys(v.Value.Properties)
708765
for _, prop := range keys {
709766
p := v.Value.Properties[prop]
767+
propertyName := strcase.ToCamel(prop)
768+
710769
// We want to collect all the unique properties to create our global oneOf type.
711770
propertyType := convertToValidGoType(prop, typeName, p)
712771

772+
if propertyType == "string" && len(p.Value.Enum) == 1 {
773+
fmt.Printf("DEBUG PROP IS %s\n", prop)
774+
maybeDiscriminators[prop] = struct{}{}
775+
discriminatorToDiscriminatorType[prop] = typeName + strcase.ToCamel(prop)
776+
discriminatorToType = append(discriminatorToType, []string{
777+
fmt.Sprintf("%s%s%s", typeName, propertyName, strcase.ToCamel(p.Value.Enum[0].(string))),
778+
fmt.Sprintf("%s%s", typeName, strcase.ToCamel(p.Value.Enum[0].(string))),
779+
})
780+
// discriminatorToType[p.Value.Enum[0].(string)] = fmt.Sprintf("%s%s%s", typeName, propertyName, strcase.ToCamel(p.Value.Enum[0].(string)))
781+
}
713782
// Check if we have an enum in order to use the corresponding type instead of
714783
// "string"
715784
if propertyType == "string" && len(p.Value.Enum) != 0 {
716785
propertyType = typeName + strcase.ToCamel(prop)
717786
}
718787

719-
propertyName := strcase.ToCamel(prop)
720-
721788
// Avoids duplication for every enum
722789
if !containsMatchFirstWord(parsedProperties, propertyName) {
723790
field := TypeFields{
@@ -729,7 +796,11 @@ func createOneOf(s *openapi3.Schema, name, typeName string) ([]TypeTemplate, []E
729796

730797
// We set the type of a field as "any" if every element of the oneOf property isn't the same
731798
if slices.Contains(genericTypes, prop) {
732-
field.Type = "any"
799+
fmt.Printf("DEBUG CONTAINS %+v %+v %s\n", genericTypes, prop, typeName)
800+
interfaceName = fmt.Sprintf("%s%s", typeName, propertyName)
801+
field.Type = fmt.Sprintf("%s%s", typeName, propertyName)
802+
803+
// typeInterfaces[field.Type] = fmt.Sprintf("%s%s", typeName, propertyName)
733804
}
734805

735806
// Check if the field is nullable and use omitzero instead of omitempty.
@@ -765,7 +836,15 @@ func createOneOf(s *openapi3.Schema, name, typeName string) ([]TypeTemplate, []E
765836

766837
// TODO: This is the only place that has an "additional name" at the end
767838
// TODO: This is where the "allOf" is being detected
768-
tt, et := populateTypeTemplates(name, v.Value, enumFieldName)
839+
fmt.Printf("DEBUG ABOUT TO POPULATE %s %+v %s\n", name, enumFieldName, interfaceName)
840+
tt, et := populateTypeTemplates(name, v.Value, enumFieldName, interfaceName)
841+
842+
fmt.Printf("DEBUG TYPE INTERFACES %+v\n", typeInterfaces)
843+
// for idx := range tt {
844+
// if interfaceType, ok := typeInterfaces[tt[idx].Name]; ok {
845+
// tt[idx].InterfaceType = interfaceType
846+
// }
847+
// }
769848
typeTpls = append(typeTpls, tt...)
770849
enumTpls = append(enumTpls, et...)
771850
}
@@ -778,17 +857,42 @@ func createOneOf(s *openapi3.Schema, name, typeName string) ([]TypeTemplate, []E
778857
return typeTpls, enumTpls
779858
}
780859
}
860+
discriminator := ""
861+
if len(maybeDiscriminators) == 1 {
862+
discriminator = slices.Collect(maps.Keys(maybeDiscriminators))[0]
863+
fmt.Printf("DEBUG DISCRIMINATOR %s\n", discriminator)
864+
}
865+
866+
genericFieldName := ""
867+
if len(genericTypes) == 1 {
868+
genericFieldName = strcase.ToCamel(genericTypes[0])
869+
}
870+
fmt.Printf("DEBUG GENERIC FIELD %+v %d %s\n", genericTypes, len(genericTypes), genericFieldName)
781871

782872
// Make sure to only create structs if the oneOf is not a replacement for enums on the API spec
783873
if len(fields) > 0 {
784874
typeTpl := TypeTemplate{
785-
Description: formatTypeDescription(typeName, s),
786-
Name: typeName,
787-
Type: "struct",
788-
Fields: fields,
789-
}
875+
Description: formatTypeDescription(typeName, s),
876+
Name: typeName,
877+
Type: "struct",
878+
Fields: fields,
879+
DiscriminatorKey: discriminator,
880+
DiscriminatorType: discriminatorToDiscriminatorType[discriminator],
881+
DiscriminatorToType: discriminatorToType,
882+
GenericFieldName: genericFieldName,
883+
}
884+
fmt.Printf("DEBUG TT %+v\n", typeTpl)
790885
typeTpls = append(typeTpls, typeTpl)
791886
}
887+
888+
if interfaceName != "" {
889+
typeTpls = append(typeTpls, TypeTemplate{
890+
Name: interfaceName,
891+
// Name: fmt.Sprintf("%s%s", typeName, interfaceName),
892+
Type: "interface",
893+
})
894+
}
895+
792896
return typeTpls, enumTpls
793897
}
794898

0 commit comments

Comments
 (0)