From 8922ce08c5da32ffdb29b8347879f20e2b639c30 Mon Sep 17 00:00:00 2001 From: Michal Korbela Date: Mon, 6 Nov 2023 13:07:03 +0100 Subject: [PATCH] chore: Upgrade pydantic>=2.0, flask>=3.0 --- .github/workflows/build.yml | 4 +- flask_ninja/__init__.py | 2 +- flask_ninja/api.py | 29 +- flask_ninja/model_field.py | 68 ++++ flask_ninja/models.py | 98 +++--- flask_ninja/operation.py | 137 ++++----- flask_ninja/param.py | 33 +- flask_ninja/param_functions.py | 52 +++- flask_ninja/router.py | 2 +- flask_ninja/utils.py | 290 +++++++++--------- pyproject.toml | 6 +- .../test_api/test_get_schema/api_schema | 2 +- tests/unit/test_operation.py | 160 +++++----- tests/unit/test_router.py | 4 +- tests/unit/test_utils.py | 0 tox.ini | 2 +- 16 files changed, 479 insertions(+), 410 deletions(-) create mode 100644 flask_ninja/model_field.py create mode 100644 tests/unit/test_utils.py diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index ca8e38e..863c27c 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -33,12 +33,14 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - tox_job: [py39, py310] + tox_job: [py39, py310, py311] include: - tox_job: py39 python: 3.9 - tox_job: py310 python: 3.10.4 + - tox_job: py311 + python: 3.11 steps: - uses: actions/checkout@v1 diff --git a/flask_ninja/__init__.py b/flask_ninja/__init__.py index 5b931c8..43920bf 100644 --- a/flask_ninja/__init__.py +++ b/flask_ninja/__init__.py @@ -1,6 +1,6 @@ from flask_ninja.api import NinjaAPI, Server from flask_ninja.constants import ParamType from flask_ninja.operation import ApiConfigError, Callback, Operation -from flask_ninja.param_functions import Header, Path, Query +from flask_ninja.param_functions import Body, Header, Path, Query from flask_ninja.router import Router from flask_ninja.security import HttpAuthBase, HttpBearer diff --git a/flask_ninja/api.py b/flask_ninja/api.py index b837cee..85ef32d 100644 --- a/flask_ninja/api.py +++ b/flask_ninja/api.py @@ -1,15 +1,13 @@ -import json import re from typing import Any, Callable, Optional from flask import Blueprint, Flask, render_template -from pydantic.schema import get_flat_models_from_fields, get_model_name_map +from pydantic.json_schema import GenerateJsonSchema from .constants import NOT_SET from .models import Components, Info, OpenAPI, Server from .router import Router from .swagger_ui import swagger_ui_path -from .utils import get_model_definitions class NinjaAPI: @@ -70,7 +68,7 @@ def delete(self, path: str, **kwargs: Any) -> Callable: def add_router(self, router: Router, prefix: str = "") -> None: self.router.add_router(router, f"{self.prefix}{prefix}") - def get_schema(self) -> str: + def get_schema(self) -> dict[str, Any]: """Creates OpenAPI schema for the API.""" # At first we collect all pydantic models used anywhere @@ -79,17 +77,20 @@ def get_schema(self) -> str: for operation in self.router.operations: models += operation.get_models() - # Then we create from them flat models - it means we extract the models from Generics - flat_models = get_flat_models_from_fields(models, known_models=set()) - # Then we generate unique names for them - if there are two models from different modules - # but with the same name, we need provide different names for them in the Definitions list - model_name_map = get_model_name_map(flat_models) + schema_generator = GenerateJsonSchema( + ref_template="#/components/schemas/{model}" + ) + + inputs = [ + (field, field.mode, field.type_adapter.core_schema) for field in models + ] + field_mapping, definitions = schema_generator.generate_definitions( + inputs=inputs + ) + paths: dict = {} security_schemes: dict = {} # Create OpenAPI schemas for all models - definitions = get_model_definitions( - flat_models=flat_models, model_name_map=model_name_map - ) # Create OpenAPI schema for all operations for operation in self.router.operations: @@ -99,7 +100,7 @@ def get_schema(self) -> str: if swagger_path not in paths: paths[swagger_path] = {} paths[swagger_path][operation.method.lower()] = operation.get_schema( - model_name_map=model_name_map + field_mapping=field_mapping ) if operation.auth: security_schemes.update(operation.auth.schema()) @@ -118,4 +119,4 @@ def get_schema(self) -> str: servers=self.servers or None, ) - return json.loads(schema.json(by_alias=True, exclude_none=True)) + return schema.model_dump(mode="json", by_alias=True, exclude_none=True) diff --git a/flask_ninja/model_field.py b/flask_ninja/model_field.py new file mode 100644 index 0000000..d766048 --- /dev/null +++ b/flask_ninja/model_field.py @@ -0,0 +1,68 @@ +import types +from dataclasses import dataclass +from typing import Annotated, Any, Dict, List, Literal, Sequence, Set, Tuple, Union + +from pydantic import TypeAdapter +from pydantic.fields import FieldInfo +from pydantic.json_schema import JsonSchemaValue +from pydantic_core import PydanticUndefined, PydanticUndefinedType + +Required = PydanticUndefined +Undefined = PydanticUndefined +UndefinedType = PydanticUndefinedType +IncEx = Union[Set[int], Set[str], Dict[int, Any], Dict[str, Any]] +UnionType = getattr(types, "UnionType", Union) + + +@dataclass +class ModelField: + field_info: FieldInfo + name: str + mode: Literal["validation", "serialization"] = "validation" + + @property + def alias(self) -> str: + a = self.field_info.alias + return a if a is not None else self.name + + @property + def required(self) -> bool: + return self.field_info.is_required() + + @property + def default(self) -> Any: + return self.get_default() + + @property + def type_(self) -> Any: + return self.field_info.annotation + + def __post_init__(self) -> None: + self.type_adapter: TypeAdapter[Any] = TypeAdapter( + Annotated[self.field_info.annotation, self.field_info] + ) + + def get_default(self) -> Any: + if self.field_info.is_required(): + return Undefined + return self.field_info.get_default(call_default_factory=True) + + def __hash__(self) -> int: + # Each ModelField is unique for our purposes, to allow making a dict from + # ModelField to its JSON Schema. + return id(self) + + +def _regenerate_error_with_loc( + *, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...] +) -> List[Dict[str, Any]]: + updated_loc_errors: List[Any] = [ + {**err, "loc": loc_prefix + err.get("loc", ())} for err in errors + ] + + return updated_loc_errors + + +FieldMapping = Dict[ + tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue +] diff --git a/flask_ninja/models.py b/flask_ninja/models.py index e3b1c80..fa9c2e6 100644 --- a/flask_ninja/models.py +++ b/flask_ninja/models.py @@ -1,7 +1,8 @@ from enum import Enum from typing import Any, Dict, List, Optional, Union -from pydantic import AnyUrl, BaseModel, Field +from pydantic import AnyUrl, BaseModel, ConfigDict, Field +from pydantic.json_schema import DefsRef class Contact(BaseModel): @@ -9,16 +10,14 @@ class Contact(BaseModel): url: Optional[AnyUrl] = None email: Optional[str] = None - class Config: - extra = "allow" + model_config = ConfigDict(extra="allow") class License(BaseModel): name: str url: Optional[AnyUrl] = None - class Config: - extra = "allow" + model_config = ConfigDict(extra="allow") class Info(BaseModel): @@ -29,8 +28,7 @@ class Info(BaseModel): license: Optional[License] = None version: str - class Config: - extra = "allow" + model_config = ConfigDict(extra="allow") class ServerVariable(BaseModel): @@ -38,8 +36,7 @@ class ServerVariable(BaseModel): default: str description: Optional[str] = None - class Config: - extra = "allow" + model_config = ConfigDict(extra="allow") class Server(BaseModel): @@ -47,8 +44,7 @@ class Server(BaseModel): description: Optional[str] = None variables: Optional[Dict[str, ServerVariable]] = None - class Config: - extra = "allow" + model_config = ConfigDict(extra="allow") class Reference(BaseModel): @@ -67,16 +63,14 @@ class XML(BaseModel): attribute: Optional[bool] = None wrapped: Optional[bool] = None - class Config: - extra = "allow" + model_config = ConfigDict(extra="allow") class ExternalDocumentation(BaseModel): description: Optional[str] = None url: AnyUrl - class Config: - extra = "allow" + model_config = ConfigDict(extra="allow") class Schema(BaseModel): @@ -87,14 +81,14 @@ class Schema(BaseModel): exclusiveMaximum: Optional[float] = None minimum: Optional[float] = None exclusiveMinimum: Optional[float] = None - maxLength: Optional[int] = Field(default=None, gte=0) - minLength: Optional[int] = Field(default=None, gte=0) + maxLength: Optional[int] = Field(default=None, ge=0) + minLength: Optional[int] = Field(default=None, ge=0) pattern: Optional[str] = None - maxItems: Optional[int] = Field(default=None, gte=0) - minItems: Optional[int] = Field(default=None, gte=0) + maxItems: Optional[int] = Field(default=None, ge=0) + minItems: Optional[int] = Field(default=None, ge=0) uniqueItems: Optional[bool] = None - maxProperties: Optional[int] = Field(default=None, gte=0) - minProperties: Optional[int] = Field(default=None, gte=0) + maxProperties: Optional[int] = Field(default=None, ge=0) + minProperties: Optional[int] = Field(default=None, ge=0) required: Optional[List[str]] = None enum: Optional[List[Any]] = None type: Optional[str] = None @@ -117,8 +111,7 @@ class Schema(BaseModel): example: Optional[Any] = None deprecated: Optional[bool] = None - class Config: - extra: str = "allow" + model_config = ConfigDict(extra="allow") class Example(BaseModel): @@ -127,8 +120,7 @@ class Example(BaseModel): value: Optional[Any] = None externalValue: Optional[AnyUrl] = None - class Config: - extra = "allow" + model_config = ConfigDict(extra="allow") class ParameterInType(Enum): @@ -145,8 +137,7 @@ class Encoding(BaseModel): explode: Optional[bool] = None allowReserved: Optional[bool] = None - class Config: - extra = "allow" + model_config = ConfigDict(extra="allow") class MediaType(BaseModel): @@ -155,9 +146,7 @@ class MediaType(BaseModel): examples: Optional[Dict[str, Union[Example, Reference]]] = None encoding: Optional[Dict[str, Encoding]] = None - class Config: - extra = "allow" - allow_population_by_field_name = True + model_config = ConfigDict(extra="allow", populate_by_name=True) class ParameterBase(BaseModel): @@ -170,13 +159,11 @@ class ParameterBase(BaseModel): allowReserved: Optional[bool] = None schema_: Optional[Union[Schema, Reference]] = Field(default=None, alias="schema") example: Optional[Any] = None - examples: Optional[Dict[str, Union[Example, Reference]]] = None + examples: Optional[List[Any]] = None # Serialization rules for more complex scenarios content: Optional[Dict[str, MediaType]] = None - class Config: - extra = "allow" - allow_population_by_field_name = True + model_config = ConfigDict(extra="allow", populate_by_name=True) class Parameter(ParameterBase): @@ -193,8 +180,7 @@ class RequestBody(BaseModel): content: Dict[str, MediaType] required: Optional[bool] = None - class Config: - extra = "allow" + model_config = ConfigDict(extra="allow") class Link(BaseModel): @@ -205,8 +191,7 @@ class Link(BaseModel): description: Optional[str] = None server: Optional[Server] = None - class Config: - extra = "allow" + model_config = ConfigDict(extra="allow") class Response(BaseModel): @@ -215,8 +200,7 @@ class Response(BaseModel): content: Optional[Dict[str, MediaType]] = None links: Optional[Dict[str, Union[Link, Reference]]] = None - class Config: - extra = "allow" + model_config = ConfigDict(extra="allow") class Operation(BaseModel): @@ -229,13 +213,12 @@ class Operation(BaseModel): requestBody: Optional[Union[RequestBody, Reference]] = None # Using Any for Specification Extensions responses: Dict[str, Union[Response, Any]] - callbacks: Optional[Dict[str, Union[Dict[str, "PathItem"], Reference]]] = None + callbacks: Optional[Dict[str, Union[Dict[str, "PathItem"]]]] = None deprecated: Optional[bool] = None security: Optional[List[Dict[str, List[str]]]] = None servers: Optional[List[Server]] = None - class Config: - extra = "allow" + model_config = ConfigDict(extra="allow") class PathItem(BaseModel): @@ -253,8 +236,7 @@ class PathItem(BaseModel): servers: Optional[List[Server]] = None parameters: Optional[List[Union[Parameter, Reference]]] = None - class Config: - extra = "allow" + model_config = ConfigDict(extra="allow") class SecuritySchemeType(Enum): @@ -268,8 +250,7 @@ class SecurityBase(BaseModel): type_: SecuritySchemeType = Field(alias="type") description: Optional[str] = None - class Config: - extra = "allow" + model_config = ConfigDict(extra="allow") class APIKeyIn(Enum): @@ -298,8 +279,7 @@ class OAuthFlow(BaseModel): refreshUrl: Optional[str] = None scopes: Dict[str, str] = {} - class Config: - extra = "allow" + model_config = ConfigDict(extra="allow") class OAuthFlowImplicit(OAuthFlow): @@ -325,8 +305,7 @@ class OAuthFlows(BaseModel): clientCredentials: Optional[OAuthFlowClientCredentials] = None authorizationCode: Optional[OAuthFlowAuthorizationCode] = None - class Config: - extra = "allow" + model_config = ConfigDict(extra="allow") class OAuth2(SecurityBase): @@ -343,7 +322,7 @@ class OpenIdConnect(SecurityBase): class Components(BaseModel): - schemas: Optional[Dict[str, Union[Schema, Reference]]] = None + schemas: Optional[Dict[DefsRef, Dict[str, Any]]] = None responses: Optional[Dict[str, Union[Response, Reference]]] = None parameters: Optional[Dict[str, Union[Parameter, Reference]]] = None examples: Optional[Dict[str, Union[Example, Reference]]] = None @@ -354,8 +333,7 @@ class Components(BaseModel): # Using Any for Specification Extensions callbacks: Optional[Dict[str, Union[Dict[str, PathItem], Reference, Any]]] = None - class Config: - extra = "allow" + model_config = ConfigDict(extra="allow") class Tag(BaseModel): @@ -363,8 +341,7 @@ class Tag(BaseModel): description: Optional[str] = None externalDocs: Optional[ExternalDocumentation] = None - class Config: - extra = "allow" + model_config = ConfigDict(extra="allow") class OpenAPI(BaseModel): @@ -378,10 +355,9 @@ class OpenAPI(BaseModel): tags: Optional[List[Tag]] = None externalDocs: Optional[ExternalDocumentation] = None - class Config: - extra = "allow" + model_config = ConfigDict(extra="allow") -Schema.update_forward_refs() -Operation.update_forward_refs() -Encoding.update_forward_refs() +Schema.model_rebuild() +Operation.model_rebuild() +Encoding.model_rebuild() diff --git a/flask_ninja/operation.py b/flask_ninja/operation.py index c8a262e..555b877 100644 --- a/flask_ninja/operation.py +++ b/flask_ninja/operation.py @@ -1,6 +1,5 @@ # pylint:disable=comparison-with-callable import inspect -import json import re from collections import defaultdict from enum import Enum @@ -8,11 +7,10 @@ from docstring_parser import parse as doc_parse from flask import jsonify, request -from pydantic import BaseModel, ValidationError, parse_obj_as -from pydantic.fields import ModelField, Undefined -from pydantic.schema import field_schema +from pydantic import BaseModel, ConfigDict, ValidationError -from .constants import NOT_SET, REF_PREFIX, ApiConfigError, ParamType +from .constants import NOT_SET, ApiConfigError, ParamType +from .model_field import FieldMapping, ModelField, Undefined from .models import MediaType from .models import Operation as OAPIOperation from .models import ( @@ -27,7 +25,7 @@ from .param import FuncParam from .parse_rule import parse_rule from .security import HttpAuthBase -from .utils import create_model_field, get_param_model_field, is_scalar_sequence_field +from .utils import analyze_param, create_model_field, is_scalar_sequence_field ModelNameMapType = dict[Union[Type[BaseModel], Type[Enum]], str] @@ -35,20 +33,26 @@ class SerializationModel(BaseModel): data: Any - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) class Callback(BaseModel): name: str url: str method: str - request_body: Optional[Any] - params: Optional[list[ModelField]] + request_body: Optional[Type] = None + params: Optional[list[ModelField]] = None response_codes: dict[int, str] + field: Optional[ModelField] = None - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) + + def model_post_init(self, __context: Any) -> None: + self.field = ( + create_model_field(self.name, self.request_body) + if self.request_body + else None + ) class Operation: @@ -88,20 +92,22 @@ def run(self, *args: Any, **kwargs: Any) -> Any: field_info = cast(FuncParam, param.field_info) if field_info.in_ == ParamType.QUERY and param.name in request.args: if is_scalar_sequence_field(param): - kwargs[param.name] = parse_obj_as( - param.outer_type_, request.args.getlist(param.alias) + kwargs[param.name] = param.type_adapter.validate_python( + request.args.getlist(param.alias) ) else: - kwargs[param.name] = parse_obj_as( - param.type_, request.args[param.alias] + kwargs[param.name] = param.type_adapter.validate_python( + request.args[param.alias] ) elif field_info.in_ == ParamType.HEADER: - kwargs[param.name] = parse_obj_as( - param.type_, request.headers.get(param.alias) + kwargs[param.name] = param.type_adapter.validate_python( + request.headers.get(param.alias) ) # Parse request body elif field_info.in_ == ParamType.BODY: - kwargs[param.name] = parse_obj_as(param.outer_type_, request.json) + kwargs[param.name] = param.type_adapter.validate_python( + request.json + ) except ValidationError as validation_error: return validation_error.json(), 400 @@ -114,7 +120,7 @@ def run(self, *args: Any, **kwargs: Any) -> Any: # e.g. list[str], dict[str, Response], etc, we can't use isinstance and # at first need to get the unspecified generic type - e.g. list, dict, etc # TODO match also the inner types of generics - but that's a corner case - if isinstance(resp, get_origin(model.outer_type_) or model.outer_type_): + if isinstance(resp, get_origin(model.type_) or model.type_): # hotfix: if the resp is str we shouldn't use jsonify as it # changes the response adding additional characters. if isinstance(resp, str): @@ -154,30 +160,28 @@ def _sanitize_responses( if func_return_type: if get_origin(func_return_type) == Union: for ret_type in func_return_type.__args__: - if not any( - resp.outer_type_ != ret_type for resp in responses.values() - ): + if not any(resp.type_ != ret_type for resp in responses.values()): raise ApiConfigError( f"Return type {ret_type} http code must be specified explicitly." ) # If we specified different return type as we specified as response - elif 200 in responses and responses[200].outer_type_ != func_return_type: + elif ( + 200 in responses + and responses[200].field_info.annotation != func_return_type + ): raise ApiConfigError( - f"Return type of the function {type(func_return_type)} does not match response type {type(responses[200].outer_type_)}" + f"Return type of the function {type(func_return_type)} does not match response type {type(responses[200].type_)}" ) return responses @classmethod def serialize(cls, resp: Any) -> Any: - """Convert response object into json serializable object. - - TODO: Avoid json serialization and deserialization. - """ - return json.loads(SerializationModel(data=resp).json())["data"] + """Convert response object into json serializable object.""" + return SerializationModel(data=resp).model_dump(mode="json")["data"] def get_callback_schema( - self, cb: Callback, model_name_map: ModelNameMapType + self, cb: Callback, field_mapping: FieldMapping ) -> dict[str, PathItem]: """Generate schema for a callback. @@ -186,12 +190,8 @@ def get_callback_schema( to declare callbacks. """ - if cb.request_body: - request_body, _, _ = field_schema( - create_model_field("Callback", cb.request_body), - model_name_map=model_name_map, - ref_prefix=REF_PREFIX, - ) + if cb.field: + request_body = field_mapping[(cb.field, "validation")] else: request_body = None @@ -207,12 +207,8 @@ def get_callback_schema( name=param.alias, in_=ParameterInType(field_info.in_.value), # Undefined type is tricky, because it can't be serialized - required=None if param.required == Undefined else bool(param.required), - schema_=Schema.parse_obj( - field_schema( - param, model_name_map=model_name_map, ref_prefix=REF_PREFIX - )[0] - ), + required=param.required, + schema_=Schema.model_validate(field_mapping[(param, "validation")]), description=field_info.description, examples=field_info.examples, example=field_info.example if field_info.example != Undefined else None, @@ -236,12 +232,9 @@ def get_callback_schema( }, ) - return {cb.url: PathItem.parse_obj({cb.method.lower(): schema})} + return {cb.url: PathItem.model_validate({cb.method.lower(): schema})} - def get_openapi_parameters( - self, - model_name_map: ModelNameMapType, - ) -> list[Parameter]: + def get_openapi_parameters(self, field_mapping: FieldMapping) -> list[Parameter]: """Create OpenAPI schema for parameters of this operation.""" parameters = [] for param in self.params: @@ -254,23 +247,19 @@ def get_openapi_parameters( parameter = Parameter( name=param.alias, in_=ParameterInType(field_info.in_.value), - required=None if param.required == Undefined else bool(param.required), - schema_=Schema.parse_obj( - field_schema( - param, model_name_map=model_name_map, ref_prefix=REF_PREFIX - )[0] - ), + required=param.required, + schema_=Schema.model_validate(field_mapping[(param, "validation")]), description=field_info.description, examples=field_info.examples, example=field_info.example if field_info.example != Undefined else None, deprecated=field_info.deprecated, ) - parameters.append(Parameter.parse_obj(parameter)) + parameters.append(Parameter.model_validate(parameter)) return parameters def get_openapi_request_body( - self, model_name_map: ModelNameMapType + self, field_mapping: FieldMapping ) -> Optional[RequestBody]: """Create OpenAPI schema for request body of this operation. @@ -279,14 +268,11 @@ def get_openapi_request_body( for param in self.params: field_info = cast(FuncParam, param.field_info) if field_info.in_ == ParamType.BODY: - request_body, _, _ = field_schema( - param, model_name_map=model_name_map, ref_prefix=REF_PREFIX - ) - + request_body = field_mapping[(param, "validation")] return RequestBody( content={ "application/json": MediaType( - schema_=Schema.parse_obj(request_body), + schema_=Schema.model_validate(request_body), ) }, description="", @@ -294,27 +280,24 @@ def get_openapi_request_body( ) return None - def get_schema(self, model_name_map: ModelNameMapType) -> OAPIOperation: + def get_schema(self, field_mapping: FieldMapping) -> OAPIOperation: """Create OpenAPI schema for this operation.""" doc = doc_parse(self.view_func.__doc__ or "") responses: Dict[str, Response] = {} for code, response in self.responses.items(): - response_schema, _, _ = field_schema( - response, model_name_map=model_name_map, ref_prefix=REF_PREFIX - ) - + response_schema = field_mapping[(response, "validation")] responses[str(code)] = Response( content={ "application/json": MediaType( - schema_=Schema.parse_obj(response_schema) + schema_=Schema.model_validate(response_schema) ) }, description="", ) callbacks = { - cb.name: self.get_callback_schema(cb, model_name_map=model_name_map) + cb.name: self.get_callback_schema(cb, field_mapping=field_mapping) for cb in (self.callbacks or []) } @@ -322,9 +305,9 @@ def get_schema(self, model_name_map: ModelNameMapType) -> OAPIOperation: summary=doc.short_description or self.summary, description=doc.long_description or self.description, responses=responses, - parameters=self.get_openapi_parameters(model_name_map=model_name_map) + parameters=self.get_openapi_parameters(field_mapping=field_mapping) or None, # type:ignore - requestBody=self.get_openapi_request_body(model_name_map=model_name_map), + requestBody=self.get_openapi_request_body(field_mapping=field_mapping), security=[{self.auth.schema_name: []}] if self.auth else None, callbacks=callbacks or None, ) @@ -352,10 +335,12 @@ def _parse_params(self, path: str) -> list[ModelField]: # Additional attributes for a parameter are set via the default value # we retrieve the default value using inspect, and we convert it # to a ModelField get_param_model_field function - for param in inspect.signature(self.view_func).parameters.values(): - model_field = get_param_model_field( - param=param, - force_type=ParamType.PATH if param.name in path_param_names else None, + for param_name, param in inspect.signature(self.view_func).parameters.items(): + model_field = analyze_param( + param_name=param_name, + annotation=param.annotation, + value=param.default, + is_path_param=param.name in path_param_names, ) if param.name in param_docs and not model_field.field_info.description: model_field.field_info.description = param_docs[param.name] @@ -384,10 +369,12 @@ def get_models(self) -> list[ModelField]: return ( self.params + list(self.responses.values()) + + list(cb.field for cb in (self.callbacks or []) if cb.field) + list( - create_model_field(cb.name, cb.request_body) + param for cb in (self.callbacks or []) if cb.request_body + for param in (cb.params or []) ) ) diff --git a/flask_ninja/param.py b/flask_ninja/param.py index 84e7a36..ffb2d47 100644 --- a/flask_ninja/param.py +++ b/flask_ninja/param.py @@ -1,10 +1,11 @@ from __future__ import annotations -from typing import Any, Dict, Optional +from typing import Any, List, Optional -from pydantic.fields import FieldInfo, Undefined +from pydantic.fields import FieldInfo from flask_ninja.constants import ParamType +from flask_ninja.model_field import Undefined class FuncParam(FieldInfo): @@ -25,7 +26,7 @@ def __init__( max_length: Optional[int] = None, regex: Optional[str] = None, example: Any = Undefined, - examples: Optional[Dict[str, Any]] = None, + examples: Optional[List[Any]] = None, deprecated: Optional[bool] = None, include_in_schema: bool = True, **extra: Any, @@ -71,7 +72,7 @@ def __init__( max_length: Optional[int] = None, regex: Optional[str] = None, example: Any = Undefined, - examples: Optional[Dict[str, Any]] = None, + examples: Optional[List[Any]] = None, deprecated: Optional[bool] = None, include_in_schema: bool = True, **extra: Any, @@ -115,7 +116,7 @@ def __init__( max_length: Optional[int] = None, regex: Optional[str] = None, example: Any = Undefined, - examples: Optional[Dict[str, Any]] = None, + examples: Optional[List[Any]] = None, deprecated: Optional[bool] = None, include_in_schema: bool = True, **extra: Any, @@ -159,7 +160,7 @@ def __init__( max_length: Optional[int] = None, regex: Optional[str] = None, example: Any = Undefined, - examples: Optional[Dict[str, Any]] = None, + examples: Optional[List[Any]] = None, deprecated: Optional[bool] = None, include_in_schema: bool = True, **extra: Any, @@ -203,7 +204,7 @@ def __init__( max_length: Optional[int] = None, regex: Optional[str] = None, example: Any = Undefined, - examples: Optional[Dict[str, Any]] = None, + examples: Optional[List[Any]] = None, deprecated: Optional[bool] = None, include_in_schema: bool = True, **extra: Any, @@ -228,13 +229,13 @@ def __init__( ) -class Body(FieldInfo): +class Body(FuncParam): + in_ = ParamType.BODY + def __init__( self, default: Any = Undefined, *, - embed: bool = False, - media_type: str = "application/json", alias: Optional[str] = None, title: Optional[str] = None, description: Optional[str] = None, @@ -246,13 +247,11 @@ def __init__( max_length: Optional[int] = None, regex: Optional[str] = None, example: Any = Undefined, - examples: Optional[Dict[str, Any]] = None, + examples: Optional[List[Any]] = None, + deprecated: Optional[bool] = None, + include_in_schema: bool = True, **extra: Any, ): - self.embed = embed - self.media_type = media_type - self.example = example - self.examples = examples super().__init__( default=default, alias=alias, @@ -265,6 +264,10 @@ def __init__( min_length=min_length, max_length=max_length, regex=regex, + deprecated=deprecated, + example=example, + examples=examples, + include_in_schema=include_in_schema, **extra, ) diff --git a/flask_ninja/param_functions.py b/flask_ninja/param_functions.py index 6700fef..7f45f2d 100644 --- a/flask_ninja/param_functions.py +++ b/flask_ninja/param_functions.py @@ -1,8 +1,7 @@ -from typing import Any, Dict, Optional - -from pydantic.fields import Undefined +from typing import Any, List, Optional from . import param +from .model_field import Undefined def Path( @@ -19,7 +18,7 @@ def Path( max_length: Optional[int] = None, regex: Optional[str] = None, example: Any = Undefined, - examples: Optional[Dict[str, Any]] = None, + examples: Optional[List[Any]] = None, deprecated: Optional[bool] = None, include_in_schema: bool = True, **extra: Any, @@ -58,7 +57,7 @@ def Query( max_length: Optional[int] = None, regex: Optional[str] = None, example: Any = Undefined, - examples: Optional[Dict[str, Any]] = None, + examples: Optional[List[Any]] = None, deprecated: Optional[bool] = None, include_in_schema: bool = True, **extra: Any, @@ -98,7 +97,7 @@ def Header( max_length: Optional[int] = None, regex: Optional[str] = None, example: Any = Undefined, - examples: Optional[Dict[str, Any]] = None, + examples: Optional[List[Any]] = None, deprecated: Optional[bool] = None, include_in_schema: bool = True, **extra: Any, @@ -122,3 +121,44 @@ def Header( include_in_schema=include_in_schema, **extra, ) + + +def Body( + default: Any = Undefined, + *, + alias: Optional[str] = None, + convert_underscores: bool = True, + title: Optional[str] = None, + description: Optional[str] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + regex: Optional[str] = None, + example: Any = Undefined, + examples: Optional[List[Any]] = None, + deprecated: Optional[bool] = None, + include_in_schema: bool = True, + **extra: Any, +) -> Any: + return param.Body( + default=default, + alias=alias, + convert_underscores=convert_underscores, + title=title, + description=description, + gt=gt, + ge=ge, + lt=lt, + le=le, + min_length=min_length, + max_length=max_length, + regex=regex, + example=example, + examples=examples, + deprecated=deprecated, + include_in_schema=include_in_schema, + **extra, + ) diff --git a/flask_ninja/router.py b/flask_ninja/router.py index d83bb88..4b45bd9 100644 --- a/flask_ninja/router.py +++ b/flask_ninja/router.py @@ -3,9 +3,9 @@ from typing import Any, Callable, Optional from flask import Flask -from pydantic.fields import ModelField from flask_ninja.constants import NOT_SET +from flask_ninja.model_field import ModelField from flask_ninja.operation import Callback, Operation diff --git a/flask_ninja/utils.py b/flask_ninja/utils.py index b165460..139896a 100644 --- a/flask_ninja/utils.py +++ b/flask_ninja/utils.py @@ -1,120 +1,103 @@ +# pylint: disable=protected-access import dataclasses import inspect -from enum import Enum -from typing import Any, Dict, Optional, Set, Type, Union - -from pydantic import BaseConfig, BaseModel -from pydantic.class_validators import Validator -from pydantic.fields import ( - SHAPE_FROZENSET, - SHAPE_LIST, - SHAPE_SEQUENCE, - SHAPE_SET, - SHAPE_SINGLETON, - SHAPE_TUPLE, - SHAPE_TUPLE_ELLIPSIS, - FieldInfo, - ModelField, - Required, - Undefined, - UndefinedType, -) -from pydantic.schema import model_process_schema -from pydantic.utils import lenient_issubclass - -from flask_ninja.constants import REF_PREFIX, ApiConfigError -from flask_ninja.param import FuncParam, Header, ParamType, Query - -sequence_shapes = { - SHAPE_LIST, - SHAPE_SET, - SHAPE_FROZENSET, - SHAPE_TUPLE, - SHAPE_SEQUENCE, - SHAPE_TUPLE_ELLIPSIS, -} +from typing import Annotated, Any, Mapping, Optional, Type, Union, get_args, get_origin + +from pydantic import BaseModel +from pydantic._internal._utils import lenient_issubclass +from pydantic.fields import FieldInfo + +from flask_ninja import param +from flask_ninja.constants import ApiConfigError +from flask_ninja.model_field import ModelField, Required, Undefined, UnionType +from flask_ninja.param import FuncParam + sequence_types = (list, set, tuple) def create_model_field( name: str, type_: Type[Any], - class_validators: Optional[Dict[str, Validator]] = None, - default: Optional[Any] = None, - required: Union[bool, UndefinedType] = True, - model_config: Type[BaseConfig] = BaseConfig, + default: Optional[Any] = Undefined, field_info: Optional[FieldInfo] = None, alias: Optional[str] = None, ) -> ModelField: - class_validators = class_validators or {} - field_info = field_info or FieldInfo() - - return ModelField( - name=name, - type_=type_, - class_validators=class_validators, - default=default, - required=required, - model_config=model_config, - alias=alias, - field_info=field_info, + field_info = field_info or FieldInfo(annotation=type_, default=default, alias=alias) + + return ModelField(name=name, field_info=field_info, mode="validation") + + +def _annotation_is_complex(annotation: Union[Type[Any], None]) -> bool: + return ( + lenient_issubclass(annotation, (BaseModel, Mapping)) + or _annotation_is_sequence(annotation) + or dataclasses.is_dataclass(annotation) ) -def is_scalar_field(field: ModelField) -> bool: - if ( - field.shape != SHAPE_SINGLETON - or lenient_issubclass(field.type_, BaseModel) - or lenient_issubclass(field.type_, sequence_types + (dict,)) - or dataclasses.is_dataclass(field.type_) - ): +def field_annotation_is_complex(annotation: Union[Type[Any], None]) -> bool: + origin = get_origin(annotation) + if origin is Union or origin is UnionType: + return any(field_annotation_is_complex(arg) for arg in get_args(annotation)) + + return ( + _annotation_is_complex(annotation) + or _annotation_is_complex(origin) + or hasattr(origin, "__pydantic_core_schema__") + or hasattr(origin, "__get_pydantic_core_schema__") + ) + + +def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool: + if lenient_issubclass(annotation, (str, bytes)): return False + return lenient_issubclass(annotation, sequence_types) - if field.sub_fields: - if not all(is_scalar_field(f) for f in field.sub_fields): - return False - return True +def field_annotation_is_scalar(annotation: Any) -> bool: + # handle Ellipsis here to make tuple[int, ...] work nicely + return annotation is Ellipsis or not field_annotation_is_complex(annotation) -def is_scalar_sequence_field(field: ModelField) -> bool: - if (field.shape in sequence_shapes) and not lenient_issubclass( - field.type_, BaseModel - ): - if field.sub_fields is not None: - for sub_field in field.sub_fields: - if not is_scalar_field(sub_field): - return False - return True - if lenient_issubclass(field.type_, sequence_types): - return True - return False - - -def get_model_definitions( - flat_models: Set[Union[Type[BaseModel], Type[Enum]]], - model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str], -) -> Dict[str, Any]: - definitions: Dict[str, Dict[str, Any]] = {} - for model in flat_models: - m_schema, m_definitions, _ = model_process_schema( - model, model_name_map=model_name_map, ref_prefix=REF_PREFIX - ) - definitions.update(m_definitions) - model_name = model_name_map[model] - if "description" in m_schema: - m_schema["description"] = m_schema["description"].split("\f")[0] - definitions[model_name] = m_schema +def is_scalar_field(field: ModelField) -> bool: + return field_annotation_is_scalar(field.field_info.annotation) and not isinstance( + field.field_info, param.Body + ) + + +def field_annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool: + return _annotation_is_sequence(annotation) or _annotation_is_sequence( + get_origin(annotation) + ) + + +def field_annotation_is_scalar_sequence(annotation: Union[Type[Any], None]) -> bool: + origin = get_origin(annotation) + if origin is Union or origin is UnionType: + at_least_one_scalar_sequence = False + for arg in get_args(annotation): + if field_annotation_is_scalar_sequence(arg): + at_least_one_scalar_sequence = True + continue + if not field_annotation_is_scalar(arg): + return False + return at_least_one_scalar_sequence + return field_annotation_is_sequence(annotation) and all( + field_annotation_is_scalar(sub_annotation) + for sub_annotation in get_args(annotation) + ) + - return definitions +def is_scalar_sequence_field(field: ModelField) -> bool: + return field_annotation_is_scalar_sequence(field.field_info.annotation) -def get_param_model_field( +def analyze_param( *, - param: inspect.Parameter, - default_field_info: Type[FuncParam] = FuncParam, - force_type: Optional[ParamType] = None, - ignore_default: bool = False, + param_name: str, + annotation: Any, + value: Any, + is_path_param: bool, ) -> ModelField: """Converts inspected parameter into pydantic ModelField object. @@ -129,58 +112,83 @@ def foo(arg:int = Header(description="Sample")): Then we combine it with the annotation info - argument name, type, etc. and create a Model field containing all the information about the parameter. """ - default_value: Any = Undefined - - if not param.default == param.empty and ignore_default is False: - default_value = param.default - if isinstance(default_value, FieldInfo): - field_info = default_value - default_value = field_info.default - if ( - isinstance(field_info, FuncParam) - and getattr(field_info, "in_", None) is None - ): - field_info.in_ = default_field_info.in_ - else: - field_info = default_field_info(default=default_value) - - if force_type: - field_info.in_ = force_type # type: ignore - required = True - if default_value is Required or ignore_default: - required = True - default_value = None - elif default_value is not Undefined: - required = False + field_info = None + type_annotation: Any = Any - if not field_info.alias and getattr(field_info, "convert_underscores", None): - alias = param.name.replace("_", "-") - else: - alias = field_info.alias or param.name - - field = ModelField( - name=param.name, - type_=param.annotation, - default=default_value, - alias=alias, - required=required, - field_info=field_info, - class_validators={}, - model_config=BaseConfig, - ) + if ( + annotation is not inspect.Signature.empty + and get_origin(annotation) is Annotated + ): + # Handle annotated types + # We need to extract info from the annotated type + annotated_args = get_args(annotation) + type_annotation = annotated_args[0] + annotations = [arg for arg in annotated_args[1:] if isinstance(arg, FieldInfo)] + if len(annotations) > 1: + raise ApiConfigError( + f"Cannot specify multiple `Annotated` arguments for {param_name!r}" + ) + next_annotation = next(iter(annotations), None) + if isinstance(next_annotation, FieldInfo): + field_info = type(next_annotation).from_annotation(annotation) + type_annotation = field_info.annotation + + # Handle not annotated types + elif annotation is not inspect.Signature.empty: + type_annotation = annotation + + # If the type wasn't annotated or didn't contain any field info annotation + if field_info is None: + default_value = ( + value + if value is not inspect.Signature.empty and not isinstance(value, FieldInfo) + else Required + ) + field_info = FieldInfo(annotation=type_annotation) + if not is_path_param: + field_info.default = default_value + + # If the type has assigned a default value that is an instance of FuncParam, e.g. Body() + # we need to merge field info from the type and from the value. + # + # It may happen that we have type e.g. + # Notification = Annotated[Union[SuccessNotification, ErrorNotification], Field(discriminator="result")] + # and a function + # def compute(body: Notification = Body(description="Some description") + # + # In this case we need to merge the filed infos, otherwise we either lose the information about the discriminator + # or about the description + if isinstance(value, FuncParam): + field_info = value.__class__( + **(value._attributes_set | field_info._attributes_set) + ) - if getattr(field.field_info, "in_", None) is None: - if is_scalar_field(field): - field.field_info.in_ = ParamType.QUERY # type:ignore + # If it wasn't set explicitly, determinite the type of the field, e.g. Query, Body or Path + # based on the type of the object + if isinstance(field_info, FieldInfo) and not isinstance(field_info, FuncParam): + if is_path_param: + field_info = param.Path(**field_info._attributes_set) + elif not field_annotation_is_scalar(annotation=type_annotation): + field_info = param.Body(**field_info._attributes_set) else: - field.field_info.in_ = ParamType.BODY # type:ignore + field_info = param.Query(**field_info._attributes_set) - if isinstance(param.default, (Query, Header)) and is_scalar_sequence_field(field): - return field - if field.field_info.in_ != ParamType.BODY and not is_scalar_field( # type:ignore - field + # Check consistency of the objects + if is_path_param and not isinstance(field_info, param.Path): + raise ApiConfigError( + f"Cannot use `{field_info.__class__.__name__}` for path param {param_name!r}" + ) + + if isinstance(field_info, param.Path) and not field_annotation_is_scalar( + annotation=field_info.annotation ): raise ApiConfigError("Path param must be of a simple type.") - return field + if not field_info.alias and getattr(field_info, "convert_underscores", None): + alias = param_name.replace("_", "-") + else: + alias = field_info.alias or param_name + field_info.alias = alias + + return ModelField(name=param_name, field_info=field_info, mode="validation") diff --git a/pyproject.toml b/pyproject.toml index 6c08447..3295b0c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,14 +1,14 @@ [tool.poetry] name = "flask-ninja" -version = "1.2.2" +version = "1.3.0" description = "Flask Ninja is a web framework for building APIs with Flask and Python 3.9+ type hints." readme = "README.md" authors = ["Michal Korbela "] [tool.poetry.dependencies] python = "^3.9" -Flask = ">=1.1.2" -pydantic = "^1.9.0" +Flask = ">=2.3.0" +pydantic = "^2.4.2" docstring-parser = "^0.14.1" [tool.poetry.dev-dependencies] diff --git a/tests/unit/snapshots/test_api/test_get_schema/api_schema b/tests/unit/snapshots/test_api/test_get_schema/api_schema index 1c6413c..7dddeeb 100644 --- a/tests/unit/snapshots/test_api/test_get_schema/api_schema +++ b/tests/unit/snapshots/test_api/test_get_schema/api_schema @@ -1 +1 @@ -{"openapi": "3.0.3", "info": {"title": "", "description": "", "version": "1.0.0"}, "paths": {"/some_endpoint/{param}": {"get": {"summary": "", "description": "", "parameters": [{"required": true, "schema": {"title": "Param", "type": "integer"}, "name": "param", "in": "path"}], "requestBody": {"description": "", "content": {"application/json": {"schema": {"$ref": "#/components/schemas/Server"}}}, "required": true}, "responses": {"200": {"description": "", "content": {"application/json": {"schema": {"title": "Response 200", "type": "integer"}}}}}, "security": [{"bearerTokenAuth": []}]}}}, "components": {"schemas": {"ServerVariable": {"title": "ServerVariable", "required": ["default"], "type": "object", "properties": {"enum": {"title": "Enum", "type": "array", "items": {"type": "string"}}, "default": {"title": "Default", "type": "string"}, "description": {"title": "Description", "type": "string"}}}, "Server": {"title": "Server", "required": ["url"], "type": "object", "properties": {"url": {"title": "Url", "anyOf": [{"maxLength": 65536, "minLength": 1, "type": "string", "format": "uri"}, {"type": "string"}]}, "description": {"title": "Description", "type": "string"}, "variables": {"title": "Variables", "type": "object", "additionalProperties": {"$ref": "#/components/schemas/ServerVariable"}}}}}, "securitySchemes": {"bearerTokenAuth": {"type": "http", "scheme": "bearer"}}}} \ No newline at end of file +{"openapi": "3.0.3", "info": {"title": "", "description": "", "version": "1.0.0"}, "paths": {"/some_endpoint/{param}": {"get": {"summary": "", "description": "", "parameters": [{"required": true, "schema": {"type": "integer"}, "name": "param", "in": "path"}], "requestBody": {"description": "", "content": {"application/json": {"schema": {"$ref": "#/components/schemas/Server"}}}, "required": true}, "responses": {"200": {"description": "", "content": {"application/json": {"schema": {"type": "integer"}}}}}, "security": [{"bearerTokenAuth": []}]}}}, "components": {"schemas": {"Server": {"additionalProperties": true, "properties": {"url": {"anyOf": [{"format": "uri", "minLength": 1, "type": "string"}, {"type": "string"}], "title": "Url"}, "description": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Description"}, "variables": {"anyOf": [{"additionalProperties": {"$ref": "#/components/schemas/ServerVariable"}, "type": "object"}, {"type": "null"}], "default": null, "title": "Variables"}}, "required": ["url"], "title": "Server", "type": "object"}, "ServerVariable": {"additionalProperties": true, "properties": {"enum": {"anyOf": [{"items": {"type": "string"}, "type": "array"}, {"type": "null"}], "default": null, "title": "Enum"}, "default": {"title": "Default", "type": "string"}, "description": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Description"}}, "required": ["default"], "title": "ServerVariable", "type": "object"}}, "securitySchemes": {"bearerTokenAuth": {"type": "http", "scheme": "bearer"}}}} \ No newline at end of file diff --git a/tests/unit/test_operation.py b/tests/unit/test_operation.py index fc2a770..3797fef 100644 --- a/tests/unit/test_operation.py +++ b/tests/unit/test_operation.py @@ -1,11 +1,10 @@ # pylint: disable=protected-access, unused-argument,disallowed-name -import json from datetime import datetime -from typing import Mapping, Union +from typing import Union from unittest.mock import MagicMock import pytest -from pydantic.schema import get_flat_models_from_fields, get_model_name_map +from pydantic.json_schema import GenerateJsonSchema from werkzeug.datastructures import MultiDict from flask_ninja import Header, Query @@ -52,13 +51,13 @@ def view_func_none_return() -> None: pytest.param( None, view_func_str, - {200: create_model_field(name="Response 200", type_=str, required=True)}, + {200: create_model_field(name="Response 200", type_=str)}, id="generate str 200 response", ), pytest.param( None, view_func_pydantic_object, - {200: create_model_field(name="Response 200", type_=Server, required=True)}, + {200: create_model_field(name="Response 200", type_=Server)}, id="generate object response", ), pytest.param( @@ -66,7 +65,8 @@ def view_func_none_return() -> None: view_func_list, { 200: create_model_field( - name="Response 200", type_=list[Server], required=True + name="Response 200", + type_=list[Server], ) }, id="generate list response", @@ -76,7 +76,8 @@ def view_func_none_return() -> None: view_func_list_dict, { 200: create_model_field( - name="Response 200", type_=list[dict[str, Server]], required=True + name="Response 200", + type_=list[dict[str, Server]], ) }, id="generate list dict response", @@ -85,25 +86,21 @@ def view_func_none_return() -> None: {200: list[Server], 202: dict[str, Server]}, view_func_union, { - 200: create_model_field( - name="Response 200", type_=list[Server], required=True - ), - 202: create_model_field( - name="Response 202", type_=Mapping[str, Server], required=True - ), + 200: create_model_field(name="Response 200", type_=list[Server]), + 202: create_model_field(name="Response 202", type_=dict[str, Server]), }, id="Multiple responses - Union return type", ), pytest.param( Server, view_func_pydantic_object, - {200: create_model_field(name="Response 200", type_=Server, required=True)}, + {200: create_model_field(name="Response 200", type_=Server)}, id="Specified response", ), pytest.param( {"200": str}, view_func_str, - {200: create_model_field(name="Response 200", type_=str, required=True)}, + {200: create_model_field(name="Response 200", type_=str)}, id="string return codes", ), ], @@ -171,89 +168,76 @@ def test_get_schema(): ) models = operation.get_models() - flat_models = get_flat_models_from_fields(models, known_models=set()) - model_name_map = get_model_name_map(flat_models) - - assert operation.get_schema(model_name_map=model_name_map).json( - by_alias=True, exclude_none=True - ) == json.dumps( - { - "summary": "Some title.", - "description": "Some long description", - "parameters": [ - { - "description": "Some int", - "required": True, - "schema": { - "title": "Bid", - "type": "integer", - "description": "Some int", - }, - "name": "bid", - "in": "query", - } - ], - "requestBody": { + + schema_generator = GenerateJsonSchema(ref_template="#/components/schemas/{model}") + inputs = [(field, field.mode, field.type_adapter.core_schema) for field in models] + field_mapping, _ = schema_generator.generate_definitions(inputs=inputs) + + assert operation.get_schema(field_mapping=field_mapping).model_dump( + by_alias=True, exclude_none=True, mode="json" + ) == { + "summary": "Some title.", + "description": "Some long description", + "parameters": [ + { + "description": "Some int", + "required": True, + "schema": {"type": "integer"}, + "name": "bid", + "in": "query", + } + ], + "requestBody": { + "description": "", + "content": { + "application/json": {"schema": {"$ref": "#/components/schemas/Server"}} + }, + "required": True, + }, + "responses": { + "200": { "description": "", "content": { "application/json": { - "schema": { - "title": "Server", - "allOf": [{"$ref": "#/components/schemas/Server"}], - "description": "Some server", - } + "schema": {"$ref": "#/components/schemas/Server"} } }, - "required": True, - }, - "responses": { - "200": { - "description": "", - "content": { - "application/json": { - "schema": {"$ref": "#/components/schemas/Server"} - } - }, - } - }, - "callbacks": { - "callback": { - "someurl": { - "get": { - "parameters": [ - { + } + }, + "callbacks": { + "callback": { + "someurl": { + "get": { + "parameters": [ + { + "description": "Some callback param description", + "required": True, + "schema": { + "type": "null", "description": "Some callback param description", - "required": True, - "schema": { - "title": "Get Param", - "type": "integer", - "description": "Some callback param description", - }, - "name": "get_param", - "in": "query", - } - ], - "requestBody": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/Server" - } - } }, - "required": True, - }, - "responses": { - "200": {"description": "Success"}, - "500": {"description": "Error"}, + "name": "get_param", + "in": "query", + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/Server"} + } }, - } + "required": True, + }, + "responses": { + "200": {"description": "Success"}, + "500": {"description": "Error"}, + }, } } - }, - "security": [{"bearerTokenAuth": []}], - } - ) + } + }, + "security": [{"bearerTokenAuth": []}], + } @pytest.mark.parametrize( diff --git a/tests/unit/test_router.py b/tests/unit/test_router.py index 8e977c4..f4c2e77 100644 --- a/tests/unit/test_router.py +++ b/tests/unit/test_router.py @@ -35,7 +35,7 @@ def sample_method(): assert router.operations[0].path == "/foo" assert router.operations[0].method == "GET" assert str(router.operations[0].responses) == str( - {200: create_model_field(name="Response 200", type_=str, required=True)} + {200: create_model_field(name="Response 200", type_=str)} ) assert router.operations[0].callbacks == [callback] assert router.operations[0].summary == "some_summary" @@ -57,7 +57,7 @@ def sample_method() -> str: assert router.operations[0].path == "/foo" assert router.operations[0].method == "GET" assert str(router.operations[0].responses) == str( - {200: create_model_field(name="Response 200", type_=str, required=True)} + {200: create_model_field(name="Response 200", type_=str)} ) assert router.operations[0].callbacks is None assert router.operations[0].summary == "" diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py new file mode 100644 index 0000000..e69de29 diff --git a/tox.ini b/tox.ini index 5e4d4a9..5e6333a 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py3.9, py3.10 +envlist = py3.9, py3.10, py3.11 isolated_build = True [testenv]