Skip to content

Commit

Permalink
chore: Upgrade pydantic>=2.0, flask>=3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
Michal Korbela committed Nov 6, 2023
1 parent d19a00e commit 829e069
Show file tree
Hide file tree
Showing 14 changed files with 467 additions and 407 deletions.
2 changes: 1 addition & 1 deletion flask_ninja/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from flask_ninja.api import NinjaAPI, Server
from flask_ninja.constants import ParamType
from flask_ninja.operation import ApiConfigError, Callback, Operation
from flask_ninja.param_functions import Header, Path, Query
from flask_ninja.param_functions import Body, Header, Path, Query
from flask_ninja.router import Router
from flask_ninja.security import HttpAuthBase, HttpBearer
29 changes: 15 additions & 14 deletions flask_ninja/api.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import json
import re
from typing import Any, Callable, Optional

from flask import Blueprint, Flask, render_template
from pydantic.schema import get_flat_models_from_fields, get_model_name_map
from pydantic.json_schema import GenerateJsonSchema

from .constants import NOT_SET
from .models import Components, Info, OpenAPI, Server
from .router import Router
from .swagger_ui import swagger_ui_path
from .utils import get_model_definitions


class NinjaAPI:
Expand Down Expand Up @@ -70,7 +68,7 @@ def delete(self, path: str, **kwargs: Any) -> Callable:
def add_router(self, router: Router, prefix: str = "") -> None:
self.router.add_router(router, f"{self.prefix}{prefix}")

def get_schema(self) -> str:
def get_schema(self) -> dict[str, Any]:
"""Creates OpenAPI schema for the API."""

# At first we collect all pydantic models used anywhere
Expand All @@ -79,17 +77,20 @@ def get_schema(self) -> str:
for operation in self.router.operations:
models += operation.get_models()

# Then we create from them flat models - it means we extract the models from Generics
flat_models = get_flat_models_from_fields(models, known_models=set())
# Then we generate unique names for them - if there are two models from different modules
# but with the same name, we need provide different names for them in the Definitions list
model_name_map = get_model_name_map(flat_models)
schema_generator = GenerateJsonSchema(
ref_template="#/components/schemas/{model}"
)

inputs = [
(field, field.mode, field.type_adapter.core_schema) for field in models
]
field_mapping, definitions = schema_generator.generate_definitions(
inputs=inputs
)
print(field_mapping)
paths: dict = {}
security_schemes: dict = {}
# Create OpenAPI schemas for all models
definitions = get_model_definitions(
flat_models=flat_models, model_name_map=model_name_map
)

# Create OpenAPI schema for all operations
for operation in self.router.operations:
Expand All @@ -99,7 +100,7 @@ def get_schema(self) -> str:
if swagger_path not in paths:
paths[swagger_path] = {}
paths[swagger_path][operation.method.lower()] = operation.get_schema(
model_name_map=model_name_map
field_mapping=field_mapping
)
if operation.auth:
security_schemes.update(operation.auth.schema())
Expand All @@ -118,4 +119,4 @@ def get_schema(self) -> str:
servers=self.servers or None,
)

return json.loads(schema.json(by_alias=True, exclude_none=True))
return schema.model_dump(mode="json", by_alias=True, exclude_none=True)
68 changes: 68 additions & 0 deletions flask_ninja/model_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import types
from dataclasses import dataclass
from typing import Annotated, Any, Dict, List, Literal, Sequence, Set, Tuple, Union

from pydantic import TypeAdapter
from pydantic.fields import FieldInfo
from pydantic.json_schema import JsonSchemaValue
from pydantic_core import PydanticUndefined, PydanticUndefinedType

Required = PydanticUndefined
Undefined = PydanticUndefined
UndefinedType = PydanticUndefinedType
IncEx = Union[Set[int], Set[str], Dict[int, Any], Dict[str, Any]]
UnionType = getattr(types, "UnionType", Union)


@dataclass
class ModelField:
field_info: FieldInfo
name: str
mode: Literal["validation", "serialization"] = "validation"

@property
def alias(self) -> str:
a = self.field_info.alias
return a if a is not None else self.name

@property
def required(self) -> bool:
return self.field_info.is_required()

@property
def default(self) -> Any:
return self.get_default()

@property
def type_(self) -> Any:
return self.field_info.annotation

def __post_init__(self) -> None:
self.type_adapter: TypeAdapter[Any] = TypeAdapter(
Annotated[self.field_info.annotation, self.field_info]
)

def get_default(self) -> Any:
if self.field_info.is_required():
return Undefined
return self.field_info.get_default(call_default_factory=True)

def __hash__(self) -> int:
# Each ModelField is unique for our purposes, to allow making a dict from
# ModelField to its JSON Schema.
return id(self)


def _regenerate_error_with_loc(
*, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...]
) -> List[Dict[str, Any]]:
updated_loc_errors: List[Any] = [
{**err, "loc": loc_prefix + err.get("loc", ())} for err in errors
]

return updated_loc_errors


FieldMapping = Dict[
tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
]
Loading

0 comments on commit 829e069

Please sign in to comment.