Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions flask_pydantic/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,22 @@ def _model_dump_json(model: V1OrV2BaseModel, **kwargs):
return model.json(**kwargs)


def _sanitize_ctx_errors(errors):
"""
Make Pydantic `ctx["error"]` JSON-serializable by replacing
exception instances with {type, message}.
"""
for error in errors:
ctx = error.get("ctx")
if isinstance(ctx, dict) and isinstance(ctx.get("error"), Exception):
exc = ctx["error"]
ctx["error"] = {
"type": type(exc).__name__,
"message": str(exc),
}
return errors


def make_json_response(
content: Union[V1OrV2BaseModel, Iterable[V1OrV2BaseModel]],
status_code: int,
Expand Down Expand Up @@ -80,7 +96,7 @@ def validate_many_models(

raise ManyModelValidationError(err) from te
except (ValidationError, V1ValidationError) as ve:
raise ManyModelValidationError(ve.errors()) from ve
raise ManyModelValidationError(_sanitize_ctx_errors(ve.errors())) from ve


def validate_path_params(func: Callable, kwargs: dict) -> Tuple[dict, list]:
Expand Down Expand Up @@ -196,7 +212,7 @@ def wrapper(*args, **kwargs):
try:
q = query_model(**query_params)
except (ValidationError, V1ValidationError) as ve:
err["query_params"] = ve.errors()
err["query_params"] = _sanitize_ctx_errors(ve.errors())
body_in_kwargs = func.__annotations__.get("body")
body_model = body_in_kwargs or body
if body_model:
Expand All @@ -208,12 +224,12 @@ def wrapper(*args, **kwargs):
try:
b = body_model(__root__=body_params).__root__
except (ValidationError, V1ValidationError) as ve:
err["body_params"] = ve.errors()
err["body_params"] = _sanitize_ctx_errors(ve.errors())
elif issubclass(body_model, RootModel):
try:
b = body_model(body_params)
except (ValidationError, V1ValidationError) as ve:
err["body_params"] = ve.errors()
err["body_params"] = _sanitize_ctx_errors(ve.errors())
elif request_body_many:
try:
b = validate_many_models(body_model, body_params)
Expand All @@ -230,7 +246,7 @@ def wrapper(*args, **kwargs):
else:
raise JsonBodyParsingError() from te
except (ValidationError, V1ValidationError) as ve:
err["body_params"] = ve.errors()
err["body_params"] = _sanitize_ctx_errors(ve.errors())
form_in_kwargs = func.__annotations__.get("form")
form_model = form_in_kwargs or form
if form_model:
Expand All @@ -242,12 +258,12 @@ def wrapper(*args, **kwargs):
try:
f = form_model(form_params)
except (ValidationError, V1ValidationError) as ve:
err["form_params"] = ve.errors()
err["form_params"] = _sanitize_ctx_errors(ve.errors())
elif issubclass(form_model, RootModel):
try:
f = form_model(form_params)
except (ValidationError, V1ValidationError) as ve:
err["form_params"] = ve.errors()
err["form_params"] = _sanitize_ctx_errors(ve.errors())
else:
try:
f = form_model(**form_params)
Expand All @@ -259,7 +275,7 @@ def wrapper(*args, **kwargs):
else:
raise JsonBodyParsingError from te
except (ValidationError, V1ValidationError) as ve:
err["form_params"] = ve.errors()
err["form_params"] = _sanitize_ctx_errors(ve.errors())
request.query_params = q
request.body_params = b
request.form_params = f
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ build-backend = "flit_core.buildapi"
name = "flask_pydantic"

[tool.pytest]
testpaths = "tests"
addopts = "-vv --ruff --ruff-format --cov --cov-config=pyproject.toml -s"
testpaths = ["tests"]
addopts = ["-vv", "--ruff", "--ruff-format", "--cov", "--cov-config=pyproject.toml", "-s"]

[tool.ruff]
src = ["flask_pydantic"]
Expand Down
87 changes: 86 additions & 1 deletion tests/func/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
from flask import jsonify, request
from flask_pydantic import ValidationError, validate
from pydantic import BaseModel, ConfigDict, RootModel
from pydantic import BaseModel, ConfigDict, RootModel, field_validator, model_validator

from ..util import assert_matches

Expand Down Expand Up @@ -163,6 +163,32 @@ async def compute(body: RequestModel):
return ResultModel(result=2 * body.n)


@pytest.fixture
def app_with_field_and_model_validators(app):
class RequestModel(BaseModel):
value: str
other: str

@field_validator("value")
def must_be_foo(cls, v):
if v != "foo":
raise ValueError("value must be foo")
return v

@model_validator(mode="after")
def check_other(self):
if self.other != "ok":
raise ValueError("other must be ok")
return self

@app.route("/validate", methods=["POST"])
@validate()
def handler(body: RequestModel):
return body

return app


test_cases = [
pytest.param(
"?limit=limit",
Expand Down Expand Up @@ -516,3 +542,62 @@ def test_succeeds(self, client):
response = client.post("/compute", json={"n": 1})

assert_matches(expected_response, response.json)


@pytest.mark.usefixtures("app_with_field_and_model_validators")
class TestValidatorResponse:
def test_fail_field_validator(self, client):
response = client.post("/validate", json={"value": "foo1", "other": "ok"})

assert_matches(
{
"validation_error": {
"body_params": [
{
"input": "foo1",
"loc": ["value"],
"msg": "Value error, value must be foo",
"type": "value_error",
"url": re.compile(
r"https://errors\.pydantic\.dev/.*/v/value_error"
),
"ctx": {
"error": {
"message": "value must be foo",
"type": "ValueError",
}
},
}
]
}
},
response.json,
)

def test_fail_model_validator(self, client):
response = client.post("/validate", json={"value": "foo", "other": "no"})

assert_matches(
{
"validation_error": {
"body_params": [
{
"input": {"value": "foo", "other": "no"},
"loc": [],
"msg": "Value error, other must be ok",
"type": "value_error",
"url": re.compile(
r"https://errors\.pydantic\.dev/.*/v/value_error"
),
"ctx": {
"error": {
"message": "other must be ok",
"type": "ValueError",
}
},
}
]
}
},
response.json,
)
67 changes: 66 additions & 1 deletion tests/pydantic_v1/func/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from flask import jsonify, request
from flask_pydantic import ValidationError, validate
from pydantic.v1 import BaseModel
from pydantic.v1 import BaseModel, root_validator, validator

from ...util import assert_matches

Expand Down Expand Up @@ -149,6 +149,32 @@ def compute(query: RequestModel):
)


@pytest.fixture
def app_with_field_and_root_validators(app):
class RequestModel(BaseModel):
value: str
other: str

@validator("value", allow_reuse=True)
def must_be_foo(cls, v):
if v != "foo":
raise ValueError("value must be foo")
return v

@root_validator(allow_reuse=True)
def check_other(cls, values):
if values.get("other") != "ok":
raise ValueError("other must be ok")
return values

@app.route("/validate", methods=["POST"])
@validate()
def handler(body: RequestModel):
return body

return app


test_cases = [
pytest.param(
"?limit=limit",
Expand Down Expand Up @@ -449,3 +475,42 @@ def test_silent(self, client):
response.json["body"],
)
assert response.status_code == 422


@pytest.mark.usefixtures("app_with_field_and_root_validators")
class TestValidatorResponse:
def test_fail_field_validator(self, client):
response = client.post("/validate", json={"value": "foo1", "other": "ok"})

assert_matches(
{
"validation_error": {
"body_params": [
{
"loc": ["value"],
"msg": "value must be foo",
"type": "value_error",
}
]
}
},
response.json,
)

def test_fail_model_validator(self, client):
response = client.post("/validate", json={"value": "foo", "other": "no"})

assert_matches(
{
"validation_error": {
"body_params": [
{
"loc": ["__root__"],
"msg": "other must be ok",
"type": "value_error",
}
]
}
},
response.json,
)
Loading