diff --git a/CHANGELOG.md b/CHANGELOG.md index 9c6bebc..ee50112 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # CHANGELOG +## 0.14.0 (Unreleased) + +- Re-added `model_rebuild` calls for input types with forward references. + + ## 0.13.0 (2024-03-4) - Fixed `str_to_snake_case` utility to capture fully capitalized words followed by an underscore. diff --git a/ariadne_codegen/client_generators/input_types.py b/ariadne_codegen/client_generators/input_types.py index 4d5d4ca..3a94541 100644 --- a/ariadne_codegen/client_generators/input_types.py +++ b/ariadne_codegen/client_generators/input_types.py @@ -13,10 +13,13 @@ generate_ann_assign, generate_class_def, generate_constant, + generate_expr, generate_import_from, generate_keyword, + generate_method_call, generate_module, generate_pydantic_field, + model_has_forward_refs, ) from ..plugins.manager import PluginManager from ..utils import process_name @@ -28,6 +31,7 @@ BASE_MODEL_IMPORT, FIELD_CLASS, LIST, + MODEL_REBUILD_METHOD, OPTIONAL, PLAIN_SERIALIZER, PYDANTIC_MODULE, @@ -85,8 +89,16 @@ def generate(self, types_to_include: Optional[List[str]] = None) -> ast.Module: scalar_data = self.custom_scalars[scalar_name] self._imports.extend(generate_scalar_imports(scalar_data)) - module_body = cast(List[ast.stmt], self._imports) + cast( - List[ast.stmt], class_defs + model_rebuild_calls = [ + generate_expr(generate_method_call(class_def.name, MODEL_REBUILD_METHOD)) + for class_def in class_defs + if model_has_forward_refs(class_def) + ] + + module_body = ( + cast(List[ast.stmt], self._imports) + + cast(List[ast.stmt], class_defs) + + cast(List[ast.stmt], model_rebuild_calls) ) module = generate_module(body=module_body) diff --git a/ariadne_codegen/client_generators/result_types.py b/ariadne_codegen/client_generators/result_types.py index 6820e5a..b43ac4a 100644 --- a/ariadne_codegen/client_generators/result_types.py +++ b/ariadne_codegen/client_generators/result_types.py @@ -40,6 +40,7 @@ generate_module, generate_pass, generate_pydantic_field, + model_has_forward_refs, ) from ..exceptions import NotSupported, ParsingError from ..plugins.manager import PluginManager @@ -158,7 +159,7 @@ def generate(self) -> ast.Module: model_rebuild_calls = [ generate_expr(generate_method_call(class_def.name, MODEL_REBUILD_METHOD)) for class_def in self._class_defs - if self.include_model_rebuild(class_def) + if model_has_forward_refs(class_def) ] module_body = ( @@ -174,11 +175,6 @@ def generate(self) -> ast.Module: ) return module - def include_model_rebuild(self, class_def: ast.ClassDef) -> bool: - visitor = ClassDefNamesVisitor() - visitor.visit(class_def) - return visitor.found_name_with_quote - def get_imports(self) -> List[ast.ImportFrom]: return self._imports @@ -576,19 +572,3 @@ def enter_field(node: FieldNode, *_args: Any) -> FieldNode: copied_node = deepcopy(node) visit(copied_node, RemoveMixinVisitor()) return copied_node - - -class ClassDefNamesVisitor(ast.NodeVisitor): - def __init__(self): - self.found_name_with_quote = False - - def visit_Name(self, node): # pylint: disable=C0103 - if '"' in node.id: - self.found_name_with_quote = True - self.generic_visit(node) - - def visit_Subscript(self, node): # pylint: disable=C0103 - if isinstance(node.value, ast.Name) and node.value.id == "Literal": - return - - self.generic_visit(node) diff --git a/ariadne_codegen/codegen.py b/ariadne_codegen/codegen.py index c5a74df..40a8b89 100644 --- a/ariadne_codegen/codegen.py +++ b/ariadne_codegen/codegen.py @@ -332,3 +332,25 @@ def generate_yield(value: Optional[ast.expr] = None) -> ast.Yield: def generate_pass() -> ast.Pass: return ast.Pass() + + +def model_has_forward_refs(class_def: ast.ClassDef) -> bool: + visitor = ClassDefNamesVisitor() + visitor.visit(class_def) + return visitor.found_name_with_quote + + +class ClassDefNamesVisitor(ast.NodeVisitor): + def __init__(self): + self.found_name_with_quote = False + + def visit_Name(self, node): # pylint: disable=C0103 + if '"' in node.id: + self.found_name_with_quote = True + self.generic_visit(node) + + def visit_Subscript(self, node): # pylint: disable=C0103 + if isinstance(node.value, ast.Name) and node.value.id == "Literal": + return + + self.generic_visit(node) diff --git a/tests/main/clients/example/expected_client/input_types.py b/tests/main/clients/example/expected_client/input_types.py index e44cb44..933cfe8 100644 --- a/tests/main/clients/example/expected_client/input_types.py +++ b/tests/main/clients/example/expected_client/input_types.py @@ -46,3 +46,7 @@ class NotificationsPreferencesInput(BaseModel): receive_push_notifications: bool = Field(alias="receivePushNotifications") receive_sms: bool = Field(alias="receiveSms") title: str + + +UserCreateInput.model_rebuild() +UserPreferencesInput.model_rebuild() diff --git a/tests/main/clients/only_used_inputs_and_enums/expected_client/input_types.py b/tests/main/clients/only_used_inputs_and_enums/expected_client/input_types.py index 392b4cb..57f8515 100644 --- a/tests/main/clients/only_used_inputs_and_enums/expected_client/input_types.py +++ b/tests/main/clients/only_used_inputs_and_enums/expected_client/input_types.py @@ -26,3 +26,8 @@ class InputAB(BaseModel): class InputE(BaseModel): val: EnumE + + +InputA.model_rebuild() +InputAA.model_rebuild() +InputAB.model_rebuild()