Skip to content

Commit

Permalink
Merge pull request #267 from bombsimon/nested_fragments
Browse files Browse the repository at this point in the history
Restore model_rebuild calls to top level fragments
  • Loading branch information
rafalp authored Jan 25, 2024
2 parents fa7b04a + da102b5 commit 9edc90b
Show file tree
Hide file tree
Showing 13 changed files with 128 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
- Fixed `graphql-transport-ws` protocol implementation not waiting for the `connection_ack` message on new connection.
- Fixed `get_client_settings` mutating `config_dict` instance.
- Added support to `graphqlschema` for saving schema as a GraphQL file.
- Restored `model_rebuild` calls for top level fragment models.


## 0.11.0 (2023-12-05)
Expand Down
4 changes: 4 additions & 0 deletions EXAMPLE.md
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,10 @@ class BasicUser(BaseModel):
class UserPersonalData(BaseModel):
first_name: Optional[str] = Field(alias="firstName")
last_name: Optional[str] = Field(alias="lastName")


BasicUser.model_rebuild()
UserPersonalData.model_rebuild()
```

### Init file
Expand Down
1 change: 1 addition & 0 deletions ariadne_codegen/client_generators/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
MODEL_VALIDATE_METHOD = "model_validate"
PLAIN_SERIALIZER = "PlainSerializer"
BEFORE_VALIDATOR = "BeforeValidator"
MODEL_REBUILD_METHOD = "model_rebuild"

ENUM_MODULE = "enum"
ENUM_CLASS = "Enum"
Expand Down
32 changes: 28 additions & 4 deletions ariadne_codegen/client_generators/fragments.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

from graphql import FragmentDefinitionNode, GraphQLSchema

from ..codegen import generate_module
from ..codegen import generate_expr, generate_method_call, generate_module
from ..plugins.manager import PluginManager
from .constants import BASE_MODEL_IMPORT
from .constants import BASE_MODEL_IMPORT, MODEL_REBUILD_METHOD
from .result_types import ResultTypesGenerator
from .scalars import ScalarData

Expand Down Expand Up @@ -36,6 +36,7 @@ def __init__(
def generate(self, exclude_names: Optional[Set[str]] = None) -> ast.Module:
class_defs_dict: Dict[str, List[ast.ClassDef]] = {}
imports: List[ast.ImportFrom] = []
top_level_class_names: List[str] = []
dependencies_dict: Dict[str, Set[str]] = {}

names_to_exclude = exclude_names or set()
Expand All @@ -53,7 +54,10 @@ def generate(self, exclude_names: Optional[Set[str]] = None) -> ast.Module:
plugin_manager=self.plugin_manager,
)
imports.extend(generator.get_imports())
class_defs_dict[name] = generator.get_classes()
class_defs = generator.get_classes()
class_defs_dict[name] = class_defs
if class_defs:
top_level_class_names.append(class_defs[0].name)
dependencies_dict[name] = generator.get_fragments_used_as_mixins()
self._generated_public_names.extend(generator.get_generated_public_names())
self._used_enums.extend(generator.get_used_enums())
Expand All @@ -62,7 +66,15 @@ def generate(self, exclude_names: Optional[Set[str]] = None) -> ast.Module:
class_defs_dict=class_defs_dict, dependencies_dict=dependencies_dict
)
module = generate_module(
body=cast(List[ast.stmt], imports) + cast(List[ast.stmt], sorted_class_defs)
body=cast(List[ast.stmt], imports)
+ cast(List[ast.stmt], sorted_class_defs)
+ cast(
List[ast.stmt],
self._get_model_rebuild_calls(
top_level_fragments_names=top_level_class_names,
class_defs=sorted_class_defs,
),
)
)
if self.plugin_manager:
module = self.plugin_manager.generate_fragments_module(
Expand Down Expand Up @@ -108,3 +120,15 @@ def visit(name):
visit(name)

return sorted_names

def _get_model_rebuild_calls(
self, top_level_fragments_names: List[str], class_defs: List[ast.ClassDef]
) -> List[ast.Call]:
class_names = [c.name for c in class_defs]
sorted_fragments_names = sorted(
top_level_fragments_names, key=class_names.index
)
return [
generate_expr(generate_method_call(name, MODEL_REBUILD_METHOD))
for name in sorted_fragments_names
]
4 changes: 4 additions & 0 deletions tests/main/clients/example/expected_client/fragments.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,7 @@ class BasicUser(BaseModel):
class UserPersonalData(BaseModel):
first_name: Optional[str] = Field(alias="firstName")
last_name: Optional[str] = Field(alias="lastName")


BasicUser.model_rebuild()
UserPersonalData.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,8 @@ class GetQueryAFragment(BaseModel):

class GetQueryAFragmentQueryA(BaseModel, MixinA, CommonMixin):
field_a: int = Field(alias="fieldA")


FragmentA.model_rebuild()
FragmentB.model_rebuild()
GetQueryAFragment.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@ class FragmentA(BaseModel):
class FragmentB(BaseModel):
id: str
value_b: str = Field(alias="valueB")


FragmentA.model_rebuild()
FragmentB.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,8 @@ class FragmentOnQueryWithUnionQueryUTypeC(BaseModel):
class UnusedFragmentOnTypeA(BaseModel):
id: str
field_a: str = Field(alias="fieldA")


FragmentOnQueryWithInterface.model_rebuild()
FragmentOnQueryWithUnion.model_rebuild()
UnusedFragmentOnTypeA.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,10 @@ class MinimalA(BaseModel):

class MinimalAFieldB(MinimalB):
pass


CompleteA.model_rebuild()
FullB.model_rebuild()
FullA.model_rebuild()
MinimalB.model_rebuild()
MinimalA.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,7 @@ class FragmentG(BaseModel):

class FragmentGG(BaseModel):
val: EnumGG


FragmentG.model_rebuild()
FragmentGG.model_rebuild()
4 changes: 4 additions & 0 deletions tests/main/clients/operations/expected_client/fragments.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,7 @@ class FragmentB(BaseModel):

class FragmentY(BaseModel):
value_y: int = Field(alias="valueY")


FragmentB.model_rebuild()
FragmentY.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,7 @@ class ListAnimalsFragment(BaseModel):
class ListAnimalsFragmentListAnimals(BaseModel):
typename__: Literal["Animal", "Cat", "Dog"] = Field(alias="__typename")
name: str


FragmentWithSingleField.model_rebuild()
ListAnimalsFragment.model_rebuild()
57 changes: 57 additions & 0 deletions tests/main/test_model_rebuild_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""
To ensure all models with nested dependencies are fully rebuilt those tests
create an instance of the query from `multiple_fragments` containing the
`FullA` fragment (used by the `ExampleQuery2ExampleQuery`) which itself includes
a field of type `FullAFieldB` that extends the `FullB` fragment.
If this model is not rebuilt with `FullA.model_rebuild()` `ExampleQuery2` will
not be fully defined and we will raise a `PydanticUserError`.
Reference to Pydantic documentation about when and why we need to call
`model_rebuild` on our types:
https://errors.pydantic.dev/2.5/u/class-not-fully-defined
"""

import pytest
from pydantic_core import ValidationError

from .clients.multiple_fragments.expected_client.example_query_2 import (
ExampleQuery2,
ExampleQuery2ExampleQuery,
)
from .clients.multiple_fragments.expected_client.fragments import FullA


def test_json_schema_contains_all_properties():
json_schema = ExampleQuery2.model_json_schema()
assert "ExampleQuery2ExampleQuery" in json_schema["$defs"]
assert "FullAFieldB" in json_schema["$defs"]

query_props = json_schema["$defs"]["ExampleQuery2ExampleQuery"]["properties"]
assert "id" in query_props
assert "value" in query_props
assert "fieldB" in query_props
assert query_props["fieldB"]["$ref"] == "#/$defs/FullAFieldB"


@pytest.fixture
def field_a_data():
field_b = {"id": "321", "value": 13.37}
field_a = {"id": "123", "value": "A", "field_b": field_b}

return field_a


def test_validate_field_a_on_faulty_model(field_a_data):
with pytest.raises(ValidationError):
ExampleQuery2.model_validate(field_a_data)


def test_validate_field_a_on_correct_model(field_a_data):
FullA.model_validate(field_a_data)
ExampleQuery2ExampleQuery.model_validate(field_a_data)


def test_validate_field_a_in_example_query(field_a_data):
example_query_2 = {"example_query": field_a_data}
ExampleQuery2.model_validate(example_query_2)

0 comments on commit 9edc90b

Please sign in to comment.