From ff112f9e4cea6aafbba9f5a71798883702a4cfed Mon Sep 17 00:00:00 2001 From: Damjan Kuznar Date: Thu, 21 Dec 2023 08:28:43 +0100 Subject: [PATCH 01/34] Handle connection_ack message for subscriptions --- .../dependencies/async_base_client.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/ariadne_codegen/client_generators/dependencies/async_base_client.py b/ariadne_codegen/client_generators/dependencies/async_base_client.py index d3ad17ef..8926e640 100644 --- a/ariadne_codegen/client_generators/dependencies/async_base_client.py +++ b/ariadne_codegen/client_generators/dependencies/async_base_client.py @@ -163,6 +163,12 @@ async def execute_ws( **merged_kwargs, ) as websocket: await self._send_connection_init(websocket) + # wait for connection_ack from server + await self._handle_ws_message( + await websocket.recv(), + websocket, + expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK, + ) await self._send_subscribe( websocket, operation_id=operation_id, @@ -324,7 +330,10 @@ async def _send_subscribe( await websocket.send(json.dumps(payload)) async def _handle_ws_message( - self, message: Data, websocket: WebSocketClientProtocol + self, + message: Data, + websocket: WebSocketClientProtocol, + expected_type: GraphQLTransportWSMessageType | None = None, ) -> Optional[Dict[str, Any]]: try: message_dict = json.loads(message) @@ -337,6 +346,11 @@ async def _handle_ws_message( if not type_ or type_ not in {t.value for t in GraphQLTransportWSMessageType}: raise GraphQLClientInvalidMessageFormat(message=message) + if expected_type and expected_type != type_: + raise GraphQLClientInvalidMessageFormat( + f"Invalid message type - expected {expected_type.value}" + ) + if type_ == GraphQLTransportWSMessageType.NEXT: if "data" not in payload: raise GraphQLClientInvalidMessageFormat(message=message) From 8393f634f54bcbfd2dbe15db8e4589670d7e335e Mon Sep 17 00:00:00 2001 From: Scott Lessans Date: Thu, 21 Dec 2023 10:48:42 -0800 Subject: [PATCH 02/34] fix scalar data issue in plugin settings --- ariadne_codegen/config.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ariadne_codegen/config.py b/ariadne_codegen/config.py index d6a7b8d8..476fa7b0 100644 --- a/ariadne_codegen/config.py +++ b/ariadne_codegen/config.py @@ -38,7 +38,9 @@ def get_client_settings(config_dict: Dict) -> ClientSettings: settings_fields_names = {f.name for f in fields(ClientSettings)} try: section["scalars"] = { - name: ScalarData( + name: data + if isinstance(data, ScalarData) + else ScalarData( type_=data["type"], serialize=data.get("serialize"), parse=data.get("parse"), From 54087f0371e833c85ca0a015573d04e59ad84068 Mon Sep 17 00:00:00 2001 From: Damjan Kuznar Date: Fri, 22 Dec 2023 09:00:38 +0100 Subject: [PATCH 03/34] Fix tests --- .../async_base_client_open_telemetry.py | 18 +++++++++++++++++- .../dependencies/test_websockets.py | 13 +++++++------ .../test_websockets_open_telemetry.py | 5 +++-- .../expected_client/async_base_client.py | 16 +++++++++++++++- .../expected_client/async_base_client.py | 16 +++++++++++++++- .../expected_client/async_base_client.py | 16 +++++++++++++++- .../expected_client/async_base_client.py | 16 +++++++++++++++- .../expected_client/async_base_client.py | 16 +++++++++++++++- .../expected_client/async_base_client.py | 16 +++++++++++++++- .../expected_client/async_base_client.py | 16 +++++++++++++++- .../expected_client/async_base_client.py | 16 +++++++++++++++- .../expected_client/async_base_client.py | 16 +++++++++++++++- .../expected_client/async_base_client.py | 16 +++++++++++++++- .../expected_client/async_base_client.py | 16 +++++++++++++++- .../expected_client/async_base_client.py | 16 +++++++++++++++- 15 files changed, 207 insertions(+), 21 deletions(-) diff --git a/ariadne_codegen/client_generators/dependencies/async_base_client_open_telemetry.py b/ariadne_codegen/client_generators/dependencies/async_base_client_open_telemetry.py index ecd786f5..5580cb90 100644 --- a/ariadne_codegen/client_generators/dependencies/async_base_client_open_telemetry.py +++ b/ariadne_codegen/client_generators/dependencies/async_base_client_open_telemetry.py @@ -563,6 +563,13 @@ async def _execute_ws_with_telemetry( root_span=root_span, websocket=websocket, ) + # wait for connection_ack from server + await self._handle_ws_message_with_telemetry( + root_span=root_span, + message=await websocket.recv(), + websocket=websocket, + expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK, + ) await self._send_subscribe_with_telemetry( root_span=root_span, websocket=websocket, @@ -628,7 +635,11 @@ async def _send_subscribe_with_telemetry( ) async def _handle_ws_message_with_telemetry( - self, root_span: Span, message: Data, websocket: WebSocketClientProtocol + self, + root_span: Span, + message: Data, + websocket: WebSocketClientProtocol, + expected_type: GraphQLTransportWSMessageType | None = None, ) -> Optional[Dict[str, Any]]: with self.tracer.start_as_current_span( # type: ignore "received message", context=set_span_in_context(root_span) @@ -650,6 +661,11 @@ async def _handle_ws_message_with_telemetry( }: raise GraphQLClientInvalidMessageFormat(message=message) + if expected_type and expected_type != type_: + raise GraphQLClientInvalidMessageFormat( + f"Invalid message type - expected {expected_type.value}" + ) + if type_ == GraphQLTransportWSMessageType.NEXT: if "data" not in payload: raise GraphQLClientInvalidMessageFormat(message=message) diff --git a/tests/client_generators/dependencies/test_websockets.py b/tests/client_generators/dependencies/test_websockets.py index 0b3f8843..b2f02996 100644 --- a/tests/client_generators/dependencies/test_websockets.py +++ b/tests/client_generators/dependencies/test_websockets.py @@ -24,12 +24,13 @@ def mocked_websocket(mocked_ws_connect): websocket.__aiter__.return_value = [ json.dumps({"type": "connection_ack"}), ] + websocket.recv.return_value = json.dumps({"type": "connection_ack"}) return websocket @pytest.mark.asyncio async def test_execute_ws_creates_websocket_connection_with_correct_url( - mocked_ws_connect, + mocked_ws_connect, mocked_websocket # pylint: disable=unused-argument ): async for _ in AsyncBaseClient(ws_url="ws://test_url").execute_ws(""): pass @@ -40,7 +41,7 @@ async def test_execute_ws_creates_websocket_connection_with_correct_url( @pytest.mark.asyncio async def test_execute_ws_creates_websocket_connection_with_correct_subprotocol( - mocked_ws_connect, + mocked_ws_connect, mocked_websocket # pylint: disable=unused-argument ): async for _ in AsyncBaseClient().execute_ws(""): pass @@ -53,7 +54,7 @@ async def test_execute_ws_creates_websocket_connection_with_correct_subprotocol( @pytest.mark.asyncio async def test_execute_ws_creates_websocket_connection_with_correct_origin( - mocked_ws_connect, + mocked_ws_connect, mocked_websocket # pylint: disable=unused-argument ): async for _ in AsyncBaseClient(ws_origin="test_origin").execute_ws(""): pass @@ -64,7 +65,7 @@ async def test_execute_ws_creates_websocket_connection_with_correct_origin( @pytest.mark.asyncio async def test_execute_ws_creates_websocket_connection_with_correct_headers( - mocked_ws_connect, + mocked_ws_connect, mocked_websocket # pylint: disable=unused-argument ): async for _ in AsyncBaseClient(ws_headers={"test_key": "test_value"}).execute_ws( "" @@ -79,7 +80,7 @@ async def test_execute_ws_creates_websocket_connection_with_correct_headers( @pytest.mark.asyncio async def test_execute_ws_creates_websocket_connection_with_passed_extra_headers( - mocked_ws_connect, + mocked_ws_connect, mocked_websocket # pylint: disable=unused-argument ): async for _ in AsyncBaseClient( ws_headers={"Client-A": "client_value_a", "Client-B": "client_value_b"} @@ -98,7 +99,7 @@ async def test_execute_ws_creates_websocket_connection_with_passed_extra_headers @pytest.mark.asyncio async def test_execute_ws_creates_websocket_connection_with_passed_kwargs( - mocked_ws_connect, + mocked_ws_connect, mocked_websocket # pylint: disable=unused-argument ): async for _ in AsyncBaseClient().execute_ws("", open_timeout=15, close_timeout=30): pass diff --git a/tests/client_generators/dependencies/test_websockets_open_telemetry.py b/tests/client_generators/dependencies/test_websockets_open_telemetry.py index e674a2f1..53f8347d 100644 --- a/tests/client_generators/dependencies/test_websockets_open_telemetry.py +++ b/tests/client_generators/dependencies/test_websockets_open_telemetry.py @@ -26,6 +26,7 @@ def mocked_websocket(mocked_ws_connect): websocket.__aiter__.return_value = [ json.dumps({"type": "connection_ack"}), ] + websocket.recv.return_value = json.dumps({"type": "connection_ack"}) return websocket @@ -82,7 +83,7 @@ async def test_execute_ws_creates_websocket_connection_with_correct_headers( @pytest.mark.asyncio @pytest.mark.parametrize("tracer", ["tracer name", None]) async def test_execute_ws_creates_websocket_connection_with_passed_extra_headers( - mocked_ws_connect, tracer + mocked_ws_connect, mocked_websocket, tracer # pylint: disable=unused-argument ): async for _ in AsyncBaseClientOpenTelemetry( ws_headers={"Client-A": "client_value_a", "Client-B": "client_value_b"}, @@ -103,7 +104,7 @@ async def test_execute_ws_creates_websocket_connection_with_passed_extra_headers @pytest.mark.asyncio @pytest.mark.parametrize("tracer", ["tracer name", None]) async def test_execute_ws_creates_websocket_connection_with_passed_kwargs( - mocked_ws_connect, tracer + mocked_ws_connect, mocked_websocket, tracer # pylint: disable=unused-argument ): async for _ in AsyncBaseClientOpenTelemetry(tracer=tracer).execute_ws( "", open_timeout=15, close_timeout=30 diff --git a/tests/main/clients/custom_config_file/expected_client/async_base_client.py b/tests/main/clients/custom_config_file/expected_client/async_base_client.py index d3ad17ef..8926e640 100644 --- a/tests/main/clients/custom_config_file/expected_client/async_base_client.py +++ b/tests/main/clients/custom_config_file/expected_client/async_base_client.py @@ -163,6 +163,12 @@ async def execute_ws( **merged_kwargs, ) as websocket: await self._send_connection_init(websocket) + # wait for connection_ack from server + await self._handle_ws_message( + await websocket.recv(), + websocket, + expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK, + ) await self._send_subscribe( websocket, operation_id=operation_id, @@ -324,7 +330,10 @@ async def _send_subscribe( await websocket.send(json.dumps(payload)) async def _handle_ws_message( - self, message: Data, websocket: WebSocketClientProtocol + self, + message: Data, + websocket: WebSocketClientProtocol, + expected_type: GraphQLTransportWSMessageType | None = None, ) -> Optional[Dict[str, Any]]: try: message_dict = json.loads(message) @@ -337,6 +346,11 @@ async def _handle_ws_message( if not type_ or type_ not in {t.value for t in GraphQLTransportWSMessageType}: raise GraphQLClientInvalidMessageFormat(message=message) + if expected_type and expected_type != type_: + raise GraphQLClientInvalidMessageFormat( + f"Invalid message type - expected {expected_type.value}" + ) + if type_ == GraphQLTransportWSMessageType.NEXT: if "data" not in payload: raise GraphQLClientInvalidMessageFormat(message=message) diff --git a/tests/main/clients/custom_files_names/expected_client/async_base_client.py b/tests/main/clients/custom_files_names/expected_client/async_base_client.py index d3ad17ef..8926e640 100644 --- a/tests/main/clients/custom_files_names/expected_client/async_base_client.py +++ b/tests/main/clients/custom_files_names/expected_client/async_base_client.py @@ -163,6 +163,12 @@ async def execute_ws( **merged_kwargs, ) as websocket: await self._send_connection_init(websocket) + # wait for connection_ack from server + await self._handle_ws_message( + await websocket.recv(), + websocket, + expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK, + ) await self._send_subscribe( websocket, operation_id=operation_id, @@ -324,7 +330,10 @@ async def _send_subscribe( await websocket.send(json.dumps(payload)) async def _handle_ws_message( - self, message: Data, websocket: WebSocketClientProtocol + self, + message: Data, + websocket: WebSocketClientProtocol, + expected_type: GraphQLTransportWSMessageType | None = None, ) -> Optional[Dict[str, Any]]: try: message_dict = json.loads(message) @@ -337,6 +346,11 @@ async def _handle_ws_message( if not type_ or type_ not in {t.value for t in GraphQLTransportWSMessageType}: raise GraphQLClientInvalidMessageFormat(message=message) + if expected_type and expected_type != type_: + raise GraphQLClientInvalidMessageFormat( + f"Invalid message type - expected {expected_type.value}" + ) + if type_ == GraphQLTransportWSMessageType.NEXT: if "data" not in payload: raise GraphQLClientInvalidMessageFormat(message=message) diff --git a/tests/main/clients/custom_scalars/expected_client/async_base_client.py b/tests/main/clients/custom_scalars/expected_client/async_base_client.py index d3ad17ef..8926e640 100644 --- a/tests/main/clients/custom_scalars/expected_client/async_base_client.py +++ b/tests/main/clients/custom_scalars/expected_client/async_base_client.py @@ -163,6 +163,12 @@ async def execute_ws( **merged_kwargs, ) as websocket: await self._send_connection_init(websocket) + # wait for connection_ack from server + await self._handle_ws_message( + await websocket.recv(), + websocket, + expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK, + ) await self._send_subscribe( websocket, operation_id=operation_id, @@ -324,7 +330,10 @@ async def _send_subscribe( await websocket.send(json.dumps(payload)) async def _handle_ws_message( - self, message: Data, websocket: WebSocketClientProtocol + self, + message: Data, + websocket: WebSocketClientProtocol, + expected_type: GraphQLTransportWSMessageType | None = None, ) -> Optional[Dict[str, Any]]: try: message_dict = json.loads(message) @@ -337,6 +346,11 @@ async def _handle_ws_message( if not type_ or type_ not in {t.value for t in GraphQLTransportWSMessageType}: raise GraphQLClientInvalidMessageFormat(message=message) + if expected_type and expected_type != type_: + raise GraphQLClientInvalidMessageFormat( + f"Invalid message type - expected {expected_type.value}" + ) + if type_ == GraphQLTransportWSMessageType.NEXT: if "data" not in payload: raise GraphQLClientInvalidMessageFormat(message=message) diff --git a/tests/main/clients/example/expected_client/async_base_client.py b/tests/main/clients/example/expected_client/async_base_client.py index d3ad17ef..8926e640 100644 --- a/tests/main/clients/example/expected_client/async_base_client.py +++ b/tests/main/clients/example/expected_client/async_base_client.py @@ -163,6 +163,12 @@ async def execute_ws( **merged_kwargs, ) as websocket: await self._send_connection_init(websocket) + # wait for connection_ack from server + await self._handle_ws_message( + await websocket.recv(), + websocket, + expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK, + ) await self._send_subscribe( websocket, operation_id=operation_id, @@ -324,7 +330,10 @@ async def _send_subscribe( await websocket.send(json.dumps(payload)) async def _handle_ws_message( - self, message: Data, websocket: WebSocketClientProtocol + self, + message: Data, + websocket: WebSocketClientProtocol, + expected_type: GraphQLTransportWSMessageType | None = None, ) -> Optional[Dict[str, Any]]: try: message_dict = json.loads(message) @@ -337,6 +346,11 @@ async def _handle_ws_message( if not type_ or type_ not in {t.value for t in GraphQLTransportWSMessageType}: raise GraphQLClientInvalidMessageFormat(message=message) + if expected_type and expected_type != type_: + raise GraphQLClientInvalidMessageFormat( + f"Invalid message type - expected {expected_type.value}" + ) + if type_ == GraphQLTransportWSMessageType.NEXT: if "data" not in payload: raise GraphQLClientInvalidMessageFormat(message=message) diff --git a/tests/main/clients/extended_models/expected_client/async_base_client.py b/tests/main/clients/extended_models/expected_client/async_base_client.py index d3ad17ef..8926e640 100644 --- a/tests/main/clients/extended_models/expected_client/async_base_client.py +++ b/tests/main/clients/extended_models/expected_client/async_base_client.py @@ -163,6 +163,12 @@ async def execute_ws( **merged_kwargs, ) as websocket: await self._send_connection_init(websocket) + # wait for connection_ack from server + await self._handle_ws_message( + await websocket.recv(), + websocket, + expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK, + ) await self._send_subscribe( websocket, operation_id=operation_id, @@ -324,7 +330,10 @@ async def _send_subscribe( await websocket.send(json.dumps(payload)) async def _handle_ws_message( - self, message: Data, websocket: WebSocketClientProtocol + self, + message: Data, + websocket: WebSocketClientProtocol, + expected_type: GraphQLTransportWSMessageType | None = None, ) -> Optional[Dict[str, Any]]: try: message_dict = json.loads(message) @@ -337,6 +346,11 @@ async def _handle_ws_message( if not type_ or type_ not in {t.value for t in GraphQLTransportWSMessageType}: raise GraphQLClientInvalidMessageFormat(message=message) + if expected_type and expected_type != type_: + raise GraphQLClientInvalidMessageFormat( + f"Invalid message type - expected {expected_type.value}" + ) + if type_ == GraphQLTransportWSMessageType.NEXT: if "data" not in payload: raise GraphQLClientInvalidMessageFormat(message=message) diff --git a/tests/main/clients/fragments_on_abstract_types/expected_client/async_base_client.py b/tests/main/clients/fragments_on_abstract_types/expected_client/async_base_client.py index d3ad17ef..8926e640 100644 --- a/tests/main/clients/fragments_on_abstract_types/expected_client/async_base_client.py +++ b/tests/main/clients/fragments_on_abstract_types/expected_client/async_base_client.py @@ -163,6 +163,12 @@ async def execute_ws( **merged_kwargs, ) as websocket: await self._send_connection_init(websocket) + # wait for connection_ack from server + await self._handle_ws_message( + await websocket.recv(), + websocket, + expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK, + ) await self._send_subscribe( websocket, operation_id=operation_id, @@ -324,7 +330,10 @@ async def _send_subscribe( await websocket.send(json.dumps(payload)) async def _handle_ws_message( - self, message: Data, websocket: WebSocketClientProtocol + self, + message: Data, + websocket: WebSocketClientProtocol, + expected_type: GraphQLTransportWSMessageType | None = None, ) -> Optional[Dict[str, Any]]: try: message_dict = json.loads(message) @@ -337,6 +346,11 @@ async def _handle_ws_message( if not type_ or type_ not in {t.value for t in GraphQLTransportWSMessageType}: raise GraphQLClientInvalidMessageFormat(message=message) + if expected_type and expected_type != type_: + raise GraphQLClientInvalidMessageFormat( + f"Invalid message type - expected {expected_type.value}" + ) + if type_ == GraphQLTransportWSMessageType.NEXT: if "data" not in payload: raise GraphQLClientInvalidMessageFormat(message=message) diff --git a/tests/main/clients/inline_fragments/expected_client/async_base_client.py b/tests/main/clients/inline_fragments/expected_client/async_base_client.py index d3ad17ef..8926e640 100644 --- a/tests/main/clients/inline_fragments/expected_client/async_base_client.py +++ b/tests/main/clients/inline_fragments/expected_client/async_base_client.py @@ -163,6 +163,12 @@ async def execute_ws( **merged_kwargs, ) as websocket: await self._send_connection_init(websocket) + # wait for connection_ack from server + await self._handle_ws_message( + await websocket.recv(), + websocket, + expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK, + ) await self._send_subscribe( websocket, operation_id=operation_id, @@ -324,7 +330,10 @@ async def _send_subscribe( await websocket.send(json.dumps(payload)) async def _handle_ws_message( - self, message: Data, websocket: WebSocketClientProtocol + self, + message: Data, + websocket: WebSocketClientProtocol, + expected_type: GraphQLTransportWSMessageType | None = None, ) -> Optional[Dict[str, Any]]: try: message_dict = json.loads(message) @@ -337,6 +346,11 @@ async def _handle_ws_message( if not type_ or type_ not in {t.value for t in GraphQLTransportWSMessageType}: raise GraphQLClientInvalidMessageFormat(message=message) + if expected_type and expected_type != type_: + raise GraphQLClientInvalidMessageFormat( + f"Invalid message type - expected {expected_type.value}" + ) + if type_ == GraphQLTransportWSMessageType.NEXT: if "data" not in payload: raise GraphQLClientInvalidMessageFormat(message=message) diff --git a/tests/main/clients/multiple_fragments/expected_client/async_base_client.py b/tests/main/clients/multiple_fragments/expected_client/async_base_client.py index d3ad17ef..8926e640 100644 --- a/tests/main/clients/multiple_fragments/expected_client/async_base_client.py +++ b/tests/main/clients/multiple_fragments/expected_client/async_base_client.py @@ -163,6 +163,12 @@ async def execute_ws( **merged_kwargs, ) as websocket: await self._send_connection_init(websocket) + # wait for connection_ack from server + await self._handle_ws_message( + await websocket.recv(), + websocket, + expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK, + ) await self._send_subscribe( websocket, operation_id=operation_id, @@ -324,7 +330,10 @@ async def _send_subscribe( await websocket.send(json.dumps(payload)) async def _handle_ws_message( - self, message: Data, websocket: WebSocketClientProtocol + self, + message: Data, + websocket: WebSocketClientProtocol, + expected_type: GraphQLTransportWSMessageType | None = None, ) -> Optional[Dict[str, Any]]: try: message_dict = json.loads(message) @@ -337,6 +346,11 @@ async def _handle_ws_message( if not type_ or type_ not in {t.value for t in GraphQLTransportWSMessageType}: raise GraphQLClientInvalidMessageFormat(message=message) + if expected_type and expected_type != type_: + raise GraphQLClientInvalidMessageFormat( + f"Invalid message type - expected {expected_type.value}" + ) + if type_ == GraphQLTransportWSMessageType.NEXT: if "data" not in payload: raise GraphQLClientInvalidMessageFormat(message=message) diff --git a/tests/main/clients/only_used_inputs_and_enums/expected_client/async_base_client.py b/tests/main/clients/only_used_inputs_and_enums/expected_client/async_base_client.py index d3ad17ef..8926e640 100644 --- a/tests/main/clients/only_used_inputs_and_enums/expected_client/async_base_client.py +++ b/tests/main/clients/only_used_inputs_and_enums/expected_client/async_base_client.py @@ -163,6 +163,12 @@ async def execute_ws( **merged_kwargs, ) as websocket: await self._send_connection_init(websocket) + # wait for connection_ack from server + await self._handle_ws_message( + await websocket.recv(), + websocket, + expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK, + ) await self._send_subscribe( websocket, operation_id=operation_id, @@ -324,7 +330,10 @@ async def _send_subscribe( await websocket.send(json.dumps(payload)) async def _handle_ws_message( - self, message: Data, websocket: WebSocketClientProtocol + self, + message: Data, + websocket: WebSocketClientProtocol, + expected_type: GraphQLTransportWSMessageType | None = None, ) -> Optional[Dict[str, Any]]: try: message_dict = json.loads(message) @@ -337,6 +346,11 @@ async def _handle_ws_message( if not type_ or type_ not in {t.value for t in GraphQLTransportWSMessageType}: raise GraphQLClientInvalidMessageFormat(message=message) + if expected_type and expected_type != type_: + raise GraphQLClientInvalidMessageFormat( + f"Invalid message type - expected {expected_type.value}" + ) + if type_ == GraphQLTransportWSMessageType.NEXT: if "data" not in payload: raise GraphQLClientInvalidMessageFormat(message=message) diff --git a/tests/main/clients/operations/expected_client/async_base_client.py b/tests/main/clients/operations/expected_client/async_base_client.py index d3ad17ef..8926e640 100644 --- a/tests/main/clients/operations/expected_client/async_base_client.py +++ b/tests/main/clients/operations/expected_client/async_base_client.py @@ -163,6 +163,12 @@ async def execute_ws( **merged_kwargs, ) as websocket: await self._send_connection_init(websocket) + # wait for connection_ack from server + await self._handle_ws_message( + await websocket.recv(), + websocket, + expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK, + ) await self._send_subscribe( websocket, operation_id=operation_id, @@ -324,7 +330,10 @@ async def _send_subscribe( await websocket.send(json.dumps(payload)) async def _handle_ws_message( - self, message: Data, websocket: WebSocketClientProtocol + self, + message: Data, + websocket: WebSocketClientProtocol, + expected_type: GraphQLTransportWSMessageType | None = None, ) -> Optional[Dict[str, Any]]: try: message_dict = json.loads(message) @@ -337,6 +346,11 @@ async def _handle_ws_message( if not type_ or type_ not in {t.value for t in GraphQLTransportWSMessageType}: raise GraphQLClientInvalidMessageFormat(message=message) + if expected_type and expected_type != type_: + raise GraphQLClientInvalidMessageFormat( + f"Invalid message type - expected {expected_type.value}" + ) + if type_ == GraphQLTransportWSMessageType.NEXT: if "data" not in payload: raise GraphQLClientInvalidMessageFormat(message=message) diff --git a/tests/main/clients/remote_schema/expected_client/async_base_client.py b/tests/main/clients/remote_schema/expected_client/async_base_client.py index d3ad17ef..8926e640 100644 --- a/tests/main/clients/remote_schema/expected_client/async_base_client.py +++ b/tests/main/clients/remote_schema/expected_client/async_base_client.py @@ -163,6 +163,12 @@ async def execute_ws( **merged_kwargs, ) as websocket: await self._send_connection_init(websocket) + # wait for connection_ack from server + await self._handle_ws_message( + await websocket.recv(), + websocket, + expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK, + ) await self._send_subscribe( websocket, operation_id=operation_id, @@ -324,7 +330,10 @@ async def _send_subscribe( await websocket.send(json.dumps(payload)) async def _handle_ws_message( - self, message: Data, websocket: WebSocketClientProtocol + self, + message: Data, + websocket: WebSocketClientProtocol, + expected_type: GraphQLTransportWSMessageType | None = None, ) -> Optional[Dict[str, Any]]: try: message_dict = json.loads(message) @@ -337,6 +346,11 @@ async def _handle_ws_message( if not type_ or type_ not in {t.value for t in GraphQLTransportWSMessageType}: raise GraphQLClientInvalidMessageFormat(message=message) + if expected_type and expected_type != type_: + raise GraphQLClientInvalidMessageFormat( + f"Invalid message type - expected {expected_type.value}" + ) + if type_ == GraphQLTransportWSMessageType.NEXT: if "data" not in payload: raise GraphQLClientInvalidMessageFormat(message=message) diff --git a/tests/main/clients/shorter_results/expected_client/async_base_client.py b/tests/main/clients/shorter_results/expected_client/async_base_client.py index d3ad17ef..8926e640 100644 --- a/tests/main/clients/shorter_results/expected_client/async_base_client.py +++ b/tests/main/clients/shorter_results/expected_client/async_base_client.py @@ -163,6 +163,12 @@ async def execute_ws( **merged_kwargs, ) as websocket: await self._send_connection_init(websocket) + # wait for connection_ack from server + await self._handle_ws_message( + await websocket.recv(), + websocket, + expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK, + ) await self._send_subscribe( websocket, operation_id=operation_id, @@ -324,7 +330,10 @@ async def _send_subscribe( await websocket.send(json.dumps(payload)) async def _handle_ws_message( - self, message: Data, websocket: WebSocketClientProtocol + self, + message: Data, + websocket: WebSocketClientProtocol, + expected_type: GraphQLTransportWSMessageType | None = None, ) -> Optional[Dict[str, Any]]: try: message_dict = json.loads(message) @@ -337,6 +346,11 @@ async def _handle_ws_message( if not type_ or type_ not in {t.value for t in GraphQLTransportWSMessageType}: raise GraphQLClientInvalidMessageFormat(message=message) + if expected_type and expected_type != type_: + raise GraphQLClientInvalidMessageFormat( + f"Invalid message type - expected {expected_type.value}" + ) + if type_ == GraphQLTransportWSMessageType.NEXT: if "data" not in payload: raise GraphQLClientInvalidMessageFormat(message=message) From 16186baea94752bedc20eb879e1ec6a7548365df Mon Sep 17 00:00:00 2001 From: Damjan Kuznar Date: Fri, 22 Dec 2023 09:20:07 +0100 Subject: [PATCH 04/34] Add change log --- CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 49ce3a35..46335a11 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # CHANGELOG +## UNRELEASED + +- Improved `graphql-transport-ws` protocol compliance for `connection_ack` messages + + ## 0.11.0 (2023-12-05) - Removed `model_rebuild` calls for generated input, fragment and result models. From 46353fa1a6ed82eaf8661e03d922c04bf60d599c Mon Sep 17 00:00:00 2001 From: Damjan Kuznar Date: Fri, 22 Dec 2023 10:28:26 +0100 Subject: [PATCH 05/34] Fix compatibility with Python 3.9 --- .../client_generators/dependencies/async_base_client.py | 2 +- .../dependencies/async_base_client_open_telemetry.py | 2 +- .../custom_config_file/expected_client/async_base_client.py | 2 +- .../custom_files_names/expected_client/async_base_client.py | 2 +- .../clients/custom_scalars/expected_client/async_base_client.py | 2 +- tests/main/clients/example/expected_client/async_base_client.py | 2 +- .../extended_models/expected_client/async_base_client.py | 2 +- .../expected_client/async_base_client.py | 2 +- .../inline_fragments/expected_client/async_base_client.py | 2 +- .../multiple_fragments/expected_client/async_base_client.py | 2 +- .../expected_client/async_base_client.py | 2 +- .../clients/operations/expected_client/async_base_client.py | 2 +- .../clients/remote_schema/expected_client/async_base_client.py | 2 +- .../shorter_results/expected_client/async_base_client.py | 2 +- 14 files changed, 14 insertions(+), 14 deletions(-) diff --git a/ariadne_codegen/client_generators/dependencies/async_base_client.py b/ariadne_codegen/client_generators/dependencies/async_base_client.py index 8926e640..a771269a 100644 --- a/ariadne_codegen/client_generators/dependencies/async_base_client.py +++ b/ariadne_codegen/client_generators/dependencies/async_base_client.py @@ -333,7 +333,7 @@ async def _handle_ws_message( self, message: Data, websocket: WebSocketClientProtocol, - expected_type: GraphQLTransportWSMessageType | None = None, + expected_type: Optional[GraphQLTransportWSMessageType] = None, ) -> Optional[Dict[str, Any]]: try: message_dict = json.loads(message) diff --git a/ariadne_codegen/client_generators/dependencies/async_base_client_open_telemetry.py b/ariadne_codegen/client_generators/dependencies/async_base_client_open_telemetry.py index 5580cb90..56617b6e 100644 --- a/ariadne_codegen/client_generators/dependencies/async_base_client_open_telemetry.py +++ b/ariadne_codegen/client_generators/dependencies/async_base_client_open_telemetry.py @@ -639,7 +639,7 @@ async def _handle_ws_message_with_telemetry( root_span: Span, message: Data, websocket: WebSocketClientProtocol, - expected_type: GraphQLTransportWSMessageType | None = None, + expected_type: Optional[GraphQLTransportWSMessageType] = None, ) -> Optional[Dict[str, Any]]: with self.tracer.start_as_current_span( # type: ignore "received message", context=set_span_in_context(root_span) diff --git a/tests/main/clients/custom_config_file/expected_client/async_base_client.py b/tests/main/clients/custom_config_file/expected_client/async_base_client.py index 8926e640..a771269a 100644 --- a/tests/main/clients/custom_config_file/expected_client/async_base_client.py +++ b/tests/main/clients/custom_config_file/expected_client/async_base_client.py @@ -333,7 +333,7 @@ async def _handle_ws_message( self, message: Data, websocket: WebSocketClientProtocol, - expected_type: GraphQLTransportWSMessageType | None = None, + expected_type: Optional[GraphQLTransportWSMessageType] = None, ) -> Optional[Dict[str, Any]]: try: message_dict = json.loads(message) diff --git a/tests/main/clients/custom_files_names/expected_client/async_base_client.py b/tests/main/clients/custom_files_names/expected_client/async_base_client.py index 8926e640..a771269a 100644 --- a/tests/main/clients/custom_files_names/expected_client/async_base_client.py +++ b/tests/main/clients/custom_files_names/expected_client/async_base_client.py @@ -333,7 +333,7 @@ async def _handle_ws_message( self, message: Data, websocket: WebSocketClientProtocol, - expected_type: GraphQLTransportWSMessageType | None = None, + expected_type: Optional[GraphQLTransportWSMessageType] = None, ) -> Optional[Dict[str, Any]]: try: message_dict = json.loads(message) diff --git a/tests/main/clients/custom_scalars/expected_client/async_base_client.py b/tests/main/clients/custom_scalars/expected_client/async_base_client.py index 8926e640..a771269a 100644 --- a/tests/main/clients/custom_scalars/expected_client/async_base_client.py +++ b/tests/main/clients/custom_scalars/expected_client/async_base_client.py @@ -333,7 +333,7 @@ async def _handle_ws_message( self, message: Data, websocket: WebSocketClientProtocol, - expected_type: GraphQLTransportWSMessageType | None = None, + expected_type: Optional[GraphQLTransportWSMessageType] = None, ) -> Optional[Dict[str, Any]]: try: message_dict = json.loads(message) diff --git a/tests/main/clients/example/expected_client/async_base_client.py b/tests/main/clients/example/expected_client/async_base_client.py index 8926e640..a771269a 100644 --- a/tests/main/clients/example/expected_client/async_base_client.py +++ b/tests/main/clients/example/expected_client/async_base_client.py @@ -333,7 +333,7 @@ async def _handle_ws_message( self, message: Data, websocket: WebSocketClientProtocol, - expected_type: GraphQLTransportWSMessageType | None = None, + expected_type: Optional[GraphQLTransportWSMessageType] = None, ) -> Optional[Dict[str, Any]]: try: message_dict = json.loads(message) diff --git a/tests/main/clients/extended_models/expected_client/async_base_client.py b/tests/main/clients/extended_models/expected_client/async_base_client.py index 8926e640..a771269a 100644 --- a/tests/main/clients/extended_models/expected_client/async_base_client.py +++ b/tests/main/clients/extended_models/expected_client/async_base_client.py @@ -333,7 +333,7 @@ async def _handle_ws_message( self, message: Data, websocket: WebSocketClientProtocol, - expected_type: GraphQLTransportWSMessageType | None = None, + expected_type: Optional[GraphQLTransportWSMessageType] = None, ) -> Optional[Dict[str, Any]]: try: message_dict = json.loads(message) diff --git a/tests/main/clients/fragments_on_abstract_types/expected_client/async_base_client.py b/tests/main/clients/fragments_on_abstract_types/expected_client/async_base_client.py index 8926e640..a771269a 100644 --- a/tests/main/clients/fragments_on_abstract_types/expected_client/async_base_client.py +++ b/tests/main/clients/fragments_on_abstract_types/expected_client/async_base_client.py @@ -333,7 +333,7 @@ async def _handle_ws_message( self, message: Data, websocket: WebSocketClientProtocol, - expected_type: GraphQLTransportWSMessageType | None = None, + expected_type: Optional[GraphQLTransportWSMessageType] = None, ) -> Optional[Dict[str, Any]]: try: message_dict = json.loads(message) diff --git a/tests/main/clients/inline_fragments/expected_client/async_base_client.py b/tests/main/clients/inline_fragments/expected_client/async_base_client.py index 8926e640..a771269a 100644 --- a/tests/main/clients/inline_fragments/expected_client/async_base_client.py +++ b/tests/main/clients/inline_fragments/expected_client/async_base_client.py @@ -333,7 +333,7 @@ async def _handle_ws_message( self, message: Data, websocket: WebSocketClientProtocol, - expected_type: GraphQLTransportWSMessageType | None = None, + expected_type: Optional[GraphQLTransportWSMessageType] = None, ) -> Optional[Dict[str, Any]]: try: message_dict = json.loads(message) diff --git a/tests/main/clients/multiple_fragments/expected_client/async_base_client.py b/tests/main/clients/multiple_fragments/expected_client/async_base_client.py index 8926e640..a771269a 100644 --- a/tests/main/clients/multiple_fragments/expected_client/async_base_client.py +++ b/tests/main/clients/multiple_fragments/expected_client/async_base_client.py @@ -333,7 +333,7 @@ async def _handle_ws_message( self, message: Data, websocket: WebSocketClientProtocol, - expected_type: GraphQLTransportWSMessageType | None = None, + expected_type: Optional[GraphQLTransportWSMessageType] = None, ) -> Optional[Dict[str, Any]]: try: message_dict = json.loads(message) diff --git a/tests/main/clients/only_used_inputs_and_enums/expected_client/async_base_client.py b/tests/main/clients/only_used_inputs_and_enums/expected_client/async_base_client.py index 8926e640..a771269a 100644 --- a/tests/main/clients/only_used_inputs_and_enums/expected_client/async_base_client.py +++ b/tests/main/clients/only_used_inputs_and_enums/expected_client/async_base_client.py @@ -333,7 +333,7 @@ async def _handle_ws_message( self, message: Data, websocket: WebSocketClientProtocol, - expected_type: GraphQLTransportWSMessageType | None = None, + expected_type: Optional[GraphQLTransportWSMessageType] = None, ) -> Optional[Dict[str, Any]]: try: message_dict = json.loads(message) diff --git a/tests/main/clients/operations/expected_client/async_base_client.py b/tests/main/clients/operations/expected_client/async_base_client.py index 8926e640..a771269a 100644 --- a/tests/main/clients/operations/expected_client/async_base_client.py +++ b/tests/main/clients/operations/expected_client/async_base_client.py @@ -333,7 +333,7 @@ async def _handle_ws_message( self, message: Data, websocket: WebSocketClientProtocol, - expected_type: GraphQLTransportWSMessageType | None = None, + expected_type: Optional[GraphQLTransportWSMessageType] = None, ) -> Optional[Dict[str, Any]]: try: message_dict = json.loads(message) diff --git a/tests/main/clients/remote_schema/expected_client/async_base_client.py b/tests/main/clients/remote_schema/expected_client/async_base_client.py index 8926e640..a771269a 100644 --- a/tests/main/clients/remote_schema/expected_client/async_base_client.py +++ b/tests/main/clients/remote_schema/expected_client/async_base_client.py @@ -333,7 +333,7 @@ async def _handle_ws_message( self, message: Data, websocket: WebSocketClientProtocol, - expected_type: GraphQLTransportWSMessageType | None = None, + expected_type: Optional[GraphQLTransportWSMessageType] = None, ) -> Optional[Dict[str, Any]]: try: message_dict = json.loads(message) diff --git a/tests/main/clients/shorter_results/expected_client/async_base_client.py b/tests/main/clients/shorter_results/expected_client/async_base_client.py index 8926e640..a771269a 100644 --- a/tests/main/clients/shorter_results/expected_client/async_base_client.py +++ b/tests/main/clients/shorter_results/expected_client/async_base_client.py @@ -333,7 +333,7 @@ async def _handle_ws_message( self, message: Data, websocket: WebSocketClientProtocol, - expected_type: GraphQLTransportWSMessageType | None = None, + expected_type: Optional[GraphQLTransportWSMessageType] = None, ) -> Optional[Dict[str, Any]]: try: message_dict = json.loads(message) From 0d889a9e64f9026594296d566991fcaf2ee419cc Mon Sep 17 00:00:00 2001 From: Scott Lessans Date: Sun, 24 Dec 2023 11:56:11 -0800 Subject: [PATCH 06/34] copy the section rather than modify in-place --- ariadne_codegen/config.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/ariadne_codegen/config.py b/ariadne_codegen/config.py index 476fa7b0..c087ce7a 100644 --- a/ariadne_codegen/config.py +++ b/ariadne_codegen/config.py @@ -34,13 +34,11 @@ def get_config_dict(config_file_name: Optional[str] = None) -> Dict: def get_client_settings(config_dict: Dict) -> ClientSettings: """Parse configuration dict and return ClientSettings instance.""" - section = get_section(config_dict) + section = get_section(config_dict).copy() settings_fields_names = {f.name for f in fields(ClientSettings)} - try: + try: section["scalars"] = { - name: data - if isinstance(data, ScalarData) - else ScalarData( + name: ScalarData( type_=data["type"], serialize=data.get("serialize"), parse=data.get("parse"), From ceb71a00776593f688f538287b689fab3980b6cb Mon Sep 17 00:00:00 2001 From: Scott Lessans Date: Sun, 24 Dec 2023 11:57:13 -0800 Subject: [PATCH 07/34] whitespace --- ariadne_codegen/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ariadne_codegen/config.py b/ariadne_codegen/config.py index c087ce7a..7aa4d792 100644 --- a/ariadne_codegen/config.py +++ b/ariadne_codegen/config.py @@ -36,7 +36,7 @@ def get_client_settings(config_dict: Dict) -> ClientSettings: """Parse configuration dict and return ClientSettings instance.""" section = get_section(config_dict).copy() settings_fields_names = {f.name for f in fields(ClientSettings)} - try: + try: section["scalars"] = { name: ScalarData( type_=data["type"], From 361df204b19f874e308cdaffffdfef802f890633 Mon Sep 17 00:00:00 2001 From: Damjan Kuznar Date: Tue, 9 Jan 2024 07:26:46 +0100 Subject: [PATCH 08/34] Add tests --- .../async_base_client_open_telemetry.py | 16 +++++++- .../dependencies/test_websockets.py | 19 ++++++++++ .../test_websockets_open_telemetry.py | 38 +++++++++++++++++++ 3 files changed, 72 insertions(+), 1 deletion(-) diff --git a/ariadne_codegen/client_generators/dependencies/async_base_client_open_telemetry.py b/ariadne_codegen/client_generators/dependencies/async_base_client_open_telemetry.py index 56617b6e..bd479e98 100644 --- a/ariadne_codegen/client_generators/dependencies/async_base_client_open_telemetry.py +++ b/ariadne_codegen/client_generators/dependencies/async_base_client_open_telemetry.py @@ -373,6 +373,12 @@ async def _execute_ws( **merged_kwargs, ) as websocket: await self._send_connection_init(websocket) + # wait for connection_ack from server + await self._handle_ws_message( + await websocket.recv(), + websocket, + expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK, + ) await self._send_subscribe( websocket, operation_id=operation_id, @@ -414,7 +420,10 @@ async def _send_subscribe( await websocket.send(json.dumps(payload)) async def _handle_ws_message( - self, message: Data, websocket: WebSocketClientProtocol + self, + message: Data, + websocket: WebSocketClientProtocol, + expected_type: Optional[GraphQLTransportWSMessageType] = None, ) -> Optional[Dict[str, Any]]: try: message_dict = json.loads(message) @@ -427,6 +436,11 @@ async def _handle_ws_message( if not type_ or type_ not in {t.value for t in GraphQLTransportWSMessageType}: raise GraphQLClientInvalidMessageFormat(message=message) + if expected_type and expected_type != type_: + raise GraphQLClientInvalidMessageFormat( + f"Invalid message type - expected {expected_type.value}" + ) + if type_ == GraphQLTransportWSMessageType.NEXT: if "data" not in payload: raise GraphQLClientInvalidMessageFormat(message=message) diff --git a/tests/client_generators/dependencies/test_websockets.py b/tests/client_generators/dependencies/test_websockets.py index b2f02996..b82639da 100644 --- a/tests/client_generators/dependencies/test_websockets.py +++ b/tests/client_generators/dependencies/test_websockets.py @@ -28,6 +28,12 @@ def mocked_websocket(mocked_ws_connect): return websocket +@pytest.fixture +def mocked_faulty_websocket(mocked_ws_connect): + websocket = mocked_ws_connect.return_value.__aenter__.return_value + return websocket + + @pytest.mark.asyncio async def test_execute_ws_creates_websocket_connection_with_correct_url( mocked_ws_connect, mocked_websocket # pylint: disable=unused-argument @@ -251,3 +257,16 @@ async def test_execute_ws_raises_graphql_multi_error_for_message_with_error_type with pytest.raises(GraphQLClientGraphQLMultiError): async for _ in AsyncBaseClient().execute_ws(""): pass + + +@pytest.mark.asyncio +async def test_execute_ws_raises_invalid_message_format_for_missing_ack_after_init( + mocked_faulty_websocket, +): + mocked_faulty_websocket.recv.return_value = json.dumps( + {"type": "next", "payload": {"data": "test_data"}} + ) + + with pytest.raises(GraphQLClientInvalidMessageFormat): + async for _ in AsyncBaseClient().execute_ws(""): + pass diff --git a/tests/client_generators/dependencies/test_websockets_open_telemetry.py b/tests/client_generators/dependencies/test_websockets_open_telemetry.py index 53f8347d..48c6fb48 100644 --- a/tests/client_generators/dependencies/test_websockets_open_telemetry.py +++ b/tests/client_generators/dependencies/test_websockets_open_telemetry.py @@ -30,6 +30,12 @@ def mocked_websocket(mocked_ws_connect): return websocket +@pytest.fixture +def mocked_faulty_websocket(mocked_ws_connect): + websocket = mocked_ws_connect.return_value.__aenter__.return_value + return websocket + + @pytest.mark.asyncio async def test_execute_ws_creates_websocket_connection_with_correct_url( mocked_ws_connect, @@ -260,6 +266,19 @@ async def test_execute_ws_raises_graphql_multi_error_for_message_with_error_type pass +@pytest.mark.asyncio +async def test_execute_ws_raises_invalid_message_format_for_missing_ack_after_init( + mocked_faulty_websocket, +): + mocked_faulty_websocket.recv.return_value = json.dumps( + {"type": "next", "payload": {"data": "test_data"}} + ) + + with pytest.raises(GraphQLClientInvalidMessageFormat): + async for _ in AsyncBaseClientOpenTelemetry().execute_ws(""): + pass + + @pytest.fixture def mocked_start_as_current_span(mocker): mocker_get_tracer = mocker.patch( @@ -395,3 +414,22 @@ async def test_execute_ws_creates_span_for_received_error_message( mocked_start_as_current_span.assert_any_call("received message", context=ANY) with mocked_start_as_current_span.return_value as span: span.set_attribute.assert_any_call("type", "error") + + +@pytest.mark.asyncio +async def test_execute_ws_ws_creates_span_for_missing_ack_after_init( + mocked_faulty_websocket, mocked_start_as_current_span +): + mocked_faulty_websocket.recv.return_value = json.dumps( + {"type": "next", "payload": {"data": "test_data"}} + ) + + client = AsyncBaseClientOpenTelemetry(ws_url="ws://test_url", tracer="tracker") + + with pytest.raises(GraphQLClientInvalidMessageFormat): + async for _ in client.execute_ws(""): + pass + + mocked_start_as_current_span.assert_any_call("received message", context=ANY) + with mocked_start_as_current_span.return_value as span: + span.set_attribute.assert_any_call("type", "next") From b6e77a4fbd17e12fad101bf95899612cf0fe7a49 Mon Sep 17 00:00:00 2001 From: Simon Sawert Date: Tue, 9 Jan 2024 09:36:21 +0100 Subject: [PATCH 09/34] Fix faulty import path for `NoReimportsPlugin` --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index d8945be2..57513eb6 100644 --- a/README.md +++ b/README.md @@ -105,7 +105,7 @@ plugins = ["ariadne_codegen.contrib.extract_operations.ExtractOperationsPlugin"] operations_module_name = "custom_operations_module_name" ``` -- [`ariadne_codegen.contrib.extract_operations.NoReimportsPlugin`](ariadne_codegen/contrib/no_reimports.py) - This plugin removes content of generated `__init__.py`. This is useful in scenarios where generated plugins contain so many Pydantic models that client's eager initialization of entire package on first import is very slow. +- [`ariadne_codegen.contrib.no_reimports.NoReimportsPlugin`](ariadne_codegen/contrib/no_reimports.py) - This plugin removes content of generated `__init__.py`. This is useful in scenarios where generated plugins contain so many Pydantic models that client's eager initialization of entire package on first import is very slow. ## Using generated client From 2d080c170dd6f9e073a68f86d86dfa20f689b57c Mon Sep 17 00:00:00 2001 From: Damjan Kuznar Date: Tue, 9 Jan 2024 18:08:28 +0100 Subject: [PATCH 10/34] review comments --- CHANGELOG.md | 4 ++-- .../client_generators/dependencies/async_base_client.py | 2 +- .../dependencies/async_base_client_open_telemetry.py | 4 ++-- .../custom_config_file/expected_client/async_base_client.py | 2 +- .../custom_files_names/expected_client/async_base_client.py | 2 +- .../custom_scalars/expected_client/async_base_client.py | 2 +- .../main/clients/example/expected_client/async_base_client.py | 2 +- .../extended_models/expected_client/async_base_client.py | 2 +- .../expected_client/async_base_client.py | 2 +- .../inline_fragments/expected_client/async_base_client.py | 2 +- .../multiple_fragments/expected_client/async_base_client.py | 2 +- .../expected_client/async_base_client.py | 2 +- .../clients/operations/expected_client/async_base_client.py | 2 +- .../remote_schema/expected_client/async_base_client.py | 2 +- .../shorter_results/expected_client/async_base_client.py | 2 +- 15 files changed, 17 insertions(+), 17 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 46335a11..9cb8e7f1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,8 +1,8 @@ # CHANGELOG -## UNRELEASED +## 0.12.0 (UNRELEASED) -- Improved `graphql-transport-ws` protocol compliance for `connection_ack` messages +- Fixed `graphql-transport-ws` protocol implementation not waiting for the `connection_ack` message on new connection. ## 0.11.0 (2023-12-05) diff --git a/ariadne_codegen/client_generators/dependencies/async_base_client.py b/ariadne_codegen/client_generators/dependencies/async_base_client.py index a771269a..5358ced6 100644 --- a/ariadne_codegen/client_generators/dependencies/async_base_client.py +++ b/ariadne_codegen/client_generators/dependencies/async_base_client.py @@ -348,7 +348,7 @@ async def _handle_ws_message( if expected_type and expected_type != type_: raise GraphQLClientInvalidMessageFormat( - f"Invalid message type - expected {expected_type.value}" + f"Invalid message received. Expected: {expected_type.value}" ) if type_ == GraphQLTransportWSMessageType.NEXT: diff --git a/ariadne_codegen/client_generators/dependencies/async_base_client_open_telemetry.py b/ariadne_codegen/client_generators/dependencies/async_base_client_open_telemetry.py index bd479e98..65d56458 100644 --- a/ariadne_codegen/client_generators/dependencies/async_base_client_open_telemetry.py +++ b/ariadne_codegen/client_generators/dependencies/async_base_client_open_telemetry.py @@ -438,7 +438,7 @@ async def _handle_ws_message( if expected_type and expected_type != type_: raise GraphQLClientInvalidMessageFormat( - f"Invalid message type - expected {expected_type.value}" + f"Invalid message received. Expected: {expected_type.value}" ) if type_ == GraphQLTransportWSMessageType.NEXT: @@ -677,7 +677,7 @@ async def _handle_ws_message_with_telemetry( if expected_type and expected_type != type_: raise GraphQLClientInvalidMessageFormat( - f"Invalid message type - expected {expected_type.value}" + f"Invalid message received. Expected: {expected_type.value}" ) if type_ == GraphQLTransportWSMessageType.NEXT: diff --git a/tests/main/clients/custom_config_file/expected_client/async_base_client.py b/tests/main/clients/custom_config_file/expected_client/async_base_client.py index a771269a..5358ced6 100644 --- a/tests/main/clients/custom_config_file/expected_client/async_base_client.py +++ b/tests/main/clients/custom_config_file/expected_client/async_base_client.py @@ -348,7 +348,7 @@ async def _handle_ws_message( if expected_type and expected_type != type_: raise GraphQLClientInvalidMessageFormat( - f"Invalid message type - expected {expected_type.value}" + f"Invalid message received. Expected: {expected_type.value}" ) if type_ == GraphQLTransportWSMessageType.NEXT: diff --git a/tests/main/clients/custom_files_names/expected_client/async_base_client.py b/tests/main/clients/custom_files_names/expected_client/async_base_client.py index a771269a..5358ced6 100644 --- a/tests/main/clients/custom_files_names/expected_client/async_base_client.py +++ b/tests/main/clients/custom_files_names/expected_client/async_base_client.py @@ -348,7 +348,7 @@ async def _handle_ws_message( if expected_type and expected_type != type_: raise GraphQLClientInvalidMessageFormat( - f"Invalid message type - expected {expected_type.value}" + f"Invalid message received. Expected: {expected_type.value}" ) if type_ == GraphQLTransportWSMessageType.NEXT: diff --git a/tests/main/clients/custom_scalars/expected_client/async_base_client.py b/tests/main/clients/custom_scalars/expected_client/async_base_client.py index a771269a..5358ced6 100644 --- a/tests/main/clients/custom_scalars/expected_client/async_base_client.py +++ b/tests/main/clients/custom_scalars/expected_client/async_base_client.py @@ -348,7 +348,7 @@ async def _handle_ws_message( if expected_type and expected_type != type_: raise GraphQLClientInvalidMessageFormat( - f"Invalid message type - expected {expected_type.value}" + f"Invalid message received. Expected: {expected_type.value}" ) if type_ == GraphQLTransportWSMessageType.NEXT: diff --git a/tests/main/clients/example/expected_client/async_base_client.py b/tests/main/clients/example/expected_client/async_base_client.py index a771269a..5358ced6 100644 --- a/tests/main/clients/example/expected_client/async_base_client.py +++ b/tests/main/clients/example/expected_client/async_base_client.py @@ -348,7 +348,7 @@ async def _handle_ws_message( if expected_type and expected_type != type_: raise GraphQLClientInvalidMessageFormat( - f"Invalid message type - expected {expected_type.value}" + f"Invalid message received. Expected: {expected_type.value}" ) if type_ == GraphQLTransportWSMessageType.NEXT: diff --git a/tests/main/clients/extended_models/expected_client/async_base_client.py b/tests/main/clients/extended_models/expected_client/async_base_client.py index a771269a..5358ced6 100644 --- a/tests/main/clients/extended_models/expected_client/async_base_client.py +++ b/tests/main/clients/extended_models/expected_client/async_base_client.py @@ -348,7 +348,7 @@ async def _handle_ws_message( if expected_type and expected_type != type_: raise GraphQLClientInvalidMessageFormat( - f"Invalid message type - expected {expected_type.value}" + f"Invalid message received. Expected: {expected_type.value}" ) if type_ == GraphQLTransportWSMessageType.NEXT: diff --git a/tests/main/clients/fragments_on_abstract_types/expected_client/async_base_client.py b/tests/main/clients/fragments_on_abstract_types/expected_client/async_base_client.py index a771269a..5358ced6 100644 --- a/tests/main/clients/fragments_on_abstract_types/expected_client/async_base_client.py +++ b/tests/main/clients/fragments_on_abstract_types/expected_client/async_base_client.py @@ -348,7 +348,7 @@ async def _handle_ws_message( if expected_type and expected_type != type_: raise GraphQLClientInvalidMessageFormat( - f"Invalid message type - expected {expected_type.value}" + f"Invalid message received. Expected: {expected_type.value}" ) if type_ == GraphQLTransportWSMessageType.NEXT: diff --git a/tests/main/clients/inline_fragments/expected_client/async_base_client.py b/tests/main/clients/inline_fragments/expected_client/async_base_client.py index a771269a..5358ced6 100644 --- a/tests/main/clients/inline_fragments/expected_client/async_base_client.py +++ b/tests/main/clients/inline_fragments/expected_client/async_base_client.py @@ -348,7 +348,7 @@ async def _handle_ws_message( if expected_type and expected_type != type_: raise GraphQLClientInvalidMessageFormat( - f"Invalid message type - expected {expected_type.value}" + f"Invalid message received. Expected: {expected_type.value}" ) if type_ == GraphQLTransportWSMessageType.NEXT: diff --git a/tests/main/clients/multiple_fragments/expected_client/async_base_client.py b/tests/main/clients/multiple_fragments/expected_client/async_base_client.py index a771269a..5358ced6 100644 --- a/tests/main/clients/multiple_fragments/expected_client/async_base_client.py +++ b/tests/main/clients/multiple_fragments/expected_client/async_base_client.py @@ -348,7 +348,7 @@ async def _handle_ws_message( if expected_type and expected_type != type_: raise GraphQLClientInvalidMessageFormat( - f"Invalid message type - expected {expected_type.value}" + f"Invalid message received. Expected: {expected_type.value}" ) if type_ == GraphQLTransportWSMessageType.NEXT: diff --git a/tests/main/clients/only_used_inputs_and_enums/expected_client/async_base_client.py b/tests/main/clients/only_used_inputs_and_enums/expected_client/async_base_client.py index a771269a..5358ced6 100644 --- a/tests/main/clients/only_used_inputs_and_enums/expected_client/async_base_client.py +++ b/tests/main/clients/only_used_inputs_and_enums/expected_client/async_base_client.py @@ -348,7 +348,7 @@ async def _handle_ws_message( if expected_type and expected_type != type_: raise GraphQLClientInvalidMessageFormat( - f"Invalid message type - expected {expected_type.value}" + f"Invalid message received. Expected: {expected_type.value}" ) if type_ == GraphQLTransportWSMessageType.NEXT: diff --git a/tests/main/clients/operations/expected_client/async_base_client.py b/tests/main/clients/operations/expected_client/async_base_client.py index a771269a..5358ced6 100644 --- a/tests/main/clients/operations/expected_client/async_base_client.py +++ b/tests/main/clients/operations/expected_client/async_base_client.py @@ -348,7 +348,7 @@ async def _handle_ws_message( if expected_type and expected_type != type_: raise GraphQLClientInvalidMessageFormat( - f"Invalid message type - expected {expected_type.value}" + f"Invalid message received. Expected: {expected_type.value}" ) if type_ == GraphQLTransportWSMessageType.NEXT: diff --git a/tests/main/clients/remote_schema/expected_client/async_base_client.py b/tests/main/clients/remote_schema/expected_client/async_base_client.py index a771269a..5358ced6 100644 --- a/tests/main/clients/remote_schema/expected_client/async_base_client.py +++ b/tests/main/clients/remote_schema/expected_client/async_base_client.py @@ -348,7 +348,7 @@ async def _handle_ws_message( if expected_type and expected_type != type_: raise GraphQLClientInvalidMessageFormat( - f"Invalid message type - expected {expected_type.value}" + f"Invalid message received. Expected: {expected_type.value}" ) if type_ == GraphQLTransportWSMessageType.NEXT: diff --git a/tests/main/clients/shorter_results/expected_client/async_base_client.py b/tests/main/clients/shorter_results/expected_client/async_base_client.py index a771269a..5358ced6 100644 --- a/tests/main/clients/shorter_results/expected_client/async_base_client.py +++ b/tests/main/clients/shorter_results/expected_client/async_base_client.py @@ -348,7 +348,7 @@ async def _handle_ws_message( if expected_type and expected_type != type_: raise GraphQLClientInvalidMessageFormat( - f"Invalid message type - expected {expected_type.value}" + f"Invalid message received. Expected: {expected_type.value}" ) if type_ == GraphQLTransportWSMessageType.NEXT: From f50795b67cdc9c899cb8c6868c37a4018fa8d9b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Pito=C5=84?= Date: Thu, 18 Jan 2024 16:16:38 +0100 Subject: [PATCH 11/34] Support different schema formats in graphqlschema strategy --- .../graphql_schema_generators/schema.py | 7 +++- ariadne_codegen/main.py | 23 ++++++---- ariadne_codegen/settings.py | 42 ++++++++++++++++--- 3 files changed, 59 insertions(+), 13 deletions(-) diff --git a/ariadne_codegen/graphql_schema_generators/schema.py b/ariadne_codegen/graphql_schema_generators/schema.py index a1587b14..bc5808c1 100644 --- a/ariadne_codegen/graphql_schema_generators/schema.py +++ b/ariadne_codegen/graphql_schema_generators/schema.py @@ -3,6 +3,7 @@ from graphql import GraphQLSchema from graphql.type.schema import TypeMap +from graphql import print_schema from ..codegen import ( generate_ann_assign, @@ -23,7 +24,11 @@ from .utils import get_optional_named_type -def generate_graphql_schema_file( +def generate_graphql_schema_graphql_file(schema: GraphQLSchema, target_file_path: str): + Path(target_file_path).write_text(print_schema(schema), encoding="UTF-8") + + +def generate_graphql_schema_python_file( schema: GraphQLSchema, target_file_path: str, type_map_name: str, diff --git a/ariadne_codegen/main.py b/ariadne_codegen/main.py index 8f7531fa..57eb60c5 100644 --- a/ariadne_codegen/main.py +++ b/ariadne_codegen/main.py @@ -5,7 +5,10 @@ from .client_generators.package import get_package_generator from .config import get_client_settings, get_config_dict, get_graphql_schema_settings -from .graphql_schema_generators.schema import generate_graphql_schema_file +from .graphql_schema_generators.schema import ( + generate_graphql_schema_graphql_file, + generate_graphql_schema_python_file, +) from .plugins.explorer import get_plugins_types from .plugins.manager import PluginManager from .schema import ( @@ -99,9 +102,15 @@ def graphql_schema(config_dict): sys.stdout.write(settings.used_settings_message) - generate_graphql_schema_file( - schema=schema, - target_file_path=settings.target_file_path, - type_map_name=settings.type_map_variable_name, - schema_variable_name=settings.schema_variable_name, - ) + if settings.target_file_format == "py": + generate_graphql_schema_python_file( + schema=schema, + target_file_path=settings.target_file_path, + type_map_name=settings.type_map_variable_name, + schema_variable_name=settings.schema_variable_name, + ) + else: + generate_graphql_schema_graphql_file( + schema=schema, + target_file_path=settings.target_file_path, + ) diff --git a/ariadne_codegen/settings.py b/ariadne_codegen/settings.py index 96c967a0..190a5f7e 100644 --- a/ariadne_codegen/settings.py +++ b/ariadne_codegen/settings.py @@ -195,6 +195,8 @@ class GraphQLSchemaSettings(BaseSettings): def __post_init__(self): super().__post_init__() + + assert_string_is_valid_schema_target_filename(self.target_file_path) assert_string_is_valid_python_identifier(self.schema_variable_name) assert_string_is_valid_python_identifier(self.type_map_variable_name) @@ -206,17 +208,32 @@ def used_settings_message(self): if self.plugins else "No plugin is being used." ) + + if self.target_file_format == "py": + return dedent( + f"""\ + Selected strategy: {Strategy.GRAPHQL_SCHEMA} + Using schema from {self.schema_path or self.remote_schema_url} + Saving graphql schema to: {self.target_file_path} + Using {self.schema_variable_name} as variable name for schema. + Using {self.type_map_variable_name} as variable name for type map. + {plugins_msg} + """ + ) + return dedent( f"""\ Selected strategy: {Strategy.GRAPHQL_SCHEMA} - Using schema from '{self.schema_path or self.remote_schema_url}'. - Saving graphql schema to: {self.target_file_path}. - Using {self.schema_variable_name} as variable name for schema. - Using {self.type_map_variable_name} as variable name for type map. + Using schema from {self.schema_path or self.remote_schema_url} + Saving graphql schema to: {self.target_file_path} {plugins_msg} """ ) + @property + def target_file_format(self): + return Path(self.target_file_path).suffix[1:].lower() + def assert_path_exists(path: str): if not Path(path).exists(): @@ -233,10 +250,25 @@ def assert_path_is_valid_file(path: str): raise InvalidConfiguration(f"Provided path {path} isn't a file.") +def assert_string_is_valid_schema_target_filename(filename: str): + file_type = Path(filename).suffix + if not file_type: + raise InvalidConfiguration( + f"Provided file name {filename} is missing a file type." + ) + + file_type = file_type[1:].lower() + if file_type not in ("py", "graphql", "gql"): + raise InvalidConfiguration( + f"Provided file name {filename} has an invalid type {file_type}." + " Valid types are py, graphql and gql." + ) + + def assert_string_is_valid_python_identifier(name: str): if not name.isidentifier() and not iskeyword(name): raise InvalidConfiguration( - f"Provided name {name} cannot be used as python indetifier" + f"Provided name {name} cannot be used as python identifier." ) From 078eaaa1a73b4ffd1f54cb213ad9463a9ff2cd74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Pito=C5=84?= Date: Thu, 18 Jan 2024 16:30:28 +0100 Subject: [PATCH 12/34] Update changelog and readme --- CHANGELOG.md | 5 +++++ README.md | 34 ++++++++++++++++++++++++++++------ 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 49ce3a35..d3bec3c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # CHANGELOG +## 0.12.0 (UNRELEASED) + +- Added support to `graphqlschema` for saving schema as a GraphQL file. + + ## 0.11.0 (2023-12-05) - Removed `model_rebuild` calls for generated input, fragment and result models. diff --git a/README.md b/README.md index 57513eb6..d28e66cb 100644 --- a/README.md +++ b/README.md @@ -323,23 +323,45 @@ Example with simple schema and few queries and mutations is available [here](htt ## Generating graphql schema's python representation -Instead of generating client, you can generate file with a copy of GraphQL schema as `GraphQLSchema` declaration. To do this call `ariadne-codegen` with `graphqlschema` argument: +Instead of generating a client, you can generate a file with a copy of a GraphQL schema. To do this call `ariadne-codegen` with `graphqlschema` argument: + ``` ariadne-codegen graphqlschema ``` -`graphqlschema` mode reads configuration from the same place as [`client`](#configuration) but uses only `schema_path`, `remote_schema_url`, `remote_schema_headers`, `remote_schema_verify_ssl` and `plugins` options with addition to some extra options specific to it: +`graphqlschema` mode reads configuration from the same place as [`client`](#configuration) but uses only `schema_path`, `remote_schema_url`, `remote_schema_headers`, `remote_schema_verify_ssl` options to retrieve the schema and `plugins` option to load plugins. + +In addition to the above, `graphqlschema` mode also accepts additional settings specific to it: + + +### `target_file_path` -- `target_file_path` (defaults to `"schema.py"`) - destination path for generated file -- `schema_variable_name` (defaults to `"schema"`) - name for schema variable, must be valid python identifier -- `type_map_variable_name` (defaults to `"type_map"`) - name for type map variable, must be valid python identifier +A string with destination path for generated file. Must be either a Python (`.py`), or GraphQL (`.graphql` or `.gql`) file. -Generated file contains: +Defaults to `schema.py`. + +Generated Python file will contain: - Necessary imports - Type map declaration `{type_map_variable_name}: TypeMap = {...}` - Schema declaration `{schema_variable_name}: GraphQLSchema = GraphQLSchema(...)` +Generated GraphQL file will contain a formatted output of the `print_schema` function from the `graphql-core` package. + + +### `schema_variable_name` + +A string with a name for schema variable, must be valid python identifier. + +Defaults to `"schema"`. Used only if target is a Python file. + + +### `type_map_variable_name` + +A string with a name for type map variable, must be valid python identifier. + +Defaults to `"type_map"`. Used only if target is a Python file. + ## Contributing From bbae81ea2734e2b79a25cd05b1c87c3d1b3f28b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Pito=C5=84?= Date: Thu, 18 Jan 2024 16:34:53 +0100 Subject: [PATCH 13/34] Fix tests --- .../graphql_schema_generators/test_schema.py | 22 +++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/tests/graphql_schema_generators/test_schema.py b/tests/graphql_schema_generators/test_schema.py index 7b7e0980..31d7aee5 100644 --- a/tests/graphql_schema_generators/test_schema.py +++ b/tests/graphql_schema_generators/test_schema.py @@ -1,9 +1,10 @@ import ast -from graphql import Undefined, build_schema +from graphql import Undefined, build_schema, print_schema from ariadne_codegen.graphql_schema_generators.schema import ( - generate_graphql_schema_file, + generate_graphql_schema_graphql_file, + generate_graphql_schema_python_file, generate_schema, generate_schema_module, generate_type_map, @@ -28,11 +29,24 @@ """ -def test_generate_graphql_schema_file_creates_file_with_variables(tmp_path): +def test_generate_graphql_schema_graphql_file_creates_file_printed_schema(tmp_path): + schema = build_schema(SCHEMA_STR) + file_path = tmp_path / "test_schema.graphql" + + generate_graphql_schema_graphql_file(schema, file_path.as_posix()) + + assert file_path.exists() + assert file_path.is_file() + with file_path.open() as file_: + content = file_.read() + assert content == print_schema(schema) + + +def test_generate_graphql_schema_python_file_creates_file_with_variables(tmp_path): schema = build_schema(SCHEMA_STR) file_path = tmp_path / "test_schema.py" - generate_graphql_schema_file(schema, file_path.as_posix(), "type_map", "schema") + generate_graphql_schema_python_file(schema, file_path.as_posix(), "type_map", "schema") assert file_path.exists() assert file_path.is_file() From 279da88d9ce4ebf085407ed47de28120d5277eee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Pito=C5=84?= Date: Thu, 18 Jan 2024 16:35:39 +0100 Subject: [PATCH 14/34] Fix tests --- ariadne_codegen/settings.py | 2 +- tests/graphql_schema_generators/test_schema.py | 12 +++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/ariadne_codegen/settings.py b/ariadne_codegen/settings.py index 190a5f7e..fbd8f427 100644 --- a/ariadne_codegen/settings.py +++ b/ariadne_codegen/settings.py @@ -257,7 +257,7 @@ def assert_string_is_valid_schema_target_filename(filename: str): f"Provided file name {filename} is missing a file type." ) - file_type = file_type[1:].lower() + file_type = file_type[1:].lower() if file_type not in ("py", "graphql", "gql"): raise InvalidConfiguration( f"Provided file name {filename} has an invalid type {file_type}." diff --git a/tests/graphql_schema_generators/test_schema.py b/tests/graphql_schema_generators/test_schema.py index 31d7aee5..b7d9bc2c 100644 --- a/tests/graphql_schema_generators/test_schema.py +++ b/tests/graphql_schema_generators/test_schema.py @@ -29,7 +29,9 @@ """ -def test_generate_graphql_schema_graphql_file_creates_file_printed_schema(tmp_path): +def test_generate_graphql_schema_graphql_file_creates_file_with_printed_schema( + tmp_path, +): schema = build_schema(SCHEMA_STR) file_path = tmp_path / "test_schema.graphql" @@ -42,11 +44,15 @@ def test_generate_graphql_schema_graphql_file_creates_file_printed_schema(tmp_pa assert content == print_schema(schema) -def test_generate_graphql_schema_python_file_creates_file_with_variables(tmp_path): +def test_generate_graphql_schema_python_file_creates_py_file_with_variables( + tmp_path, +): schema = build_schema(SCHEMA_STR) file_path = tmp_path / "test_schema.py" - generate_graphql_schema_python_file(schema, file_path.as_posix(), "type_map", "schema") + generate_graphql_schema_python_file( + schema, file_path.as_posix(), "type_map", "schema" + ) assert file_path.exists() assert file_path.is_file() From f7ee398dcfc830f735cb95d0963879d0eabbe37c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Pito=C5=84?= Date: Thu, 18 Jan 2024 16:41:12 +0100 Subject: [PATCH 15/34] More tests --- .../graphql_schema_generators/schema.py | 3 +- tests/test_settings.py | 55 +++++++++++++++++++ 2 files changed, 56 insertions(+), 2 deletions(-) diff --git a/ariadne_codegen/graphql_schema_generators/schema.py b/ariadne_codegen/graphql_schema_generators/schema.py index bc5808c1..9de87871 100644 --- a/ariadne_codegen/graphql_schema_generators/schema.py +++ b/ariadne_codegen/graphql_schema_generators/schema.py @@ -1,9 +1,8 @@ import ast from pathlib import Path -from graphql import GraphQLSchema +from graphql import GraphQLSchema, print_schema from graphql.type.schema import TypeMap -from graphql import print_schema from ..codegen import ( generate_ann_assign, diff --git a/tests/test_settings.py b/tests/test_settings.py index 3a629287..decebdd9 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -226,6 +226,45 @@ def test_graphq_schema_settings_without_schema_path_with_remote_schema_url_is_va assert not settings.schema_path +def test_graphql_schema_settings_with_target_file_path_with_py_extension_is_valid(): + settings = GraphQLSchemaSettings( + remote_schema_url="http://testserver/graphq/", + target_file_path="schema_file.py", + ) + + assert settings.target_file_path == "schema_file.py" + assert settings.target_file_format == "py" + + +def test_graphql_schema_settings_with_target_file_path_with_graphql_extension_is_valid(): + settings = GraphQLSchemaSettings( + remote_schema_url="http://testserver/graphq/", + target_file_path="schema_file.graphql", + ) + + assert settings.target_file_path == "schema_file.graphql" + assert settings.target_file_format == "graphql" + + +def test_graphql_schema_settings_with_target_file_path_with_graphql_extension_is_valid(): + settings = GraphQLSchemaSettings( + remote_schema_url="http://testserver/graphq/", + target_file_path="schema_file.gql", + ) + + assert settings.target_file_path == "schema_file.gql" + assert settings.target_file_format == "gql" + + +def test_graphql_schema_settings_target_file_format_is_lowercased(): + settings = GraphQLSchemaSettings( + remote_schema_url="http://testserver/graphq/", + target_file_path="schema_file.GQL", + ) + + assert settings.target_file_format == "gql" + + def test_graphq_schema_settings_without_schema_path_or_remote_schema_url_is_not_valid(): with pytest.raises(InvalidConfiguration): GraphQLSchemaSettings() @@ -236,6 +275,22 @@ def test_graphql_schema_settings_raises_invalid_configuration_for_invalid_schema GraphQLSchemaSettings(schema_path="not_exisitng.graphql") +def test_graphql_schema_settings_with_target_file_path_missing_extension_raises_exception(): + with pytest.raises(InvalidConfiguration): + GraphQLSchemaSettings( + remote_schema_url="http://testserver/graphq/", + target_file_path="schema_file", + ) + + +def test_graphql_schema_settings_with_target_file_path_invalid_extension_raises_exception(): + with pytest.raises(InvalidConfiguration): + GraphQLSchemaSettings( + remote_schema_url="http://testserver/graphq/", + target_file_path="schema_file.invalid", + ) + + def test_graphql_schema_settings_with_invalid_schema_variable_name_raises_exception(): with pytest.raises(InvalidConfiguration): GraphQLSchemaSettings( From 4efe2d2dc3ae7b9af69c7272e6846791da996b88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Pito=C5=84?= Date: Thu, 18 Jan 2024 16:53:19 +0100 Subject: [PATCH 16/34] More tests --- tests/main/test_main.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/main/test_main.py b/tests/main/test_main.py index 569ef416..8dc755ae 100644 --- a/tests/main/test_main.py +++ b/tests/main/test_main.py @@ -315,6 +315,22 @@ def test_main_can_read_config_from_provided_file(tmp_path): "schema.py", GRAPHQL_SCHEMAS_PATH / "all_types" / "expected_schema.py", ), + ( + ( + GRAPHQL_SCHEMAS_PATH / "example" / "pyproject-schema-graphql.toml", + (GRAPHQL_SCHEMAS_PATH / "example" / "schema.graphql",), + ), + "expected_schema.graphql", + GRAPHQL_SCHEMAS_PATH / "example" / "expected_schema.graphql", + ), + ( + ( + GRAPHQL_SCHEMAS_PATH / "example" / "pyproject-schema-gql.toml", + (GRAPHQL_SCHEMAS_PATH / "example" / "schema.graphql",), + ), + "expected_schema.gql", + GRAPHQL_SCHEMAS_PATH / "example" / "expected_schema.gql", + ), ], indirect=["project_dir"], ) From 640f22ea2d59def58d112b2e5cada6df208afa79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Pito=C5=84?= Date: Thu, 18 Jan 2024 16:55:04 +0100 Subject: [PATCH 17/34] commit missing test files --- .../example/expected_schema.gql | 59 +++++++++++++++++++ .../example/expected_schema.graphql | 59 +++++++++++++++++++ .../example/pyproject-schema-gql.toml | 3 + .../example/pyproject-schema-graphql.toml | 3 + 4 files changed, 124 insertions(+) create mode 100644 tests/main/graphql_schemas/example/expected_schema.gql create mode 100644 tests/main/graphql_schemas/example/expected_schema.graphql create mode 100644 tests/main/graphql_schemas/example/pyproject-schema-gql.toml create mode 100644 tests/main/graphql_schemas/example/pyproject-schema-graphql.toml diff --git a/tests/main/graphql_schemas/example/expected_schema.gql b/tests/main/graphql_schemas/example/expected_schema.gql new file mode 100644 index 00000000..2e19a47b --- /dev/null +++ b/tests/main/graphql_schemas/example/expected_schema.gql @@ -0,0 +1,59 @@ +type Query { + users(country: String): [User!]! +} + +type Mutation { + userCreate(userData: UserCreateInput!): User + userPreferences(data: UserPreferencesInput): Boolean! +} + +input UserCreateInput { + firstName: String + lastName: String + email: String! + favouriteColor: Color + location: LocationInput +} + +input LocationInput { + city: String + country: String +} + +type User { + id: ID! + firstName: String + lastName: String + email: String! + favouriteColor: Color + location: Location +} + +type Location { + city: String + country: String +} + +enum Color { + BLACK + WHITE + RED + GREEN + BLUE + YELLOW +} + +input UserPreferencesInput { + luckyNumber: Int = 7 + favouriteWord: String = "word" + colorOpacity: Float = 1 + excludedTags: [String!] = ["offtop", "tag123"] + notificationsPreferences: NotificationsPreferencesInput! = {receiveMails: true, receivePushNotifications: true, receiveSms: false, title: "Mr"} +} + +input NotificationsPreferencesInput { + receiveMails: Boolean! + receivePushNotifications: Boolean! + receiveSms: Boolean! + title: String! +} \ No newline at end of file diff --git a/tests/main/graphql_schemas/example/expected_schema.graphql b/tests/main/graphql_schemas/example/expected_schema.graphql new file mode 100644 index 00000000..2e19a47b --- /dev/null +++ b/tests/main/graphql_schemas/example/expected_schema.graphql @@ -0,0 +1,59 @@ +type Query { + users(country: String): [User!]! +} + +type Mutation { + userCreate(userData: UserCreateInput!): User + userPreferences(data: UserPreferencesInput): Boolean! +} + +input UserCreateInput { + firstName: String + lastName: String + email: String! + favouriteColor: Color + location: LocationInput +} + +input LocationInput { + city: String + country: String +} + +type User { + id: ID! + firstName: String + lastName: String + email: String! + favouriteColor: Color + location: Location +} + +type Location { + city: String + country: String +} + +enum Color { + BLACK + WHITE + RED + GREEN + BLUE + YELLOW +} + +input UserPreferencesInput { + luckyNumber: Int = 7 + favouriteWord: String = "word" + colorOpacity: Float = 1 + excludedTags: [String!] = ["offtop", "tag123"] + notificationsPreferences: NotificationsPreferencesInput! = {receiveMails: true, receivePushNotifications: true, receiveSms: false, title: "Mr"} +} + +input NotificationsPreferencesInput { + receiveMails: Boolean! + receivePushNotifications: Boolean! + receiveSms: Boolean! + title: String! +} \ No newline at end of file diff --git a/tests/main/graphql_schemas/example/pyproject-schema-gql.toml b/tests/main/graphql_schemas/example/pyproject-schema-gql.toml new file mode 100644 index 00000000..557b014f --- /dev/null +++ b/tests/main/graphql_schemas/example/pyproject-schema-gql.toml @@ -0,0 +1,3 @@ +[tool.ariadne-codegen] +schema_path = "schema.graphql" +target_file_path = "expected_schema.gql" diff --git a/tests/main/graphql_schemas/example/pyproject-schema-graphql.toml b/tests/main/graphql_schemas/example/pyproject-schema-graphql.toml new file mode 100644 index 00000000..c9083813 --- /dev/null +++ b/tests/main/graphql_schemas/example/pyproject-schema-graphql.toml @@ -0,0 +1,3 @@ +[tool.ariadne-codegen] +schema_path = "schema.graphql" +target_file_path = "expected_schema.graphql" From 04de31c169f4e4fb9df15a153b5f5d7187d06ea0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Pito=C5=84?= Date: Thu, 18 Jan 2024 16:57:40 +0100 Subject: [PATCH 18/34] Shutup pylint about test names being too long --- tests/test_settings.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_settings.py b/tests/test_settings.py index decebdd9..771ace47 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -1,3 +1,4 @@ +# pylint: disable=line-too-long import os from pathlib import Path from textwrap import dedent @@ -246,7 +247,7 @@ def test_graphql_schema_settings_with_target_file_path_with_graphql_extension_is assert settings.target_file_format == "graphql" -def test_graphql_schema_settings_with_target_file_path_with_graphql_extension_is_valid(): +def test_graphql_schema_settings_with_target_file_path_with_gql_extension_is_valid(): settings = GraphQLSchemaSettings( remote_schema_url="http://testserver/graphq/", target_file_path="schema_file.gql", From fa30489184d46509ebe87075b7a3e62d805207f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Pito=C5=84?= Date: Thu, 18 Jan 2024 16:58:39 +0100 Subject: [PATCH 19/34] Shorten test names --- tests/test_settings.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/test_settings.py b/tests/test_settings.py index 771ace47..11d03523 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -1,4 +1,3 @@ -# pylint: disable=line-too-long import os from pathlib import Path from textwrap import dedent @@ -237,7 +236,7 @@ def test_graphql_schema_settings_with_target_file_path_with_py_extension_is_vali assert settings.target_file_format == "py" -def test_graphql_schema_settings_with_target_file_path_with_graphql_extension_is_valid(): +def test_graphql_schema_settings_with_target_file_with_graphql_extension_is_valid(): settings = GraphQLSchemaSettings( remote_schema_url="http://testserver/graphq/", target_file_path="schema_file.graphql", @@ -276,7 +275,7 @@ def test_graphql_schema_settings_raises_invalid_configuration_for_invalid_schema GraphQLSchemaSettings(schema_path="not_exisitng.graphql") -def test_graphql_schema_settings_with_target_file_path_missing_extension_raises_exception(): +def test_graphql_schema_settings_with_target_file_missing_extension_raises_exception(): with pytest.raises(InvalidConfiguration): GraphQLSchemaSettings( remote_schema_url="http://testserver/graphq/", @@ -284,7 +283,7 @@ def test_graphql_schema_settings_with_target_file_path_missing_extension_raises_ ) -def test_graphql_schema_settings_with_target_file_path_invalid_extension_raises_exception(): +def test_graphql_schema_settings_with_target_file_invalid_extension_raises_exception(): with pytest.raises(InvalidConfiguration): GraphQLSchemaSettings( remote_schema_url="http://testserver/graphq/", From 9fe3660f1aa355ff6b39eac53482eb9f47c6daf0 Mon Sep 17 00:00:00 2001 From: Damjan Kuznar Date: Thu, 18 Jan 2024 18:08:27 +0100 Subject: [PATCH 20/34] Fix tests --- .../dependencies/test_websockets_open_telemetry.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/client_generators/dependencies/test_websockets_open_telemetry.py b/tests/client_generators/dependencies/test_websockets_open_telemetry.py index 48c6fb48..6f14b813 100644 --- a/tests/client_generators/dependencies/test_websockets_open_telemetry.py +++ b/tests/client_generators/dependencies/test_websockets_open_telemetry.py @@ -38,7 +38,7 @@ def mocked_faulty_websocket(mocked_ws_connect): @pytest.mark.asyncio async def test_execute_ws_creates_websocket_connection_with_correct_url( - mocked_ws_connect, + mocked_ws_connect, mocked_websocket ): async for _ in AsyncBaseClientOpenTelemetry(ws_url="ws://test_url").execute_ws(""): pass @@ -49,7 +49,7 @@ async def test_execute_ws_creates_websocket_connection_with_correct_url( @pytest.mark.asyncio async def test_execute_ws_creates_websocket_connection_with_correct_subprotocol( - mocked_ws_connect, + mocked_ws_connect, mocked_websocket ): async for _ in AsyncBaseClientOpenTelemetry().execute_ws(""): pass @@ -62,7 +62,7 @@ async def test_execute_ws_creates_websocket_connection_with_correct_subprotocol( @pytest.mark.asyncio async def test_execute_ws_creates_websocket_connection_with_correct_origin( - mocked_ws_connect, + mocked_ws_connect, mocked_websocket ): async for _ in AsyncBaseClientOpenTelemetry(ws_origin="test_origin").execute_ws(""): pass @@ -73,7 +73,7 @@ async def test_execute_ws_creates_websocket_connection_with_correct_origin( @pytest.mark.asyncio async def test_execute_ws_creates_websocket_connection_with_correct_headers( - mocked_ws_connect, + mocked_ws_connect, mocked_websocket ): async for _ in AsyncBaseClientOpenTelemetry( ws_headers={"test_key": "test_value"} From 822a22d0860364a33c966fd27f0b014b0b161677 Mon Sep 17 00:00:00 2001 From: Damjan Kuznar Date: Thu, 18 Jan 2024 18:18:47 +0100 Subject: [PATCH 21/34] pylint: disable=unused-argument --- .../dependencies/test_websockets_open_telemetry.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/client_generators/dependencies/test_websockets_open_telemetry.py b/tests/client_generators/dependencies/test_websockets_open_telemetry.py index 6f14b813..b75a61bc 100644 --- a/tests/client_generators/dependencies/test_websockets_open_telemetry.py +++ b/tests/client_generators/dependencies/test_websockets_open_telemetry.py @@ -38,7 +38,7 @@ def mocked_faulty_websocket(mocked_ws_connect): @pytest.mark.asyncio async def test_execute_ws_creates_websocket_connection_with_correct_url( - mocked_ws_connect, mocked_websocket + mocked_ws_connect, mocked_websocket # pylint: disable=unused-argument ): async for _ in AsyncBaseClientOpenTelemetry(ws_url="ws://test_url").execute_ws(""): pass @@ -49,7 +49,7 @@ async def test_execute_ws_creates_websocket_connection_with_correct_url( @pytest.mark.asyncio async def test_execute_ws_creates_websocket_connection_with_correct_subprotocol( - mocked_ws_connect, mocked_websocket + mocked_ws_connect, mocked_websocket # pylint: disable=unused-argument ): async for _ in AsyncBaseClientOpenTelemetry().execute_ws(""): pass @@ -62,7 +62,7 @@ async def test_execute_ws_creates_websocket_connection_with_correct_subprotocol( @pytest.mark.asyncio async def test_execute_ws_creates_websocket_connection_with_correct_origin( - mocked_ws_connect, mocked_websocket + mocked_ws_connect, mocked_websocket # pylint: disable=unused-argument ): async for _ in AsyncBaseClientOpenTelemetry(ws_origin="test_origin").execute_ws(""): pass @@ -73,7 +73,7 @@ async def test_execute_ws_creates_websocket_connection_with_correct_origin( @pytest.mark.asyncio async def test_execute_ws_creates_websocket_connection_with_correct_headers( - mocked_ws_connect, mocked_websocket + mocked_ws_connect, mocked_websocket # pylint: disable=unused-argument ): async for _ in AsyncBaseClientOpenTelemetry( ws_headers={"test_key": "test_value"} From 2f062303cf59316c9c8ba2f97d216aecde477d5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Pito=C5=84?= Date: Fri, 19 Jan 2024 17:50:30 +0100 Subject: [PATCH 22/34] Add regression test for bug #256 --- tests/test_config.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/test_config.py b/tests/test_config.py index b728af64..1070b320 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -112,6 +112,16 @@ def test_get_client_settings_returns_client_settings_object(tmp_path): ) } + # Regression test for #256 + assert config_dict["tool"]["ariadne-codegen"]["scalars"] == { + "ID": { + "type": "str", + "parse": "parse_id", + "serialize": "serialize_id", + "import": ".custom_scalars", + }, + } + def test_get_client_settings_without_section_raises_missing_configuration_exception(): config_dict = {"invalid-section": {"schema_path": "."}} From 2387ebf5c0695cb1b781285d7a8d188197a02731 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Pito=C5=84?= Date: Fri, 19 Jan 2024 17:51:23 +0100 Subject: [PATCH 23/34] Improve comment --- tests/test_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_config.py b/tests/test_config.py index 1070b320..2157fe02 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -112,7 +112,7 @@ def test_get_client_settings_returns_client_settings_object(tmp_path): ) } - # Regression test for #256 + # Regression test for #256: don't mutate config_dict's scalars assert config_dict["tool"]["ariadne-codegen"]["scalars"] == { "ID": { "type": "str", From 71482054521454ea1f83e21a6048ba9e5f650d83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Pito=C5=84?= Date: Fri, 19 Jan 2024 17:52:00 +0100 Subject: [PATCH 24/34] Update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4c7e3622..fca48a85 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ## 0.12.0 (UNRELEASED) - Fixed `graphql-transport-ws` protocol implementation not waiting for the `connection_ack` message on new connection. +- Fixed `get_client_settings` mutating `config_dict` instance. - Added support to `graphqlschema` for saving schema as a GraphQL file. From b099166cecd344fb49ab5c0fa4a18993eb867e21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sopi=C5=84ski?= Date: Fri, 22 Dec 2023 14:32:29 +0100 Subject: [PATCH 25/34] Restore model_rebuild calls to top level fragments --- .../client_generators/constants.py | 1 + .../client_generators/fragments.py | 32 ++++++++++++++++--- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/ariadne_codegen/client_generators/constants.py b/ariadne_codegen/client_generators/constants.py index 0dfe53a2..5cecf209 100644 --- a/ariadne_codegen/client_generators/constants.py +++ b/ariadne_codegen/client_generators/constants.py @@ -56,6 +56,7 @@ MODEL_VALIDATE_METHOD = "model_validate" PLAIN_SERIALIZER = "PlainSerializer" BEFORE_VALIDATOR = "BeforeValidator" +MODEL_REBUILD_METHOD = "model_rebuild" ENUM_MODULE = "enum" ENUM_CLASS = "Enum" diff --git a/ariadne_codegen/client_generators/fragments.py b/ariadne_codegen/client_generators/fragments.py index cba04810..3ad901ce 100644 --- a/ariadne_codegen/client_generators/fragments.py +++ b/ariadne_codegen/client_generators/fragments.py @@ -3,9 +3,9 @@ from graphql import FragmentDefinitionNode, GraphQLSchema -from ..codegen import generate_module +from ..codegen import generate_expr, generate_method_call, generate_module from ..plugins.manager import PluginManager -from .constants import BASE_MODEL_IMPORT +from .constants import BASE_MODEL_IMPORT, MODEL_REBUILD_METHOD from .result_types import ResultTypesGenerator from .scalars import ScalarData @@ -36,6 +36,7 @@ def __init__( def generate(self, exclude_names: Optional[Set[str]] = None) -> ast.Module: class_defs_dict: Dict[str, List[ast.ClassDef]] = {} imports: List[ast.ImportFrom] = [] + top_level_class_names: List[str] = [] dependencies_dict: Dict[str, Set[str]] = {} names_to_exclude = exclude_names or set() @@ -53,7 +54,10 @@ def generate(self, exclude_names: Optional[Set[str]] = None) -> ast.Module: plugin_manager=self.plugin_manager, ) imports.extend(generator.get_imports()) - class_defs_dict[name] = generator.get_classes() + class_defs = generator.get_classes() + class_defs_dict[name] = class_defs + if class_defs: + top_level_class_names.append(class_defs[0].name) dependencies_dict[name] = generator.get_fragments_used_as_mixins() self._generated_public_names.extend(generator.get_generated_public_names()) self._used_enums.extend(generator.get_used_enums()) @@ -62,7 +66,15 @@ def generate(self, exclude_names: Optional[Set[str]] = None) -> ast.Module: class_defs_dict=class_defs_dict, dependencies_dict=dependencies_dict ) module = generate_module( - body=cast(List[ast.stmt], imports) + cast(List[ast.stmt], sorted_class_defs) + body=cast(List[ast.stmt], imports) + + cast(List[ast.stmt], sorted_class_defs) + + cast( + List[ast.stmt], + self._get_model_rebuild_calls( + top_level_fragments_names=top_level_class_names, + class_defs=sorted_class_defs, + ), + ) ) if self.plugin_manager: module = self.plugin_manager.generate_fragments_module( @@ -108,3 +120,15 @@ def visit(name): visit(name) return sorted_names + + def _get_model_rebuild_calls( + self, top_level_fragments_names: List[str], class_defs: List[ast.ClassDef] + ) -> List[ast.Call]: + class_names = [c.name for c in class_defs] + sorted_fragments_names = sorted( + top_level_fragments_names, key=lambda n: class_names.index(n) + ) + return [ + generate_expr(generate_method_call(name, MODEL_REBUILD_METHOD)) + for name in sorted_fragments_names + ] From 0e6754cb837bfa2e0cad92e913528334a2c6d49c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sopi=C5=84ski?= Date: Fri, 22 Dec 2023 14:32:48 +0100 Subject: [PATCH 26/34] Update tests --- tests/main/clients/example/expected_client/fragments.py | 4 ++++ .../clients/extended_models/expected_client/fragments.py | 5 +++++ .../expected_client/fragments.py | 4 ++++ .../clients/inline_fragments/expected_client/fragments.py | 5 +++++ .../multiple_fragments/expected_client/fragments.py | 7 +++++++ .../expected_client/fragments.py | 4 ++++ tests/main/clients/operations/expected_client/fragments.py | 4 ++++ .../expected_client/shorter_results_fragments.py | 4 ++++ 8 files changed, 37 insertions(+) diff --git a/tests/main/clients/example/expected_client/fragments.py b/tests/main/clients/example/expected_client/fragments.py index 4795d454..e1cee688 100644 --- a/tests/main/clients/example/expected_client/fragments.py +++ b/tests/main/clients/example/expected_client/fragments.py @@ -13,3 +13,7 @@ class BasicUser(BaseModel): class UserPersonalData(BaseModel): first_name: Optional[str] = Field(alias="firstName") last_name: Optional[str] = Field(alias="lastName") + + +BasicUser.model_rebuild() +UserPersonalData.model_rebuild() diff --git a/tests/main/clients/extended_models/expected_client/fragments.py b/tests/main/clients/extended_models/expected_client/fragments.py index 30a907d0..77a32fea 100644 --- a/tests/main/clients/extended_models/expected_client/fragments.py +++ b/tests/main/clients/extended_models/expected_client/fragments.py @@ -20,3 +20,8 @@ class GetQueryAFragment(BaseModel): class GetQueryAFragmentQueryA(BaseModel, MixinA, CommonMixin): field_a: int = Field(alias="fieldA") + + +FragmentA.model_rebuild() +FragmentB.model_rebuild() +GetQueryAFragment.model_rebuild() diff --git a/tests/main/clients/fragments_on_abstract_types/expected_client/fragments.py b/tests/main/clients/fragments_on_abstract_types/expected_client/fragments.py index 1206200f..e20413f0 100644 --- a/tests/main/clients/fragments_on_abstract_types/expected_client/fragments.py +++ b/tests/main/clients/fragments_on_abstract_types/expected_client/fragments.py @@ -11,3 +11,7 @@ class FragmentA(BaseModel): class FragmentB(BaseModel): id: str value_b: str = Field(alias="valueB") + + +FragmentA.model_rebuild() +FragmentB.model_rebuild() diff --git a/tests/main/clients/inline_fragments/expected_client/fragments.py b/tests/main/clients/inline_fragments/expected_client/fragments.py index 16fcc46a..4770eb21 100644 --- a/tests/main/clients/inline_fragments/expected_client/fragments.py +++ b/tests/main/clients/inline_fragments/expected_client/fragments.py @@ -57,3 +57,8 @@ class FragmentOnQueryWithUnionQueryUTypeC(BaseModel): class UnusedFragmentOnTypeA(BaseModel): id: str field_a: str = Field(alias="fieldA") + + +FragmentOnQueryWithInterface.model_rebuild() +FragmentOnQueryWithUnion.model_rebuild() +UnusedFragmentOnTypeA.model_rebuild() diff --git a/tests/main/clients/multiple_fragments/expected_client/fragments.py b/tests/main/clients/multiple_fragments/expected_client/fragments.py index c4d0d139..7f29eb44 100644 --- a/tests/main/clients/multiple_fragments/expected_client/fragments.py +++ b/tests/main/clients/multiple_fragments/expected_client/fragments.py @@ -40,3 +40,10 @@ class MinimalA(BaseModel): class MinimalAFieldB(MinimalB): pass + + +CompleteA.model_rebuild() +FullB.model_rebuild() +FullA.model_rebuild() +MinimalB.model_rebuild() +MinimalA.model_rebuild() diff --git a/tests/main/clients/only_used_inputs_and_enums/expected_client/fragments.py b/tests/main/clients/only_used_inputs_and_enums/expected_client/fragments.py index 4bad3235..25a2a806 100644 --- a/tests/main/clients/only_used_inputs_and_enums/expected_client/fragments.py +++ b/tests/main/clients/only_used_inputs_and_enums/expected_client/fragments.py @@ -8,3 +8,7 @@ class FragmentG(BaseModel): class FragmentGG(BaseModel): val: EnumGG + + +FragmentG.model_rebuild() +FragmentGG.model_rebuild() diff --git a/tests/main/clients/operations/expected_client/fragments.py b/tests/main/clients/operations/expected_client/fragments.py index 5adc6550..6836e301 100644 --- a/tests/main/clients/operations/expected_client/fragments.py +++ b/tests/main/clients/operations/expected_client/fragments.py @@ -9,3 +9,7 @@ class FragmentB(BaseModel): class FragmentY(BaseModel): value_y: int = Field(alias="valueY") + + +FragmentB.model_rebuild() +FragmentY.model_rebuild() diff --git a/tests/main/clients/shorter_results/expected_client/shorter_results_fragments.py b/tests/main/clients/shorter_results/expected_client/shorter_results_fragments.py index f7e5e395..4b103a44 100644 --- a/tests/main/clients/shorter_results/expected_client/shorter_results_fragments.py +++ b/tests/main/clients/shorter_results/expected_client/shorter_results_fragments.py @@ -22,3 +22,7 @@ class ListAnimalsFragment(BaseModel): class ListAnimalsFragmentListAnimals(BaseModel): typename__: Literal["Animal", "Cat", "Dog"] = Field(alias="__typename") name: str + + +FragmentWithSingleField.model_rebuild() +ListAnimalsFragment.model_rebuild() From c54bb83cf9fd9fa85b167e0598acc35630b8be55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sopi=C5=84ski?= Date: Fri, 22 Dec 2023 14:33:25 +0100 Subject: [PATCH 27/34] Update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index fca48a85..dcf06dfc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ - Fixed `graphql-transport-ws` protocol implementation not waiting for the `connection_ack` message on new connection. - Fixed `get_client_settings` mutating `config_dict` instance. - Added support to `graphqlschema` for saving schema as a GraphQL file. +- Restored `model_rebuild` calls for top level fragment models. ## 0.11.0 (2023-12-05) From 62e49e4433602e089dd119779617302294b2e4a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sopi=C5=84ski?= Date: Fri, 22 Dec 2023 14:33:55 +0100 Subject: [PATCH 28/34] Update example --- EXAMPLE.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/EXAMPLE.md b/EXAMPLE.md index 8646df5c..a465e31c 100644 --- a/EXAMPLE.md +++ b/EXAMPLE.md @@ -501,6 +501,10 @@ class BasicUser(BaseModel): class UserPersonalData(BaseModel): first_name: Optional[str] = Field(alias="firstName") last_name: Optional[str] = Field(alias="lastName") + + +BasicUser.model_rebuild() +UserPersonalData.model_rebuild() ``` ### Init file From f2cd15de9b44b2b442a31dba0815c70871e37296 Mon Sep 17 00:00:00 2001 From: Simon Sawert Date: Tue, 23 Jan 2024 09:28:59 +0100 Subject: [PATCH 29/34] Fix `pylint` violation --- ariadne_codegen/client_generators/fragments.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ariadne_codegen/client_generators/fragments.py b/ariadne_codegen/client_generators/fragments.py index 3ad901ce..3a7d2ecc 100644 --- a/ariadne_codegen/client_generators/fragments.py +++ b/ariadne_codegen/client_generators/fragments.py @@ -126,7 +126,7 @@ def _get_model_rebuild_calls( ) -> List[ast.Call]: class_names = [c.name for c in class_defs] sorted_fragments_names = sorted( - top_level_fragments_names, key=lambda n: class_names.index(n) + top_level_fragments_names, key=class_names.index ) return [ generate_expr(generate_method_call(name, MODEL_REBUILD_METHOD)) From 5ae332a3e2b8cc24f3d7f83fb8cde31c82bec3cb Mon Sep 17 00:00:00 2001 From: Simon Sawert Date: Tue, 23 Jan 2024 11:39:51 +0100 Subject: [PATCH 30/34] Add test to ensure models are rebuilt properly --- tests/main/test_model_rebuild_validation.py | 59 +++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 tests/main/test_model_rebuild_validation.py diff --git a/tests/main/test_model_rebuild_validation.py b/tests/main/test_model_rebuild_validation.py new file mode 100644 index 00000000..d457ca31 --- /dev/null +++ b/tests/main/test_model_rebuild_validation.py @@ -0,0 +1,59 @@ +""" +To ensure all models with nested dependencies are fully rebuilt this test will +create an instance of the query from `multiple_fragments` containing a `FullA` +(extended from `ExampleQuery2ExampleQuery`) which in turn holds a `FullB` +(extended from `FullAFieldB`). + +If this model is not rebuilt with `FullA.model_rebuild()` `ExampleQuery2` will +not be fully defined and we will get a `PydanticUserError`. + +Reference to Pydantic documentation about when and why we need to call +`model_rebuild` on our types: +https://errors.pydantic.dev/2.5/u/class-not-fully-defined +""" + +import pytest +from pydantic_core import ValidationError + +from .clients.multiple_fragments.expected_client.example_query_2 import ( + ExampleQuery2, + ExampleQuery2ExampleQuery, +) +from .clients.multiple_fragments.expected_client.fragments import FullA + + +def test_model_rebuild_validate(): + # Perform some sanity checks on the schema for `ExampleQuery2` to test that + # it confirms to the fields of `FullA` and that it references `FullB`. + json_schema = ExampleQuery2.model_json_schema() + assert all( + x in json_schema["$defs"] for x in ["ExampleQuery2ExampleQuery", "FullAFieldB"] + ) + + query_props = json_schema["$defs"]["ExampleQuery2ExampleQuery"]["properties"] + assert all(x in query_props for x in ["id", "value", "fieldB"]) + assert query_props["fieldB"]["$ref"] == "#/$defs/FullAFieldB" + + # Assert we cannot validate a faulty type. + field_b = {"id": "321", "value": 13.37} + field_a = {"id": "123", "value": "A", "field_b": field_b} + + with pytest.raises(ValidationError): + ExampleQuery2.model_validate(field_a) + + # However it should work with the correct type and the type extending the + # correct type. + try: + FullA.model_validate(field_a) + ExampleQuery2ExampleQuery.model_validate(field_a) + except ValidationError as e: + assert False, f"model_valiadte failed: {e}" + + # And since the model is rebuilt we should be able to construct a full + # `ExampleQuery2`. + example_query_2 = {"example_query": field_a} + + try: + ExampleQuery2.model_validate(example_query_2) + except ValidationError as e: + assert False, f"model validation failed: {e}" From 52699208fe3d0316571537d3658567d349e03bb9 Mon Sep 17 00:00:00 2001 From: Simon Sawert Date: Tue, 23 Jan 2024 21:58:40 +0100 Subject: [PATCH 31/34] Update docstring, separate assert, split test --- tests/main/test_model_rebuild_validation.py | 37 ++++++++++++--------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/tests/main/test_model_rebuild_validation.py b/tests/main/test_model_rebuild_validation.py index d457ca31..7caefbc6 100644 --- a/tests/main/test_model_rebuild_validation.py +++ b/tests/main/test_model_rebuild_validation.py @@ -1,8 +1,8 @@ """ -To ensure all models with nested dependencies are fully rebuilt this test will -create an instance of the query from `multiple_fragments` containing a `FullA` -(extended from `ExampleQuery2ExampleQuery`) which in turn holds a `FullB` -(extended from `FullAFieldB`). +To ensure all models with nested dependencies are fully rebuilt this test +creates an instance of the query from `multiple_fragments` containing the +`FullA` fragment (used by the `ExampleQuery2ExampleQuery`) which itself includes +a field of type `FullAFieldB` that extends the `FullB` fragment. If this model is not rebuilt with `FullA.model_rebuild()` `ExampleQuery2` will not be fully defined and we will get a `PydanticUserError`. @@ -22,35 +22,40 @@ from .clients.multiple_fragments.expected_client.fragments import FullA -def test_model_rebuild_validate(): - # Perform some sanity checks on the schema for `ExampleQuery2` to test that - # it confirms to the fields of `FullA` and that it references `FullB`. +def test_json_schema_contains_all_properties(): json_schema = ExampleQuery2.model_json_schema() - assert all( - x in json_schema["$defs"] for x in ["ExampleQuery2ExampleQuery", "FullAFieldB"] - ) + assert "ExampleQuery2ExampleQuery" in json_schema["$defs"] + assert "FullAFieldB" in json_schema["$defs"] query_props = json_schema["$defs"]["ExampleQuery2ExampleQuery"]["properties"] - assert all(x in query_props for x in ["id", "value", "fieldB"]) + assert "id" in query_props + assert "value" in query_props + assert "fieldB" in query_props assert query_props["fieldB"]["$ref"] == "#/$defs/FullAFieldB" - # Assert we cannot validate a faulty type. + +@pytest.fixture +def field_a(): field_b = {"id": "321", "value": 13.37} field_a = {"id": "123", "value": "A", "field_b": field_b} + return field_a + + +def test_validate_field_a_on_faulty_model(field_a): with pytest.raises(ValidationError): ExampleQuery2.model_validate(field_a) - # However it should work with the correct type and the type extending the - # correct type. + +def test_validate_field_a_on_correct_model(field_a): try: FullA.model_validate(field_a) ExampleQuery2ExampleQuery.model_validate(field_a) except ValidationError as e: assert False, f"model_valiadte failed: {e}" - # And since the model is rebuilt we should be able to construct a full - # `ExampleQuery2`. + +def test_validate_field_a_in_example_query(field_a): example_query_2 = {"example_query": field_a} try: From 5aa4fad02a7048fb58e51dbada68516cc23b0057 Mon Sep 17 00:00:00 2001 From: Simon Sawert Date: Wed, 24 Jan 2024 13:50:45 +0100 Subject: [PATCH 32/34] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Rafał Pitoń --- tests/main/test_model_rebuild_validation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/main/test_model_rebuild_validation.py b/tests/main/test_model_rebuild_validation.py index 7caefbc6..15cb36d7 100644 --- a/tests/main/test_model_rebuild_validation.py +++ b/tests/main/test_model_rebuild_validation.py @@ -1,11 +1,11 @@ """ -To ensure all models with nested dependencies are fully rebuilt this test -creates an instance of the query from `multiple_fragments` containing the +To ensure all models with nested dependencies are fully rebuilt those tests +create an instance of the query from `multiple_fragments` containing the `FullA` fragment (used by the `ExampleQuery2ExampleQuery`) which itself includes a field of type `FullAFieldB` that extends the `FullB` fragment. If this model is not rebuilt with `FullA.model_rebuild()` `ExampleQuery2` will -not be fully defined and we will get a `PydanticUserError`. +not be fully defined and we will raise a `PydanticUserError`. Reference to Pydantic documentation about when and why we need to call `model_rebuild` on our types: From 5d6e3f8fa06a606032148a9c43bce7956199239e Mon Sep 17 00:00:00 2001 From: Simon Sawert Date: Wed, 24 Jan 2024 13:52:23 +0100 Subject: [PATCH 33/34] Rename field_a fixture --- tests/main/test_model_rebuild_validation.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/main/test_model_rebuild_validation.py b/tests/main/test_model_rebuild_validation.py index 15cb36d7..1e82ff64 100644 --- a/tests/main/test_model_rebuild_validation.py +++ b/tests/main/test_model_rebuild_validation.py @@ -35,28 +35,28 @@ def test_json_schema_contains_all_properties(): @pytest.fixture -def field_a(): +def field_a_data(): field_b = {"id": "321", "value": 13.37} field_a = {"id": "123", "value": "A", "field_b": field_b} return field_a -def test_validate_field_a_on_faulty_model(field_a): +def test_validate_field_a_on_faulty_model(field_a_data): with pytest.raises(ValidationError): - ExampleQuery2.model_validate(field_a) + ExampleQuery2.model_validate(field_a_data) -def test_validate_field_a_on_correct_model(field_a): +def test_validate_field_a_on_correct_model(field_a_data): try: - FullA.model_validate(field_a) - ExampleQuery2ExampleQuery.model_validate(field_a) + FullA.model_validate(field_a_data) + ExampleQuery2ExampleQuery.model_validate(field_a_data) except ValidationError as e: assert False, f"model_valiadte failed: {e}" -def test_validate_field_a_in_example_query(field_a): - example_query_2 = {"example_query": field_a} +def test_validate_field_a_in_example_query(field_a_data): + example_query_2 = {"example_query": field_a_data} try: ExampleQuery2.model_validate(example_query_2) From da102b58525cea5f86498d37e1c4bbad34fbd513 Mon Sep 17 00:00:00 2001 From: Simon Sawert Date: Wed, 24 Jan 2024 19:57:18 +0100 Subject: [PATCH 34/34] Remove try/except --- tests/main/test_model_rebuild_validation.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/tests/main/test_model_rebuild_validation.py b/tests/main/test_model_rebuild_validation.py index 1e82ff64..973b3104 100644 --- a/tests/main/test_model_rebuild_validation.py +++ b/tests/main/test_model_rebuild_validation.py @@ -48,17 +48,10 @@ def test_validate_field_a_on_faulty_model(field_a_data): def test_validate_field_a_on_correct_model(field_a_data): - try: - FullA.model_validate(field_a_data) - ExampleQuery2ExampleQuery.model_validate(field_a_data) - except ValidationError as e: - assert False, f"model_valiadte failed: {e}" + FullA.model_validate(field_a_data) + ExampleQuery2ExampleQuery.model_validate(field_a_data) def test_validate_field_a_in_example_query(field_a_data): example_query_2 = {"example_query": field_a_data} - - try: - ExampleQuery2.model_validate(example_query_2) - except ValidationError as e: - assert False, f"model validation failed: {e}" + ExampleQuery2.model_validate(example_query_2)