@@ -361,8 +361,11 @@ func writeTypes(f *os.File, typeCollection []TypeTemplate, typeValidationCollect
361361 fmt .Fprint (f , "}\n " )
362362 }
363363
364+ // Write custom UnmarshalJSON method for oneOf types.
364365 if tt .DiscriminatorKey != "" && tt .VariantField != "" {
365366 fmt .Fprintf (f , "func (v *%s) UnmarshalJSON(data []byte) error {\n " , tt .Name )
367+
368+ // Check the discriminator to decide which type to unmarshal to.
366369 fmt .Fprintf (f , "\t var peek struct {\n " )
367370 fmt .Fprintf (f , "\t \t Discriminator %s `json:\" %s\" `\n " , tt .DiscriminatorType , tt .DiscriminatorKey )
368371 fmt .Fprintf (f , "\t }\n " )
@@ -371,10 +374,12 @@ func writeTypes(f *os.File, typeCollection []TypeTemplate, typeValidationCollect
371374 fmt .Fprintf (f , "\t }\n " )
372375 fmt .Fprintf (f , "\t switch peek.Discriminator {\n " )
373376
377+ // Construct a case for each possible variant.
374378 for _ , mapping := range tt .DiscriminatorMappings {
375379 fmt .Fprintf (f , "\t case %s:\n " , mapping .EnumConstant )
376380
377- if slices .Contains ([]string {"string" , "int" , "*bool" }, mapping .ObjectType ) {
381+ // For objects, unmarshal into the corresponding struct. For simple types, unmarshal into a temporary struct, then grab the value from it.
382+ if isSimpleType (mapping .ObjectType ) {
378383 fmt .Fprintf (f , "\t \t type value struct {\n " )
379384 fmt .Fprintf (f , "\t \t \t Value %s `json:\" %s\" `\n " , mapping .ConcreteType , strings .ToLower (tt .VariantField ))
380385 fmt .Fprintf (f , "\t \t }\n " )
@@ -392,7 +397,7 @@ func writeTypes(f *os.File, typeCollection []TypeTemplate, typeValidationCollect
392397 }
393398 }
394399 fmt .Fprintf (f , "\t default:\n " )
395- fmt .Fprintf (f , "\t \t return fmt.Errorf(\" unknown %s discriminator value for %s %%s : %%v\" , string(data) , peek.Discriminator)\n " , tt .Name , tt .DiscriminatorKey )
400+ fmt .Fprintf (f , "\t \t return fmt.Errorf(\" unknown %s discriminator value for %s: %%v\" , peek.Discriminator)\n " , tt .Name , tt .DiscriminatorKey )
396401 fmt .Fprintf (f , "\t }\n " )
397402 fmt .Fprintf (f , "\t v.%s = peek.Discriminator\n " , tt .DiscriminatorField )
398403 fmt .Fprintf (f , "\t return nil\n " )
@@ -479,7 +484,7 @@ func populateTypeTemplates(name string, s *openapi3.Schema, enumFieldName string
479484 case "string" , "*bool" , "int" , "int8" , "int16" , "int32" , "int64" , "uint" , "uint8" ,
480485 "uint16" , "uint32" , "uint64" , "uintptr" , "float32" , "float64" :
481486 typeTpl .Description = formatTypeDescription (typeName , s )
482- typeTpl .Type = ot
487+ typeTpl .Type = strings . TrimPrefix ( ot , "*" )
483488 typeTpl .Name = typeName
484489 case "array" :
485490 typeTpl .Description = formatTypeDescription (typeName , s )
@@ -850,33 +855,29 @@ func createOneOf(s *openapi3.Schema, name, typeName string) ([]TypeTemplate, []E
850855 // TODO: This is the only place that has an "additional name" at the end
851856 // TODO: This is where the "allOf" is being detected
852857 if len (variantFields ) == 1 && v .Value .Properties [variantFields [0 ]] != nil {
853- variantType := getObjectType ( v . Value . Properties [ variantFields [0 ]]. Value )
854- // if slices.Contains([]string{"string", "*bool", "int", "int8", "int16", "int32", "int64", "uint", "uint8",
855- // "uint16", "uint32", "uint64", "uintptr", "float32", "float64", "bytes"}, variantType) {
856- if ! slices . Contains ([] string { "array" , "object" , "all_of" , "any_of" , "one_of" , "string_enum" }, variantType ) {
857- tt , _ := populateTypeTemplates (name , v .Value .Properties [variantFields [ 0 ] ].Value , enumFieldName , variantInterface )
858+ variantField := variantFields [0 ]
859+ variantType := getObjectType ( v . Value . Properties [ variantField ]. Value )
860+ if isSimpleType ( variantType ) {
861+ // Special case: process the variant field property separately
862+ tt , _ := populateTypeTemplates (name , v .Value .Properties [variantField ].Value , enumFieldName , variantInterface )
858863 typeTpls = append (typeTpls , tt ... )
859- // enumTpls = append(enumTpls, et...)
860- fooTT , et := populateTypeTemplates (name , v .Value , enumFieldName , variantInterface )
861- // typeTpls = append(typeTpls, tt...)
864+ parentTT , et := populateTypeTemplates (name , v .Value , enumFieldName , variantInterface )
862865 enumTpls = append (enumTpls , et ... )
863866
864- for idx , tt := range fooTT {
865- if strings . HasSuffix ( tt . Name , "Type" ) {
866- fmt . Printf ( "DEBUG TT TYPE %d %s %s %+v \n " , idx , name , enumFieldName , tt )
867+ // Only include parent types with Type suffix
868+ for _ , tt := range parentTT {
869+ if strings . HasSuffix ( tt . Name , strcase . ToCamel ( discriminator )) {
867870 typeTpls = append (typeTpls , tt )
868871 }
869872 }
870- } else {
871- tt , et := populateTypeTemplates (name , v .Value , enumFieldName , variantInterface )
872- typeTpls = append (typeTpls , tt ... )
873- enumTpls = append (enumTpls , et ... )
873+ continue
874874 }
875- } else {
876- tt , et := populateTypeTemplates (name , v .Value , enumFieldName , variantInterface )
877- typeTpls = append (typeTpls , tt ... )
878- enumTpls = append (enumTpls , et ... )
879875 }
876+
877+ // Normal case: process the parent schema
878+ tt , et := populateTypeTemplates (name , v .Value , enumFieldName , variantInterface )
879+ typeTpls = append (typeTpls , tt ... )
880+ enumTpls = append (enumTpls , et ... )
880881 }
881882
882883 // TODO: For now AllOf values within a OneOf are treated as enums
@@ -958,3 +959,9 @@ func formatTypeDescription(name string, s *openapi3.Schema) string {
958959 }
959960 return fmt .Sprintf ("// %s is the type definition for a %s." , name , name )
960961}
962+
963+ func isSimpleType (t string ) bool {
964+ simpleTypes := []string {"string" , "*bool" , "int" , "int8" , "int16" , "int32" , "int64" , "uint" , "uint8" ,
965+ "uint16" , "uint32" , "uint64" , "uintptr" , "float32" , "float64" }
966+ return slices .Contains (simpleTypes , t )
967+ }
0 commit comments