Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add include_all_inputs #243

Merged
merged 4 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

- Removed `model_rebuild` calls for generated input, fragment and result models.
- Added `NoReimportsPlugin` that makes the `__init__.py` of generated client package empty.
- Added `include_all_inputs` config flag to generate only inputs used in supplied operations.


## 0.10.0 (2023-11-15)
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ Optional settings:
- `fragments_module_name` (defaults to `"fragments"`) - name of file with generated fragments models
- `include_comments` (defaults to `"stable"`) - option which sets content of comments included at the top of every generated file. Valid choices are: `"none"` (no comments), `"timestamp"` (comment with generation timestamp), `"stable"` (comment contains a message that this is a generated file)
- `convert_to_snake_case` (defaults to `true`) - a flag that specifies whether to convert fields and arguments names to snake case
- `include_all_inputs` (defaults to `true`) - a flag specifying whether to include all inputs defined in the schema, or only those used in supplied operations
- `async_client` (defaults to `true`) - default generated client is `async`, change this to option `false` to generate synchronous client instead
- `opentelemetry_client` (defaults to `false`) - default base clients don't support any performance tracing. Change this option to `true` to use the base client with Open Telemetry support.
- `files_to_include` (defaults to `[]`) - list of files which will be copied into generated package
Expand Down
47 changes: 41 additions & 6 deletions ariadne_codegen/client_generators/input_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@ def __init__(
self._class_defs: List[ast.ClassDef] = [
self._parse_input_definition(d) for d in self._filter_input_types()
]
self._generated_public_names: List[str] = []

def generate(self) -> ast.Module:
def generate(self, types_to_include: Optional[List[str]] = None) -> ast.Module:
if self._used_enums:
self._imports.append(
generate_import_from(self._used_enums, self.enums_module, 1)
Expand All @@ -81,16 +82,19 @@ def generate(self) -> ast.Module:
scalar_data = self.custom_scalars[scalar_name]
self._imports.extend(generate_scalar_imports(scalar_data))

class_defs = self._filter_class_defs(types_to_include=types_to_include)
self._generated_public_names = [class_def.name for class_def in class_defs]
module_body = cast(List[ast.stmt], self._imports) + cast(
List[ast.stmt], self._class_defs
List[ast.stmt], class_defs
)
module = generate_module(body=module_body)

if self.plugin_manager:
module = self.plugin_manager.generate_inputs_module(module)
return module

def get_generated_public_names(self) -> List[str]:
return [c.name for c in self._class_defs]
return self._generated_public_names

def _filter_input_types(self) -> List[GraphQLInputObjectType]:
return [
Expand All @@ -100,6 +104,35 @@ def _filter_input_types(self) -> List[GraphQLInputObjectType]:
and not name.startswith("__")
]

def _filter_class_defs(
self, types_to_include: Optional[List[str]] = None
) -> List[ast.ClassDef]:
if types_to_include is None:
return self._class_defs

types_names = set()
for name in types_to_include:
types_names.update(self._get_dependencies_of_type(name))

return [
class_def for class_def in self._class_defs if class_def.name in types_names
]

def _get_dependencies_of_type(self, type_name: str) -> List[str]:
visited = set()
result = []

def dfs(node):
if node not in visited:
visited.add(node)
result.append(node)

for neighbor in self._dependencies[node]:
dfs(neighbor)

dfs(type_name)
return result

def _parse_input_definition(
self, definition: GraphQLInputObjectType
) -> ast.ClassDef:
Expand Down Expand Up @@ -137,7 +170,7 @@ def _parse_input_definition(
field_implementation, input_field=field, field_name=org_name
)
class_def.body.append(field_implementation)
self._save_used_enums_and_scalars(field_type=field_type)
self._save_dependencies(root_type=definition.name, field_type=field_type)

if self.plugin_manager:
class_def = self.plugin_manager.generate_input_class(
Expand Down Expand Up @@ -167,10 +200,12 @@ def _process_field_value(
)
return field_with_alias

def _save_used_enums_and_scalars(self, field_type: str = "") -> None:
def _save_dependencies(self, root_type: str, field_type: str = "") -> None:
if not field_type:
return
if isinstance(self.schema.type_map[field_type], GraphQLEnumType):
if isinstance(self.schema.type_map[field_type], GraphQLInputObjectType):
self._dependencies[root_type].append(field_type)
elif isinstance(self.schema.type_map[field_type], GraphQLEnumType):
self._used_enums.append(field_type)
elif isinstance(self.schema.type_map[field_type], GraphQLScalarType):
self._used_scalars.append(field_type)
10 changes: 9 additions & 1 deletion ariadne_codegen/client_generators/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(
queries_source: str = "",
schema_source: str = "",
convert_to_snake_case: bool = True,
include_all_inputs: bool = True,
base_model_file_path: str = BASE_MODEL_FILE_PATH.as_posix(),
base_model_import: ast.ImportFrom = BASE_MODEL_IMPORT,
upload_import: ast.ImportFrom = UPLOAD_IMPORT,
Expand Down Expand Up @@ -94,6 +95,7 @@ def __init__(
self.schema_source = schema_source

self.convert_to_snake_case = convert_to_snake_case
self.include_all_inputs = include_all_inputs

self.base_model_file_path = Path(base_model_file_path)
self.base_model_import = base_model_import
Expand Down Expand Up @@ -242,7 +244,12 @@ def _generate_enums(self):
)

def _generate_input_types(self):
module = self.input_types_generator.generate()
if self.include_all_inputs:
module = self.input_types_generator.generate()
else:
used_inputs = self.client_generator.arguments_generator.get_used_inputs()
module = self.input_types_generator.generate(types_to_include=used_inputs)

input_types_file_path = self.package_path / f"{self.input_types_module_name}.py"
code = self._add_comments_to_code(ast_to_str(module), self.schema_source)
if self.plugin_manager:
Expand Down Expand Up @@ -388,6 +395,7 @@ def get_package_generator(
queries_source=settings.queries_path,
schema_source=settings.schema_source,
convert_to_snake_case=settings.convert_to_snake_case,
include_all_inputs=settings.include_all_inputs,
base_model_file_path=BASE_MODEL_FILE_PATH.as_posix(),
base_model_import=BASE_MODEL_IMPORT,
upload_import=UPLOAD_IMPORT,
Expand Down
1 change: 1 addition & 0 deletions ariadne_codegen/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class ClientSettings(BaseSettings):
fragments_module_name: str = "fragments"
include_comments: CommentsStrategy = field(default=CommentsStrategy.STABLE)
convert_to_snake_case: bool = True
include_all_inputs: bool = True
async_client: bool = True
opentelemetry_client: bool = False
files_to_include: List[str] = field(default_factory=list)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import ast

import pytest
from graphql import build_ast_schema, parse

from ariadne_codegen.client_generators.input_types import InputTypesGenerator


@pytest.mark.parametrize(
"used_types, expected_classes",
[
(
None,
["InputA", "InputAA", "InputAAA", "InputAB", "InputX", "InputY", "InputZ"],
),
(["InputA"], ["InputA", "InputAA", "InputAAA", "InputAB"]),
(["InputAA"], ["InputAA", "InputAAA"]),
(["InputX"], ["InputX", "InputY", "InputZ"]),
(
["InputA", "InputX"],
["InputA", "InputAA", "InputAAA", "InputAB", "InputX", "InputY", "InputZ"],
),
(["InputAB"], ["InputA", "InputAA", "InputAAA", "InputAB"]),
(["InputAAA", "InputZ"], ["InputAAA", "InputZ"]),
(
["InputA", "InputA", "InputA", "InputAA", "InputAAA"],
["InputA", "InputAA", "InputAAA", "InputAB"],
),
],
)
def test_generator_returns_module_with_filtered_classes(used_types, expected_classes):
schema_str = """
input InputA {
valueAA: InputAA!
valueAB: InputAB
}

input InputAA {
valueAAA: InputAAA!
}

input InputAAA {
val: String!
}

input InputAB {
val: String!
valueA: InputA
}

input InputX {
valueY: InputY
}

input InputY {
valueZ: InputZ
}

input InputZ {
val: String
}
"""

generator = InputTypesGenerator(schema=build_ast_schema(parse(schema_str)))

module = generator.generate(used_types)

assert [
class_def.name
for class_def in module.body
if isinstance(class_def, ast.ClassDef)
] == expected_classes
33 changes: 33 additions & 0 deletions tests/main/clients/only_used_inputs/expected_client/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from .async_base_client import AsyncBaseClient
from .base_model import BaseModel, Upload
from .client import Client
from .exceptions import (
GraphQLClientError,
GraphQLClientGraphQLError,
GraphQLClientGraphQLMultiError,
GraphQLClientHttpError,
GraphQlClientInvalidResponseError,
)
from .get_a import GetA
from .get_a_2 import GetA2
from .get_b import GetB
from .input_types import InputA, InputAA, InputAAA, InputAB

__all__ = [
"AsyncBaseClient",
"BaseModel",
"Client",
"GetA",
"GetA2",
"GetB",
"GraphQLClientError",
"GraphQLClientGraphQLError",
"GraphQLClientGraphQLMultiError",
"GraphQLClientHttpError",
"GraphQlClientInvalidResponseError",
"InputA",
"InputAA",
"InputAAA",
"InputAB",
"Upload",
]
Loading