diff --git a/CHANGELOG.md b/CHANGELOG.md index 5a67a582..ab3788bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ - Added escaping of enum values which are Python keywords by appending `_` to them. - Fixed `enums_module_name` option not being passed to generators. - Added additional base clients supporting the Open Telemetry tracing. Added `opentelemetry_client` config option. +- Changed generated client's methods to pass `**kwargs` to base client's `execute` and `execute_ws` methods (breaking change for custom base clients). ## 0.9.0 (2023-09-11) diff --git a/EXAMPLE.md b/EXAMPLE.md index cf626e78..83f6ace6 100644 --- a/EXAMPLE.md +++ b/EXAMPLE.md @@ -180,7 +180,7 @@ Generated client class inherits from `AsyncBaseClient` and has async method for ```py # graphql_client/client.py -from typing import AsyncIterator, Optional, Union +from typing import Any, AsyncIterator, Dict, Optional, Union from .async_base_client import AsyncBaseClient from .base_model import UNSET, UnsetType, Upload @@ -197,7 +197,9 @@ def gql(q: str) -> str: class Client(AsyncBaseClient): - async def create_user(self, user_data: UserCreateInput) -> CreateUser: + async def create_user( + self, user_data: UserCreateInput, **kwargs: Any + ) -> CreateUser: query = gql( """ mutation CreateUser($userData: UserCreateInput!) { @@ -207,12 +209,12 @@ class Client(AsyncBaseClient): } """ ) - variables: dict[str, object] = {"userData": user_data} - response = await self.execute(query=query, variables=variables) + variables: Dict[str, object] = {"userData": user_data} + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return CreateUser.model_validate(data) - async def list_all_users(self) -> ListAllUsers: + async def list_all_users(self, **kwargs: Any) -> ListAllUsers: query = gql( """ query ListAllUsers { @@ -228,13 +230,13 @@ class Client(AsyncBaseClient): } """ ) - variables: dict[str, object] = {} - response = await self.execute(query=query, variables=variables) + variables: Dict[str, object] = {} + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return ListAllUsers.model_validate(data) async def list_users_by_country( - self, country: Union[Optional[str], UnsetType] = UNSET + self, country: Union[Optional[str], UnsetType] = UNSET, **kwargs: Any ) -> ListUsersByCountry: query = gql( """ @@ -257,12 +259,12 @@ class Client(AsyncBaseClient): } """ ) - variables: dict[str, object] = {"country": country} - response = await self.execute(query=query, variables=variables) + variables: Dict[str, object] = {"country": country} + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return ListUsersByCountry.model_validate(data) - async def get_users_counter(self) -> AsyncIterator[GetUsersCounter]: + async def get_users_counter(self, **kwargs: Any) -> AsyncIterator[GetUsersCounter]: query = gql( """ subscription GetUsersCounter { @@ -270,11 +272,11 @@ class Client(AsyncBaseClient): } """ ) - variables: dict[str, object] = {} - async for data in self.execute_ws(query=query, variables=variables): + variables: Dict[str, object] = {} + async for data in self.execute_ws(query=query, variables=variables, **kwargs): yield GetUsersCounter.model_validate(data) - async def upload_file(self, file: Upload) -> UploadFile: + async def upload_file(self, file: Upload, **kwargs: Any) -> UploadFile: query = gql( """ mutation uploadFile($file: Upload!) { @@ -282,8 +284,8 @@ class Client(AsyncBaseClient): } """ ) - variables: dict[str, object] = {"file": file} - response = await self.execute(query=query, variables=variables) + variables: Dict[str, object] = {"file": file} + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return UploadFile.model_validate(data) ``` diff --git a/ariadne_codegen/client_generators/arguments.py b/ariadne_codegen/client_generators/arguments.py index 3825890e..3f9f2ebb 100644 --- a/ariadne_codegen/client_generators/arguments.py +++ b/ariadne_codegen/client_generators/arguments.py @@ -27,7 +27,14 @@ from ..exceptions import ParsingError from ..plugins.manager import PluginManager from ..utils import process_name -from .constants import ANY, INPUT_SCALARS_MAP, OPTIONAL, UNSET_NAME, UNSET_TYPE_NAME +from .constants import ( + ANY, + INPUT_SCALARS_MAP, + KWARGS_NAMES, + OPTIONAL, + UNSET_NAME, + UNSET_TYPE_NAME, +) from .scalars import ScalarData @@ -83,6 +90,7 @@ def generate( arguments = generate_arguments( args=required_args + optional_args, defaults=[generate_name(UNSET_NAME) for _ in optional_args], + kwarg=generate_arg(KWARGS_NAMES, annotation=generate_name(ANY)), ) if self.plugin_manager: diff --git a/ariadne_codegen/client_generators/client.py b/ariadne_codegen/client_generators/client.py index 23b6798b..f27b721f 100644 --- a/ariadne_codegen/client_generators/client.py +++ b/ariadne_codegen/client_generators/client.py @@ -33,6 +33,7 @@ ANY, ASYNC_ITERATOR, DICT, + KWARGS_NAMES, LIST, MODEL_VALIDATE_METHOD, OPTIONAL, @@ -290,10 +291,13 @@ def _generate_execute_call(self) -> ast.Call: return generate_call( func=generate_attribute(generate_name("self"), "execute"), keywords=[ - generate_keyword("query", generate_name(self._operation_str_variable)), generate_keyword( - "variables", generate_name(self._variables_dict_variable) + value=generate_name(self._operation_str_variable), arg="query" ), + generate_keyword( + value=generate_name(self._variables_dict_variable), arg="variables" + ), + generate_keyword(value=generate_name(KWARGS_NAMES)), ], ) @@ -325,13 +329,13 @@ def _generate_async_generator_loop( func=generate_attribute(value=generate_name("self"), attr="execute_ws"), keywords=[ generate_keyword( - arg="query", - value=generate_name(self._operation_str_variable), + value=generate_name(self._operation_str_variable), arg="query" ), generate_keyword( - arg="variables", value=generate_name(self._variables_dict_variable), + arg="variables", ), + generate_keyword(value=generate_name(KWARGS_NAMES)), ], ), body=[self._generate_yield_parsed_obj(return_type)], diff --git a/ariadne_codegen/client_generators/constants.py b/ariadne_codegen/client_generators/constants.py index 94e8ef75..0eef24b4 100644 --- a/ariadne_codegen/client_generators/constants.py +++ b/ariadne_codegen/client_generators/constants.py @@ -68,6 +68,7 @@ SKIP_DIRECTIVE_NAME = "skip" INCLUDE_DIRECTIVE_NAME = "include" +KWARGS_NAMES = "kwargs" DEFAULT_ASYNC_BASE_CLIENT_PATH = ( Path(__file__).parent / "dependencies" / "async_base_client.py" diff --git a/ariadne_codegen/client_generators/dependencies/async_base_client.py b/ariadne_codegen/client_generators/dependencies/async_base_client.py index 3d33b247..25618471 100644 --- a/ariadne_codegen/client_generators/dependencies/async_base_client.py +++ b/ariadne_codegen/client_generators/dependencies/async_base_client.py @@ -84,7 +84,7 @@ async def __aexit__( await self.http_client.aclose() async def execute( - self, query: str, variables: Optional[Dict[str, Any]] = None + self, query: str, variables: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> httpx.Response: processed_variables, files, files_map = self._process_variables(variables) @@ -94,9 +94,12 @@ async def execute( variables=processed_variables, files=files, files_map=files_map, + **kwargs, ) - return await self._execute_json(query=query, variables=processed_variables) + return await self._execute_json( + query=query, variables=processed_variables, **kwargs + ) def get_data(self, response: httpx.Response) -> Dict[str, Any]: if not response.is_success: @@ -123,14 +126,20 @@ def get_data(self, response: httpx.Response) -> Dict[str, Any]: return cast(Dict[str, Any], data) async def execute_ws( - self, query: str, variables: Optional[Dict[str, Any]] = None + self, query: str, variables: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> AsyncIterator[Dict[str, Any]]: + headers = self.ws_headers.copy() + headers.update(kwargs.get("extra_headers", {})) + + merged_kwargs: Dict[str, Any] = {"origin": self.ws_origin} + merged_kwargs.update(kwargs) + merged_kwargs["extra_headers"] = headers + operation_id = str(uuid4()) async with ws_connect( self.ws_url, subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, + **merged_kwargs, ) as websocket: await self._send_connection_init(websocket) await self._send_subscribe( @@ -220,6 +229,7 @@ async def _execute_multipart( variables: Dict[str, Any], files: Dict[str, Tuple[str, IO[bytes], str]], files_map: Dict[str, List[str]], + **kwargs: Any, ) -> httpx.Response: data = { "operations": json.dumps( @@ -228,19 +238,25 @@ async def _execute_multipart( "map": json.dumps(files_map, default=to_jsonable_python), } - return await self.http_client.post(url=self.url, data=data, files=files) + return await self.http_client.post( + url=self.url, data=data, files=files, **kwargs + ) async def _execute_json( - self, - query: str, - variables: Dict[str, Any], + self, query: str, variables: Dict[str, Any], **kwargs: Any ) -> httpx.Response: + headers: Dict[str, str] = {"Content-Type": "application/json"} + headers.update(kwargs.get("headers", {})) + + merged_kwargs: Dict[str, Any] = kwargs.copy() + merged_kwargs["headers"] = headers + return await self.http_client.post( url=self.url, content=json.dumps( {"query": query, "variables": variables}, default=to_jsonable_python ), - headers={"Content-Type": "application/json"}, + **merged_kwargs, ) async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: diff --git a/ariadne_codegen/client_generators/dependencies/async_base_client_open_telemetry.py b/ariadne_codegen/client_generators/dependencies/async_base_client_open_telemetry.py index 40ac419c..4b0ed449 100644 --- a/ariadne_codegen/client_generators/dependencies/async_base_client_open_telemetry.py +++ b/ariadne_codegen/client_generators/dependencies/async_base_client_open_telemetry.py @@ -128,12 +128,14 @@ async def __aexit__( await self.http_client.aclose() async def execute( - self, query: str, variables: Optional[Dict[str, Any]] = None + self, query: str, variables: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> httpx.Response: if self.tracer: - return await self._execute_with_telemetry(query=query, variables=variables) + return await self._execute_with_telemetry( + query=query, variables=variables, **kwargs + ) - return await self._execute(query=query, variables=variables) + return await self._execute(query=query, variables=variables, **kwargs) def get_data(self, response: httpx.Response) -> Dict[str, Any]: if not response.is_success: @@ -160,20 +162,20 @@ def get_data(self, response: httpx.Response) -> Dict[str, Any]: return cast(Dict[str, Any], data) async def execute_ws( - self, query: str, variables: Optional[Dict[str, Any]] = None + self, query: str, variables: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> AsyncIterator[Dict[str, Any]]: if self.tracer: generator = self._execute_ws_with_telemetry( - query=query, variables=variables + query=query, variables=variables, **kwargs ) else: - generator = self._execute_ws(query=query, variables=variables) + generator = self._execute_ws(query=query, variables=variables, **kwargs) async for message in generator: yield message async def _execute( - self, query: str, variables: Optional[Dict[str, Any]] = None + self, query: str, variables: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> httpx.Response: processed_variables, files, files_map = self._process_variables(variables) @@ -183,9 +185,12 @@ async def _execute( variables=processed_variables, files=files, files_map=files_map, + **kwargs, ) - return await self._execute_json(query=query, variables=processed_variables) + return await self._execute_json( + query=query, variables=processed_variables, **kwargs + ) def _process_variables( self, variables: Optional[Dict[str, Any]] @@ -262,6 +267,7 @@ async def _execute_multipart( variables: Dict[str, Any], files: Dict[str, Tuple[str, IO[bytes], str]], files_map: Dict[str, List[str]], + **kwargs: Any, ) -> httpx.Response: data = { "operations": json.dumps( @@ -270,30 +276,42 @@ async def _execute_multipart( "map": json.dumps(files_map, default=to_jsonable_python), } - return await self.http_client.post(url=self.url, data=data, files=files) + return await self.http_client.post( + url=self.url, data=data, files=files, **kwargs + ) async def _execute_json( - self, - query: str, - variables: Dict[str, Any], + self, query: str, variables: Dict[str, Any], **kwargs: Any ) -> httpx.Response: + headers: Dict[str, str] = {"Content-Type": "application/json"} + headers.update(kwargs.get("headers", {})) + + merged_kwargs: Dict[str, Any] = kwargs.copy() + merged_kwargs["headers"] = headers + return await self.http_client.post( url=self.url, content=json.dumps( {"query": query, "variables": variables}, default=to_jsonable_python ), - headers={"Content-Type": "application/json"}, + **merged_kwargs, ) async def _execute_ws( - self, query: str, variables: Optional[Dict[str, Any]] = None + self, query: str, variables: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> AsyncIterator[Dict[str, Any]]: + headers = self.ws_headers.copy() + headers.update(kwargs.get("extra_headers", {})) + + merged_kwargs: Dict[str, Any] = {"origin": self.ws_origin} + merged_kwargs.update(kwargs) + merged_kwargs["extra_headers"] = headers + operation_id = str(uuid4()) async with ws_connect( self.ws_url, subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, + **merged_kwargs, ) as websocket: await self._send_connection_init(websocket) await self._send_subscribe( @@ -367,7 +385,7 @@ async def _handle_ws_message( return None async def _execute_with_telemetry( - self, query: str, variables: Optional[Dict[str, Any]] = None + self, query: str, variables: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> httpx.Response: with self.tracer.start_as_current_span( # type: ignore self.root_span_name, context=self.root_context @@ -383,10 +401,14 @@ async def _execute_with_telemetry( variables=processed_variables, files=files, files_map=files_map, + **kwargs, ) return await self._execute_json_with_telemetry( - root_span=root_span, query=query, variables=processed_variables + root_span=root_span, + query=query, + variables=processed_variables, + **kwargs, ) async def _execute_multipart_with_telemetry( @@ -396,6 +418,7 @@ async def _execute_multipart_with_telemetry( variables: Dict[str, Any], files: Dict[str, Tuple[str, IO[bytes], str]], files_map: Dict[str, List[str]], + **kwargs: Any, ) -> httpx.Response: with self.tracer.start_as_current_span( # type: ignore "multipart request", context=set_span_in_context(root_span) @@ -409,14 +432,15 @@ async def _execute_multipart_with_telemetry( span.set_attribute("variables", serialized_variables) span.set_attribute("map", serialized_map) return await self._execute_multipart( - query=query, variables=variables, files=files, files_map=files_map + query=query, + variables=variables, + files=files, + files_map=files_map, + **kwargs, ) async def _execute_json_with_telemetry( - self, - root_span: Span, - query: str, - variables: Dict[str, Any], + self, root_span: Span, query: str, variables: Dict[str, Any], **kwargs: Any ) -> httpx.Response: with self.tracer.start_as_current_span( # type: ignore "json request", context=set_span_in_context(root_span) @@ -427,21 +451,28 @@ async def _execute_json_with_telemetry( span.set_attribute("query", query) span.set_attribute("variables", serialized_variables) - return await self._execute_json(query=query, variables=variables) + return await self._execute_json(query=query, variables=variables, **kwargs) async def _execute_ws_with_telemetry( - self, query: str, variables: Optional[Dict[str, Any]] = None + self, query: str, variables: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> AsyncIterator[Dict[str, Any]]: with self.tracer.start_as_current_span( # type: ignore self.ws_root_span_name, context=self.ws_root_context ) as root_span: root_span.set_attribute("component", "GraphQL Client") + + headers = self.ws_headers.copy() + headers.update(kwargs.get("extra_headers", {})) + + merged_kwargs: Dict[str, Any] = {"origin": self.ws_origin} + merged_kwargs.update(kwargs) + merged_kwargs["extra_headers"] = headers + operation_id = str(uuid4()) async with ws_connect( self.ws_url, subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, + **merged_kwargs, ) as websocket: await self._send_connection_init_with_telemetry( root_span=root_span, diff --git a/ariadne_codegen/client_generators/dependencies/base_client.py b/ariadne_codegen/client_generators/dependencies/base_client.py index fa981789..db71b14a 100644 --- a/ariadne_codegen/client_generators/dependencies/base_client.py +++ b/ariadne_codegen/client_generators/dependencies/base_client.py @@ -39,7 +39,7 @@ def __exit__( self.http_client.close() def execute( - self, query: str, variables: Optional[Dict[str, Any]] = None + self, query: str, variables: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> httpx.Response: processed_variables, files, files_map = self._process_variables(variables) @@ -49,9 +49,10 @@ def execute( variables=processed_variables, files=files, files_map=files_map, + **kwargs, ) - return self._execute_json(query=query, variables=processed_variables) + return self._execute_json(query=query, variables=processed_variables, **kwargs) def get_data(self, response: httpx.Response) -> dict[str, Any]: if not response.is_success: @@ -152,6 +153,7 @@ def _execute_multipart( variables: Dict[str, Any], files: Dict[str, Tuple[str, IO[bytes], str]], files_map: Dict[str, List[str]], + **kwargs: Any, ) -> httpx.Response: data = { "operations": json.dumps( @@ -160,13 +162,21 @@ def _execute_multipart( "map": json.dumps(files_map, default=to_jsonable_python), } - return self.http_client.post(url=self.url, data=data, files=files) + return self.http_client.post(url=self.url, data=data, files=files, **kwargs) + + def _execute_json( + self, query: str, variables: Dict[str, Any], **kwargs: Any + ) -> httpx.Response: + headers: Dict[str, str] = {"Content-Type": "application/json"} + headers.update(kwargs.get("headers", {})) + + merged_kwargs: Dict[str, Any] = kwargs.copy() + merged_kwargs["headers"] = headers - def _execute_json(self, query: str, variables: Dict[str, Any]) -> httpx.Response: return self.http_client.post( url=self.url, content=json.dumps( {"query": query, "variables": variables}, default=to_jsonable_python ), - headers={"Content-Type": "application/json"}, + **merged_kwargs, ) diff --git a/ariadne_codegen/client_generators/dependencies/base_client_open_telemetry.py b/ariadne_codegen/client_generators/dependencies/base_client_open_telemetry.py index 033629f0..5696b87d 100644 --- a/ariadne_codegen/client_generators/dependencies/base_client_open_telemetry.py +++ b/ariadne_codegen/client_generators/dependencies/base_client_open_telemetry.py @@ -68,11 +68,13 @@ def __exit__( self.http_client.close() def execute( - self, query: str, variables: Optional[Dict[str, Any]] = None + self, query: str, variables: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> httpx.Response: if self.tracer: - return self._execute_with_telemetry(query=query, variables=variables) - return self._execute(query=query, variables=variables) + return self._execute_with_telemetry( + query=query, variables=variables, **kwargs + ) + return self._execute(query=query, variables=variables, **kwargs) def get_data(self, response: httpx.Response) -> dict[str, Any]: if not response.is_success: @@ -99,7 +101,7 @@ def get_data(self, response: httpx.Response) -> dict[str, Any]: return cast(dict[str, Any], data) def _execute( - self, query: str, variables: Optional[Dict[str, Any]] = None + self, query: str, variables: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> httpx.Response: processed_variables, files, files_map = self._process_variables(variables) @@ -109,9 +111,10 @@ def _execute( variables=processed_variables, files=files, files_map=files_map, + **kwargs, ) - return self._execute_json(query=query, variables=processed_variables) + return self._execute_json(query=query, variables=processed_variables, **kwargs) def _process_variables( self, variables: Optional[Dict[str, Any]] @@ -188,6 +191,7 @@ def _execute_multipart( variables: Dict[str, Any], files: Dict[str, Tuple[str, IO[bytes], str]], files_map: Dict[str, List[str]], + **kwargs: Any, ) -> httpx.Response: data = { "operations": json.dumps( @@ -196,19 +200,27 @@ def _execute_multipart( "map": json.dumps(files_map, default=to_jsonable_python), } - return self.http_client.post(url=self.url, data=data, files=files) + return self.http_client.post(url=self.url, data=data, files=files, **kwargs) + + def _execute_json( + self, query: str, variables: Dict[str, Any], **kwargs: Any + ) -> httpx.Response: + headers: Dict[str, str] = {"Content-Type": "application/json"} + headers.update(kwargs.get("headers", {})) + + merged_kwargs: Dict[str, Any] = kwargs.copy() + merged_kwargs["headers"] = headers - def _execute_json(self, query: str, variables: Dict[str, Any]) -> httpx.Response: return self.http_client.post( url=self.url, content=json.dumps( {"query": query, "variables": variables}, default=to_jsonable_python ), - headers={"Content-Type": "application/json"}, + **merged_kwargs, ) def _execute_with_telemetry( - self, query: str, variables: Optional[Dict[str, Any]] = None + self, query: str, variables: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> httpx.Response: with self.tracer.start_as_current_span( # type: ignore self.root_span_name, context=self.root_context @@ -224,10 +236,14 @@ def _execute_with_telemetry( variables=processed_variables, files=files, files_map=files_map, + **kwargs, ) return self._execute_json_with_telemetry( - root_span=root_span, query=query, variables=processed_variables + root_span=root_span, + query=query, + variables=processed_variables, + **kwargs, ) def _execute_multipart_with_telemetry( @@ -237,6 +253,7 @@ def _execute_multipart_with_telemetry( variables: Dict[str, Any], files: Dict[str, Tuple[str, IO[bytes], str]], files_map: Dict[str, List[str]], + **kwargs: Any, ) -> httpx.Response: with self.tracer.start_as_current_span( # type: ignore "multipart request", context=set_span_in_context(root_span) @@ -251,11 +268,15 @@ def _execute_multipart_with_telemetry( span.set_attribute("map", serialized_map) return self._execute_multipart( - query=query, variables=variables, files=files, files_map=files_map + query=query, + variables=variables, + files=files, + files_map=files_map, + **kwargs, ) def _execute_json_with_telemetry( - self, root_span: Span, query: str, variables: Dict[str, Any] + self, root_span: Span, query: str, variables: Dict[str, Any], **kwargs: Any ) -> httpx.Response: with self.tracer.start_as_current_span( # type: ignore "json request", context=set_span_in_context(root_span) @@ -267,4 +288,4 @@ def _execute_json_with_telemetry( span.set_attribute("query", query) span.set_attribute("variables", serialized_variables) - return self._execute_json(query=query, variables=variables) + return self._execute_json(query=query, variables=variables, **kwargs) diff --git a/ariadne_codegen/client_generators/input_fields.py b/ariadne_codegen/client_generators/input_fields.py index 6cb1fe63..bf65bc60 100644 --- a/ariadne_codegen/client_generators/input_fields.py +++ b/ariadne_codegen/client_generators/input_fields.py @@ -157,8 +157,7 @@ def parse_input_const_value_node( func=generate_name(FIELD_CLASS), keywords=[ generate_keyword( - arg="default_factory", - value=generate_lambda(body=list_), + value=generate_lambda(body=list_), arg="default_factory" ) ], ) @@ -182,7 +181,6 @@ def parse_input_const_value_node( func=generate_name(FIELD_CLASS), keywords=[ generate_keyword( - arg="default_factory", value=generate_lambda( body=generate_call( func=generate_attribute( @@ -197,6 +195,7 @@ def parse_input_const_value_node( args=[dict_], ) ), + arg="default_factory", ) ], ) diff --git a/ariadne_codegen/client_generators/input_types.py b/ariadne_codegen/client_generators/input_types.py index 06d73826..6ed4b0d5 100644 --- a/ariadne_codegen/client_generators/input_types.py +++ b/ariadne_codegen/client_generators/input_types.py @@ -172,10 +172,7 @@ def _process_field_value( field_with_alias.keywords.extend(field_implementation.value.keywords) else: field_with_alias.keywords.append( - generate_keyword( - arg="default", - value=field_implementation.value, - ) + generate_keyword(value=field_implementation.value, arg="default") ) return field_with_alias diff --git a/ariadne_codegen/codegen.py b/ariadne_codegen/codegen.py index 7ea57c5e..c5a74df8 100644 --- a/ariadne_codegen/codegen.py +++ b/ariadne_codegen/codegen.py @@ -64,7 +64,9 @@ def generate_arg( def generate_arguments( - args: Optional[List[ast.arg]] = None, defaults: Optional[List[ast.expr]] = None + args: Optional[List[ast.arg]] = None, + defaults: Optional[List[ast.expr]] = None, + kwarg: Optional[ast.arg] = None, ) -> ast.arguments: """Generate arguments.""" return ast.arguments( @@ -72,6 +74,7 @@ def generate_arguments( args=args if args else [], kwonlyargs=[], kw_defaults=[], + kwarg=kwarg, defaults=defaults or [], ) @@ -182,7 +185,7 @@ def generate_attribute(value: ast.expr, attr: str) -> ast.Attribute: return ast.Attribute(value=value, attr=attr) -def generate_keyword(arg: str, value: ast.expr) -> ast.keyword: +def generate_keyword(value: ast.expr, arg: Optional[str] = None) -> ast.keyword: """Generate keyword object.""" return ast.keyword(arg=arg, value=value) @@ -273,7 +276,7 @@ def generate_pydantic_field(keywords: Dict[str, ast.expr]) -> ast.Call: return generate_call( func=generate_name(FIELD_CLASS), keywords=[ - generate_keyword(arg=arg, value=value) for arg, value in keywords.items() + generate_keyword(value=value, arg=arg) for arg, value in keywords.items() ], ) diff --git a/ariadne_codegen/graphql_schema_generators/directives.py b/ariadne_codegen/graphql_schema_generators/directives.py index 25768c82..411571ec 100644 --- a/ariadne_codegen/graphql_schema_generators/directives.py +++ b/ariadne_codegen/graphql_schema_generators/directives.py @@ -23,17 +23,17 @@ def generate_directive(directive: GraphQLDirective, type_map_name: str) -> ast.C return generate_call( func=generate_name("GraphQLDirective"), keywords=[ - generate_keyword(arg="name", value=generate_constant(directive.name)), + generate_keyword(value=generate_constant(directive.name), arg="name"), generate_keyword( - arg="description", value=generate_constant(directive.description) + value=generate_constant(directive.description), arg="description" ), generate_keyword( - arg="is_repeatable", value=generate_constant(directive.is_repeatable) + value=generate_constant(directive.is_repeatable), arg="is_repeatable" ), generate_keyword( - arg="locations", value=generate_directive_locations(directive.locations) + value=generate_directive_locations(directive.locations), arg="locations" ), - generate_keyword(arg="args", value=args), + generate_keyword(value=args, arg="args"), ], ) diff --git a/ariadne_codegen/graphql_schema_generators/fields.py b/ariadne_codegen/graphql_schema_generators/fields.py index 8ecfdbb6..dc5cab85 100644 --- a/ariadne_codegen/graphql_schema_generators/fields.py +++ b/ariadne_codegen/graphql_schema_generators/fields.py @@ -55,14 +55,14 @@ def generate_field(field: GraphQLField, type_map_name: str) -> ast.Call: args=[generate_field_type(field.type, type_map_name)], keywords=[ generate_keyword( - arg="args", value=generate_args(field.args, type_map_name) + value=generate_args(field.args, type_map_name), arg="args" ), generate_keyword( - arg="description", value=generate_constant(field.description) + value=generate_constant(field.description), arg="description" ), generate_keyword( - arg="deprecation_reason", value=generate_constant(field.deprecation_reason), + arg="deprecation_reason", ), ], ) @@ -117,14 +117,15 @@ def generate_arg(arg: GraphQLArgument, type_map_name: str) -> ast.Call: args=[generate_field_type(arg.type, type_map_name)], keywords=[ generate_keyword( - arg="default_value", value=generate_constant(arg.default_value) + value=generate_constant(arg.default_value), arg="default_value" ), generate_keyword( - arg="description", value=generate_constant(arg.description) + value=generate_constant(arg.description), + arg="description", ), generate_keyword( - arg="deprecation_reason", value=generate_constant(arg.deprecation_reason), + arg="deprecation_reason", ), ], ) @@ -142,13 +143,13 @@ def generate_enum_value(value: GraphQLEnumValue) -> ast.Call: return generate_call( func=generate_name("GraphQLEnumValue"), keywords=[ - generate_keyword(arg="value", value=generate_constant(value.value)), + generate_keyword(value=generate_constant(value.value), arg="value"), generate_keyword( - arg="description", value=generate_constant(value.description) + value=generate_constant(value.description), arg="description" ), generate_keyword( - arg="deprecation_reason", value=generate_constant(value.deprecation_reason), + arg="deprecation_reason", ), ], ) @@ -178,14 +179,14 @@ def generate_input_field( args=[generate_field_type(input_field.type, type_map_name)], keywords=[ generate_keyword( - arg="default_value", value=generate_constant(input_field.default_value) + value=generate_constant(input_field.default_value), arg="default_value" ), generate_keyword( - arg="description", value=generate_constant(input_field.description) + value=generate_constant(input_field.description), arg="description" ), generate_keyword( - arg="deprecation_reason", value=generate_constant(input_field.deprecation_reason), + arg="deprecation_reason", ), ], ) diff --git a/ariadne_codegen/graphql_schema_generators/named_types.py b/ariadne_codegen/graphql_schema_generators/named_types.py index 606b4827..1bd83b51 100644 --- a/ariadne_codegen/graphql_schema_generators/named_types.py +++ b/ariadne_codegen/graphql_schema_generators/named_types.py @@ -37,12 +37,12 @@ def generate_scalar_type(type_: GraphQLScalarType, *_) -> ast.Call: return generate_call( func=generate_name("GraphQLScalarType"), keywords=[ - generate_keyword(arg="name", value=generate_constant(type_.name)), + generate_keyword(value=generate_constant(type_.name), arg="name"), generate_keyword( - arg="description", value=generate_constant(type_.description) + value=generate_constant(type_.description), arg="description" ), generate_keyword( - arg="specified_by_url", value=generate_constant(type_.specified_by_url) + value=generate_constant(type_.specified_by_url), arg="specified_by_url" ), ], ) @@ -52,20 +52,20 @@ def generate_object_type(type_: GraphQLObjectType, type_map_name: str) -> ast.Ca return generate_call( func=generate_name("GraphQLObjectType"), keywords=[ - generate_keyword(arg="name", value=generate_constant(type_.name)), + generate_keyword(value=generate_constant(type_.name), arg="name"), generate_keyword( - arg="description", value=generate_constant(type_.description) + value=generate_constant(type_.description), arg="description" ), generate_keyword( - arg="interfaces", value=get_list_of_named_types( [i.name for i in type_.interfaces], type_map_name, GraphQLInterfaceType.__name__, ), + arg="interfaces", ), generate_keyword( - arg="fields", value=generate_field_map(type_.fields, type_map_name) + value=generate_field_map(type_.fields, type_map_name), arg="fields" ), ], ) @@ -77,20 +77,20 @@ def generate_interface_type( return generate_call( func=generate_name("GraphQLInterfaceType"), keywords=[ - generate_keyword(arg="name", value=generate_constant(type_.name)), + generate_keyword(value=generate_constant(type_.name), arg="name"), generate_keyword( - arg="description", value=generate_constant(type_.description) + value=generate_constant(type_.description), arg="description" ), generate_keyword( - arg="interfaces", value=get_list_of_named_types( [i.name for i in type_.interfaces], type_map_name, GraphQLInterfaceType.__name__, ), + arg="interfaces", ), generate_keyword( - arg="fields", value=generate_field_map(type_.fields, type_map_name) + value=generate_field_map(type_.fields, type_map_name), arg="fields" ), ], ) @@ -100,17 +100,17 @@ def generate_union_type(type_: GraphQLUnionType, type_map_name: str) -> ast.Call return generate_call( func=generate_name("GraphQLUnionType"), keywords=[ - generate_keyword(arg="name", value=generate_constant(type_.name)), + generate_keyword(value=generate_constant(type_.name), arg="name"), generate_keyword( - arg="description", value=generate_constant(type_.description) + value=generate_constant(type_.description), arg="description" ), generate_keyword( - arg="types", value=get_list_of_named_types( [t.name for t in type_.types], type_map_name, GraphQLObjectType.__name__, ), + arg="types", ), ], ) @@ -120,11 +120,11 @@ def generate_enum_type(type_: GraphQLEnumType, *_) -> ast.Call: return generate_call( func=generate_name("GraphQLEnumType"), keywords=[ - generate_keyword(arg="name", value=generate_constant(type_.name)), + generate_keyword(value=generate_constant(type_.name), arg="name"), generate_keyword( - arg="description", value=generate_constant(type_.description) + value=generate_constant(type_.description), arg="description" ), - generate_keyword(arg="values", value=generate_enum_values(type_.values)), + generate_keyword(value=generate_enum_values(type_.values), arg="values"), ], ) @@ -135,13 +135,13 @@ def generate_input_object_type( return generate_call( func=generate_name("GraphQLInputObjectType"), keywords=[ - generate_keyword(arg="name", value=generate_constant(type_.name)), + generate_keyword(value=generate_constant(type_.name), arg="name"), generate_keyword( - arg="description", value=generate_constant(type_.description) + value=generate_constant(type_.description), arg="description" ), generate_keyword( - arg="fields", value=generate_input_field_map(type_.fields, type_map_name), + arg="fields", ), ], ) diff --git a/ariadne_codegen/graphql_schema_generators/schema.py b/ariadne_codegen/graphql_schema_generators/schema.py index 7f352717..a1587b14 100644 --- a/ariadne_codegen/graphql_schema_generators/schema.py +++ b/ariadne_codegen/graphql_schema_generators/schema.py @@ -104,35 +104,35 @@ def generate_schema(schema: GraphQLSchema, type_map_name: str) -> ast.Call: func=generate_name("GraphQLSchema"), keywords=[ generate_keyword( - arg="query", value=get_optional_named_type(schema.query_type, type_map_name), + arg="query", ), generate_keyword( - arg="mutation", value=get_optional_named_type(schema.mutation_type, type_map_name), + arg="mutation", ), generate_keyword( - arg="subscription", value=get_optional_named_type(schema.subscription_type, type_map_name), + arg="subscription", ), generate_keyword( - arg="types", value=generate_call( func=generate_attribute( value=generate_name(type_map_name), attr="values" ), ), + arg="types", ), generate_keyword( - arg="directives", value=generate_list( elements=[ generate_directive(d, type_map_name) for d in schema.directives ] ), + arg="directives", ), generate_keyword( - arg="description", value=generate_constant(schema.description) + value=generate_constant(schema.description), arg="description" ), ], ) diff --git a/tests/client_generators/dependencies/test_async_base_client.py b/tests/client_generators/dependencies/test_async_base_client.py index b7281164..9cc175d2 100644 --- a/tests/client_generators/dependencies/test_async_base_client.py +++ b/tests/client_generators/dependencies/test_async_base_client.py @@ -211,6 +211,42 @@ async def test_execute_sends_request_with_extra_headers_and_correct_content_type assert request.headers["Content-Type"] == "application/json" +@pytest.mark.asyncio +async def test_execute_passes_kwargs_to_json_post(mocker): + http_client = mocker.AsyncMock() + + await AsyncBaseClient(url="http://base_url", http_client=http_client).execute( + "query Abc { abc }", {}, timeout=333, follow_redirects=False + ) + + assert http_client.post.call_args.kwargs["timeout"] == 333 + assert http_client.post.call_args.kwargs["follow_redirects"] is False + + +@pytest.mark.asyncio +async def test_execute_sends_json_request_with_headers_from_passed_kwargs(httpx_mock): + httpx_mock.add_response() + client = AsyncBaseClient( + url="http://base_url", + headers={ + "Client-Header-A": "client_value_A", + "Client-Header-B": "client_value_b", + }, + ) + + await client.execute( + "query Abc { abc }", + {}, + headers={"Other-Header": "other", "Client-Header-A": "execute_value"}, + ) + + request = httpx_mock.get_request() + assert request.headers["Content-Type"] == "application/json" + assert request.headers["Other-Header"] == "other" + assert request.headers["Client-Header-A"] == "execute_value" + assert request.headers["Client-Header-b"] == "client_value_b" + + @pytest.mark.asyncio async def test_execute_sends_file_with_multipart_form_data_content_type( httpx_mock, txt_file @@ -378,6 +414,49 @@ async def test_execute_sends_each_file_only_once(httpx_mock, txt_file): assert sent_parts["0"].content == b"abcdefgh" +@pytest.mark.asyncio +async def test_execute_passes_kwargs_to_multipart_post(mocker, txt_file): + http_client = mocker.AsyncMock() + + await AsyncBaseClient(url="http://base_url", http_client=http_client).execute( + "query Abc($file: Upload!) { abc(file: $file) }", + {"file": txt_file}, + timeout=333, + follow_redirects=False, + ) + + assert http_client.post.call_args.kwargs["timeout"] == 333 + assert http_client.post.call_args.kwargs["follow_redirects"] is False + + +@pytest.mark.asyncio +async def test_execute_sends_multipart_request_with_headers_from_passed_kwargs( + httpx_mock, txt_file +): + httpx_mock.add_response() + client = AsyncBaseClient( + url="http://base_url", + headers={ + "Client-Header-A": "client_value_A", + "Client-Header-B": "client_value_b", + }, + ) + + await client.execute( + "query Abc($file: Upload!) { abc(file: $file) }", + {"file": txt_file}, + headers={ + "Other-Header": "other", + "Client-Header-A": "execute_value", + }, + ) + + request = httpx_mock.get_request() + assert request.headers["Other-Header"] == "other" + assert request.headers["Client-Header-A"] == "execute_value" + assert request.headers["Client-Header-B"] == "client_value_b" + + @pytest.mark.parametrize( "status_code, response_content", [ diff --git a/tests/client_generators/dependencies/test_async_base_client_open_telemetry.py b/tests/client_generators/dependencies/test_async_base_client_open_telemetry.py index 66361525..6989c795 100644 --- a/tests/client_generators/dependencies/test_async_base_client_open_telemetry.py +++ b/tests/client_generators/dependencies/test_async_base_client_open_telemetry.py @@ -214,6 +214,47 @@ async def test_execute_sends_request_with_extra_headers_and_correct_content_type assert request.headers["Content-Type"] == "application/json" +@pytest.mark.asyncio +@pytest.mark.parametrize("tracer", ["tracer name", None]) +async def test_execute_passes_kwargs_to_json_post(mocker, tracer): + http_client = mocker.AsyncMock() + + await AsyncBaseClientOpenTelemetry( + url="http://base_url", http_client=http_client, tracer=tracer + ).execute("query Abc { abc }", {}, timeout=333, follow_redirects=False) + + assert http_client.post.call_args.kwargs["timeout"] == 333 + assert http_client.post.call_args.kwargs["follow_redirects"] is False + + +@pytest.mark.asyncio +@pytest.mark.parametrize("tracer", ["tracer name", None]) +async def test_execute_sends_json_request_with_headers_from_passed_kwargs( + httpx_mock, tracer +): + httpx_mock.add_response() + client = AsyncBaseClientOpenTelemetry( + url="http://base_url", + headers={ + "Client-Header-A": "client_value_A", + "Client-Header-B": "client_value_b", + }, + tracer=tracer, + ) + + await client.execute( + "query Abc { abc }", + {}, + headers={"Other-Header": "other", "Client-Header-A": "execute_value"}, + ) + + request = httpx_mock.get_request() + assert request.headers["Content-Type"] == "application/json" + assert request.headers["Other-Header"] == "other" + assert request.headers["Client-Header-A"] == "execute_value" + assert request.headers["Client-Header-b"] == "client_value_b" + + @pytest.mark.asyncio async def test_execute_sends_file_with_multipart_form_data_content_type( httpx_mock, txt_file @@ -381,6 +422,54 @@ async def test_execute_sends_each_file_only_once(httpx_mock, txt_file): assert sent_parts["0"].content == b"abcdefgh" +@pytest.mark.asyncio +@pytest.mark.parametrize("tracer", ["tracer name", None]) +async def test_execute_passes_kwargs_to_multipart_post(mocker, tracer, txt_file): + http_client = mocker.AsyncMock() + + await AsyncBaseClientOpenTelemetry( + url="http://base_url", http_client=http_client, tracer=tracer + ).execute( + "query Abc($file: Upload!) { abc(file: $file) }", + {"file": txt_file}, + timeout=333, + follow_redirects=False, + ) + + assert http_client.post.call_args.kwargs["timeout"] == 333 + assert http_client.post.call_args.kwargs["follow_redirects"] is False + + +@pytest.mark.asyncio +@pytest.mark.parametrize("tracer", ["tracer name", None]) +async def test_execute_sends_multipart_request_with_headers_from_passed_kwargs( + httpx_mock, tracer, txt_file +): + httpx_mock.add_response() + client = AsyncBaseClientOpenTelemetry( + url="http://base_url", + headers={ + "Client-Header-A": "client_value_A", + "Client-Header-B": "client_value_b", + }, + tracer=tracer, + ) + + await client.execute( + "query Abc($file: Upload!) { abc(file: $file) }", + {"file": txt_file}, + headers={ + "Other-Header": "other", + "Client-Header-A": "execute_value", + }, + ) + + request = httpx_mock.get_request() + assert request.headers["Other-Header"] == "other" + assert request.headers["Client-Header-A"] == "execute_value" + assert request.headers["Client-Header-B"] == "client_value_b" + + @pytest.mark.parametrize( "status_code, response_content", [ diff --git a/tests/client_generators/dependencies/test_base_client.py b/tests/client_generators/dependencies/test_base_client.py index 24fd4087..65641849 100644 --- a/tests/client_generators/dependencies/test_base_client.py +++ b/tests/client_generators/dependencies/test_base_client.py @@ -194,6 +194,40 @@ def test_execute_sends_request_with_extra_headers_and_correct_content_type(httpx assert request.headers["Content-Type"] == "application/json" +def test_execute_passes_kwargs_to_json_post(mocker): + http_client = mocker.MagicMock() + + BaseClient(url="http://base_url", http_client=http_client).execute( + "query Abc { abc }", {}, timeout=333, follow_redirects=False + ) + + assert http_client.post.call_args.kwargs["timeout"] == 333 + assert http_client.post.call_args.kwargs["follow_redirects"] is False + + +def test_execute_sends_json_request_with_headers_from_passed_kwargs(httpx_mock): + httpx_mock.add_response() + client = BaseClient( + url="http://base_url", + headers={ + "Client-Header-A": "client_value_A", + "Client-Header-B": "client_value_b", + }, + ) + + client.execute( + "query Abc { abc }", + {}, + headers={"Other-Header": "other", "Client-Header-A": "execute_value"}, + ) + + request = httpx_mock.get_request() + assert request.headers["Content-Type"] == "application/json" + assert request.headers["Other-Header"] == "other" + assert request.headers["Client-Header-A"] == "execute_value" + assert request.headers["Client-Header-b"] == "client_value_b" + + def test_execute_sends_file_with_multipart_form_data_content_type(httpx_mock, txt_file): httpx_mock.add_response() @@ -351,6 +385,47 @@ def test_execute_sends_each_file_only_once(httpx_mock, txt_file): assert sent_parts["0"].content == b"abcdefgh" +def test_execute_passes_kwargs_to_multipart_post(mocker, txt_file): + http_client = mocker.MagicMock() + + BaseClient(url="http://base_url", http_client=http_client).execute( + "query Abc($file: Upload!) { abc(file: $file) }", + {"file": txt_file}, + timeout=333, + follow_redirects=False, + ) + + assert http_client.post.call_args.kwargs["timeout"] == 333 + assert http_client.post.call_args.kwargs["follow_redirects"] is False + + +def test_execute_sends_multipart_request_with_headers_from_passed_kwargs( + httpx_mock, txt_file +): + httpx_mock.add_response() + client = BaseClient( + url="http://base_url", + headers={ + "Client-Header-A": "client_value_A", + "Client-Header-B": "client_value_b", + }, + ) + + client.execute( + "query Abc($file: Upload!) { abc(file: $file) }", + {"file": txt_file}, + headers={ + "Other-Header": "other", + "Client-Header-A": "execute_value", + }, + ) + + request = httpx_mock.get_request() + assert request.headers["Other-Header"] == "other" + assert request.headers["Client-Header-A"] == "execute_value" + assert request.headers["Client-Header-B"] == "client_value_b" + + @pytest.mark.parametrize( "status_code, response_content", [ diff --git a/tests/client_generators/dependencies/test_base_client_open_telemetry.py b/tests/client_generators/dependencies/test_base_client_open_telemetry.py index 2a855055..c387953a 100644 --- a/tests/client_generators/dependencies/test_base_client_open_telemetry.py +++ b/tests/client_generators/dependencies/test_base_client_open_telemetry.py @@ -199,6 +199,43 @@ def test_execute_sends_request_with_extra_headers_and_correct_content_type(httpx assert request.headers["Content-Type"] == "application/json" +@pytest.mark.parametrize("tracer", ["tracer name", None]) +def test_execute_passes_kwargs_to_json_post(mocker, tracer): + http_client = mocker.MagicMock() + + BaseClientOpenTelemetry( + url="http://base_url", http_client=http_client, tracer=tracer + ).execute("query Abc { abc }", {}, timeout=333, follow_redirects=False) + + assert http_client.post.call_args.kwargs["timeout"] == 333 + assert http_client.post.call_args.kwargs["follow_redirects"] is False + + +@pytest.mark.parametrize("tracer", ["tracer name", None]) +def test_execute_sends_json_request_with_headers_from_passed_kwargs(httpx_mock, tracer): + httpx_mock.add_response() + client = BaseClientOpenTelemetry( + url="http://base_url", + headers={ + "Client-Header-A": "client_value_A", + "Client-Header-B": "client_value_b", + }, + tracer=tracer, + ) + + client.execute( + "query Abc { abc }", + {}, + headers={"Other-Header": "other", "Client-Header-A": "execute_value"}, + ) + + request = httpx_mock.get_request() + assert request.headers["Content-Type"] == "application/json" + assert request.headers["Other-Header"] == "other" + assert request.headers["Client-Header-A"] == "execute_value" + assert request.headers["Client-Header-b"] == "client_value_b" + + def test_execute_sends_file_with_multipart_form_data_content_type(httpx_mock, txt_file): httpx_mock.add_response() @@ -356,6 +393,52 @@ def test_execute_sends_each_file_only_once(httpx_mock, txt_file): assert sent_parts["0"].content == b"abcdefgh" +@pytest.mark.parametrize("tracer", ["tracer name", None]) +def test_execute_passes_kwargs_to_multipart_post(mocker, tracer, txt_file): + http_client = mocker.MagicMock() + + BaseClientOpenTelemetry( + url="http://base_url", http_client=http_client, tracer=tracer + ).execute( + "query Abc($file: Upload!) { abc(file: $file) }", + {"file": txt_file}, + timeout=333, + follow_redirects=False, + ) + + assert http_client.post.call_args.kwargs["timeout"] == 333 + assert http_client.post.call_args.kwargs["follow_redirects"] is False + + +@pytest.mark.parametrize("tracer", ["tracer name", None]) +def test_execute_sends_multipart_request_with_headers_from_passed_kwargs( + httpx_mock, tracer, txt_file +): + httpx_mock.add_response() + client = BaseClientOpenTelemetry( + url="http://base_url", + headers={ + "Client-Header-A": "client_value_A", + "Client-Header-B": "client_value_b", + }, + tracer=tracer, + ) + + client.execute( + "query Abc($file: Upload!) { abc(file: $file) }", + {"file": txt_file}, + headers={ + "Other-Header": "other", + "Client-Header-A": "execute_value", + }, + ) + + request = httpx_mock.get_request() + assert request.headers["Other-Header"] == "other" + assert request.headers["Client-Header-A"] == "execute_value" + assert request.headers["Client-Header-B"] == "client_value_b" + + @pytest.mark.parametrize( "status_code, response_content", [ diff --git a/tests/client_generators/dependencies/test_websockets.py b/tests/client_generators/dependencies/test_websockets.py index 995e2a30..5b907c77 100644 --- a/tests/client_generators/dependencies/test_websockets.py +++ b/tests/client_generators/dependencies/test_websockets.py @@ -77,6 +77,37 @@ async def test_execute_ws_creates_websocket_connection_with_correct_headers( } +@pytest.mark.asyncio +async def test_execute_ws_creates_websocket_connection_with_passed_extra_headers( + mocked_ws_connect, +): + async for _ in AsyncBaseClient( + ws_headers={"Client-A": "client_value_a", "Client-B": "client_value_b"} + ).execute_ws( + "", extra_headers={"Client-A": "execute_value_a", "Execute-Other": "other"} + ): + pass + + assert mocked_ws_connect.called + assert mocked_ws_connect.call_args.kwargs["extra_headers"] == { + "Client-A": "execute_value_a", + "Client-B": "client_value_b", + "Execute-Other": "other", + } + + +@pytest.mark.asyncio +async def test_execute_ws_creates_websocket_connection_with_passed_kwargs( + mocked_ws_connect, +): + async for _ in AsyncBaseClient().execute_ws("", open_timeout=15, close_timeout=30): + pass + + assert mocked_ws_connect.called + assert mocked_ws_connect.call_args.kwargs["open_timeout"] == 15 + assert mocked_ws_connect.call_args.kwargs["close_timeout"] == 30 + + @pytest.mark.asyncio async def test_execute_ws_sends_correct_init_connection_data(mocked_websocket): async for _ in AsyncBaseClient( diff --git a/tests/client_generators/dependencies/test_websockets_open_telemetry.py b/tests/client_generators/dependencies/test_websockets_open_telemetry.py index 09616e5e..ef209502 100644 --- a/tests/client_generators/dependencies/test_websockets_open_telemetry.py +++ b/tests/client_generators/dependencies/test_websockets_open_telemetry.py @@ -79,6 +79,42 @@ async def test_execute_ws_creates_websocket_connection_with_correct_headers( } +@pytest.mark.asyncio +@pytest.mark.parametrize("tracer", ["tracer name", None]) +async def test_execute_ws_creates_websocket_connection_with_passed_extra_headers( + mocked_ws_connect, tracer +): + async for _ in AsyncBaseClientOpenTelemetry( + ws_headers={"Client-A": "client_value_a", "Client-B": "client_value_b"}, + tracer=tracer, + ).execute_ws( + "", extra_headers={"Client-A": "execute_value_a", "Execute-Other": "other"} + ): + pass + + assert mocked_ws_connect.called + assert mocked_ws_connect.call_args.kwargs["extra_headers"] == { + "Client-A": "execute_value_a", + "Client-B": "client_value_b", + "Execute-Other": "other", + } + + +@pytest.mark.asyncio +@pytest.mark.parametrize("tracer", ["tracer name", None]) +async def test_execute_ws_creates_websocket_connection_with_passed_kwargs( + mocked_ws_connect, tracer +): + async for _ in AsyncBaseClientOpenTelemetry(tracer=tracer).execute_ws( + "", open_timeout=15, close_timeout=30 + ): + pass + + assert mocked_ws_connect.called + assert mocked_ws_connect.call_args.kwargs["open_timeout"] == 15 + assert mocked_ws_connect.call_args.kwargs["close_timeout"] == 30 + + @pytest.mark.asyncio async def test_execute_ws_sends_correct_init_connection_data(mocked_websocket): async for _ in AsyncBaseClientOpenTelemetry( diff --git a/tests/client_generators/package_generator/test_generated_files.py b/tests/client_generators/package_generator/test_generated_files.py index a2ad4417..f92a91c4 100644 --- a/tests/client_generators/package_generator/test_generated_files.py +++ b/tests/client_generators/package_generator/test_generated_files.py @@ -392,7 +392,9 @@ def test_generate_with_enum_as_query_argument_generates_client_with_correct_meth } """ - expected_method_def = "def custom_query(self, val: CustomEnum) -> CustomQuery:" + expected_method_def = ( + "def custom_query(self, val: CustomEnum, **kwargs: Any) -> CustomQuery:" + ) expected_enum_import = f"from .{generator.enums_module_name} import CustomEnum" generator.add_operation(parse(query_str).definitions[0]) diff --git a/tests/client_generators/test_arguments_generator.py b/tests/client_generators/test_arguments_generator.py index 1687a43f..f0cc04d0 100644 --- a/tests/client_generators/test_arguments_generator.py +++ b/tests/client_generators/test_arguments_generator.py @@ -5,6 +5,7 @@ from ariadne_codegen.client_generators.arguments import ArgumentsGenerator from ariadne_codegen.client_generators.constants import ( ANY, + KWARGS_NAMES, OPTIONAL, UNION, UNSET_NAME, @@ -90,6 +91,7 @@ def test_generate_returns_arguments_with_correct_optional_annotation(): ), ), ], + kwarg=ast.arg(arg=KWARGS_NAMES, annotation=ast.Name(id=ANY)), kwonlyargs=[], kw_defaults=[], defaults=[ast.Name(id="UNSET")], @@ -145,6 +147,7 @@ def test_generate_returns_arguments_with_default_value_for_optional_args(): ), ), ], + kwarg=ast.arg(arg=KWARGS_NAMES, annotation=ast.Name(id=ANY)), kwonlyargs=[], kw_defaults=[], defaults=[ast.Name(id=UNSET_NAME), ast.Name(id=UNSET_NAME)], @@ -155,18 +158,21 @@ def test_generate_returns_arguments_with_default_value_for_optional_args(): assert compare_ast(arguments, expected_arguments) -def test_generate_returns_arguments_with_only_self_argument_without_annotation(): +def test_generate_returns_arguments_with_only_self_and_kwargs_arguments(): generator = ArgumentsGenerator(schema=GraphQLSchema()) - query = "query q {r}" - variable_definitions = _get_variable_definitions_from_query_str(query) + variable_definitions = _get_variable_definitions_from_query_str("query q {r}") + expected_arguments = ast.arguments( + posonlyargs=[], + args=[ast.arg(arg="self")], + kwarg=ast.arg(arg=KWARGS_NAMES, annotation=ast.Name(id=ANY)), + kwonlyargs=[], + kw_defaults=[], + defaults=[], + ) arguments, _ = generator.generate(variable_definitions) - assert isinstance(arguments, ast.arguments) - assert len(arguments.args) == 1 - self_arg = arguments.args[0] - assert self_arg.arg == "self" - assert not self_arg.annotation + assert compare_ast(arguments, expected_arguments) def test_generate_saves_used_non_scalar_types(): @@ -203,6 +209,7 @@ def test_generate_returns_arguments_and_dictionary_with_snake_case_names(): ast.arg(arg="camel_case", annotation=ast.Name(id="str")), ast.arg(arg="snake_case", annotation=ast.Name(id="str")), ], + kwarg=ast.arg(arg=KWARGS_NAMES, annotation=ast.Name(id=ANY)), kwonlyargs=[], kw_defaults=[], defaults=[], @@ -236,6 +243,7 @@ def test_generate_returns_arguments_and_dictionary_with_valid_names(): ast.arg(arg="field_a", annotation=ast.Name(id="str")), ast.arg(arg="field_b", annotation=ast.Name(id="str")), ], + kwarg=ast.arg(arg=KWARGS_NAMES, annotation=ast.Name(id=ANY)), kwonlyargs=[], kw_defaults=[], defaults=[], @@ -276,6 +284,7 @@ def test_generate_returns_arguments_with_not_mapped_custom_scalar(): ast.arg(arg="self"), ast.arg(arg="arg", annotation=ast.Name(id=ANY)), ], + kwarg=ast.arg(arg=KWARGS_NAMES, annotation=ast.Name(id=ANY)), kwonlyargs=[], kw_defaults=[], defaults=[], @@ -313,6 +322,7 @@ def test_generate_returns_arguments_with_custom_scalar_and_used_serialize_method ast.arg(arg="self"), ast.arg(arg="arg", annotation=ast.Name(id="ScalarABC")), ], + kwarg=ast.arg(arg=KWARGS_NAMES, annotation=ast.Name(id=ANY)), kwonlyargs=[], kw_defaults=[], defaults=[], @@ -351,6 +361,7 @@ def test_generate_returns_arguments_with_upload_scalar(): ast.arg(arg="self"), ast.arg(arg="arg", annotation=ast.Name(id=UPLOAD_CLASS_NAME)), ], + kwarg=ast.arg(arg=KWARGS_NAMES, annotation=ast.Name(id=ANY)), kwonlyargs=[], kw_defaults=[], defaults=[], diff --git a/tests/client_generators/test_client_generator.py b/tests/client_generators/test_client_generator.py index f8a812e4..f43a0709 100644 --- a/tests/client_generators/test_client_generator.py +++ b/tests/client_generators/test_client_generator.py @@ -10,6 +10,7 @@ ANY, ASYNC_ITERATOR, DICT, + KWARGS_NAMES, LIST, MODEL_VALIDATE_METHOD, OPTIONAL, @@ -277,6 +278,7 @@ def test_add_method_generates_correct_async_method_body(async_base_client_import keywords=[ ast.keyword(arg="query", value=ast.Name(id="query")), ast.keyword(arg="variables", value=ast.Name(id="variables")), + ast.keyword(value=ast.Name(id=KWARGS_NAMES)), ], ) ), @@ -414,6 +416,7 @@ def test_add_method_generates_correct_method_body(base_client_import): keywords=[ ast.keyword(arg="query", value=ast.Name(id="query")), ast.keyword(arg="variables", value=ast.Name(id="variables")), + ast.keyword(value=ast.Name(id=KWARGS_NAMES)), ], ), ), @@ -470,6 +473,7 @@ def test_add_method_generates_async_generator_for_subscription_definition( args=ast.arguments( posonlyargs=[], args=[ast.arg(arg="self")], + kwarg=ast.arg(arg=KWARGS_NAMES, annotation=ast.Name(id=ANY)), kwonlyargs=[], kw_defaults=[], defaults=[], @@ -502,6 +506,7 @@ def test_add_method_generates_async_generator_for_subscription_definition( keywords=[ ast.keyword(arg="query", value=ast.Name(id="query")), ast.keyword(arg="variables", value=ast.Name(id="variables")), + ast.keyword(value=ast.Name(id=KWARGS_NAMES)), ], ), body=[ diff --git a/tests/codegen/test_generated_calls.py b/tests/codegen/test_generated_calls.py index d57654b9..005c9f1b 100644 --- a/tests/codegen/test_generated_calls.py +++ b/tests/codegen/test_generated_calls.py @@ -48,7 +48,7 @@ def test_generate_await_returns_await_object(): def test_generate_keyword_returns_keyword_object(): - result = generate_keyword(arg="test_variable", value=ast.Name(id="test value")) + result = generate_keyword(value=ast.Name(id="test value"), arg="test_variable") assert isinstance(result, ast.keyword) assert result.arg == "test_variable" diff --git a/tests/main/clients/custom_base_client/expected_client/client.py b/tests/main/clients/custom_base_client/expected_client/client.py index a46eff79..c0a77c67 100644 --- a/tests/main/clients/custom_base_client/expected_client/client.py +++ b/tests/main/clients/custom_base_client/expected_client/client.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Any, Dict from .custom_base_client import CustomAsyncBaseClient from .get_query_a import GetQueryA @@ -10,7 +10,7 @@ def gql(q: str) -> str: class Client(CustomAsyncBaseClient): - async def get_query_a(self, data_a: inputA) -> GetQueryA: + async def get_query_a(self, data_a: inputA, **kwargs: Any) -> GetQueryA: query = gql( """ query getQueryA($dataA: inputA!) { @@ -21,6 +21,6 @@ async def get_query_a(self, data_a: inputA) -> GetQueryA: """ ) variables: Dict[str, object] = {"dataA": data_a} - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return GetQueryA.model_validate(data) diff --git a/tests/main/clients/custom_config_file/expected_client/async_base_client.py b/tests/main/clients/custom_config_file/expected_client/async_base_client.py index 3d33b247..25618471 100644 --- a/tests/main/clients/custom_config_file/expected_client/async_base_client.py +++ b/tests/main/clients/custom_config_file/expected_client/async_base_client.py @@ -84,7 +84,7 @@ async def __aexit__( await self.http_client.aclose() async def execute( - self, query: str, variables: Optional[Dict[str, Any]] = None + self, query: str, variables: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> httpx.Response: processed_variables, files, files_map = self._process_variables(variables) @@ -94,9 +94,12 @@ async def execute( variables=processed_variables, files=files, files_map=files_map, + **kwargs, ) - return await self._execute_json(query=query, variables=processed_variables) + return await self._execute_json( + query=query, variables=processed_variables, **kwargs + ) def get_data(self, response: httpx.Response) -> Dict[str, Any]: if not response.is_success: @@ -123,14 +126,20 @@ def get_data(self, response: httpx.Response) -> Dict[str, Any]: return cast(Dict[str, Any], data) async def execute_ws( - self, query: str, variables: Optional[Dict[str, Any]] = None + self, query: str, variables: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> AsyncIterator[Dict[str, Any]]: + headers = self.ws_headers.copy() + headers.update(kwargs.get("extra_headers", {})) + + merged_kwargs: Dict[str, Any] = {"origin": self.ws_origin} + merged_kwargs.update(kwargs) + merged_kwargs["extra_headers"] = headers + operation_id = str(uuid4()) async with ws_connect( self.ws_url, subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, + **merged_kwargs, ) as websocket: await self._send_connection_init(websocket) await self._send_subscribe( @@ -220,6 +229,7 @@ async def _execute_multipart( variables: Dict[str, Any], files: Dict[str, Tuple[str, IO[bytes], str]], files_map: Dict[str, List[str]], + **kwargs: Any, ) -> httpx.Response: data = { "operations": json.dumps( @@ -228,19 +238,25 @@ async def _execute_multipart( "map": json.dumps(files_map, default=to_jsonable_python), } - return await self.http_client.post(url=self.url, data=data, files=files) + return await self.http_client.post( + url=self.url, data=data, files=files, **kwargs + ) async def _execute_json( - self, - query: str, - variables: Dict[str, Any], + self, query: str, variables: Dict[str, Any], **kwargs: Any ) -> httpx.Response: + headers: Dict[str, str] = {"Content-Type": "application/json"} + headers.update(kwargs.get("headers", {})) + + merged_kwargs: Dict[str, Any] = kwargs.copy() + merged_kwargs["headers"] = headers + return await self.http_client.post( url=self.url, content=json.dumps( {"query": query, "variables": variables}, default=to_jsonable_python ), - headers={"Content-Type": "application/json"}, + **merged_kwargs, ) async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: diff --git a/tests/main/clients/custom_config_file/expected_client/client.py b/tests/main/clients/custom_config_file/expected_client/client.py index 8ec4e828..c9b8b7f3 100644 --- a/tests/main/clients/custom_config_file/expected_client/client.py +++ b/tests/main/clients/custom_config_file/expected_client/client.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Any, Dict from .async_base_client import AsyncBaseClient from .test import Test @@ -9,7 +9,7 @@ def gql(q: str) -> str: class Client(AsyncBaseClient): - async def test(self) -> Test: + async def test(self, **kwargs: Any) -> Test: query = gql( """ query test { @@ -18,6 +18,6 @@ async def test(self) -> Test: """ ) variables: Dict[str, object] = {} - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return Test.model_validate(data) diff --git a/tests/main/clients/custom_files_names/expected_client/async_base_client.py b/tests/main/clients/custom_files_names/expected_client/async_base_client.py index 3d33b247..25618471 100644 --- a/tests/main/clients/custom_files_names/expected_client/async_base_client.py +++ b/tests/main/clients/custom_files_names/expected_client/async_base_client.py @@ -84,7 +84,7 @@ async def __aexit__( await self.http_client.aclose() async def execute( - self, query: str, variables: Optional[Dict[str, Any]] = None + self, query: str, variables: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> httpx.Response: processed_variables, files, files_map = self._process_variables(variables) @@ -94,9 +94,12 @@ async def execute( variables=processed_variables, files=files, files_map=files_map, + **kwargs, ) - return await self._execute_json(query=query, variables=processed_variables) + return await self._execute_json( + query=query, variables=processed_variables, **kwargs + ) def get_data(self, response: httpx.Response) -> Dict[str, Any]: if not response.is_success: @@ -123,14 +126,20 @@ def get_data(self, response: httpx.Response) -> Dict[str, Any]: return cast(Dict[str, Any], data) async def execute_ws( - self, query: str, variables: Optional[Dict[str, Any]] = None + self, query: str, variables: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> AsyncIterator[Dict[str, Any]]: + headers = self.ws_headers.copy() + headers.update(kwargs.get("extra_headers", {})) + + merged_kwargs: Dict[str, Any] = {"origin": self.ws_origin} + merged_kwargs.update(kwargs) + merged_kwargs["extra_headers"] = headers + operation_id = str(uuid4()) async with ws_connect( self.ws_url, subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, + **merged_kwargs, ) as websocket: await self._send_connection_init(websocket) await self._send_subscribe( @@ -220,6 +229,7 @@ async def _execute_multipart( variables: Dict[str, Any], files: Dict[str, Tuple[str, IO[bytes], str]], files_map: Dict[str, List[str]], + **kwargs: Any, ) -> httpx.Response: data = { "operations": json.dumps( @@ -228,19 +238,25 @@ async def _execute_multipart( "map": json.dumps(files_map, default=to_jsonable_python), } - return await self.http_client.post(url=self.url, data=data, files=files) + return await self.http_client.post( + url=self.url, data=data, files=files, **kwargs + ) async def _execute_json( - self, - query: str, - variables: Dict[str, Any], + self, query: str, variables: Dict[str, Any], **kwargs: Any ) -> httpx.Response: + headers: Dict[str, str] = {"Content-Type": "application/json"} + headers.update(kwargs.get("headers", {})) + + merged_kwargs: Dict[str, Any] = kwargs.copy() + merged_kwargs["headers"] = headers + return await self.http_client.post( url=self.url, content=json.dumps( {"query": query, "variables": variables}, default=to_jsonable_python ), - headers={"Content-Type": "application/json"}, + **merged_kwargs, ) async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: diff --git a/tests/main/clients/custom_files_names/expected_client/custom_client.py b/tests/main/clients/custom_files_names/expected_client/custom_client.py index cf9881fc..5c2c56e4 100644 --- a/tests/main/clients/custom_files_names/expected_client/custom_client.py +++ b/tests/main/clients/custom_files_names/expected_client/custom_client.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Any, Dict from .async_base_client import AsyncBaseClient from .custom_input_types import inputA @@ -10,7 +10,7 @@ def gql(q: str) -> str: class Client(AsyncBaseClient): - async def get_query_a(self, data_a: inputA) -> GetQueryA: + async def get_query_a(self, data_a: inputA, **kwargs: Any) -> GetQueryA: query = gql( """ query getQueryA($dataA: inputA!) { @@ -21,6 +21,6 @@ async def get_query_a(self, data_a: inputA) -> GetQueryA: """ ) variables: Dict[str, object] = {"dataA": data_a} - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return GetQueryA.model_validate(data) diff --git a/tests/main/clients/custom_scalars/expected_client/__init__.py b/tests/main/clients/custom_scalars/expected_client/__init__.py index 23dc8f79..443db892 100644 --- a/tests/main/clients/custom_scalars/expected_client/__init__.py +++ b/tests/main/clients/custom_scalars/expected_client/__init__.py @@ -8,15 +8,15 @@ GraphQLClientHttpError, GraphQlClientInvalidResponseError, ) -from .get_test import GetTest, GetTestTestQuery +from .get_a import GetA, GetATestQuery from .input_types import TestInput __all__ = [ "AsyncBaseClient", "BaseModel", "Client", - "GetTest", - "GetTestTestQuery", + "GetA", + "GetATestQuery", "GraphQLClientError", "GraphQLClientGraphQLError", "GraphQLClientGraphQLMultiError", diff --git a/tests/main/clients/custom_scalars/expected_client/async_base_client.py b/tests/main/clients/custom_scalars/expected_client/async_base_client.py index 3d33b247..25618471 100644 --- a/tests/main/clients/custom_scalars/expected_client/async_base_client.py +++ b/tests/main/clients/custom_scalars/expected_client/async_base_client.py @@ -84,7 +84,7 @@ async def __aexit__( await self.http_client.aclose() async def execute( - self, query: str, variables: Optional[Dict[str, Any]] = None + self, query: str, variables: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> httpx.Response: processed_variables, files, files_map = self._process_variables(variables) @@ -94,9 +94,12 @@ async def execute( variables=processed_variables, files=files, files_map=files_map, + **kwargs, ) - return await self._execute_json(query=query, variables=processed_variables) + return await self._execute_json( + query=query, variables=processed_variables, **kwargs + ) def get_data(self, response: httpx.Response) -> Dict[str, Any]: if not response.is_success: @@ -123,14 +126,20 @@ def get_data(self, response: httpx.Response) -> Dict[str, Any]: return cast(Dict[str, Any], data) async def execute_ws( - self, query: str, variables: Optional[Dict[str, Any]] = None + self, query: str, variables: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> AsyncIterator[Dict[str, Any]]: + headers = self.ws_headers.copy() + headers.update(kwargs.get("extra_headers", {})) + + merged_kwargs: Dict[str, Any] = {"origin": self.ws_origin} + merged_kwargs.update(kwargs) + merged_kwargs["extra_headers"] = headers + operation_id = str(uuid4()) async with ws_connect( self.ws_url, subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, + **merged_kwargs, ) as websocket: await self._send_connection_init(websocket) await self._send_subscribe( @@ -220,6 +229,7 @@ async def _execute_multipart( variables: Dict[str, Any], files: Dict[str, Tuple[str, IO[bytes], str]], files_map: Dict[str, List[str]], + **kwargs: Any, ) -> httpx.Response: data = { "operations": json.dumps( @@ -228,19 +238,25 @@ async def _execute_multipart( "map": json.dumps(files_map, default=to_jsonable_python), } - return await self.http_client.post(url=self.url, data=data, files=files) + return await self.http_client.post( + url=self.url, data=data, files=files, **kwargs + ) async def _execute_json( - self, - query: str, - variables: Dict[str, Any], + self, query: str, variables: Dict[str, Any], **kwargs: Any ) -> httpx.Response: + headers: Dict[str, str] = {"Content-Type": "application/json"} + headers.update(kwargs.get("headers", {})) + + merged_kwargs: Dict[str, Any] = kwargs.copy() + merged_kwargs["headers"] = headers + return await self.http_client.post( url=self.url, content=json.dumps( {"query": query, "variables": variables}, default=to_jsonable_python ), - headers={"Content-Type": "application/json"}, + **merged_kwargs, ) async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: diff --git a/tests/main/clients/custom_scalars/expected_client/client.py b/tests/main/clients/custom_scalars/expected_client/client.py index 68f45add..06ff315f 100644 --- a/tests/main/clients/custom_scalars/expected_client/client.py +++ b/tests/main/clients/custom_scalars/expected_client/client.py @@ -3,7 +3,7 @@ from .async_base_client import AsyncBaseClient from .custom_scalars import Code, serialize_code -from .get_test import GetTest +from .get_a import GetA from .input_types import TestInput @@ -12,12 +12,18 @@ def gql(q: str) -> str: class Client(AsyncBaseClient): - async def get_test( - self, date: datetime, code: Code, id: int, input: TestInput, other: Any - ) -> GetTest: + async def get_a( + self, + date: datetime, + code: Code, + id: int, + input: TestInput, + other: Any, + **kwargs: Any + ) -> GetA: query = gql( """ - query getTest($date: DATETIME!, $code: CODE!, $id: CUSTOMID!, $input: TestInput!, $other: NOTMAPPED!) { + query getA($date: DATETIME!, $code: CODE!, $id: CUSTOMID!, $input: TestInput!, $other: NOTMAPPED!) { testQuery(date: $date, code: $code, id: $id, input: $input, other: $other) { date code @@ -34,6 +40,6 @@ async def get_test( "input": input, "other": other, } - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) - return GetTest.model_validate(data) + return GetA.model_validate(data) diff --git a/tests/main/clients/custom_scalars/expected_client/get_test.py b/tests/main/clients/custom_scalars/expected_client/get_a.py similarity index 62% rename from tests/main/clients/custom_scalars/expected_client/get_test.py rename to tests/main/clients/custom_scalars/expected_client/get_a.py index 2a77cfa0..d1170bf0 100644 --- a/tests/main/clients/custom_scalars/expected_client/get_test.py +++ b/tests/main/clients/custom_scalars/expected_client/get_a.py @@ -7,16 +7,16 @@ from .custom_scalars import Code, parse_code -class GetTest(BaseModel): - test_query: "GetTestTestQuery" = Field(alias="testQuery") +class GetA(BaseModel): + test_query: "GetATestQuery" = Field(alias="testQuery") -class GetTestTestQuery(BaseModel): +class GetATestQuery(BaseModel): date: datetime code: Annotated[Code, BeforeValidator(parse_code)] id: int other: Any -GetTest.model_rebuild() -GetTestTestQuery.model_rebuild() +GetA.model_rebuild() +GetATestQuery.model_rebuild() diff --git a/tests/main/clients/custom_scalars/queries.graphql b/tests/main/clients/custom_scalars/queries.graphql index 2b3e0331..f2ef5fb7 100644 --- a/tests/main/clients/custom_scalars/queries.graphql +++ b/tests/main/clients/custom_scalars/queries.graphql @@ -1,4 +1,4 @@ -query getTest( +query getA( $date: DATETIME! $code: CODE! $id: CUSTOMID! diff --git a/tests/main/clients/example/expected_client/async_base_client.py b/tests/main/clients/example/expected_client/async_base_client.py index 3d33b247..25618471 100644 --- a/tests/main/clients/example/expected_client/async_base_client.py +++ b/tests/main/clients/example/expected_client/async_base_client.py @@ -84,7 +84,7 @@ async def __aexit__( await self.http_client.aclose() async def execute( - self, query: str, variables: Optional[Dict[str, Any]] = None + self, query: str, variables: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> httpx.Response: processed_variables, files, files_map = self._process_variables(variables) @@ -94,9 +94,12 @@ async def execute( variables=processed_variables, files=files, files_map=files_map, + **kwargs, ) - return await self._execute_json(query=query, variables=processed_variables) + return await self._execute_json( + query=query, variables=processed_variables, **kwargs + ) def get_data(self, response: httpx.Response) -> Dict[str, Any]: if not response.is_success: @@ -123,14 +126,20 @@ def get_data(self, response: httpx.Response) -> Dict[str, Any]: return cast(Dict[str, Any], data) async def execute_ws( - self, query: str, variables: Optional[Dict[str, Any]] = None + self, query: str, variables: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> AsyncIterator[Dict[str, Any]]: + headers = self.ws_headers.copy() + headers.update(kwargs.get("extra_headers", {})) + + merged_kwargs: Dict[str, Any] = {"origin": self.ws_origin} + merged_kwargs.update(kwargs) + merged_kwargs["extra_headers"] = headers + operation_id = str(uuid4()) async with ws_connect( self.ws_url, subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, + **merged_kwargs, ) as websocket: await self._send_connection_init(websocket) await self._send_subscribe( @@ -220,6 +229,7 @@ async def _execute_multipart( variables: Dict[str, Any], files: Dict[str, Tuple[str, IO[bytes], str]], files_map: Dict[str, List[str]], + **kwargs: Any, ) -> httpx.Response: data = { "operations": json.dumps( @@ -228,19 +238,25 @@ async def _execute_multipart( "map": json.dumps(files_map, default=to_jsonable_python), } - return await self.http_client.post(url=self.url, data=data, files=files) + return await self.http_client.post( + url=self.url, data=data, files=files, **kwargs + ) async def _execute_json( - self, - query: str, - variables: Dict[str, Any], + self, query: str, variables: Dict[str, Any], **kwargs: Any ) -> httpx.Response: + headers: Dict[str, str] = {"Content-Type": "application/json"} + headers.update(kwargs.get("headers", {})) + + merged_kwargs: Dict[str, Any] = kwargs.copy() + merged_kwargs["headers"] = headers + return await self.http_client.post( url=self.url, content=json.dumps( {"query": query, "variables": variables}, default=to_jsonable_python ), - headers={"Content-Type": "application/json"}, + **merged_kwargs, ) async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: diff --git a/tests/main/clients/example/expected_client/client.py b/tests/main/clients/example/expected_client/client.py index ff271b89..1d76f679 100644 --- a/tests/main/clients/example/expected_client/client.py +++ b/tests/main/clients/example/expected_client/client.py @@ -1,4 +1,4 @@ -from typing import AsyncIterator, Dict, Optional, Union +from typing import Any, AsyncIterator, Dict, Optional, Union from .async_base_client import AsyncBaseClient from .base_model import UNSET, UnsetType, Upload @@ -15,7 +15,9 @@ def gql(q: str) -> str: class Client(AsyncBaseClient): - async def create_user(self, user_data: UserCreateInput) -> CreateUser: + async def create_user( + self, user_data: UserCreateInput, **kwargs: Any + ) -> CreateUser: query = gql( """ mutation CreateUser($userData: UserCreateInput!) { @@ -26,11 +28,11 @@ async def create_user(self, user_data: UserCreateInput) -> CreateUser: """ ) variables: Dict[str, object] = {"userData": user_data} - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return CreateUser.model_validate(data) - async def list_all_users(self) -> ListAllUsers: + async def list_all_users(self, **kwargs: Any) -> ListAllUsers: query = gql( """ query ListAllUsers { @@ -47,12 +49,12 @@ async def list_all_users(self) -> ListAllUsers: """ ) variables: Dict[str, object] = {} - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return ListAllUsers.model_validate(data) async def list_users_by_country( - self, country: Union[Optional[str], UnsetType] = UNSET + self, country: Union[Optional[str], UnsetType] = UNSET, **kwargs: Any ) -> ListUsersByCountry: query = gql( """ @@ -76,11 +78,11 @@ async def list_users_by_country( """ ) variables: Dict[str, object] = {"country": country} - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return ListUsersByCountry.model_validate(data) - async def get_users_counter(self) -> AsyncIterator[GetUsersCounter]: + async def get_users_counter(self, **kwargs: Any) -> AsyncIterator[GetUsersCounter]: query = gql( """ subscription GetUsersCounter { @@ -89,10 +91,10 @@ async def get_users_counter(self) -> AsyncIterator[GetUsersCounter]: """ ) variables: Dict[str, object] = {} - async for data in self.execute_ws(query=query, variables=variables): + async for data in self.execute_ws(query=query, variables=variables, **kwargs): yield GetUsersCounter.model_validate(data) - async def upload_file(self, file: Upload) -> UploadFile: + async def upload_file(self, file: Upload, **kwargs: Any) -> UploadFile: query = gql( """ mutation uploadFile($file: Upload!) { @@ -101,6 +103,6 @@ async def upload_file(self, file: Upload) -> UploadFile: """ ) variables: Dict[str, object] = {"file": file} - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return UploadFile.model_validate(data) diff --git a/tests/main/clients/extended_models/expected_client/async_base_client.py b/tests/main/clients/extended_models/expected_client/async_base_client.py index 3d33b247..25618471 100644 --- a/tests/main/clients/extended_models/expected_client/async_base_client.py +++ b/tests/main/clients/extended_models/expected_client/async_base_client.py @@ -84,7 +84,7 @@ async def __aexit__( await self.http_client.aclose() async def execute( - self, query: str, variables: Optional[Dict[str, Any]] = None + self, query: str, variables: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> httpx.Response: processed_variables, files, files_map = self._process_variables(variables) @@ -94,9 +94,12 @@ async def execute( variables=processed_variables, files=files, files_map=files_map, + **kwargs, ) - return await self._execute_json(query=query, variables=processed_variables) + return await self._execute_json( + query=query, variables=processed_variables, **kwargs + ) def get_data(self, response: httpx.Response) -> Dict[str, Any]: if not response.is_success: @@ -123,14 +126,20 @@ def get_data(self, response: httpx.Response) -> Dict[str, Any]: return cast(Dict[str, Any], data) async def execute_ws( - self, query: str, variables: Optional[Dict[str, Any]] = None + self, query: str, variables: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> AsyncIterator[Dict[str, Any]]: + headers = self.ws_headers.copy() + headers.update(kwargs.get("extra_headers", {})) + + merged_kwargs: Dict[str, Any] = {"origin": self.ws_origin} + merged_kwargs.update(kwargs) + merged_kwargs["extra_headers"] = headers + operation_id = str(uuid4()) async with ws_connect( self.ws_url, subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, + **merged_kwargs, ) as websocket: await self._send_connection_init(websocket) await self._send_subscribe( @@ -220,6 +229,7 @@ async def _execute_multipart( variables: Dict[str, Any], files: Dict[str, Tuple[str, IO[bytes], str]], files_map: Dict[str, List[str]], + **kwargs: Any, ) -> httpx.Response: data = { "operations": json.dumps( @@ -228,19 +238,25 @@ async def _execute_multipart( "map": json.dumps(files_map, default=to_jsonable_python), } - return await self.http_client.post(url=self.url, data=data, files=files) + return await self.http_client.post( + url=self.url, data=data, files=files, **kwargs + ) async def _execute_json( - self, - query: str, - variables: Dict[str, Any], + self, query: str, variables: Dict[str, Any], **kwargs: Any ) -> httpx.Response: + headers: Dict[str, str] = {"Content-Type": "application/json"} + headers.update(kwargs.get("headers", {})) + + merged_kwargs: Dict[str, Any] = kwargs.copy() + merged_kwargs["headers"] = headers + return await self.http_client.post( url=self.url, content=json.dumps( {"query": query, "variables": variables}, default=to_jsonable_python ), - headers={"Content-Type": "application/json"}, + **merged_kwargs, ) async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: diff --git a/tests/main/clients/extended_models/expected_client/client.py b/tests/main/clients/extended_models/expected_client/client.py index 14be0337..5fc55bf9 100644 --- a/tests/main/clients/extended_models/expected_client/client.py +++ b/tests/main/clients/extended_models/expected_client/client.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Any, Dict from .async_base_client import AsyncBaseClient from .fragments_with_mixins import FragmentsWithMixins @@ -12,7 +12,7 @@ def gql(q: str) -> str: class Client(AsyncBaseClient): - async def get_query_a(self) -> GetQueryA: + async def get_query_a(self, **kwargs: Any) -> GetQueryA: query = gql( """ query getQueryA { @@ -23,11 +23,11 @@ async def get_query_a(self) -> GetQueryA: """ ) variables: Dict[str, object] = {} - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return GetQueryA.model_validate(data) - async def get_query_b(self) -> GetQueryB: + async def get_query_b(self, **kwargs: Any) -> GetQueryB: query = gql( """ query getQueryB { @@ -38,11 +38,11 @@ async def get_query_b(self) -> GetQueryB: """ ) variables: Dict[str, object] = {} - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return GetQueryB.model_validate(data) - async def get_query_a_with_fragment(self) -> GetQueryAWithFragment: + async def get_query_a_with_fragment(self, **kwargs: Any) -> GetQueryAWithFragment: query = gql( """ query getQueryAWithFragment { @@ -57,11 +57,11 @@ async def get_query_a_with_fragment(self) -> GetQueryAWithFragment: """ ) variables: Dict[str, object] = {} - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return GetQueryAWithFragment.model_validate(data) - async def fragments_with_mixins(self) -> FragmentsWithMixins: + async def fragments_with_mixins(self, **kwargs: Any) -> FragmentsWithMixins: query = gql( """ query fragmentsWithMixins { @@ -83,6 +83,6 @@ async def fragments_with_mixins(self) -> FragmentsWithMixins: """ ) variables: Dict[str, object] = {} - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return FragmentsWithMixins.model_validate(data) diff --git a/tests/main/clients/fragments_on_abstract_types/expected_client/async_base_client.py b/tests/main/clients/fragments_on_abstract_types/expected_client/async_base_client.py index 3d33b247..25618471 100644 --- a/tests/main/clients/fragments_on_abstract_types/expected_client/async_base_client.py +++ b/tests/main/clients/fragments_on_abstract_types/expected_client/async_base_client.py @@ -84,7 +84,7 @@ async def __aexit__( await self.http_client.aclose() async def execute( - self, query: str, variables: Optional[Dict[str, Any]] = None + self, query: str, variables: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> httpx.Response: processed_variables, files, files_map = self._process_variables(variables) @@ -94,9 +94,12 @@ async def execute( variables=processed_variables, files=files, files_map=files_map, + **kwargs, ) - return await self._execute_json(query=query, variables=processed_variables) + return await self._execute_json( + query=query, variables=processed_variables, **kwargs + ) def get_data(self, response: httpx.Response) -> Dict[str, Any]: if not response.is_success: @@ -123,14 +126,20 @@ def get_data(self, response: httpx.Response) -> Dict[str, Any]: return cast(Dict[str, Any], data) async def execute_ws( - self, query: str, variables: Optional[Dict[str, Any]] = None + self, query: str, variables: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> AsyncIterator[Dict[str, Any]]: + headers = self.ws_headers.copy() + headers.update(kwargs.get("extra_headers", {})) + + merged_kwargs: Dict[str, Any] = {"origin": self.ws_origin} + merged_kwargs.update(kwargs) + merged_kwargs["extra_headers"] = headers + operation_id = str(uuid4()) async with ws_connect( self.ws_url, subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, + **merged_kwargs, ) as websocket: await self._send_connection_init(websocket) await self._send_subscribe( @@ -220,6 +229,7 @@ async def _execute_multipart( variables: Dict[str, Any], files: Dict[str, Tuple[str, IO[bytes], str]], files_map: Dict[str, List[str]], + **kwargs: Any, ) -> httpx.Response: data = { "operations": json.dumps( @@ -228,19 +238,25 @@ async def _execute_multipart( "map": json.dumps(files_map, default=to_jsonable_python), } - return await self.http_client.post(url=self.url, data=data, files=files) + return await self.http_client.post( + url=self.url, data=data, files=files, **kwargs + ) async def _execute_json( - self, - query: str, - variables: Dict[str, Any], + self, query: str, variables: Dict[str, Any], **kwargs: Any ) -> httpx.Response: + headers: Dict[str, str] = {"Content-Type": "application/json"} + headers.update(kwargs.get("headers", {})) + + merged_kwargs: Dict[str, Any] = kwargs.copy() + merged_kwargs["headers"] = headers + return await self.http_client.post( url=self.url, content=json.dumps( {"query": query, "variables": variables}, default=to_jsonable_python ), - headers={"Content-Type": "application/json"}, + **merged_kwargs, ) async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: diff --git a/tests/main/clients/fragments_on_abstract_types/expected_client/client.py b/tests/main/clients/fragments_on_abstract_types/expected_client/client.py index ca0b775a..57b277e2 100644 --- a/tests/main/clients/fragments_on_abstract_types/expected_client/client.py +++ b/tests/main/clients/fragments_on_abstract_types/expected_client/client.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Any, Dict from .async_base_client import AsyncBaseClient from .query_with_fragment_on_sub_interface import QueryWithFragmentOnSubInterface @@ -14,7 +14,7 @@ def gql(q: str) -> str: class Client(AsyncBaseClient): async def query_with_fragment_on_sub_interface( - self, + self, **kwargs: Any ) -> QueryWithFragmentOnSubInterface: query = gql( """ @@ -32,12 +32,12 @@ async def query_with_fragment_on_sub_interface( """ ) variables: Dict[str, object] = {} - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return QueryWithFragmentOnSubInterface.model_validate(data) async def query_with_fragment_on_sub_interface_with_inline_fragment( - self, + self, **kwargs: Any ) -> QueryWithFragmentOnSubInterfaceWithInlineFragment: query = gql( """ @@ -58,12 +58,12 @@ async def query_with_fragment_on_sub_interface_with_inline_fragment( """ ) variables: Dict[str, object] = {} - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return QueryWithFragmentOnSubInterfaceWithInlineFragment.model_validate(data) async def query_with_fragment_on_union_member( - self, + self, **kwargs: Any ) -> QueryWithFragmentOnUnionMember: query = gql( """ @@ -81,6 +81,6 @@ async def query_with_fragment_on_union_member( """ ) variables: Dict[str, object] = {} - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return QueryWithFragmentOnUnionMember.model_validate(data) diff --git a/tests/main/clients/inline_fragments/expected_client/async_base_client.py b/tests/main/clients/inline_fragments/expected_client/async_base_client.py index 3d33b247..25618471 100644 --- a/tests/main/clients/inline_fragments/expected_client/async_base_client.py +++ b/tests/main/clients/inline_fragments/expected_client/async_base_client.py @@ -84,7 +84,7 @@ async def __aexit__( await self.http_client.aclose() async def execute( - self, query: str, variables: Optional[Dict[str, Any]] = None + self, query: str, variables: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> httpx.Response: processed_variables, files, files_map = self._process_variables(variables) @@ -94,9 +94,12 @@ async def execute( variables=processed_variables, files=files, files_map=files_map, + **kwargs, ) - return await self._execute_json(query=query, variables=processed_variables) + return await self._execute_json( + query=query, variables=processed_variables, **kwargs + ) def get_data(self, response: httpx.Response) -> Dict[str, Any]: if not response.is_success: @@ -123,14 +126,20 @@ def get_data(self, response: httpx.Response) -> Dict[str, Any]: return cast(Dict[str, Any], data) async def execute_ws( - self, query: str, variables: Optional[Dict[str, Any]] = None + self, query: str, variables: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> AsyncIterator[Dict[str, Any]]: + headers = self.ws_headers.copy() + headers.update(kwargs.get("extra_headers", {})) + + merged_kwargs: Dict[str, Any] = {"origin": self.ws_origin} + merged_kwargs.update(kwargs) + merged_kwargs["extra_headers"] = headers + operation_id = str(uuid4()) async with ws_connect( self.ws_url, subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, + **merged_kwargs, ) as websocket: await self._send_connection_init(websocket) await self._send_subscribe( @@ -220,6 +229,7 @@ async def _execute_multipart( variables: Dict[str, Any], files: Dict[str, Tuple[str, IO[bytes], str]], files_map: Dict[str, List[str]], + **kwargs: Any, ) -> httpx.Response: data = { "operations": json.dumps( @@ -228,19 +238,25 @@ async def _execute_multipart( "map": json.dumps(files_map, default=to_jsonable_python), } - return await self.http_client.post(url=self.url, data=data, files=files) + return await self.http_client.post( + url=self.url, data=data, files=files, **kwargs + ) async def _execute_json( - self, - query: str, - variables: Dict[str, Any], + self, query: str, variables: Dict[str, Any], **kwargs: Any ) -> httpx.Response: + headers: Dict[str, str] = {"Content-Type": "application/json"} + headers.update(kwargs.get("headers", {})) + + merged_kwargs: Dict[str, Any] = kwargs.copy() + merged_kwargs["headers"] = headers + return await self.http_client.post( url=self.url, content=json.dumps( {"query": query, "variables": variables}, default=to_jsonable_python ), - headers={"Content-Type": "application/json"}, + **merged_kwargs, ) async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: diff --git a/tests/main/clients/inline_fragments/expected_client/client.py b/tests/main/clients/inline_fragments/expected_client/client.py index b468f22c..fe2b3d91 100644 --- a/tests/main/clients/inline_fragments/expected_client/client.py +++ b/tests/main/clients/inline_fragments/expected_client/client.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Any, Dict from .async_base_client import AsyncBaseClient from .interface_a import InterfaceA @@ -22,7 +22,7 @@ def gql(q: str) -> str: class Client(AsyncBaseClient): - async def interface_a(self) -> InterfaceA: + async def interface_a(self, **kwargs: Any) -> InterfaceA: query = gql( """ query InterfaceA { @@ -40,11 +40,11 @@ async def interface_a(self) -> InterfaceA: """ ) variables: Dict[str, object] = {} - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return InterfaceA.model_validate(data) - async def interface_b(self) -> InterfaceB: + async def interface_b(self, **kwargs: Any) -> InterfaceB: query = gql( """ query InterfaceB { @@ -59,11 +59,11 @@ async def interface_b(self) -> InterfaceB: """ ) variables: Dict[str, object] = {} - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return InterfaceB.model_validate(data) - async def interface_c(self) -> InterfaceC: + async def interface_c(self, **kwargs: Any) -> InterfaceC: query = gql( """ query InterfaceC { @@ -75,11 +75,11 @@ async def interface_c(self) -> InterfaceC: """ ) variables: Dict[str, object] = {} - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return InterfaceC.model_validate(data) - async def list_interface(self) -> ListInterface: + async def list_interface(self, **kwargs: Any) -> ListInterface: query = gql( """ query ListInterface { @@ -97,11 +97,11 @@ async def list_interface(self) -> ListInterface: """ ) variables: Dict[str, object] = {} - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return ListInterface.model_validate(data) - async def interface_with_typename(self) -> InterfaceWithTypename: + async def interface_with_typename(self, **kwargs: Any) -> InterfaceWithTypename: query = gql( """ query InterfaceWithTypename { @@ -113,11 +113,11 @@ async def interface_with_typename(self) -> InterfaceWithTypename: """ ) variables: Dict[str, object] = {} - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return InterfaceWithTypename.model_validate(data) - async def union_a(self) -> UnionA: + async def union_a(self, **kwargs: Any) -> UnionA: query = gql( """ query UnionA { @@ -136,11 +136,11 @@ async def union_a(self) -> UnionA: """ ) variables: Dict[str, object] = {} - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return UnionA.model_validate(data) - async def union_b(self) -> UnionB: + async def union_b(self, **kwargs: Any) -> UnionB: query = gql( """ query UnionB { @@ -155,11 +155,11 @@ async def union_b(self) -> UnionB: """ ) variables: Dict[str, object] = {} - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return UnionB.model_validate(data) - async def list_union(self) -> ListUnion: + async def list_union(self, **kwargs: Any) -> ListUnion: query = gql( """ query ListUnion { @@ -178,11 +178,13 @@ async def list_union(self) -> ListUnion: """ ) variables: Dict[str, object] = {} - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return ListUnion.model_validate(data) - async def query_with_fragment_on_interface(self) -> QueryWithFragmentOnInterface: + async def query_with_fragment_on_interface( + self, **kwargs: Any + ) -> QueryWithFragmentOnInterface: query = gql( """ query queryWithFragmentOnInterface { @@ -204,11 +206,13 @@ async def query_with_fragment_on_interface(self) -> QueryWithFragmentOnInterface """ ) variables: Dict[str, object] = {} - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return QueryWithFragmentOnInterface.model_validate(data) - async def query_with_fragment_on_union(self) -> QueryWithFragmentOnUnion: + async def query_with_fragment_on_union( + self, **kwargs: Any + ) -> QueryWithFragmentOnUnion: query = gql( """ query queryWithFragmentOnUnion { @@ -231,12 +235,12 @@ async def query_with_fragment_on_union(self) -> QueryWithFragmentOnUnion: """ ) variables: Dict[str, object] = {} - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return QueryWithFragmentOnUnion.model_validate(data) async def query_with_fragment_on_query_with_interface( - self, + self, **kwargs: Any ) -> QueryWithFragmentOnQueryWithInterface: query = gql( """ @@ -258,12 +262,12 @@ async def query_with_fragment_on_query_with_interface( """ ) variables: Dict[str, object] = {} - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return QueryWithFragmentOnQueryWithInterface.model_validate(data) async def query_with_fragment_on_query_with_union( - self, + self, **kwargs: Any ) -> QueryWithFragmentOnQueryWithUnion: query = gql( """ @@ -286,6 +290,6 @@ async def query_with_fragment_on_query_with_union( """ ) variables: Dict[str, object] = {} - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return QueryWithFragmentOnQueryWithUnion.model_validate(data) diff --git a/tests/main/clients/multiple_fragments/expected_client/async_base_client.py b/tests/main/clients/multiple_fragments/expected_client/async_base_client.py index 3d33b247..25618471 100644 --- a/tests/main/clients/multiple_fragments/expected_client/async_base_client.py +++ b/tests/main/clients/multiple_fragments/expected_client/async_base_client.py @@ -84,7 +84,7 @@ async def __aexit__( await self.http_client.aclose() async def execute( - self, query: str, variables: Optional[Dict[str, Any]] = None + self, query: str, variables: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> httpx.Response: processed_variables, files, files_map = self._process_variables(variables) @@ -94,9 +94,12 @@ async def execute( variables=processed_variables, files=files, files_map=files_map, + **kwargs, ) - return await self._execute_json(query=query, variables=processed_variables) + return await self._execute_json( + query=query, variables=processed_variables, **kwargs + ) def get_data(self, response: httpx.Response) -> Dict[str, Any]: if not response.is_success: @@ -123,14 +126,20 @@ def get_data(self, response: httpx.Response) -> Dict[str, Any]: return cast(Dict[str, Any], data) async def execute_ws( - self, query: str, variables: Optional[Dict[str, Any]] = None + self, query: str, variables: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> AsyncIterator[Dict[str, Any]]: + headers = self.ws_headers.copy() + headers.update(kwargs.get("extra_headers", {})) + + merged_kwargs: Dict[str, Any] = {"origin": self.ws_origin} + merged_kwargs.update(kwargs) + merged_kwargs["extra_headers"] = headers + operation_id = str(uuid4()) async with ws_connect( self.ws_url, subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, + **merged_kwargs, ) as websocket: await self._send_connection_init(websocket) await self._send_subscribe( @@ -220,6 +229,7 @@ async def _execute_multipart( variables: Dict[str, Any], files: Dict[str, Tuple[str, IO[bytes], str]], files_map: Dict[str, List[str]], + **kwargs: Any, ) -> httpx.Response: data = { "operations": json.dumps( @@ -228,19 +238,25 @@ async def _execute_multipart( "map": json.dumps(files_map, default=to_jsonable_python), } - return await self.http_client.post(url=self.url, data=data, files=files) + return await self.http_client.post( + url=self.url, data=data, files=files, **kwargs + ) async def _execute_json( - self, - query: str, - variables: Dict[str, Any], + self, query: str, variables: Dict[str, Any], **kwargs: Any ) -> httpx.Response: + headers: Dict[str, str] = {"Content-Type": "application/json"} + headers.update(kwargs.get("headers", {})) + + merged_kwargs: Dict[str, Any] = kwargs.copy() + merged_kwargs["headers"] = headers + return await self.http_client.post( url=self.url, content=json.dumps( {"query": query, "variables": variables}, default=to_jsonable_python ), - headers={"Content-Type": "application/json"}, + **merged_kwargs, ) async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: diff --git a/tests/main/clients/multiple_fragments/expected_client/client.py b/tests/main/clients/multiple_fragments/expected_client/client.py index 1b00b314..874682fc 100644 --- a/tests/main/clients/multiple_fragments/expected_client/client.py +++ b/tests/main/clients/multiple_fragments/expected_client/client.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Any, Dict from .async_base_client import AsyncBaseClient from .example_query_1 import ExampleQuery1 @@ -11,7 +11,7 @@ def gql(q: str) -> str: class Client(AsyncBaseClient): - async def example_query_1(self) -> ExampleQuery1: + async def example_query_1(self, **kwargs: Any) -> ExampleQuery1: query = gql( """ query exampleQuery1 { @@ -34,11 +34,11 @@ async def example_query_1(self) -> ExampleQuery1: """ ) variables: Dict[str, object] = {} - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return ExampleQuery1.model_validate(data) - async def example_query_2(self) -> ExampleQuery2: + async def example_query_2(self, **kwargs: Any) -> ExampleQuery2: query = gql( """ query exampleQuery2 { @@ -62,11 +62,11 @@ async def example_query_2(self) -> ExampleQuery2: """ ) variables: Dict[str, object] = {} - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return ExampleQuery2.model_validate(data) - async def example_query_3(self) -> ExampleQuery3: + async def example_query_3(self, **kwargs: Any) -> ExampleQuery3: query = gql( """ query exampleQuery3 { @@ -86,6 +86,6 @@ async def example_query_3(self) -> ExampleQuery3: """ ) variables: Dict[str, object] = {} - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return ExampleQuery3.model_validate(data) diff --git a/tests/main/clients/remote_schema/expected_client/async_base_client.py b/tests/main/clients/remote_schema/expected_client/async_base_client.py index 3d33b247..25618471 100644 --- a/tests/main/clients/remote_schema/expected_client/async_base_client.py +++ b/tests/main/clients/remote_schema/expected_client/async_base_client.py @@ -84,7 +84,7 @@ async def __aexit__( await self.http_client.aclose() async def execute( - self, query: str, variables: Optional[Dict[str, Any]] = None + self, query: str, variables: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> httpx.Response: processed_variables, files, files_map = self._process_variables(variables) @@ -94,9 +94,12 @@ async def execute( variables=processed_variables, files=files, files_map=files_map, + **kwargs, ) - return await self._execute_json(query=query, variables=processed_variables) + return await self._execute_json( + query=query, variables=processed_variables, **kwargs + ) def get_data(self, response: httpx.Response) -> Dict[str, Any]: if not response.is_success: @@ -123,14 +126,20 @@ def get_data(self, response: httpx.Response) -> Dict[str, Any]: return cast(Dict[str, Any], data) async def execute_ws( - self, query: str, variables: Optional[Dict[str, Any]] = None + self, query: str, variables: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> AsyncIterator[Dict[str, Any]]: + headers = self.ws_headers.copy() + headers.update(kwargs.get("extra_headers", {})) + + merged_kwargs: Dict[str, Any] = {"origin": self.ws_origin} + merged_kwargs.update(kwargs) + merged_kwargs["extra_headers"] = headers + operation_id = str(uuid4()) async with ws_connect( self.ws_url, subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, + **merged_kwargs, ) as websocket: await self._send_connection_init(websocket) await self._send_subscribe( @@ -220,6 +229,7 @@ async def _execute_multipart( variables: Dict[str, Any], files: Dict[str, Tuple[str, IO[bytes], str]], files_map: Dict[str, List[str]], + **kwargs: Any, ) -> httpx.Response: data = { "operations": json.dumps( @@ -228,19 +238,25 @@ async def _execute_multipart( "map": json.dumps(files_map, default=to_jsonable_python), } - return await self.http_client.post(url=self.url, data=data, files=files) + return await self.http_client.post( + url=self.url, data=data, files=files, **kwargs + ) async def _execute_json( - self, - query: str, - variables: Dict[str, Any], + self, query: str, variables: Dict[str, Any], **kwargs: Any ) -> httpx.Response: + headers: Dict[str, str] = {"Content-Type": "application/json"} + headers.update(kwargs.get("headers", {})) + + merged_kwargs: Dict[str, Any] = kwargs.copy() + merged_kwargs["headers"] = headers + return await self.http_client.post( url=self.url, content=json.dumps( {"query": query, "variables": variables}, default=to_jsonable_python ), - headers={"Content-Type": "application/json"}, + **merged_kwargs, ) async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: diff --git a/tests/main/clients/remote_schema/expected_client/client.py b/tests/main/clients/remote_schema/expected_client/client.py index 8ec4e828..c9b8b7f3 100644 --- a/tests/main/clients/remote_schema/expected_client/client.py +++ b/tests/main/clients/remote_schema/expected_client/client.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Any, Dict from .async_base_client import AsyncBaseClient from .test import Test @@ -9,7 +9,7 @@ def gql(q: str) -> str: class Client(AsyncBaseClient): - async def test(self) -> Test: + async def test(self, **kwargs: Any) -> Test: query = gql( """ query test { @@ -18,6 +18,6 @@ async def test(self) -> Test: """ ) variables: Dict[str, object] = {} - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return Test.model_validate(data) diff --git a/tests/main/clients/shorter_results/expected_client/async_base_client.py b/tests/main/clients/shorter_results/expected_client/async_base_client.py index 3d33b247..25618471 100644 --- a/tests/main/clients/shorter_results/expected_client/async_base_client.py +++ b/tests/main/clients/shorter_results/expected_client/async_base_client.py @@ -84,7 +84,7 @@ async def __aexit__( await self.http_client.aclose() async def execute( - self, query: str, variables: Optional[Dict[str, Any]] = None + self, query: str, variables: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> httpx.Response: processed_variables, files, files_map = self._process_variables(variables) @@ -94,9 +94,12 @@ async def execute( variables=processed_variables, files=files, files_map=files_map, + **kwargs, ) - return await self._execute_json(query=query, variables=processed_variables) + return await self._execute_json( + query=query, variables=processed_variables, **kwargs + ) def get_data(self, response: httpx.Response) -> Dict[str, Any]: if not response.is_success: @@ -123,14 +126,20 @@ def get_data(self, response: httpx.Response) -> Dict[str, Any]: return cast(Dict[str, Any], data) async def execute_ws( - self, query: str, variables: Optional[Dict[str, Any]] = None + self, query: str, variables: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> AsyncIterator[Dict[str, Any]]: + headers = self.ws_headers.copy() + headers.update(kwargs.get("extra_headers", {})) + + merged_kwargs: Dict[str, Any] = {"origin": self.ws_origin} + merged_kwargs.update(kwargs) + merged_kwargs["extra_headers"] = headers + operation_id = str(uuid4()) async with ws_connect( self.ws_url, subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], - origin=self.ws_origin, - extra_headers=self.ws_headers, + **merged_kwargs, ) as websocket: await self._send_connection_init(websocket) await self._send_subscribe( @@ -220,6 +229,7 @@ async def _execute_multipart( variables: Dict[str, Any], files: Dict[str, Tuple[str, IO[bytes], str]], files_map: Dict[str, List[str]], + **kwargs: Any, ) -> httpx.Response: data = { "operations": json.dumps( @@ -228,19 +238,25 @@ async def _execute_multipart( "map": json.dumps(files_map, default=to_jsonable_python), } - return await self.http_client.post(url=self.url, data=data, files=files) + return await self.http_client.post( + url=self.url, data=data, files=files, **kwargs + ) async def _execute_json( - self, - query: str, - variables: Dict[str, Any], + self, query: str, variables: Dict[str, Any], **kwargs: Any ) -> httpx.Response: + headers: Dict[str, str] = {"Content-Type": "application/json"} + headers.update(kwargs.get("headers", {})) + + merged_kwargs: Dict[str, Any] = kwargs.copy() + merged_kwargs["headers"] = headers + return await self.http_client.post( url=self.url, content=json.dumps( {"query": query, "variables": variables}, default=to_jsonable_python ), - headers={"Content-Type": "application/json"}, + **merged_kwargs, ) async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: diff --git a/tests/main/clients/shorter_results/expected_client/client.py b/tests/main/clients/shorter_results/expected_client/client.py index 52fe3b2e..e6fd6463 100644 --- a/tests/main/clients/shorter_results/expected_client/client.py +++ b/tests/main/clients/shorter_results/expected_client/client.py @@ -1,4 +1,4 @@ -from typing import AsyncIterator, Dict, List, Optional, Union +from typing import Any, AsyncIterator, Dict, List, Optional, Union from .async_base_client import AsyncBaseClient from .custom_scalars import ComplexScalar, SimpleScalar @@ -33,7 +33,7 @@ def gql(q: str) -> str: class Client(AsyncBaseClient): - async def get_authenticated_user(self) -> GetAuthenticatedUserMe: + async def get_authenticated_user(self, **kwargs: Any) -> GetAuthenticatedUserMe: query = gql( """ query GetAuthenticatedUser { @@ -45,11 +45,11 @@ async def get_authenticated_user(self) -> GetAuthenticatedUserMe: """ ) variables: Dict[str, object] = {} - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return GetAuthenticatedUser.model_validate(data).me - async def list_strings_1(self) -> Optional[List[Optional[str]]]: + async def list_strings_1(self, **kwargs: Any) -> Optional[List[Optional[str]]]: query = gql( """ query ListStrings_1 { @@ -58,11 +58,11 @@ async def list_strings_1(self) -> Optional[List[Optional[str]]]: """ ) variables: Dict[str, object] = {} - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return ListStrings1.model_validate(data).optional_list_optional_string - async def list_strings_2(self) -> Optional[List[str]]: + async def list_strings_2(self, **kwargs: Any) -> Optional[List[str]]: query = gql( """ query ListStrings_2 { @@ -71,11 +71,11 @@ async def list_strings_2(self) -> Optional[List[str]]: """ ) variables: Dict[str, object] = {} - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return ListStrings2.model_validate(data).optional_list_string - async def list_strings_3(self) -> List[Optional[str]]: + async def list_strings_3(self, **kwargs: Any) -> List[Optional[str]]: query = gql( """ query ListStrings_3 { @@ -84,11 +84,11 @@ async def list_strings_3(self) -> List[Optional[str]]: """ ) variables: Dict[str, object] = {} - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return ListStrings3.model_validate(data).list_optional_string - async def list_strings_4(self) -> List[str]: + async def list_strings_4(self, **kwargs: Any) -> List[str]: query = gql( """ query ListStrings_4 { @@ -97,11 +97,13 @@ async def list_strings_4(self) -> List[str]: """ ) variables: Dict[str, object] = {} - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return ListStrings4.model_validate(data).list_string - async def list_type_a(self) -> List[Optional[ListTypeAListOptionalTypeA]]: + async def list_type_a( + self, **kwargs: Any + ) -> List[Optional[ListTypeAListOptionalTypeA]]: query = gql( """ query ListTypeA { @@ -112,12 +114,12 @@ async def list_type_a(self) -> List[Optional[ListTypeAListOptionalTypeA]]: """ ) variables: Dict[str, object] = {} - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return ListTypeA.model_validate(data).list_optional_type_a async def get_animal_by_name( - self, name: str + self, name: str, **kwargs: Any ) -> Union[ GetAnimalByNameAnimalByNameAnimal, GetAnimalByNameAnimalByNameCat, @@ -140,12 +142,12 @@ async def get_animal_by_name( """ ) variables: Dict[str, object] = {"name": name} - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return GetAnimalByName.model_validate(data).animal_by_name async def list_animals( - self, + self, **kwargs: Any ) -> List[ Union[ ListAnimalsListAnimalsAnimal, @@ -170,11 +172,13 @@ async def list_animals( """ ) variables: Dict[str, object] = {} - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return ListAnimals.model_validate(data).list_animals - async def get_animal_fragment_with_extra(self) -> GetAnimalFragmentWithExtra: + async def get_animal_fragment_with_extra( + self, **kwargs: Any + ) -> GetAnimalFragmentWithExtra: query = gql( """ query GetAnimalFragmentWithExtra { @@ -190,11 +194,11 @@ async def get_animal_fragment_with_extra(self) -> GetAnimalFragmentWithExtra: """ ) variables: Dict[str, object] = {} - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return GetAnimalFragmentWithExtra.model_validate(data) - async def get_simple_scalar(self) -> SimpleScalar: + async def get_simple_scalar(self, **kwargs: Any) -> SimpleScalar: query = gql( """ query GetSimpleScalar { @@ -203,11 +207,11 @@ async def get_simple_scalar(self) -> SimpleScalar: """ ) variables: Dict[str, object] = {} - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return GetSimpleScalar.model_validate(data).just_simple_scalar - async def get_complex_scalar(self) -> ComplexScalar: + async def get_complex_scalar(self, **kwargs: Any) -> ComplexScalar: query = gql( """ query GetComplexScalar { @@ -216,11 +220,13 @@ async def get_complex_scalar(self) -> ComplexScalar: """ ) variables: Dict[str, object] = {} - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return GetComplexScalar.model_validate(data).just_complex_scalar - async def subscribe_strings(self) -> AsyncIterator[Optional[List[str]]]: + async def subscribe_strings( + self, **kwargs: Any + ) -> AsyncIterator[Optional[List[str]]]: query = gql( """ subscription SubscribeStrings { @@ -229,10 +235,12 @@ async def subscribe_strings(self) -> AsyncIterator[Optional[List[str]]]: """ ) variables: Dict[str, object] = {} - async for data in self.execute_ws(query=query, variables=variables): + async for data in self.execute_ws(query=query, variables=variables, **kwargs): yield SubscribeStrings.model_validate(data).optional_list_string - async def unwrap_fragment(self) -> FragmentWithSingleFieldQueryUnwrapFragment: + async def unwrap_fragment( + self, **kwargs: Any + ) -> FragmentWithSingleFieldQueryUnwrapFragment: query = gql( """ query UnwrapFragment { @@ -247,6 +255,6 @@ async def unwrap_fragment(self) -> FragmentWithSingleFieldQueryUnwrapFragment: """ ) variables: Dict[str, object] = {} - response = await self.execute(query=query, variables=variables) + response = await self.execute(query=query, variables=variables, **kwargs) data = self.get_data(response) return UnwrapFragment.model_validate(data).query_unwrap_fragment