|
1 | 1 | import re |
2 | | -from typing import Any, List, NamedTuple, Optional, Type, Union |
| 2 | +import sys |
| 3 | +from typing import Any, List, NamedTuple, Optional, Tuple, Type, Union |
3 | 4 |
|
4 | 5 | import pytest |
5 | 6 | from flask import jsonify |
|
15 | 16 | from ..util import assert_matches |
16 | 17 |
|
17 | 18 |
|
| 19 | +class EmptyModel(BaseModel): |
| 20 | + pass |
| 21 | + |
| 22 | + |
18 | 23 | class ValidateParams(NamedTuple): |
19 | | - body_model: Optional[Type[BaseModel]] = None |
20 | | - query_model: Optional[Type[BaseModel]] = None |
21 | | - form_model: Optional[Type[BaseModel]] = None |
22 | | - response_model: Type[BaseModel] = None |
| 24 | + body_model: Type[BaseModel] = EmptyModel |
| 25 | + query_model: Type[BaseModel] = EmptyModel |
| 26 | + form_model: Type[BaseModel] = EmptyModel |
| 27 | + response_model: Type[BaseModel] = EmptyModel |
23 | 28 | on_success_status: int = 200 |
24 | 29 | request_query: ImmutableMultiDict = ImmutableMultiDict({}) |
| 30 | + flat_request_query: bool = True |
25 | 31 | request_body: Union[dict, List[dict]] = {} |
26 | 32 | request_form: ImmutableMultiDict = ImmutableMultiDict({}) |
27 | 33 | expected_response_body: Optional[dict] = None |
@@ -50,7 +56,25 @@ class RequestBodyModel(BaseModel): |
50 | 56 |
|
51 | 57 | class FormModel(BaseModel): |
52 | 58 | f1: int |
53 | | - f2: str = None |
| 59 | + f2: Optional[str] = None |
| 60 | + |
| 61 | + |
| 62 | +class RequestWithIterableModel(BaseModel): |
| 63 | + b1: List |
| 64 | + b2: List[str] |
| 65 | + b3: Tuple[str, int] |
| 66 | + b4: Optional[List[int]] = None |
| 67 | + b5: Union[Tuple[str, int], None] = None |
| 68 | + |
| 69 | + |
| 70 | +if sys.version_info >= (3, 10): |
| 71 | + # New Python(>=3.10) syntax tests |
| 72 | + class RequestWithIterableModelPy310(BaseModel): |
| 73 | + b1: list |
| 74 | + b2: list[str] |
| 75 | + b3: tuple[str, int] |
| 76 | + b4: list[int] | None = None |
| 77 | + b5: tuple[str, int] | None = None |
54 | 78 |
|
55 | 79 |
|
56 | 80 | class RequestBodyModelRoot(RootModel): |
@@ -195,8 +219,76 @@ class RequestBodyModelRoot(RootModel): |
195 | 219 | ), |
196 | 220 | id="invalid form param", |
197 | 221 | ), |
| 222 | + pytest.param( |
| 223 | + ValidateParams( |
| 224 | + request_query=ImmutableMultiDict( |
| 225 | + [ |
| 226 | + ("b1", "str1"), |
| 227 | + ("b1", "str2"), |
| 228 | + ("b2", "str1"), |
| 229 | + ("b2", "str2"), |
| 230 | + ("b3", "str"), |
| 231 | + ("b3", 123), |
| 232 | + ("b4", 1), |
| 233 | + ("b4", 2), |
| 234 | + ("b4", 3), |
| 235 | + ("b5", "str"), |
| 236 | + ("b5", 321), |
| 237 | + ] |
| 238 | + ), |
| 239 | + flat_request_query=False, |
| 240 | + expected_response_body={ |
| 241 | + "b1": ["str1", "str2"], |
| 242 | + "b2": ["str1", "str2"], |
| 243 | + "b3": ("str", 123), |
| 244 | + "b4": [1, 2, 3], |
| 245 | + "b5": ("str", 321), |
| 246 | + }, |
| 247 | + query_model=RequestWithIterableModel, |
| 248 | + response_model=RequestWithIterableModel, |
| 249 | + expected_status_code=200, |
| 250 | + ), |
| 251 | + id="iterable and Optional[Iterable] fields in pydantic model in query", |
| 252 | + ), |
198 | 253 | ] |
199 | 254 |
|
| 255 | +if sys.version_info >= (3, 10): |
| 256 | + validate_test_cases.extend( |
| 257 | + [ |
| 258 | + pytest.param( |
| 259 | + ValidateParams( |
| 260 | + request_query=ImmutableMultiDict( |
| 261 | + [ |
| 262 | + ("b1", "str1"), |
| 263 | + ("b1", "str2"), |
| 264 | + ("b2", "str1"), |
| 265 | + ("b2", "str2"), |
| 266 | + ("b3", "str"), |
| 267 | + ("b3", 123), |
| 268 | + ("b4", 1), |
| 269 | + ("b4", 2), |
| 270 | + ("b4", 3), |
| 271 | + ("b5", "str"), |
| 272 | + ("b5", 321), |
| 273 | + ] |
| 274 | + ), |
| 275 | + flat_request_query=False, |
| 276 | + expected_response_body={ |
| 277 | + "b1": ["str1", "str2"], |
| 278 | + "b2": ["str1", "str2"], |
| 279 | + "b3": ("str", 123), |
| 280 | + "b4": [1, 2, 3], |
| 281 | + "b5": ("str", 321), |
| 282 | + }, |
| 283 | + query_model=RequestWithIterableModelPy310, |
| 284 | + response_model=RequestWithIterableModelPy310, |
| 285 | + expected_status_code=200, |
| 286 | + ), |
| 287 | + id="iterable and Iterable | None fields in pydantic model in query (Python 3.10+)", |
| 288 | + ), |
| 289 | + ] |
| 290 | + ) |
| 291 | + |
200 | 292 |
|
201 | 293 | class TestValidate: |
202 | 294 | @pytest.mark.parametrize("parameters", validate_test_cases) |
@@ -230,17 +322,17 @@ def f(): |
230 | 322 | assert response.status_code == parameters.expected_status_code |
231 | 323 | assert_matches(parameters.expected_response_body, response.json) |
232 | 324 | if 200 <= response.status_code < 300: |
233 | | - assert ( |
| 325 | + assert_matches( |
| 326 | + parameters.request_body, |
234 | 327 | mock_request.body_params.model_dump( |
235 | 328 | exclude_none=True, exclude_defaults=True |
236 | | - ) |
237 | | - == parameters.request_body |
| 329 | + ), |
238 | 330 | ) |
239 | | - assert ( |
| 331 | + assert_matches( |
| 332 | + parameters.request_query.to_dict(flat=parameters.flat_request_query), |
240 | 333 | mock_request.query_params.model_dump( |
241 | 334 | exclude_none=True, exclude_defaults=True |
242 | | - ) |
243 | | - == parameters.request_query.to_dict() |
| 335 | + ), |
244 | 336 | ) |
245 | 337 |
|
246 | 338 | @pytest.mark.parametrize("parameters", validate_test_cases) |
@@ -269,17 +361,17 @@ def f( |
269 | 361 | assert_matches(parameters.expected_response_body, response.json) |
270 | 362 | assert response.status_code == parameters.expected_status_code |
271 | 363 | if 200 <= response.status_code < 300: |
272 | | - assert ( |
| 364 | + assert_matches( |
| 365 | + parameters.request_body, |
273 | 366 | mock_request.body_params.model_dump( |
274 | 367 | exclude_none=True, exclude_defaults=True |
275 | | - ) |
276 | | - == parameters.request_body |
| 368 | + ), |
277 | 369 | ) |
278 | | - assert ( |
| 370 | + assert_matches( |
| 371 | + parameters.request_query.to_dict(flat=parameters.flat_request_query), |
279 | 372 | mock_request.query_params.model_dump( |
280 | 373 | exclude_none=True, exclude_defaults=True |
281 | | - ) |
282 | | - == parameters.request_query.to_dict() |
| 374 | + ), |
283 | 375 | ) |
284 | 376 |
|
285 | 377 | @pytest.mark.usefixtures("request_ctx") |
@@ -468,17 +560,17 @@ def f() -> Any: |
468 | 560 | assert response.status_code == parameters.expected_status_code |
469 | 561 | assert_matches(parameters.expected_response_body, response.json) |
470 | 562 | if 200 <= response.status_code < 300: |
471 | | - assert ( |
| 563 | + assert_matches( |
| 564 | + parameters.request_body, |
472 | 565 | mock_request.body_params.model_dump( |
473 | 566 | exclude_none=True, exclude_defaults=True |
474 | | - ) |
475 | | - == parameters.request_body |
| 567 | + ), |
476 | 568 | ) |
477 | | - assert ( |
| 569 | + assert_matches( |
| 570 | + parameters.request_query.to_dict(flat=parameters.flat_request_query), |
478 | 571 | mock_request.query_params.model_dump( |
479 | 572 | exclude_none=True, exclude_defaults=True |
480 | | - ) |
481 | | - == parameters.request_query.to_dict() |
| 573 | + ), |
482 | 574 | ) |
483 | 575 |
|
484 | 576 | def test_fail_validation_custom_status_code(self, app, request_ctx, mocker): |
|
0 commit comments