Skip to content

Commit

Permalink
Merge pull request #278 from mirumee/model_rebuild_calls
Browse files Browse the repository at this point in the history
add model_rebuild_calls for result_types.py
  • Loading branch information
rafalp authored Feb 27, 2024
2 parents b5a67e1 + 6cf3077 commit fdc5fb8
Show file tree
Hide file tree
Showing 36 changed files with 142 additions and 2 deletions.
36 changes: 34 additions & 2 deletions ariadne_codegen/client_generators/result_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
generate_ann_assign,
generate_class_def,
generate_constant,
generate_expr,
generate_import_from,
generate_method_call,
generate_module,
generate_pass,
generate_pydantic_field,
Expand All @@ -55,6 +57,7 @@
MIXIN_FROM_NAME,
MIXIN_IMPORT_NAME,
MIXIN_NAME,
MODEL_REBUILD_METHOD,
OPTIONAL,
PYDANTIC_MODULE,
TYPENAME_ALIAS,
Expand Down Expand Up @@ -152,8 +155,16 @@ def _get_operation_type_name(self, definition: ExecutableDefinitionNode) -> str:
raise NotSupported(f"Not supported operation type: {definition}")

def generate(self) -> ast.Module:
module_body = cast(List[ast.stmt], self._imports) + cast(
List[ast.stmt], self._class_defs
model_rebuild_calls = [
generate_expr(generate_method_call(class_def.name, MODEL_REBUILD_METHOD))
for class_def in self._class_defs
if self.include_model_rebuild(class_def)
]

module_body = (
cast(List[ast.stmt], self._imports)
+ cast(List[ast.stmt], self._class_defs)
+ cast(List[ast.stmt], model_rebuild_calls)
)

module = generate_module(module_body)
Expand All @@ -163,6 +174,11 @@ def generate(self) -> ast.Module:
)
return module

def include_model_rebuild(self, class_def: ast.ClassDef) -> bool:
visitor = ClassDefNamesVisitor()
visitor.visit(class_def)
return visitor.found_name_with_quote

def get_imports(self) -> List[ast.ImportFrom]:
return self._imports

Expand Down Expand Up @@ -560,3 +576,19 @@ def enter_field(node: FieldNode, *_args: Any) -> FieldNode:
copied_node = deepcopy(node)
visit(copied_node, RemoveMixinVisitor())
return copied_node


class ClassDefNamesVisitor(ast.NodeVisitor):
def __init__(self):
self.found_name_with_quote = False

def visit_Name(self, node): # pylint: disable=C0103
if '"' in node.id:
self.found_name_with_quote = True
self.generic_visit(node)

def visit_Subscript(self, node): # pylint: disable=C0103
if isinstance(node.value, ast.Name) and node.value.id == "Literal":
return

self.generic_visit(node)
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,6 @@ class GetQueryA(BaseModel):

class GetQueryAQueryA(BaseModel):
field_a: int = Field(alias="fieldA")


GetQueryA.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,6 @@ class GetQueryA(BaseModel):

class GetQueryAQueryA(BaseModel):
field_a: int = Field(alias="fieldA")


GetQueryA.model_rebuild()
3 changes: 3 additions & 0 deletions tests/main/clients/custom_scalars/expected_client/get_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,6 @@ class GetATestQuery(BaseModel):
code: Annotated[Code, BeforeValidator(parse_code)]
id: int
other: Any


GetA.model_rebuild()
3 changes: 3 additions & 0 deletions tests/main/clients/example/expected_client/create_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@ class CreateUser(BaseModel):

class CreateUserUserCreate(BaseModel):
id: str


CreateUser.model_rebuild()
4 changes: 4 additions & 0 deletions tests/main/clients/example/expected_client/list_all_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,7 @@ class ListAllUsersUsers(BaseModel):

class ListAllUsersUsersLocation(BaseModel):
country: Optional[str]


ListAllUsers.model_rebuild()
ListAllUsersUsers.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@ class ListUsersByCountry(BaseModel):

class ListUsersByCountryUsers(BasicUser, UserPersonalData):
favourite_color: Optional[Color] = Field(alias="favouriteColor")


ListUsersByCountry.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,6 @@ class FragmentsWithMixinsQueryA(FragmentA, CommonMixin):

class FragmentsWithMixinsQueryB(FragmentB, CommonMixin):
pass


FragmentsWithMixins.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@ class GetQueryA(BaseModel):

class GetQueryAQueryA(BaseModel, MixinA, CommonMixin):
field_a: int = Field(alias="fieldA")


GetQueryA.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@ class GetQueryB(BaseModel):

class GetQueryBQueryB(BaseModel, MixinB, CommonMixin):
field_b: str = Field(alias="fieldB")


GetQueryB.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,6 @@ class QueryWithFragmentOnSubInterfaceQueryInterfaceBaseInterface(BaseModel):

class QueryWithFragmentOnSubInterfaceQueryInterfaceInterfaceA(FragmentA):
typename__: Literal["InterfaceA"] = Field(alias="__typename")


QueryWithFragmentOnSubInterface.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,6 @@ class QueryWithFragmentOnSubInterfaceWithInlineFragmentQueryInterfaceTypeA(BaseM
id: str
value_a: str = Field(alias="valueA")
another: str


QueryWithFragmentOnSubInterfaceWithInlineFragment.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,6 @@ class QueryWithFragmentOnUnionMemberQueryUnionTypeA(BaseModel):

class QueryWithFragmentOnUnionMemberQueryUnionTypeB(FragmentB):
typename__: Literal["TypeB"] = Field(alias="__typename")


QueryWithFragmentOnUnionMember.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,6 @@ class InterfaceAQueryITypeB(BaseModel):
typename__: Literal["TypeB"] = Field(alias="__typename")
id: str
field_b: str = Field(alias="fieldB")


InterfaceA.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,6 @@ class InterfaceBQueryITypeA(BaseModel):
typename__: Literal["TypeA"] = Field(alias="__typename")
id: str
field_a: str = Field(alias="fieldA")


InterfaceB.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@ class InterfaceCQueryI(BaseModel):
alias="__typename"
)
id: str


InterfaceC.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@ class InterfaceWithTypenameQueryI(BaseModel):
alias="__typename"
)
id: str


InterfaceWithTypename.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,6 @@ class ListInterfaceQueryListITypeB(BaseModel):
typename__: Literal["TypeB"] = Field(alias="__typename")
id: str
field_b: str = Field(alias="fieldB")


ListInterface.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,6 @@ class ListUnionQueryListUTypeB(BaseModel):

class ListUnionQueryListUTypeC(BaseModel):
typename__: Literal["TypeC"] = Field(alias="__typename")


ListUnion.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,6 @@ class QueryWithFragmentOnInterfaceQueryITypeB(BaseModel):
typename__: Literal["TypeB"] = Field(alias="__typename")
id: str
field_b: str = Field(alias="fieldB")


QueryWithFragmentOnInterface.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,6 @@ class QueryWithFragmentOnUnionQueryUTypeB(BaseModel):

class QueryWithFragmentOnUnionQueryUTypeC(BaseModel):
typename__: Literal["TypeC"] = Field(alias="__typename")


QueryWithFragmentOnUnion.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@ class UnionAQueryUTypeB(BaseModel):

class UnionAQueryUTypeC(BaseModel):
typename__: Literal["TypeC"] = Field(alias="__typename")


UnionA.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,6 @@ class UnionBQueryUTypeB(BaseModel):

class UnionBQueryUTypeC(BaseModel):
typename__: Literal["TypeC"] = Field(alias="__typename")


UnionB.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@ class ExampleQuery1(BaseModel):

class ExampleQuery1ExampleQuery(MinimalA):
value: str


ExampleQuery1.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@ class ExampleQuery2(BaseModel):

class ExampleQuery2ExampleQuery(FullA):
pass


ExampleQuery2.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@ class ExampleQuery3(BaseModel):

class ExampleQuery3ExampleQuery(CompleteA):
pass


ExampleQuery3.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,6 @@ class GetF(BaseModel):

class GetFF(BaseModel):
val: EnumF


GetF.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,6 @@ class GetG(BaseModel):

class GetGG(FragmentG):
pass


GetG.model_rebuild()
4 changes: 4 additions & 0 deletions tests/main/clients/operations/expected_client/get_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,7 @@ class GetAA(BaseModel):

class GetAAValueB(BaseModel):
value: str


GetA.model_rebuild()
GetAA.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@ class GetAWithFragmentA(BaseModel):

class GetAWithFragmentAValueB(FragmentB):
pass


GetAWithFragment.model_rebuild()
GetAWithFragmentA.model_rebuild()
3 changes: 3 additions & 0 deletions tests/main/clients/operations/expected_client/get_s.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@ class GetS(BaseModel):

class GetSS(BaseModel):
id: int


GetS.model_rebuild()
3 changes: 3 additions & 0 deletions tests/main/clients/operations/expected_client/get_xyz.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,6 @@ class GetXYZXyzTypeY(FragmentY):

class GetXYZXyzTypeZ(BaseModel):
typename__: Literal["TypeZ"] = Field(alias="__typename")


GetXYZ.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,6 @@ class GetAnimalByNameAnimalByNameDog(BaseModel):
typename__: Literal["Dog"] = Field(alias="__typename")
name: str
puppies: int


GetAnimalByName.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,6 @@ class GetAuthenticatedUser(BaseModel):
class GetAuthenticatedUserMe(BaseModel):
id: str
username: str


GetAuthenticatedUser.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,6 @@ class ListAnimalsListAnimalsDog(BaseModel):
typename__: Literal["Dog"] = Field(alias="__typename")
name: str
puppies: int


ListAnimals.model_rebuild()
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@ class ListTypeA(BaseModel):

class ListTypeAListOptionalTypeA(BaseModel):
id: int


ListTypeA.model_rebuild()

0 comments on commit fdc5fb8

Please sign in to comment.