From 831c273961c8b0c3e7fcc455a1e499686351c471 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sopi=C5=84ski?= Date: Wed, 22 Nov 2023 15:28:02 +0100 Subject: [PATCH 1/4] Change InputTypesGenerator to generate module with only provided list of inputs --- .../client_generators/input_types.py | 47 ++++++++++-- .../test_filtering_names.py | 72 +++++++++++++++++++ 2 files changed, 113 insertions(+), 6 deletions(-) create mode 100644 tests/client_generators/input_types_generator/test_filtering_names.py diff --git a/ariadne_codegen/client_generators/input_types.py b/ariadne_codegen/client_generators/input_types.py index e8612d79..1c5fcd58 100644 --- a/ariadne_codegen/client_generators/input_types.py +++ b/ariadne_codegen/client_generators/input_types.py @@ -70,8 +70,9 @@ def __init__( self._class_defs: List[ast.ClassDef] = [ self._parse_input_definition(d) for d in self._filter_input_types() ] + self._generated_public_names: List[str] = [] - def generate(self) -> ast.Module: + def generate(self, types_to_include: Optional[List[str]] = None) -> ast.Module: if self._used_enums: self._imports.append( generate_import_from(self._used_enums, self.enums_module, 1) @@ -81,16 +82,19 @@ def generate(self) -> ast.Module: scalar_data = self.custom_scalars[scalar_name] self._imports.extend(generate_scalar_imports(scalar_data)) + class_defs = self._filter_class_defs(types_to_include=types_to_include) + self._generated_public_names = [class_def.name for class_def in class_defs] module_body = cast(List[ast.stmt], self._imports) + cast( - List[ast.stmt], self._class_defs + List[ast.stmt], class_defs ) module = generate_module(body=module_body) + if self.plugin_manager: module = self.plugin_manager.generate_inputs_module(module) return module def get_generated_public_names(self) -> List[str]: - return [c.name for c in self._class_defs] + return self._generated_public_names def _filter_input_types(self) -> List[GraphQLInputObjectType]: return [ @@ -100,6 +104,35 @@ def _filter_input_types(self) -> List[GraphQLInputObjectType]: and not name.startswith("__") ] + def _filter_class_defs( + self, types_to_include: Optional[List[str]] = None + ) -> List[ast.ClassDef]: + if types_to_include is None: + return self._class_defs + + types_names = set() + for name in types_to_include: + types_names.update(self._get_dependencies_of_type(name)) + + return [ + class_def for class_def in self._class_defs if class_def.name in types_names + ] + + def _get_dependencies_of_type(self, type_name: str) -> List[str]: + visited = set() + result = [] + + def dfs(node): + if node not in visited: + visited.add(node) + result.append(node) + + for neighbor in self._dependencies[node]: + dfs(neighbor) + + dfs(type_name) + return result + def _parse_input_definition( self, definition: GraphQLInputObjectType ) -> ast.ClassDef: @@ -137,7 +170,7 @@ def _parse_input_definition( field_implementation, input_field=field, field_name=org_name ) class_def.body.append(field_implementation) - self._save_used_enums_and_scalars(field_type=field_type) + self._save_dependencies(root_type=definition.name, field_type=field_type) if self.plugin_manager: class_def = self.plugin_manager.generate_input_class( @@ -167,10 +200,12 @@ def _process_field_value( ) return field_with_alias - def _save_used_enums_and_scalars(self, field_type: str = "") -> None: + def _save_dependencies(self, root_type: str, field_type: str = "") -> None: if not field_type: return - if isinstance(self.schema.type_map[field_type], GraphQLEnumType): + if isinstance(self.schema.type_map[field_type], GraphQLInputObjectType): + self._dependencies[root_type].append(field_type) + elif isinstance(self.schema.type_map[field_type], GraphQLEnumType): self._used_enums.append(field_type) elif isinstance(self.schema.type_map[field_type], GraphQLScalarType): self._used_scalars.append(field_type) diff --git a/tests/client_generators/input_types_generator/test_filtering_names.py b/tests/client_generators/input_types_generator/test_filtering_names.py new file mode 100644 index 00000000..14106564 --- /dev/null +++ b/tests/client_generators/input_types_generator/test_filtering_names.py @@ -0,0 +1,72 @@ +import ast + +import pytest +from graphql import build_ast_schema, parse + +from ariadne_codegen.client_generators.input_types import InputTypesGenerator + + +@pytest.mark.parametrize( + "used_types, expected_classes", + [ + ( + None, + ["InputA", "InputAA", "InputAAA", "InputAB", "InputX", "InputY", "InputZ"], + ), + (["InputA"], ["InputA", "InputAA", "InputAAA", "InputAB"]), + (["InputAA"], ["InputAA", "InputAAA"]), + (["InputX"], ["InputX", "InputY", "InputZ"]), + ( + ["InputA", "InputX"], + ["InputA", "InputAA", "InputAAA", "InputAB", "InputX", "InputY", "InputZ"], + ), + (["InputAB"], ["InputA", "InputAA", "InputAAA", "InputAB"]), + (["InputAAA", "InputZ"], ["InputAAA", "InputZ"]), + ( + ["InputA", "InputA", "InputA", "InputAA", "InputAAA"], + ["InputA", "InputAA", "InputAAA", "InputAB"], + ), + ], +) +def test_generator_returns_module_with_filtered_classes(used_types, expected_classes): + schema_str = """ + input InputA { + valueAA: InputAA! + valueAB: InputAB + } + + input InputAA { + valueAAA: InputAAA! + } + + input InputAAA { + val: String! + } + + input InputAB { + val: String! + valueA: InputA + } + + input InputX { + valueY: InputY + } + + input InputY { + valueZ: InputZ + } + + input InputZ { + val: String + } + """ + + generator = InputTypesGenerator(schema=build_ast_schema(parse(schema_str))) + + module = generator.generate(used_types) + + assert [ + class_def.name + for class_def in module.body + if isinstance(class_def, ast.ClassDef) + ] == expected_classes From 85bbd895b01a935a7fdfe7bded18a0d58ec99a71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sopi=C5=84ski?= Date: Wed, 22 Nov 2023 15:28:17 +0100 Subject: [PATCH 2/4] Add include_all_inputs config option --- README.md | 1 + ariadne_codegen/client_generators/package.py | 10 +++++++++- ariadne_codegen/settings.py | 1 + 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 56fb84b5..18eaef1d 100644 --- a/README.md +++ b/README.md @@ -75,6 +75,7 @@ Optional settings: - `fragments_module_name` (defaults to `"fragments"`) - name of file with generated fragments models - `include_comments` (defaults to `"stable"`) - option which sets content of comments included at the top of every generated file. Valid choices are: `"none"` (no comments), `"timestamp"` (comment with generation timestamp), `"stable"` (comment contains a message that this is a generated file) - `convert_to_snake_case` (defaults to `true`) - a flag that specifies whether to convert fields and arguments names to snake case +- `include_all_inputs` (defaults to `true`) - a flag specifying whether to include all inputs defined in the schema, or only those used in supplied operations - `async_client` (defaults to `true`) - default generated client is `async`, change this to option `false` to generate synchronous client instead - `opentelemetry_client` (defaults to `false`) - default base clients don't support any performance tracing. Change this option to `true` to use the base client with Open Telemetry support. - `files_to_include` (defaults to `[]`) - list of files which will be copied into generated package diff --git a/ariadne_codegen/client_generators/package.py b/ariadne_codegen/client_generators/package.py index 47c349ec..0536b8b7 100644 --- a/ariadne_codegen/client_generators/package.py +++ b/ariadne_codegen/client_generators/package.py @@ -58,6 +58,7 @@ def __init__( queries_source: str = "", schema_source: str = "", convert_to_snake_case: bool = True, + include_all_inputs: bool = True, base_model_file_path: str = BASE_MODEL_FILE_PATH.as_posix(), base_model_import: ast.ImportFrom = BASE_MODEL_IMPORT, upload_import: ast.ImportFrom = UPLOAD_IMPORT, @@ -94,6 +95,7 @@ def __init__( self.schema_source = schema_source self.convert_to_snake_case = convert_to_snake_case + self.include_all_inputs = include_all_inputs self.base_model_file_path = Path(base_model_file_path) self.base_model_import = base_model_import @@ -242,7 +244,12 @@ def _generate_enums(self): ) def _generate_input_types(self): - module = self.input_types_generator.generate() + if self.include_all_inputs: + module = self.input_types_generator.generate() + else: + used_inputs = self.client_generator.arguments_generator.get_used_inputs() + module = self.input_types_generator.generate(types_to_include=used_inputs) + input_types_file_path = self.package_path / f"{self.input_types_module_name}.py" code = self._add_comments_to_code(ast_to_str(module), self.schema_source) if self.plugin_manager: @@ -388,6 +395,7 @@ def get_package_generator( queries_source=settings.queries_path, schema_source=settings.schema_source, convert_to_snake_case=settings.convert_to_snake_case, + include_all_inputs=settings.include_all_inputs, base_model_file_path=BASE_MODEL_FILE_PATH.as_posix(), base_model_import=BASE_MODEL_IMPORT, upload_import=UPLOAD_IMPORT, diff --git a/ariadne_codegen/settings.py b/ariadne_codegen/settings.py index 64fef2c5..035a8c89 100644 --- a/ariadne_codegen/settings.py +++ b/ariadne_codegen/settings.py @@ -65,6 +65,7 @@ class ClientSettings(BaseSettings): fragments_module_name: str = "fragments" include_comments: CommentsStrategy = field(default=CommentsStrategy.STABLE) convert_to_snake_case: bool = True + include_all_inputs: bool = True async_client: bool = True opentelemetry_client: bool = False files_to_include: List[str] = field(default_factory=list) From abb299a3f844b2df1df08082d8325bfa9b9f9070 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sopi=C5=84ski?= Date: Wed, 22 Nov 2023 15:28:56 +0100 Subject: [PATCH 3/4] Add e2e test with include_all_inputs set to false --- .../expected_client/__init__.py | 33 ++ .../expected_client/async_base_client.py | 318 ++++++++++++++++++ .../expected_client/base_model.py | 27 ++ .../expected_client/client.py | 52 +++ .../only_used_inputs/expected_client/enums.py | 0 .../expected_client/exceptions.py | 79 +++++ .../only_used_inputs/expected_client/get_a.py | 5 + .../expected_client/get_a_2.py | 5 + .../only_used_inputs/expected_client/get_b.py | 5 + .../expected_client/input_types.py | 23 ++ .../clients/only_used_inputs/pyproject.toml | 6 + .../clients/only_used_inputs/queries.graphql | 11 + .../clients/only_used_inputs/schema.graphql | 35 ++ tests/main/test_main.py | 11 + 14 files changed, 610 insertions(+) create mode 100644 tests/main/clients/only_used_inputs/expected_client/__init__.py create mode 100644 tests/main/clients/only_used_inputs/expected_client/async_base_client.py create mode 100644 tests/main/clients/only_used_inputs/expected_client/base_model.py create mode 100644 tests/main/clients/only_used_inputs/expected_client/client.py create mode 100644 tests/main/clients/only_used_inputs/expected_client/enums.py create mode 100644 tests/main/clients/only_used_inputs/expected_client/exceptions.py create mode 100644 tests/main/clients/only_used_inputs/expected_client/get_a.py create mode 100644 tests/main/clients/only_used_inputs/expected_client/get_a_2.py create mode 100644 tests/main/clients/only_used_inputs/expected_client/get_b.py create mode 100644 tests/main/clients/only_used_inputs/expected_client/input_types.py create mode 100644 tests/main/clients/only_used_inputs/pyproject.toml create mode 100644 tests/main/clients/only_used_inputs/queries.graphql create mode 100644 tests/main/clients/only_used_inputs/schema.graphql diff --git a/tests/main/clients/only_used_inputs/expected_client/__init__.py b/tests/main/clients/only_used_inputs/expected_client/__init__.py new file mode 100644 index 00000000..605e9f4b --- /dev/null +++ b/tests/main/clients/only_used_inputs/expected_client/__init__.py @@ -0,0 +1,33 @@ +from .async_base_client import AsyncBaseClient +from .base_model import BaseModel, Upload +from .client import Client +from .exceptions import ( + GraphQLClientError, + GraphQLClientGraphQLError, + GraphQLClientGraphQLMultiError, + GraphQLClientHttpError, + GraphQlClientInvalidResponseError, +) +from .get_a import GetA +from .get_a_2 import GetA2 +from .get_b import GetB +from .input_types import InputA, InputAA, InputAAA, InputAB + +__all__ = [ + "AsyncBaseClient", + "BaseModel", + "Client", + "GetA", + "GetA2", + "GetB", + "GraphQLClientError", + "GraphQLClientGraphQLError", + "GraphQLClientGraphQLMultiError", + "GraphQLClientHttpError", + "GraphQlClientInvalidResponseError", + "InputA", + "InputAA", + "InputAAA", + "InputAB", + "Upload", +] diff --git a/tests/main/clients/only_used_inputs/expected_client/async_base_client.py b/tests/main/clients/only_used_inputs/expected_client/async_base_client.py new file mode 100644 index 00000000..25618471 --- /dev/null +++ b/tests/main/clients/only_used_inputs/expected_client/async_base_client.py @@ -0,0 +1,318 @@ +import enum +import json +from typing import IO, Any, AsyncIterator, Dict, List, Optional, Tuple, TypeVar, cast +from uuid import uuid4 + +import httpx +from pydantic import BaseModel +from pydantic_core import to_jsonable_python + +from .base_model import UNSET, Upload +from .exceptions import ( + GraphQLClientGraphQLMultiError, + GraphQLClientHttpError, + GraphQLClientInvalidMessageFormat, + GraphQlClientInvalidResponseError, +) + +try: + from websockets.client import WebSocketClientProtocol, connect as ws_connect + from websockets.typing import Data, Origin, Subprotocol +except ImportError: + from contextlib import asynccontextmanager + + @asynccontextmanager # type: ignore + async def ws_connect(*args, **kwargs): # pylint: disable=unused-argument + raise NotImplementedError("Subscriptions require 'websockets' package.") + yield # pylint: disable=unreachable + + WebSocketClientProtocol = Any # type: ignore + Data = Any # type: ignore + Origin = Any # type: ignore + + def Subprotocol(*args, **kwargs): # type: ignore # pylint: disable=invalid-name + raise NotImplementedError("Subscriptions require 'websockets' package.") + + +Self = TypeVar("Self", bound="AsyncBaseClient") + +GRAPHQL_TRANSPORT_WS = "graphql-transport-ws" + + +class GraphQLTransportWSMessageType(str, enum.Enum): + CONNECTION_INIT = "connection_init" + CONNECTION_ACK = "connection_ack" + PING = "ping" + PONG = "pong" + SUBSCRIBE = "subscribe" + NEXT = "next" + ERROR = "error" + COMPLETE = "complete" + + +class AsyncBaseClient: + def __init__( + self, + url: str = "", + headers: Optional[Dict[str, str]] = None, + http_client: Optional[httpx.AsyncClient] = None, + ws_url: str = "", + ws_headers: Optional[Dict[str, Any]] = None, + ws_origin: Optional[str] = None, + ws_connection_init_payload: Optional[Dict[str, Any]] = None, + ) -> None: + self.url = url + self.headers = headers + self.http_client = ( + http_client if http_client else httpx.AsyncClient(headers=headers) + ) + + self.ws_url = ws_url + self.ws_headers = ws_headers or {} + self.ws_origin = Origin(ws_origin) if ws_origin else None + self.ws_connection_init_payload = ws_connection_init_payload + + async def __aenter__(self: Self) -> Self: + return self + + async def __aexit__( + self, + exc_type: object, + exc_val: object, + exc_tb: object, + ) -> None: + await self.http_client.aclose() + + async def execute( + self, query: str, variables: Optional[Dict[str, Any]] = None, **kwargs: Any + ) -> httpx.Response: + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart( + query=query, + variables=processed_variables, + files=files, + files_map=files_map, + **kwargs, + ) + + return await self._execute_json( + query=query, variables=processed_variables, **kwargs + ) + + def get_data(self, response: httpx.Response) -> Dict[str, Any]: + if not response.is_success: + raise GraphQLClientHttpError( + status_code=response.status_code, response=response + ) + + try: + response_json = response.json() + except ValueError as exc: + raise GraphQlClientInvalidResponseError(response=response) from exc + + if (not isinstance(response_json, dict)) or ("data" not in response_json): + raise GraphQlClientInvalidResponseError(response=response) + + data = response_json["data"] + errors = response_json.get("errors") + + if errors: + raise GraphQLClientGraphQLMultiError.from_errors_dicts( + errors_dicts=errors, data=data + ) + + return cast(Dict[str, Any], data) + + async def execute_ws( + self, query: str, variables: Optional[Dict[str, Any]] = None, **kwargs: Any + ) -> AsyncIterator[Dict[str, Any]]: + headers = self.ws_headers.copy() + headers.update(kwargs.get("extra_headers", {})) + + merged_kwargs: Dict[str, Any] = {"origin": self.ws_origin} + merged_kwargs.update(kwargs) + merged_kwargs["extra_headers"] = headers + + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + **merged_kwargs, + ) as websocket: + await self._send_connection_init(websocket) + await self._send_subscribe( + websocket, + operation_id=operation_id, + query=query, + variables=variables, + ) + + async for message in websocket: + data = await self._handle_ws_message(message, websocket) + if data: + yield data + + def _process_variables( + self, variables: Optional[Dict[str, Any]] + ) -> Tuple[ + Dict[str, Any], Dict[str, Tuple[str, IO[bytes], str]], Dict[str, List[str]] + ]: + if not variables: + return {}, {}, {} + + serializable_variables = self._convert_dict_to_json_serializable(variables) + return self._get_files_from_variables(serializable_variables) + + def _convert_dict_to_json_serializable( + self, dict_: Dict[str, Any] + ) -> Dict[str, Any]: + return { + key: self._convert_value(value) + for key, value in dict_.items() + if value is not UNSET + } + + def _convert_value(self, value: Any) -> Any: + if isinstance(value, BaseModel): + return value.model_dump(by_alias=True, exclude_unset=True) + if isinstance(value, list): + return [self._convert_value(item) for item in value] + return value + + def _get_files_from_variables( + self, variables: Dict[str, Any] + ) -> Tuple[ + Dict[str, Any], Dict[str, Tuple[str, IO[bytes], str]], Dict[str, List[str]] + ]: + files_map: Dict[str, List[str]] = {} + files_list: List[Upload] = [] + + def separate_files(path: str, obj: Any) -> Any: + if isinstance(obj, list): + nulled_list = [] + for index, value in enumerate(obj): + value = separate_files(f"{path}.{index}", value) + nulled_list.append(value) + return nulled_list + + if isinstance(obj, dict): + nulled_dict = {} + for key, value in obj.items(): + value = separate_files(f"{path}.{key}", value) + nulled_dict[key] = value + return nulled_dict + + if isinstance(obj, Upload): + if obj in files_list: + file_index = files_list.index(obj) + files_map[str(file_index)].append(path) + else: + file_index = len(files_list) + files_list.append(obj) + files_map[str(file_index)] = [path] + return None + + return obj + + nulled_variables = separate_files("variables", variables) + files: Dict[str, Tuple[str, IO[bytes], str]] = { + str(i): (file_.filename, cast(IO[bytes], file_.content), file_.content_type) + for i, file_ in enumerate(files_list) + } + return nulled_variables, files, files_map + + async def _execute_multipart( + self, + query: str, + variables: Dict[str, Any], + files: Dict[str, Tuple[str, IO[bytes], str]], + files_map: Dict[str, List[str]], + **kwargs: Any, + ) -> httpx.Response: + data = { + "operations": json.dumps( + {"query": query, "variables": variables}, default=to_jsonable_python + ), + "map": json.dumps(files_map, default=to_jsonable_python), + } + + return await self.http_client.post( + url=self.url, data=data, files=files, **kwargs + ) + + async def _execute_json( + self, query: str, variables: Dict[str, Any], **kwargs: Any + ) -> httpx.Response: + headers: Dict[str, str] = {"Content-Type": "application/json"} + headers.update(kwargs.get("headers", {})) + + merged_kwargs: Dict[str, Any] = kwargs.copy() + merged_kwargs["headers"] = headers + + return await self.http_client.post( + url=self.url, + content=json.dumps( + {"query": query, "variables": variables}, default=to_jsonable_python + ), + **merged_kwargs, + ) + + async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: + payload: Dict[str, Any] = { + "type": GraphQLTransportWSMessageType.CONNECTION_INIT.value + } + if self.ws_connection_init_payload: + payload["payload"] = self.ws_connection_init_payload + await websocket.send(json.dumps(payload)) + + async def _send_subscribe( + self, + websocket: WebSocketClientProtocol, + operation_id: str, + query: str, + variables: Optional[Dict[str, Any]] = None, + ) -> None: + payload: Dict[str, Any] = { + "id": operation_id, + "type": GraphQLTransportWSMessageType.SUBSCRIBE.value, + "payload": {"query": query}, + } + if variables: + payload["payload"]["variables"] = self._convert_dict_to_json_serializable( + variables + ) + await websocket.send(json.dumps(payload)) + + async def _handle_ws_message( + self, message: Data, websocket: WebSocketClientProtocol + ) -> Optional[Dict[str, Any]]: + try: + message_dict = json.loads(message) + except json.JSONDecodeError as exc: + raise GraphQLClientInvalidMessageFormat(message=message) from exc + + type_ = message_dict.get("type") + payload = message_dict.get("payload", {}) + + if not type_ or type_ not in {t.value for t in GraphQLTransportWSMessageType}: + raise GraphQLClientInvalidMessageFormat(message=message) + + if type_ == GraphQLTransportWSMessageType.NEXT: + if "data" not in payload: + raise GraphQLClientInvalidMessageFormat(message=message) + return cast(Dict[str, Any], payload["data"]) + + if type_ == GraphQLTransportWSMessageType.COMPLETE: + await websocket.close() + elif type_ == GraphQLTransportWSMessageType.PING: + await websocket.send( + json.dumps({"type": GraphQLTransportWSMessageType.PONG.value}) + ) + elif type_ == GraphQLTransportWSMessageType.ERROR: + raise GraphQLClientGraphQLMultiError.from_errors_dicts( + errors_dicts=payload, data=message_dict + ) + + return None diff --git a/tests/main/clients/only_used_inputs/expected_client/base_model.py b/tests/main/clients/only_used_inputs/expected_client/base_model.py new file mode 100644 index 00000000..ccde3975 --- /dev/null +++ b/tests/main/clients/only_used_inputs/expected_client/base_model.py @@ -0,0 +1,27 @@ +from io import IOBase + +from pydantic import BaseModel as PydanticBaseModel, ConfigDict + + +class UnsetType: + def __bool__(self) -> bool: + return False + + +UNSET = UnsetType() + + +class BaseModel(PydanticBaseModel): + model_config = ConfigDict( + populate_by_name=True, + validate_assignment=True, + arbitrary_types_allowed=True, + protected_namespaces=(), + ) + + +class Upload: + def __init__(self, filename: str, content: IOBase, content_type: str): + self.filename = filename + self.content = content + self.content_type = content_type diff --git a/tests/main/clients/only_used_inputs/expected_client/client.py b/tests/main/clients/only_used_inputs/expected_client/client.py new file mode 100644 index 00000000..621bd97e --- /dev/null +++ b/tests/main/clients/only_used_inputs/expected_client/client.py @@ -0,0 +1,52 @@ +from typing import Any, Dict + +from .async_base_client import AsyncBaseClient +from .get_a import GetA +from .get_a_2 import GetA2 +from .get_b import GetB +from .input_types import InputA, InputAB + + +def gql(q: str) -> str: + return q + + +class Client(AsyncBaseClient): + async def get_a(self, arg_a: InputA, **kwargs: Any) -> GetA: + query = gql( + """ + query getA($argA: InputA!) { + a(argA: $argA) + } + """ + ) + variables: Dict[str, object] = {"argA": arg_a} + response = await self.execute(query=query, variables=variables, **kwargs) + data = self.get_data(response) + return GetA.model_validate(data) + + async def get_a_2(self, arg_a: InputA, **kwargs: Any) -> GetA2: + query = gql( + """ + query getA2($argA: InputA!) { + a(argA: $argA) + } + """ + ) + variables: Dict[str, object] = {"argA": arg_a} + response = await self.execute(query=query, variables=variables, **kwargs) + data = self.get_data(response) + return GetA2.model_validate(data) + + async def get_b(self, arg_aa: InputAB, **kwargs: Any) -> GetB: + query = gql( + """ + query getB($argAA: InputAB!) { + b(argAA: $argAA) + } + """ + ) + variables: Dict[str, object] = {"argAA": arg_aa} + response = await self.execute(query=query, variables=variables, **kwargs) + data = self.get_data(response) + return GetB.model_validate(data) diff --git a/tests/main/clients/only_used_inputs/expected_client/enums.py b/tests/main/clients/only_used_inputs/expected_client/enums.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/main/clients/only_used_inputs/expected_client/exceptions.py b/tests/main/clients/only_used_inputs/expected_client/exceptions.py new file mode 100644 index 00000000..97a44022 --- /dev/null +++ b/tests/main/clients/only_used_inputs/expected_client/exceptions.py @@ -0,0 +1,79 @@ +from typing import Any, Dict, List, Optional, Union + +import httpx + + +class GraphQLClientError(Exception): + """Base exception.""" + + +class GraphQLClientHttpError(GraphQLClientError): + def __init__(self, status_code: int, response: httpx.Response) -> None: + self.status_code = status_code + self.response = response + + def __str__(self) -> str: + return f"HTTP status code: {self.status_code}" + + +class GraphQlClientInvalidResponseError(GraphQLClientError): + def __init__(self, response: httpx.Response) -> None: + self.response = response + + def __str__(self) -> str: + return "Invalid response format." + + +class GraphQLClientGraphQLError(GraphQLClientError): + def __init__( + self, + message: str, + locations: Optional[List[Dict[str, int]]] = None, + path: Optional[List[str]] = None, + extensions: Optional[Dict[str, object]] = None, + orginal: Optional[Dict[str, object]] = None, + ): + self.message = message + self.locations = locations + self.path = path + self.extensions = extensions + self.orginal = orginal + + def __str__(self) -> str: + return self.message + + @classmethod + def from_dict(cls, error: Dict[str, Any]) -> "GraphQLClientGraphQLError": + return cls( + message=error["message"], + locations=error.get("locations"), + path=error.get("path"), + extensions=error.get("extensions"), + orginal=error, + ) + + +class GraphQLClientGraphQLMultiError(GraphQLClientError): + def __init__(self, errors: List[GraphQLClientGraphQLError], data: Dict[str, Any]): + self.errors = errors + self.data = data + + def __str__(self) -> str: + return "; ".join(str(e) for e in self.errors) + + @classmethod + def from_errors_dicts( + cls, errors_dicts: List[Dict[str, Any]], data: Dict[str, Any] + ) -> "GraphQLClientGraphQLMultiError": + return cls( + errors=[GraphQLClientGraphQLError.from_dict(e) for e in errors_dicts], + data=data, + ) + + +class GraphQLClientInvalidMessageFormat(GraphQLClientError): + def __init__(self, message: Union[str, bytes]) -> None: + self.message = message + + def __str__(self) -> str: + return "Invalid message format." diff --git a/tests/main/clients/only_used_inputs/expected_client/get_a.py b/tests/main/clients/only_used_inputs/expected_client/get_a.py new file mode 100644 index 00000000..4aa89ddc --- /dev/null +++ b/tests/main/clients/only_used_inputs/expected_client/get_a.py @@ -0,0 +1,5 @@ +from .base_model import BaseModel + + +class GetA(BaseModel): + a: str diff --git a/tests/main/clients/only_used_inputs/expected_client/get_a_2.py b/tests/main/clients/only_used_inputs/expected_client/get_a_2.py new file mode 100644 index 00000000..1a9211e4 --- /dev/null +++ b/tests/main/clients/only_used_inputs/expected_client/get_a_2.py @@ -0,0 +1,5 @@ +from .base_model import BaseModel + + +class GetA2(BaseModel): + a: str diff --git a/tests/main/clients/only_used_inputs/expected_client/get_b.py b/tests/main/clients/only_used_inputs/expected_client/get_b.py new file mode 100644 index 00000000..d6d8adf8 --- /dev/null +++ b/tests/main/clients/only_used_inputs/expected_client/get_b.py @@ -0,0 +1,5 @@ +from .base_model import BaseModel + + +class GetB(BaseModel): + b: str diff --git a/tests/main/clients/only_used_inputs/expected_client/input_types.py b/tests/main/clients/only_used_inputs/expected_client/input_types.py new file mode 100644 index 00000000..e1e8944e --- /dev/null +++ b/tests/main/clients/only_used_inputs/expected_client/input_types.py @@ -0,0 +1,23 @@ +from typing import Optional + +from pydantic import Field + +from .base_model import BaseModel + + +class InputA(BaseModel): + value_aa: "InputAA" = Field(alias="valueAA") + value_ab: Optional["InputAB"] = Field(alias="valueAB", default=None) + + +class InputAA(BaseModel): + value_aaa: "InputAAA" = Field(alias="valueAAA") + + +class InputAAA(BaseModel): + val: str + + +class InputAB(BaseModel): + val: str + value_a: Optional["InputA"] = Field(alias="valueA", default=None) diff --git a/tests/main/clients/only_used_inputs/pyproject.toml b/tests/main/clients/only_used_inputs/pyproject.toml new file mode 100644 index 00000000..a1a559d7 --- /dev/null +++ b/tests/main/clients/only_used_inputs/pyproject.toml @@ -0,0 +1,6 @@ +[tool.ariadne-codegen] +schema_path = "schema.graphql" +queries_path = "queries.graphql" +include_comments = "none" +include_all_inputs = false +target_package_name = "client_only_used_inputs" diff --git a/tests/main/clients/only_used_inputs/queries.graphql b/tests/main/clients/only_used_inputs/queries.graphql new file mode 100644 index 00000000..091793a5 --- /dev/null +++ b/tests/main/clients/only_used_inputs/queries.graphql @@ -0,0 +1,11 @@ +query getA($argA: InputA!) { + a(argA: $argA) +} + +query getA2($argA: InputA!) { + a(argA: $argA) +} + +query getB($argAA: InputAB!) { + b(argAA: $argAA) +} diff --git a/tests/main/clients/only_used_inputs/schema.graphql b/tests/main/clients/only_used_inputs/schema.graphql new file mode 100644 index 00000000..80c194b8 --- /dev/null +++ b/tests/main/clients/only_used_inputs/schema.graphql @@ -0,0 +1,35 @@ +type Query { + a(argA: InputA!): String! + b(argAA: InputAB!): String! + c(argX: InputX!): String! +} + +input InputA { + valueAA: InputAA! + valueAB: InputAB +} + +input InputAA { + valueAAA: InputAAA! +} + +input InputAAA { + val: String! +} + +input InputAB { + val: String! + valueA: InputA +} + +input InputX { + valueY: InputY +} + +input InputY { + valueZ: InputZ +} + +input InputZ { + val: String +} diff --git a/tests/main/test_main.py b/tests/main/test_main.py index 1a3d2635..71a672b5 100644 --- a/tests/main/test_main.py +++ b/tests/main/test_main.py @@ -175,6 +175,17 @@ def test_main_shows_version(): "client_with_operations", CLIENTS_PATH / "operations" / "expected_client", ), + ( + ( + CLIENTS_PATH / "only_used_inputs" / "pyproject.toml", + ( + CLIENTS_PATH / "only_used_inputs" / "queries.graphql", + CLIENTS_PATH / "only_used_inputs" / "schema.graphql", + ), + ), + "client_only_used_inputs", + CLIENTS_PATH / "only_used_inputs" / "expected_client", + ), ], indirect=["project_dir"], ) From 2cba392abcda23138e37848d86ff6ced52053a43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sopi=C5=84ski?= Date: Wed, 22 Nov 2023 15:31:14 +0100 Subject: [PATCH 4/4] Update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 78107688..0874178e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ - Removed `model_rebuild` calls for generated input, fragment and result models. - Added `NoReimportsPlugin` that makes the `__init__.py` of generated client package empty. +- Added `include_all_inputs` config flag to generate only inputs used in supplied operations. ## 0.10.0 (2023-11-15)