55package main
66
77import (
8+ "encoding/json"
89 "fmt"
10+ "maps"
911 "os"
1012 "slices"
1113 "sort"
@@ -34,6 +36,17 @@ type TypeTemplate struct {
3436 Type string
3537 // Fields holds the information for the field
3638 Fields []TypeFields
39+
40+ // InterfaceType
41+ InterfaceType string
42+ // DiscriminatorKey
43+ DiscriminatorKey string
44+ // DiscriminatorType
45+ DiscriminatorType string
46+ // DiscriminatorToType
47+ DiscriminatorToType [][]string
48+
49+ GenericFieldName string
3750}
3851
3952// TypeFields holds the information for each type field
@@ -256,7 +269,7 @@ func constructTypes(schemas openapi3.Schemas) ([]TypeTemplate, []EnumTemplate) {
256269
257270 // Set name as a valid Go type name
258271 name = strcase .ToCamel (name )
259- typeTpl , enumTpl := populateTypeTemplates (name , s .Value , "" )
272+ typeTpl , enumTpl := populateTypeTemplates (name , s .Value , "" , "" )
260273 typeCollection = append (typeCollection , typeTpl ... )
261274 enumCollection = append (enumCollection , enumTpl ... )
262275 }
@@ -326,7 +339,12 @@ func writeTypes(f *os.File, typeCollection []TypeTemplate, typeValidationCollect
326339
327340 fmt .Fprintf (f , "%s\n " , splitDocString (tt .Description ))
328341 fmt .Fprintf (f , "type %s %s" , tt .Name , tt .Type )
329- if tt .Fields != nil {
342+
343+ if tt .Type == "interface" {
344+ fmt .Fprintf (f , " {\n " )
345+ fmt .Fprintf (f , "\t is%s()\n " , tt .Name )
346+ fmt .Fprint (f , "}\n " )
347+ } else if tt .Fields != nil {
330348 fmt .Fprint (f , " {\n " )
331349 for _ , ft := range tt .Fields {
332350 if ft .Description != "" {
@@ -335,6 +353,34 @@ func writeTypes(f *os.File, typeCollection []TypeTemplate, typeValidationCollect
335353 fmt .Fprintf (f , "\t %s %s %s\n " , ft .Name , ft .Type , ft .SerializationInfo )
336354 }
337355 fmt .Fprint (f , "}\n " )
356+
357+ if tt .DiscriminatorKey != "" && len (tt .DiscriminatorToType ) > 0 && tt .GenericFieldName != "" {
358+ fmt .Fprintf (f , "func (v *%s) UnmarshalJSON(data []byte) error {\n " , tt .Name )
359+ fmt .Fprintf (f , "\t var peek struct {\n " )
360+ fmt .Fprintf (f , "\t \t discriminator %s`json:\" %s\" `\n " , tt .DiscriminatorType , tt .DiscriminatorKey )
361+ fmt .Fprintf (f , "\t }\n " )
362+ fmt .Fprintf (f , "\t if err := json.Unmarshal(data, &peek); err != nil {\n " )
363+ fmt .Fprintf (f , "\t \t return err\n " )
364+ fmt .Fprintf (f , "\t }\n " )
365+ fmt .Fprintf (f , "\t switch peek.discriminator {\n " )
366+
367+ for _ , dtt := range tt .DiscriminatorToType {
368+ fmt .Fprintf (f , "\t case %s:\n " , dtt [0 ])
369+ fmt .Fprintf (f , "\t \t var val %s\n " , dtt [1 ])
370+ fmt .Fprintf (f , "\t \t if err := json.Unmarshal(data, &val); err != nil {\n " )
371+ fmt .Fprintf (f , "\t \t \t return err\n " )
372+ fmt .Fprintf (f , "\t \t }\n " )
373+ fmt .Fprintf (f , "\t v.%s = val\n " , tt .GenericFieldName )
374+ }
375+
376+ fmt .Fprintf (f , "\t }\n " )
377+ fmt .Fprintf (f , "\t return nil\n " )
378+ fmt .Fprint (f , "}\n " )
379+ }
380+ }
381+ if tt .InterfaceType != "" {
382+ fmt .Fprintf (f , "\n " )
383+ fmt .Fprintf (f , "func (%s) is%s() {}\n " , tt .Name , tt .InterfaceType )
338384 }
339385 fmt .Fprint (f , "\n " )
340386 }
@@ -381,7 +427,7 @@ func writeTypes(f *os.File, typeCollection []TypeTemplate, typeValidationCollect
381427// populateTypeTemplates populates the template of a type definition for the given schema.
382428// The additional parameter is only used as a suffix for the type name.
383429// This is mostly for oneOf types.
384- func populateTypeTemplates (name string , s * openapi3.Schema , enumFieldName string ) ([]TypeTemplate , []EnumTemplate ) {
430+ func populateTypeTemplates (name string , s * openapi3.Schema , enumFieldName string , interfaceName string ) ([]TypeTemplate , []EnumTemplate ) {
385431 typeName := name
386432
387433 // Type name will change for each enum type
@@ -418,21 +464,24 @@ func populateTypeTemplates(name string, s *openapi3.Schema, enumFieldName string
418464 typeTpl .Type = fmt .Sprintf ("[]%s" , s .Items .Value .Type )
419465 typeTpl .Name = typeName
420466 case "object" :
467+ b , _ := json .Marshal (s )
468+ fmt .Printf ("DEBUG IN OBJECT %s %+v %+v\n " , typeName , string (b ), enumFieldName )
421469 typeTpl = createTypeObject (s , name , typeName , formatTypeDescription (typeName , s ))
470+ typeTpl .InterfaceType = interfaceName
422471
423472 // Iterate over the properties and append the types, if we need to.
424473 properties := sortedKeys (s .Properties )
425474 for _ , k := range properties {
426475 v := s .Properties [k ]
427476 if isLocalEnum (v ) {
428- tt , et := populateTypeTemplates (fmt .Sprintf ("%s%s" , name , strcase .ToCamel (k )), v .Value , "" )
477+ tt , et := populateTypeTemplates (fmt .Sprintf ("%s%s" , name , strcase .ToCamel (k )), v .Value , "" , "" )
429478 types = append (types , tt ... )
430479 enumTypes = append (enumTypes , et ... )
431480 }
432481
433482 // TODO: So far this code is never hit with the current openapi spec
434483 if isLocalObject (v ) {
435- tt , et := populateTypeTemplates (fmt .Sprintf ("%s%s" , name , strcase .ToCamel (k )), v .Value , "" )
484+ tt , et := populateTypeTemplates (fmt .Sprintf ("%s%s" , name , strcase .ToCamel (k )), v .Value , "" , "" )
436485 types = append (types , tt ... )
437486 enumTypes = append (enumTypes , et ... )
438487 }
@@ -649,6 +698,7 @@ func createAllOf(s *openapi3.Schema, stringEnums map[string][]string, name, type
649698}
650699
651700func createOneOf (s * openapi3.Schema , name , typeName string ) ([]TypeTemplate , []EnumTemplate ) {
701+ fmt .Printf ("DEBUG IN CREATEONEOF %v\n " , typeName )
652702 var parsedProperties []string
653703 var properties []string
654704 var genericTypes []string
@@ -698,6 +748,13 @@ func createOneOf(s *openapi3.Schema, name, typeName string) ([]TypeTemplate, []E
698748 }
699749 }
700750
751+ fmt .Printf ("DEBUG GENERIC TYPES %+v\n " , genericTypes )
752+ typeInterfaces := map [string ]string {}
753+ maybeDiscriminators := map [string ]struct {}{}
754+ discriminatorToType := [][]string {}
755+ discriminatorToDiscriminatorType := map [string ]string {}
756+ interfaceName := ""
757+
701758 for _ , v := range s .OneOf {
702759 // We want to iterate over the properties of the embedded object
703760 // and find the type that is a string.
@@ -707,17 +764,27 @@ func createOneOf(s *openapi3.Schema, name, typeName string) ([]TypeTemplate, []E
707764 keys := sortedKeys (v .Value .Properties )
708765 for _ , prop := range keys {
709766 p := v .Value .Properties [prop ]
767+ propertyName := strcase .ToCamel (prop )
768+
710769 // We want to collect all the unique properties to create our global oneOf type.
711770 propertyType := convertToValidGoType (prop , typeName , p )
712771
772+ if propertyType == "string" && len (p .Value .Enum ) == 1 {
773+ fmt .Printf ("DEBUG PROP IS %s\n " , prop )
774+ maybeDiscriminators [prop ] = struct {}{}
775+ discriminatorToDiscriminatorType [prop ] = typeName + strcase .ToCamel (prop )
776+ discriminatorToType = append (discriminatorToType , []string {
777+ fmt .Sprintf ("%s%s%s" , typeName , propertyName , strcase .ToCamel (p .Value .Enum [0 ].(string ))),
778+ fmt .Sprintf ("%s%s" , typeName , strcase .ToCamel (p .Value .Enum [0 ].(string ))),
779+ })
780+ // discriminatorToType[p.Value.Enum[0].(string)] = fmt.Sprintf("%s%s%s", typeName, propertyName, strcase.ToCamel(p.Value.Enum[0].(string)))
781+ }
713782 // Check if we have an enum in order to use the corresponding type instead of
714783 // "string"
715784 if propertyType == "string" && len (p .Value .Enum ) != 0 {
716785 propertyType = typeName + strcase .ToCamel (prop )
717786 }
718787
719- propertyName := strcase .ToCamel (prop )
720-
721788 // Avoids duplication for every enum
722789 if ! containsMatchFirstWord (parsedProperties , propertyName ) {
723790 field := TypeFields {
@@ -729,7 +796,11 @@ func createOneOf(s *openapi3.Schema, name, typeName string) ([]TypeTemplate, []E
729796
730797 // We set the type of a field as "any" if every element of the oneOf property isn't the same
731798 if slices .Contains (genericTypes , prop ) {
732- field .Type = "any"
799+ fmt .Printf ("DEBUG CONTAINS %+v %+v %s\n " , genericTypes , prop , typeName )
800+ interfaceName = fmt .Sprintf ("%s%s" , typeName , propertyName )
801+ field .Type = fmt .Sprintf ("%s%s" , typeName , propertyName )
802+
803+ // typeInterfaces[field.Type] = fmt.Sprintf("%s%s", typeName, propertyName)
733804 }
734805
735806 // Check if the field is nullable and use omitzero instead of omitempty.
@@ -765,7 +836,15 @@ func createOneOf(s *openapi3.Schema, name, typeName string) ([]TypeTemplate, []E
765836
766837 // TODO: This is the only place that has an "additional name" at the end
767838 // TODO: This is where the "allOf" is being detected
768- tt , et := populateTypeTemplates (name , v .Value , enumFieldName )
839+ fmt .Printf ("DEBUG ABOUT TO POPULATE %s %+v %s\n " , name , enumFieldName , interfaceName )
840+ tt , et := populateTypeTemplates (name , v .Value , enumFieldName , interfaceName )
841+
842+ fmt .Printf ("DEBUG TYPE INTERFACES %+v\n " , typeInterfaces )
843+ // for idx := range tt {
844+ // if interfaceType, ok := typeInterfaces[tt[idx].Name]; ok {
845+ // tt[idx].InterfaceType = interfaceType
846+ // }
847+ // }
769848 typeTpls = append (typeTpls , tt ... )
770849 enumTpls = append (enumTpls , et ... )
771850 }
@@ -778,17 +857,42 @@ func createOneOf(s *openapi3.Schema, name, typeName string) ([]TypeTemplate, []E
778857 return typeTpls , enumTpls
779858 }
780859 }
860+ discriminator := ""
861+ if len (maybeDiscriminators ) == 1 {
862+ discriminator = slices .Collect (maps .Keys (maybeDiscriminators ))[0 ]
863+ fmt .Printf ("DEBUG DISCRIMINATOR %s\n " , discriminator )
864+ }
865+
866+ genericFieldName := ""
867+ if len (genericTypes ) == 1 {
868+ genericFieldName = strcase .ToCamel (genericTypes [0 ])
869+ }
870+ fmt .Printf ("DEBUG GENERIC FIELD %+v %d %s\n " , genericTypes , len (genericTypes ), genericFieldName )
781871
782872 // Make sure to only create structs if the oneOf is not a replacement for enums on the API spec
783873 if len (fields ) > 0 {
784874 typeTpl := TypeTemplate {
785- Description : formatTypeDescription (typeName , s ),
786- Name : typeName ,
787- Type : "struct" ,
788- Fields : fields ,
789- }
875+ Description : formatTypeDescription (typeName , s ),
876+ Name : typeName ,
877+ Type : "struct" ,
878+ Fields : fields ,
879+ DiscriminatorKey : discriminator ,
880+ DiscriminatorType : discriminatorToDiscriminatorType [discriminator ],
881+ DiscriminatorToType : discriminatorToType ,
882+ GenericFieldName : genericFieldName ,
883+ }
884+ fmt .Printf ("DEBUG TT %+v\n " , typeTpl )
790885 typeTpls = append (typeTpls , typeTpl )
791886 }
887+
888+ if interfaceName != "" {
889+ typeTpls = append (typeTpls , TypeTemplate {
890+ Name : interfaceName ,
891+ // Name: fmt.Sprintf("%s%s", typeName, interfaceName),
892+ Type : "interface" ,
893+ })
894+ }
895+
792896 return typeTpls , enumTpls
793897}
794898
0 commit comments