diff --git a/ariadne_codegen/contrib/no_global_imports.py b/ariadne_codegen/contrib/no_global_imports.py index eb69e40..f3fc08e 100644 --- a/ariadne_codegen/contrib/no_global_imports.py +++ b/ariadne_codegen/contrib/no_global_imports.py @@ -15,6 +15,9 @@ from ariadne_codegen import Plugin +TYPE_CHECKING_MODULE: str = "typing" +TYPE_CHECKING_FLAG: str = "TYPE_CHECKING" + class NoGlobalImportsPlugin(Plugin): """Only import types when you call an endpoint needing it""" @@ -313,7 +316,7 @@ def _update_imports(self, module: ast.Module) -> ast.Name | None: type_checking_imports[module_name].names.append(ast.alias(cls)) import_if_type_checking = ast.If( - test=ast.Name(id="TYPE_CHECKING"), + test=ast.Name(id=TYPE_CHECKING_FLAG), body=list(type_checking_imports.values()), orelse=[], ) @@ -324,8 +327,8 @@ def _update_imports(self, module: ast.Module) -> ast.Name | None: module.body.insert( len(non_empty_imports), ast.ImportFrom( - module="typing", - names=[ast.Name("TYPE_CHECKING")], + module=TYPE_CHECKING_MODULE, + names=[ast.Name(TYPE_CHECKING_FLAG)], ), )