Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add plugin to reduce client imports #287

Merged
merged 11 commits into from
Apr 3, 2024
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## 0.14.0 (Unreleased)

- Added `NoGlobalImportsPlugin` to standard plugins.
- Re-added `model_rebuild` calls for input types with forward references.


Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ Ariadne Codegen ships with optional plugins importable from the `ariadne_codegen

- [`ariadne_codegen.contrib.extract_operations.ExtractOperationsPlugin`](ariadne_codegen/contrib/extract_operations.py) - This extracts query strings from generated client's methods into separate `operations.py` module. It also modifies the generated client to import these definitions. Generated module name can be customized by adding `operations_module_name="custom_name"` to the `[tool.ariadne-codegen.operations]` section in config. Eg.:

- [`ariadne_codegen.contrib.no_global_imports.NoGlobalImportsPlugin`](ariadne_codegen/contrib/no_global_imports.py) - This plugin processes generated client module and convert all input arguments and return types to strings. The types will be imported only for type checking.
bombsimon marked this conversation as resolved.
Show resolved Hide resolved

```toml
[tool.ariadne-codegen]
...
Expand Down
8 changes: 7 additions & 1 deletion ariadne_codegen/contrib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from .extract_operations import ExtractOperationsPlugin
from .no_global_imports import NoGlobalImportsPlugin
from .no_reimports import NoReimportsPlugin
from .shorter_results import ShorterResultsPlugin

__all__ = ["ExtractOperationsPlugin", "NoReimportsPlugin", "ShorterResultsPlugin"]
__all__ = [
bombsimon marked this conversation as resolved.
Show resolved Hide resolved
"ExtractOperationsPlugin",
"NoReimportsPlugin",
"ShorterResultsPlugin",
"NoGlobalImportsPlugin",
]
365 changes: 365 additions & 0 deletions ariadne_codegen/contrib/no_global_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,365 @@
"""
Plugin to only import types for GraphQL responses when you call methods.

This will massively reduce import times for larger projects since you only have
to load the input types when loading the client.

All input and return types that's used to process the server response will
only be imported when the method is called.
"""

import ast
from typing import Dict, List, Optional, Set, Union

from graphql import GraphQLSchema

from ariadne_codegen import Plugin

TYPE_CHECKING_MODULE: str = "typing"
TYPE_CHECKING_FLAG: str = "TYPE_CHECKING"


class NoGlobalImportsPlugin(Plugin):
bombsimon marked this conversation as resolved.
Show resolved Hide resolved
"""Only import types when you call an endpoint needing it"""

def __init__(self, schema: GraphQLSchema, config_dict: Dict) -> None:
"""Constructor"""
# Types that should only be imported in a `TYPE_CHECKING` context. This
# is all the types used as arguments to a method or as a return type,
# i.e. for type checking.
self.input_and_return_types: Set[str] = set()

# Imported classes are classes imported from local imports. We keep a
# map between name and module so we know how to import them in each
# method.
self.imported_classes: Dict[str, str] = {}

# Imported classes in each method definition.
self.imported_in_method: Set[str] = set()

super().__init__(schema, config_dict)

def generate_client_module(self, module: ast.Module) -> ast.Module:
"""
Update the generated client.

This will parse all current imports to map them to a path. It will then
traverse all methods and look for the actual return type. The return
node will be converted to an `ast.Constant` if it's an `ast.Name` and
the return type will be imported only under `if TYPE_CHECKING`
conditions.

It will also move all imports of the types used to parse the response
inside each method since that's the only place where they're used. The
result will be that we end up with imports in the global scope only for
types used as input types.

:param module: The ast for the module
:returns: A modified `ast.Module`
"""
self._store_imported_classes(module.body)

# Find the actual client class so we can grab all input and output
# types. We also ensure to manipulate the ast while we do this.
client_class_def = next(
filter(lambda o: isinstance(o, ast.ClassDef), module.body), None
)
if not client_class_def or not isinstance(client_class_def, ast.ClassDef):
return super().generate_client_module(module)

for method_def in [
m
for m in client_class_def.body
if isinstance(m, (ast.FunctionDef, ast.AsyncFunctionDef))
]:
method_def = self._rewrite_input_args_to_constants(method_def)

# If the method returns anything, update whatever it returns.
if method_def.returns:
method_def.returns = self._update_name_to_constant(method_def.returns)

self._insert_import_statement_in_method(method_def)

self._update_imports(module)

return super().generate_client_module(module)

def _store_imported_classes(self, module_body: List[ast.stmt]):
"""Fetch and store imported classes.

Grab all imported classes with level 1 or starting with `.` because
these are the ones generated by us. We store a map between the class and
which module it was imported from so we can easily import it when
needed. This can be in a `TYPE_CHECKING` condition or inside a method.

:param module_body: The body of an `ast.Module`
"""
for node in module_body:
if not isinstance(node, ast.ImportFrom):
continue

if node.module is None:
continue

# We only care about local imports from our generated code.
if node.level != 1 and not node.module.startswith("."):
continue

for name in node.names:
from_ = "." * node.level + node.module
if isinstance(name, ast.alias):
self.imported_classes[name.name] = from_

def _rewrite_input_args_to_constants(
self, method_def: Union[ast.FunctionDef, ast.AsyncFunctionDef]
) -> Union[ast.FunctionDef, ast.AsyncFunctionDef]:
"""Rewrite the arguments to a method.

For any `ast.Name` that requires an import convert it to an
`ast.Constant` instead. The actual class will be noted and imported
in a `TYPE_CHECKING` context.

:param method_def: Method definition
:returns: The same definition but updated
"""
if not isinstance(method_def, (ast.FunctionDef, ast.AsyncFunctionDef)):
return method_def

for i, input_arg in enumerate(method_def.args.args):
annotation = input_arg.annotation
if isinstance(annotation, (ast.Name, ast.Subscript, ast.Tuple)):
method_def.args.args[i].annotation = self._update_name_to_constant(
annotation
)

return method_def

def _insert_import_statement_in_method(
self, method_def: Union[ast.FunctionDef, ast.AsyncFunctionDef]
):
"""Insert import statement in method.

Each method will eventually pass the returned value to a class we've
generated. Since we only need it in the scope of the method ensure we
add it at the top of the method only. It will be removed from the global
scope.

:param method_def: The method definition to updated
"""
# Find the last statement in the body, the call to this class is
# what we need to import first.
return_stmt = method_def.body[-1]
if isinstance(return_stmt, ast.Return):
call = self._get_call_arg_from_return(return_stmt)
elif isinstance(return_stmt, ast.AsyncFor):
call = self._get_call_arg_from_async_for(return_stmt)
else:
return

if call is None:
return

import_class = self._get_class_from_call(call)
if import_class is None:
return

import_class_id = import_class.id

# We add the class to our set of imported in methods - these classes
# don't need to be imported at all in the global scope.
self.imported_in_method.add(import_class.id)
method_def.body.insert(
0,
ast.ImportFrom(
module=self.imported_classes[import_class_id],
names=[import_class],
),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I developed this plugin outside of ariadne-codegen so I didn't use any helpers to generate these. Tbh I don't think it helps, it just makes it tricker to read.

Let me know if you'd rather me to import and use all the generators even for plugins or if I can keep this as is.

)

def _get_call_arg_from_return(self, return_stmt: ast.Return) -> Optional[ast.Call]:
"""Get the class used in the return statement.

:param return_stmt: The statement used for return
"""
# If it's a call of the class like produced by
# `ShorterResultsPlugin` we have an attribute.
if isinstance(return_stmt.value, ast.Attribute) and isinstance(
return_stmt.value.value, ast.Call
):
return return_stmt.value.value

# If not it's just a call statement to the generated class.
if isinstance(return_stmt.value, ast.Call):
return return_stmt.value

return None

def _get_call_arg_from_async_for(
self, last_stmt: ast.AsyncFor
) -> Optional[ast.Call]:
"""Get the class used in the yield expression.

:param last_stmt: The statement used in `ast.AsyncFor`
"""
if isinstance(last_stmt.body, list) and isinstance(last_stmt.body[0], ast.Expr):
body = last_stmt.body[0]
elif isinstance(last_stmt.body, ast.Expr):
body = last_stmt.body
else:
return None

if not isinstance(body, ast.Expr):
return None

if not isinstance(body.value, ast.Yield):
return None

# If it's a call of the class like produced by
# `ShorterResultsPlugin` we have an attribute.
if isinstance(body.value.value, ast.Attribute) and isinstance(
body.value.value.value, ast.Call
):
return body.value.value.value

# If not it's just a call statement to the generated class.
if isinstance(body.value.value, ast.Call):
return body.value.value

return None

def _get_class_from_call(self, call: ast.Call) -> Optional[ast.Name]:
"""Get the class from an `ast.Call`.

:param call: The `ast.Call` arg
:returns: `ast.Name` or `None`
"""
if not isinstance(call.func, ast.Attribute):
return None

if not isinstance(call.func.value, ast.Name):
return None

return call.func.value

def _update_imports(self, module: ast.Module):
"""Update all imports.

Iterate over all imports and remove the aliases that we use as input or
return value. These will be moved and added to an `if TYPE_CHECKING`
block.

**NOTE** If an `ast.ImportFrom` ends up without any names we must remove
it completely otherwise formatting will not work (it would remove the
empty `import from` but not format the rest of the code without running
it twice).

We do this by storing all imports that we want to keep in an array, we
then drop all from the body and re-insert the ones to keep. Lastly we
import `TYPE_CHECKING` and add all our imports in the `if TYPE_CHECKING`
block.

:param module: The ast for the whole module.
"""
# We now know all our input types and all our return types. The return
# types that are _not_ used as import types should be in an `if
# TYPE_CHECKING` import block.
return_types_not_used_as_input = set(self.input_and_return_types)

# The ones we import in the method don't need to be imported at all -
# unless that's the type we return. This behaviour can differ if you use
# a plugin such as `ShorterResultsPlugin` that will import a type that
# is different from the type returned.
return_types_not_used_as_input.update(
{k for k in self.imported_in_method if k not in self.input_and_return_types}
)

if len(return_types_not_used_as_input) == 0:
return None

# We sadly have to iterate over all imports again and remove the imports
# we will do conditionally.
# It's very important that we get this right, if we keep any
# `ImportFrom` that ends up without any names, the formatting will not
# work! It will only remove the empty `import from` but not other unused
# imports.
non_empty_imports: List[Union[ast.Import, ast.ImportFrom]] = []
last_import_at = 0
for i, node in enumerate(module.body):
if isinstance(node, ast.Import):
last_import_at = i
non_empty_imports.append(node)

if not isinstance(node, ast.ImportFrom):
continue

last_import_at = i
reduced_names = []
for name in node.names:
if name.name not in return_types_not_used_as_input:
reduced_names.append(name)

node.names = reduced_names

if len(reduced_names) > 0:
non_empty_imports.append(node)

# We can now remove all imports and re-insert the ones that's not empty.
module.body = non_empty_imports + module.body[last_import_at + 1 :]

# Create import to use for type checking. These will be put in an `if
# TYPE_CHECKING` block.
type_checking_imports = {}
for cls in self.input_and_return_types:
module_name = self.imported_classes[cls]
if module_name not in type_checking_imports:
type_checking_imports[module_name] = ast.ImportFrom(
module=module_name, names=[]
)

type_checking_imports[module_name].names.append(ast.alias(cls))

import_if_type_checking = ast.If(
test=ast.Name(id=TYPE_CHECKING_FLAG),
body=list(type_checking_imports.values()),
orelse=[],
)

module.body.insert(len(non_empty_imports), import_if_type_checking)

# Import `TYPE_CHECKING`.
module.body.insert(
len(non_empty_imports),
ast.ImportFrom(
module=TYPE_CHECKING_MODULE,
names=[ast.Name(TYPE_CHECKING_FLAG)],
),
)

return None

def _update_name_to_constant(self, node: ast.expr) -> ast.expr:
"""Update return types.

If the return type contains any type that resolves to an `ast.Name`,
convert it to an `ast.Constant`. We only need the type for type checking
and can avoid importing the type in the global scope unless needed.

:param node: The ast node used as return type
:returns: A modified ast node
"""
if isinstance(node, ast.Name):
if node.id in self.imported_classes:
self.input_and_return_types.add(node.id)
return ast.Constant(value=node.id)

if isinstance(node, ast.Subscript):
node.slice = self._update_name_to_constant(node.slice)
return node

if isinstance(node, ast.Tuple):
for i, _ in enumerate(node.elts):
node.elts[i] = self._update_name_to_constant(node.elts[i])

return node

return node
Loading
Loading