diff --git a/python/cog/predictor.py b/python/cog/predictor.py index f86753420f..27c65d616a 100644 --- a/python/cog/predictor.py +++ b/python/cog/predictor.py @@ -282,20 +282,30 @@ class Config: parameter.default.default is PydanticUndefined or parameter.default.default is ... ): - if PYDANTIC_V2: - parameter.default.default = None - else: + parameter.default.default = None + if not PYDANTIC_V2: parameter.default.default_factory = None - parameter.default.default = None default = parameter.default + extra: Dict[str, Any] = {} if PYDANTIC_V2: # https://github.com/pydantic/pydantic/blob/2.7/pydantic/json_schema.py#L1436-L1446 # json_schema_extra can be a callable, but we don't set that and users shouldn't set that if not default.json_schema_extra: # type: ignore - default.json_schema_extra = {} # type: ignore + default.json_schema_extra = {"x-order": order} # type: ignore assert isinstance(default.json_schema_extra, dict) # type: ignore - extra = default.json_schema_extra # type: ignore + # In Pydantic 2.12.0 the json_schema_extra field is copied into a variable called "_attributes_set" + # that gets created in the constructor. + # This means that changes to that dictionary after the construction don't take effect during the render + # to openapi schema JSON. + # To get around this, we will reference the dictionary in the attributes_set variable and make changes to + # json_schema_extra take effect. + if hasattr(default, "_attributes_set"): + if "json_schema_extra" not in default._attributes_set: # type: ignore + default._attributes_set["json_schema_extra"] = {"x-order": order} + extra = default._attributes_set["json_schema_extra"] # type: ignore + else: + extra = default.json_schema_extra # type: ignore else: extra = default.extra # type: ignore extra["x-order"] = order diff --git a/python/cog/server/helpers.py b/python/cog/server/helpers.py index c7af410aa2..9c86efa438 100644 --- a/python/cog/server/helpers.py +++ b/python/cog/server/helpers.py @@ -461,6 +461,7 @@ def update_openapi_schema_for_pydantic_2( _extract_enum_properties(openapi_schema) _set_default_enumeration_description(openapi_schema) _restore_allof_for_prediction_id_put(openapi_schema) + _ensure_nullable_properties_not_required(openapi_schema) def _remove_webhook_events_filter_title( @@ -501,7 +502,7 @@ def _update_nullable_anyof( if len(non_null_items) < len(value) and not in_header: openapi_schema["nullable"] = True - elif isinstance(openapi_schema, list): + elif isinstance(openapi_schema, list): # type: ignore for item in openapi_schema: _update_nullable_anyof(item, in_header=in_header) @@ -593,3 +594,13 @@ def _restore_allof_for_prediction_id_put( ref = value["$ref"] del value["$ref"] value["allOf"] = [{"$ref": ref}] + + +def _ensure_nullable_properties_not_required(openapi_schema: Dict[str, Any]) -> None: + schemas = openapi_schema["components"]["schemas"] + for schema in schemas.values(): + properties = schema.get("properties", {}) + nullable = {k for k, v in properties.items() if v.get("nullable", False)} + + if "required" in schema and nullable: + schema["required"] = [k for k in schema["required"] if k not in nullable]