diff --git a/CHANGELOG.md b/CHANGELOG.md index 49ce3a35..dcf06dfc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,13 @@ # CHANGELOG +## 0.12.0 (UNRELEASED) + +- Fixed `graphql-transport-ws` protocol implementation not waiting for the `connection_ack` message on new connection. +- Fixed `get_client_settings` mutating `config_dict` instance. +- Added support to `graphqlschema` for saving schema as a GraphQL file. +- Restored `model_rebuild` calls for top level fragment models. + + ## 0.11.0 (2023-12-05) - Removed `model_rebuild` calls for generated input, fragment and result models. diff --git a/EXAMPLE.md b/EXAMPLE.md index 8646df5c..a465e31c 100644 --- a/EXAMPLE.md +++ b/EXAMPLE.md @@ -501,6 +501,10 @@ class BasicUser(BaseModel): class UserPersonalData(BaseModel): first_name: Optional[str] = Field(alias="firstName") last_name: Optional[str] = Field(alias="lastName") + + +BasicUser.model_rebuild() +UserPersonalData.model_rebuild() ``` ### Init file diff --git a/README.md b/README.md index d8945be2..d28e66cb 100644 --- a/README.md +++ b/README.md @@ -105,7 +105,7 @@ plugins = ["ariadne_codegen.contrib.extract_operations.ExtractOperationsPlugin"] operations_module_name = "custom_operations_module_name" ``` -- [`ariadne_codegen.contrib.extract_operations.NoReimportsPlugin`](ariadne_codegen/contrib/no_reimports.py) - This plugin removes content of generated `__init__.py`. This is useful in scenarios where generated plugins contain so many Pydantic models that client's eager initialization of entire package on first import is very slow. +- [`ariadne_codegen.contrib.no_reimports.NoReimportsPlugin`](ariadne_codegen/contrib/no_reimports.py) - This plugin removes content of generated `__init__.py`. This is useful in scenarios where generated plugins contain so many Pydantic models that client's eager initialization of entire package on first import is very slow. ## Using generated client @@ -323,23 +323,45 @@ Example with simple schema and few queries and mutations is available [here](htt ## Generating graphql schema's python representation -Instead of generating client, you can generate file with a copy of GraphQL schema as `GraphQLSchema` declaration. To do this call `ariadne-codegen` with `graphqlschema` argument: +Instead of generating a client, you can generate a file with a copy of a GraphQL schema. To do this call `ariadne-codegen` with `graphqlschema` argument: + ``` ariadne-codegen graphqlschema ``` -`graphqlschema` mode reads configuration from the same place as [`client`](#configuration) but uses only `schema_path`, `remote_schema_url`, `remote_schema_headers`, `remote_schema_verify_ssl` and `plugins` options with addition to some extra options specific to it: +`graphqlschema` mode reads configuration from the same place as [`client`](#configuration) but uses only `schema_path`, `remote_schema_url`, `remote_schema_headers`, `remote_schema_verify_ssl` options to retrieve the schema and `plugins` option to load plugins. + +In addition to the above, `graphqlschema` mode also accepts additional settings specific to it: + + +### `target_file_path` -- `target_file_path` (defaults to `"schema.py"`) - destination path for generated file -- `schema_variable_name` (defaults to `"schema"`) - name for schema variable, must be valid python identifier -- `type_map_variable_name` (defaults to `"type_map"`) - name for type map variable, must be valid python identifier +A string with destination path for generated file. Must be either a Python (`.py`), or GraphQL (`.graphql` or `.gql`) file. -Generated file contains: +Defaults to `schema.py`. + +Generated Python file will contain: - Necessary imports - Type map declaration `{type_map_variable_name}: TypeMap = {...}` - Schema declaration `{schema_variable_name}: GraphQLSchema = GraphQLSchema(...)` +Generated GraphQL file will contain a formatted output of the `print_schema` function from the `graphql-core` package. + + +### `schema_variable_name` + +A string with a name for schema variable, must be valid python identifier. + +Defaults to `"schema"`. Used only if target is a Python file. + + +### `type_map_variable_name` + +A string with a name for type map variable, must be valid python identifier. + +Defaults to `"type_map"`. Used only if target is a Python file. + ## Contributing diff --git a/ariadne_codegen/client_generators/constants.py b/ariadne_codegen/client_generators/constants.py index 0dfe53a2..5cecf209 100644 --- a/ariadne_codegen/client_generators/constants.py +++ b/ariadne_codegen/client_generators/constants.py @@ -56,6 +56,7 @@ MODEL_VALIDATE_METHOD = "model_validate" PLAIN_SERIALIZER = "PlainSerializer" BEFORE_VALIDATOR = "BeforeValidator" +MODEL_REBUILD_METHOD = "model_rebuild" ENUM_MODULE = "enum" ENUM_CLASS = "Enum" diff --git a/ariadne_codegen/client_generators/dependencies/async_base_client.py b/ariadne_codegen/client_generators/dependencies/async_base_client.py index d3ad17ef..5358ced6 100644 --- a/ariadne_codegen/client_generators/dependencies/async_base_client.py +++ b/ariadne_codegen/client_generators/dependencies/async_base_client.py @@ -163,6 +163,12 @@ async def execute_ws( **merged_kwargs, ) as websocket: await self._send_connection_init(websocket) + # wait for connection_ack from server + await self._handle_ws_message( + await websocket.recv(), + websocket, + expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK, + ) await self._send_subscribe( websocket, operation_id=operation_id, @@ -324,7 +330,10 @@ async def _send_subscribe( await websocket.send(json.dumps(payload)) async def _handle_ws_message( - self, message: Data, websocket: WebSocketClientProtocol + self, + message: Data, + websocket: WebSocketClientProtocol, + expected_type: Optional[GraphQLTransportWSMessageType] = None, ) -> Optional[Dict[str, Any]]: try: message_dict = json.loads(message) @@ -337,6 +346,11 @@ async def _handle_ws_message( if not type_ or type_ not in {t.value for t in GraphQLTransportWSMessageType}: raise GraphQLClientInvalidMessageFormat(message=message) + if expected_type and expected_type != type_: + raise GraphQLClientInvalidMessageFormat( + f"Invalid message received. Expected: {expected_type.value}" + ) + if type_ == GraphQLTransportWSMessageType.NEXT: if "data" not in payload: raise GraphQLClientInvalidMessageFormat(message=message) diff --git a/ariadne_codegen/client_generators/dependencies/async_base_client_open_telemetry.py b/ariadne_codegen/client_generators/dependencies/async_base_client_open_telemetry.py index ecd786f5..65d56458 100644 --- a/ariadne_codegen/client_generators/dependencies/async_base_client_open_telemetry.py +++ b/ariadne_codegen/client_generators/dependencies/async_base_client_open_telemetry.py @@ -373,6 +373,12 @@ async def _execute_ws( **merged_kwargs, ) as websocket: await self._send_connection_init(websocket) + # wait for connection_ack from server + await self._handle_ws_message( + await websocket.recv(), + websocket, + expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK, + ) await self._send_subscribe( websocket, operation_id=operation_id, @@ -414,7 +420,10 @@ async def _send_subscribe( await websocket.send(json.dumps(payload)) async def _handle_ws_message( - self, message: Data, websocket: WebSocketClientProtocol + self, + message: Data, + websocket: WebSocketClientProtocol, + expected_type: Optional[GraphQLTransportWSMessageType] = None, ) -> Optional[Dict[str, Any]]: try: message_dict = json.loads(message) @@ -427,6 +436,11 @@ async def _handle_ws_message( if not type_ or type_ not in {t.value for t in GraphQLTransportWSMessageType}: raise GraphQLClientInvalidMessageFormat(message=message) + if expected_type and expected_type != type_: + raise GraphQLClientInvalidMessageFormat( + f"Invalid message received. Expected: {expected_type.value}" + ) + if type_ == GraphQLTransportWSMessageType.NEXT: if "data" not in payload: raise GraphQLClientInvalidMessageFormat(message=message) @@ -563,6 +577,13 @@ async def _execute_ws_with_telemetry( root_span=root_span, websocket=websocket, ) + # wait for connection_ack from server + await self._handle_ws_message_with_telemetry( + root_span=root_span, + message=await websocket.recv(), + websocket=websocket, + expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK, + ) await self._send_subscribe_with_telemetry( root_span=root_span, websocket=websocket, @@ -628,7 +649,11 @@ async def _send_subscribe_with_telemetry( ) async def _handle_ws_message_with_telemetry( - self, root_span: Span, message: Data, websocket: WebSocketClientProtocol + self, + root_span: Span, + message: Data, + websocket: WebSocketClientProtocol, + expected_type: Optional[GraphQLTransportWSMessageType] = None, ) -> Optional[Dict[str, Any]]: with self.tracer.start_as_current_span( # type: ignore "received message", context=set_span_in_context(root_span) @@ -650,6 +675,11 @@ async def _handle_ws_message_with_telemetry( }: raise GraphQLClientInvalidMessageFormat(message=message) + if expected_type and expected_type != type_: + raise GraphQLClientInvalidMessageFormat( + f"Invalid message received. Expected: {expected_type.value}" + ) + if type_ == GraphQLTransportWSMessageType.NEXT: if "data" not in payload: raise GraphQLClientInvalidMessageFormat(message=message) diff --git a/ariadne_codegen/client_generators/fragments.py b/ariadne_codegen/client_generators/fragments.py index cba04810..3a7d2ecc 100644 --- a/ariadne_codegen/client_generators/fragments.py +++ b/ariadne_codegen/client_generators/fragments.py @@ -3,9 +3,9 @@ from graphql import FragmentDefinitionNode, GraphQLSchema -from ..codegen import generate_module +from ..codegen import generate_expr, generate_method_call, generate_module from ..plugins.manager import PluginManager -from .constants import BASE_MODEL_IMPORT +from .constants import BASE_MODEL_IMPORT, MODEL_REBUILD_METHOD from .result_types import ResultTypesGenerator from .scalars import ScalarData @@ -36,6 +36,7 @@ def __init__( def generate(self, exclude_names: Optional[Set[str]] = None) -> ast.Module: class_defs_dict: Dict[str, List[ast.ClassDef]] = {} imports: List[ast.ImportFrom] = [] + top_level_class_names: List[str] = [] dependencies_dict: Dict[str, Set[str]] = {} names_to_exclude = exclude_names or set() @@ -53,7 +54,10 @@ def generate(self, exclude_names: Optional[Set[str]] = None) -> ast.Module: plugin_manager=self.plugin_manager, ) imports.extend(generator.get_imports()) - class_defs_dict[name] = generator.get_classes() + class_defs = generator.get_classes() + class_defs_dict[name] = class_defs + if class_defs: + top_level_class_names.append(class_defs[0].name) dependencies_dict[name] = generator.get_fragments_used_as_mixins() self._generated_public_names.extend(generator.get_generated_public_names()) self._used_enums.extend(generator.get_used_enums()) @@ -62,7 +66,15 @@ def generate(self, exclude_names: Optional[Set[str]] = None) -> ast.Module: class_defs_dict=class_defs_dict, dependencies_dict=dependencies_dict ) module = generate_module( - body=cast(List[ast.stmt], imports) + cast(List[ast.stmt], sorted_class_defs) + body=cast(List[ast.stmt], imports) + + cast(List[ast.stmt], sorted_class_defs) + + cast( + List[ast.stmt], + self._get_model_rebuild_calls( + top_level_fragments_names=top_level_class_names, + class_defs=sorted_class_defs, + ), + ) ) if self.plugin_manager: module = self.plugin_manager.generate_fragments_module( @@ -108,3 +120,15 @@ def visit(name): visit(name) return sorted_names + + def _get_model_rebuild_calls( + self, top_level_fragments_names: List[str], class_defs: List[ast.ClassDef] + ) -> List[ast.Call]: + class_names = [c.name for c in class_defs] + sorted_fragments_names = sorted( + top_level_fragments_names, key=class_names.index + ) + return [ + generate_expr(generate_method_call(name, MODEL_REBUILD_METHOD)) + for name in sorted_fragments_names + ] diff --git a/ariadne_codegen/config.py b/ariadne_codegen/config.py index d6a7b8d8..7aa4d792 100644 --- a/ariadne_codegen/config.py +++ b/ariadne_codegen/config.py @@ -34,7 +34,7 @@ def get_config_dict(config_file_name: Optional[str] = None) -> Dict: def get_client_settings(config_dict: Dict) -> ClientSettings: """Parse configuration dict and return ClientSettings instance.""" - section = get_section(config_dict) + section = get_section(config_dict).copy() settings_fields_names = {f.name for f in fields(ClientSettings)} try: section["scalars"] = { diff --git a/ariadne_codegen/graphql_schema_generators/schema.py b/ariadne_codegen/graphql_schema_generators/schema.py index a1587b14..9de87871 100644 --- a/ariadne_codegen/graphql_schema_generators/schema.py +++ b/ariadne_codegen/graphql_schema_generators/schema.py @@ -1,7 +1,7 @@ import ast from pathlib import Path -from graphql import GraphQLSchema +from graphql import GraphQLSchema, print_schema from graphql.type.schema import TypeMap from ..codegen import ( @@ -23,7 +23,11 @@ from .utils import get_optional_named_type -def generate_graphql_schema_file( +def generate_graphql_schema_graphql_file(schema: GraphQLSchema, target_file_path: str): + Path(target_file_path).write_text(print_schema(schema), encoding="UTF-8") + + +def generate_graphql_schema_python_file( schema: GraphQLSchema, target_file_path: str, type_map_name: str, diff --git a/ariadne_codegen/main.py b/ariadne_codegen/main.py index 8f7531fa..57eb60c5 100644 --- a/ariadne_codegen/main.py +++ b/ariadne_codegen/main.py @@ -5,7 +5,10 @@ from .client_generators.package import get_package_generator from .config import get_client_settings, get_config_dict, get_graphql_schema_settings -from .graphql_schema_generators.schema import generate_graphql_schema_file +from .graphql_schema_generators.schema import ( + generate_graphql_schema_graphql_file, + generate_graphql_schema_python_file, +) from .plugins.explorer import get_plugins_types from .plugins.manager import PluginManager from .schema import ( @@ -99,9 +102,15 @@ def graphql_schema(config_dict): sys.stdout.write(settings.used_settings_message) - generate_graphql_schema_file( - schema=schema, - target_file_path=settings.target_file_path, - type_map_name=settings.type_map_variable_name, - schema_variable_name=settings.schema_variable_name, - ) + if settings.target_file_format == "py": + generate_graphql_schema_python_file( + schema=schema, + target_file_path=settings.target_file_path, + type_map_name=settings.type_map_variable_name, + schema_variable_name=settings.schema_variable_name, + ) + else: + generate_graphql_schema_graphql_file( + schema=schema, + target_file_path=settings.target_file_path, + ) diff --git a/ariadne_codegen/settings.py b/ariadne_codegen/settings.py index 96c967a0..fbd8f427 100644 --- a/ariadne_codegen/settings.py +++ b/ariadne_codegen/settings.py @@ -195,6 +195,8 @@ class GraphQLSchemaSettings(BaseSettings): def __post_init__(self): super().__post_init__() + + assert_string_is_valid_schema_target_filename(self.target_file_path) assert_string_is_valid_python_identifier(self.schema_variable_name) assert_string_is_valid_python_identifier(self.type_map_variable_name) @@ -206,17 +208,32 @@ def used_settings_message(self): if self.plugins else "No plugin is being used." ) + + if self.target_file_format == "py": + return dedent( + f"""\ + Selected strategy: {Strategy.GRAPHQL_SCHEMA} + Using schema from {self.schema_path or self.remote_schema_url} + Saving graphql schema to: {self.target_file_path} + Using {self.schema_variable_name} as variable name for schema. + Using {self.type_map_variable_name} as variable name for type map. + {plugins_msg} + """ + ) + return dedent( f"""\ Selected strategy: {Strategy.GRAPHQL_SCHEMA} - Using schema from '{self.schema_path or self.remote_schema_url}'. - Saving graphql schema to: {self.target_file_path}. - Using {self.schema_variable_name} as variable name for schema. - Using {self.type_map_variable_name} as variable name for type map. + Using schema from {self.schema_path or self.remote_schema_url} + Saving graphql schema to: {self.target_file_path} {plugins_msg} """ ) + @property + def target_file_format(self): + return Path(self.target_file_path).suffix[1:].lower() + def assert_path_exists(path: str): if not Path(path).exists(): @@ -233,10 +250,25 @@ def assert_path_is_valid_file(path: str): raise InvalidConfiguration(f"Provided path {path} isn't a file.") +def assert_string_is_valid_schema_target_filename(filename: str): + file_type = Path(filename).suffix + if not file_type: + raise InvalidConfiguration( + f"Provided file name {filename} is missing a file type." + ) + + file_type = file_type[1:].lower() + if file_type not in ("py", "graphql", "gql"): + raise InvalidConfiguration( + f"Provided file name {filename} has an invalid type {file_type}." + " Valid types are py, graphql and gql." + ) + + def assert_string_is_valid_python_identifier(name: str): if not name.isidentifier() and not iskeyword(name): raise InvalidConfiguration( - f"Provided name {name} cannot be used as python indetifier" + f"Provided name {name} cannot be used as python identifier." ) diff --git a/tests/client_generators/dependencies/test_websockets.py b/tests/client_generators/dependencies/test_websockets.py index 0b3f8843..b82639da 100644 --- a/tests/client_generators/dependencies/test_websockets.py +++ b/tests/client_generators/dependencies/test_websockets.py @@ -24,12 +24,19 @@ def mocked_websocket(mocked_ws_connect): websocket.__aiter__.return_value = [ json.dumps({"type": "connection_ack"}), ] + websocket.recv.return_value = json.dumps({"type": "connection_ack"}) + return websocket + + +@pytest.fixture +def mocked_faulty_websocket(mocked_ws_connect): + websocket = mocked_ws_connect.return_value.__aenter__.return_value return websocket @pytest.mark.asyncio async def test_execute_ws_creates_websocket_connection_with_correct_url( - mocked_ws_connect, + mocked_ws_connect, mocked_websocket # pylint: disable=unused-argument ): async for _ in AsyncBaseClient(ws_url="ws://test_url").execute_ws(""): pass @@ -40,7 +47,7 @@ async def test_execute_ws_creates_websocket_connection_with_correct_url( @pytest.mark.asyncio async def test_execute_ws_creates_websocket_connection_with_correct_subprotocol( - mocked_ws_connect, + mocked_ws_connect, mocked_websocket # pylint: disable=unused-argument ): async for _ in AsyncBaseClient().execute_ws(""): pass @@ -53,7 +60,7 @@ async def test_execute_ws_creates_websocket_connection_with_correct_subprotocol( @pytest.mark.asyncio async def test_execute_ws_creates_websocket_connection_with_correct_origin( - mocked_ws_connect, + mocked_ws_connect, mocked_websocket # pylint: disable=unused-argument ): async for _ in AsyncBaseClient(ws_origin="test_origin").execute_ws(""): pass @@ -64,7 +71,7 @@ async def test_execute_ws_creates_websocket_connection_with_correct_origin( @pytest.mark.asyncio async def test_execute_ws_creates_websocket_connection_with_correct_headers( - mocked_ws_connect, + mocked_ws_connect, mocked_websocket # pylint: disable=unused-argument ): async for _ in AsyncBaseClient(ws_headers={"test_key": "test_value"}).execute_ws( "" @@ -79,7 +86,7 @@ async def test_execute_ws_creates_websocket_connection_with_correct_headers( @pytest.mark.asyncio async def test_execute_ws_creates_websocket_connection_with_passed_extra_headers( - mocked_ws_connect, + mocked_ws_connect, mocked_websocket # pylint: disable=unused-argument ): async for _ in AsyncBaseClient( ws_headers={"Client-A": "client_value_a", "Client-B": "client_value_b"} @@ -98,7 +105,7 @@ async def test_execute_ws_creates_websocket_connection_with_passed_extra_headers @pytest.mark.asyncio async def test_execute_ws_creates_websocket_connection_with_passed_kwargs( - mocked_ws_connect, + mocked_ws_connect, mocked_websocket # pylint: disable=unused-argument ): async for _ in AsyncBaseClient().execute_ws("", open_timeout=15, close_timeout=30): pass @@ -250,3 +257,16 @@ async def test_execute_ws_raises_graphql_multi_error_for_message_with_error_type with pytest.raises(GraphQLClientGraphQLMultiError): async for _ in AsyncBaseClient().execute_ws(""): pass + + +@pytest.mark.asyncio +async def test_execute_ws_raises_invalid_message_format_for_missing_ack_after_init( + mocked_faulty_websocket, +): + mocked_faulty_websocket.recv.return_value = json.dumps( + {"type": "next", "payload": {"data": "test_data"}} + ) + + with pytest.raises(GraphQLClientInvalidMessageFormat): + async for _ in AsyncBaseClient().execute_ws(""): + pass diff --git a/tests/client_generators/dependencies/test_websockets_open_telemetry.py b/tests/client_generators/dependencies/test_websockets_open_telemetry.py index e674a2f1..b75a61bc 100644 --- a/tests/client_generators/dependencies/test_websockets_open_telemetry.py +++ b/tests/client_generators/dependencies/test_websockets_open_telemetry.py @@ -26,12 +26,19 @@ def mocked_websocket(mocked_ws_connect): websocket.__aiter__.return_value = [ json.dumps({"type": "connection_ack"}), ] + websocket.recv.return_value = json.dumps({"type": "connection_ack"}) + return websocket + + +@pytest.fixture +def mocked_faulty_websocket(mocked_ws_connect): + websocket = mocked_ws_connect.return_value.__aenter__.return_value return websocket @pytest.mark.asyncio async def test_execute_ws_creates_websocket_connection_with_correct_url( - mocked_ws_connect, + mocked_ws_connect, mocked_websocket # pylint: disable=unused-argument ): async for _ in AsyncBaseClientOpenTelemetry(ws_url="ws://test_url").execute_ws(""): pass @@ -42,7 +49,7 @@ async def test_execute_ws_creates_websocket_connection_with_correct_url( @pytest.mark.asyncio async def test_execute_ws_creates_websocket_connection_with_correct_subprotocol( - mocked_ws_connect, + mocked_ws_connect, mocked_websocket # pylint: disable=unused-argument ): async for _ in AsyncBaseClientOpenTelemetry().execute_ws(""): pass @@ -55,7 +62,7 @@ async def test_execute_ws_creates_websocket_connection_with_correct_subprotocol( @pytest.mark.asyncio async def test_execute_ws_creates_websocket_connection_with_correct_origin( - mocked_ws_connect, + mocked_ws_connect, mocked_websocket # pylint: disable=unused-argument ): async for _ in AsyncBaseClientOpenTelemetry(ws_origin="test_origin").execute_ws(""): pass @@ -66,7 +73,7 @@ async def test_execute_ws_creates_websocket_connection_with_correct_origin( @pytest.mark.asyncio async def test_execute_ws_creates_websocket_connection_with_correct_headers( - mocked_ws_connect, + mocked_ws_connect, mocked_websocket # pylint: disable=unused-argument ): async for _ in AsyncBaseClientOpenTelemetry( ws_headers={"test_key": "test_value"} @@ -82,7 +89,7 @@ async def test_execute_ws_creates_websocket_connection_with_correct_headers( @pytest.mark.asyncio @pytest.mark.parametrize("tracer", ["tracer name", None]) async def test_execute_ws_creates_websocket_connection_with_passed_extra_headers( - mocked_ws_connect, tracer + mocked_ws_connect, mocked_websocket, tracer # pylint: disable=unused-argument ): async for _ in AsyncBaseClientOpenTelemetry( ws_headers={"Client-A": "client_value_a", "Client-B": "client_value_b"}, @@ -103,7 +110,7 @@ async def test_execute_ws_creates_websocket_connection_with_passed_extra_headers @pytest.mark.asyncio @pytest.mark.parametrize("tracer", ["tracer name", None]) async def test_execute_ws_creates_websocket_connection_with_passed_kwargs( - mocked_ws_connect, tracer + mocked_ws_connect, mocked_websocket, tracer # pylint: disable=unused-argument ): async for _ in AsyncBaseClientOpenTelemetry(tracer=tracer).execute_ws( "", open_timeout=15, close_timeout=30 @@ -259,6 +266,19 @@ async def test_execute_ws_raises_graphql_multi_error_for_message_with_error_type pass +@pytest.mark.asyncio +async def test_execute_ws_raises_invalid_message_format_for_missing_ack_after_init( + mocked_faulty_websocket, +): + mocked_faulty_websocket.recv.return_value = json.dumps( + {"type": "next", "payload": {"data": "test_data"}} + ) + + with pytest.raises(GraphQLClientInvalidMessageFormat): + async for _ in AsyncBaseClientOpenTelemetry().execute_ws(""): + pass + + @pytest.fixture def mocked_start_as_current_span(mocker): mocker_get_tracer = mocker.patch( @@ -394,3 +414,22 @@ async def test_execute_ws_creates_span_for_received_error_message( mocked_start_as_current_span.assert_any_call("received message", context=ANY) with mocked_start_as_current_span.return_value as span: span.set_attribute.assert_any_call("type", "error") + + +@pytest.mark.asyncio +async def test_execute_ws_ws_creates_span_for_missing_ack_after_init( + mocked_faulty_websocket, mocked_start_as_current_span +): + mocked_faulty_websocket.recv.return_value = json.dumps( + {"type": "next", "payload": {"data": "test_data"}} + ) + + client = AsyncBaseClientOpenTelemetry(ws_url="ws://test_url", tracer="tracker") + + with pytest.raises(GraphQLClientInvalidMessageFormat): + async for _ in client.execute_ws(""): + pass + + mocked_start_as_current_span.assert_any_call("received message", context=ANY) + with mocked_start_as_current_span.return_value as span: + span.set_attribute.assert_any_call("type", "next") diff --git a/tests/graphql_schema_generators/test_schema.py b/tests/graphql_schema_generators/test_schema.py index 7b7e0980..b7d9bc2c 100644 --- a/tests/graphql_schema_generators/test_schema.py +++ b/tests/graphql_schema_generators/test_schema.py @@ -1,9 +1,10 @@ import ast -from graphql import Undefined, build_schema +from graphql import Undefined, build_schema, print_schema from ariadne_codegen.graphql_schema_generators.schema import ( - generate_graphql_schema_file, + generate_graphql_schema_graphql_file, + generate_graphql_schema_python_file, generate_schema, generate_schema_module, generate_type_map, @@ -28,11 +29,30 @@ """ -def test_generate_graphql_schema_file_creates_file_with_variables(tmp_path): +def test_generate_graphql_schema_graphql_file_creates_file_with_printed_schema( + tmp_path, +): + schema = build_schema(SCHEMA_STR) + file_path = tmp_path / "test_schema.graphql" + + generate_graphql_schema_graphql_file(schema, file_path.as_posix()) + + assert file_path.exists() + assert file_path.is_file() + with file_path.open() as file_: + content = file_.read() + assert content == print_schema(schema) + + +def test_generate_graphql_schema_python_file_creates_py_file_with_variables( + tmp_path, +): schema = build_schema(SCHEMA_STR) file_path = tmp_path / "test_schema.py" - generate_graphql_schema_file(schema, file_path.as_posix(), "type_map", "schema") + generate_graphql_schema_python_file( + schema, file_path.as_posix(), "type_map", "schema" + ) assert file_path.exists() assert file_path.is_file() diff --git a/tests/main/clients/custom_config_file/expected_client/async_base_client.py b/tests/main/clients/custom_config_file/expected_client/async_base_client.py index d3ad17ef..5358ced6 100644 --- a/tests/main/clients/custom_config_file/expected_client/async_base_client.py +++ b/tests/main/clients/custom_config_file/expected_client/async_base_client.py @@ -163,6 +163,12 @@ async def execute_ws( **merged_kwargs, ) as websocket: await self._send_connection_init(websocket) + # wait for connection_ack from server + await self._handle_ws_message( + await websocket.recv(), + websocket, + expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK, + ) await self._send_subscribe( websocket, operation_id=operation_id, @@ -324,7 +330,10 @@ async def _send_subscribe( await websocket.send(json.dumps(payload)) async def _handle_ws_message( - self, message: Data, websocket: WebSocketClientProtocol + self, + message: Data, + websocket: WebSocketClientProtocol, + expected_type: Optional[GraphQLTransportWSMessageType] = None, ) -> Optional[Dict[str, Any]]: try: message_dict = json.loads(message) @@ -337,6 +346,11 @@ async def _handle_ws_message( if not type_ or type_ not in {t.value for t in GraphQLTransportWSMessageType}: raise GraphQLClientInvalidMessageFormat(message=message) + if expected_type and expected_type != type_: + raise GraphQLClientInvalidMessageFormat( + f"Invalid message received. Expected: {expected_type.value}" + ) + if type_ == GraphQLTransportWSMessageType.NEXT: if "data" not in payload: raise GraphQLClientInvalidMessageFormat(message=message) diff --git a/tests/main/clients/custom_files_names/expected_client/async_base_client.py b/tests/main/clients/custom_files_names/expected_client/async_base_client.py index d3ad17ef..5358ced6 100644 --- a/tests/main/clients/custom_files_names/expected_client/async_base_client.py +++ b/tests/main/clients/custom_files_names/expected_client/async_base_client.py @@ -163,6 +163,12 @@ async def execute_ws( **merged_kwargs, ) as websocket: await self._send_connection_init(websocket) + # wait for connection_ack from server + await self._handle_ws_message( + await websocket.recv(), + websocket, + expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK, + ) await self._send_subscribe( websocket, operation_id=operation_id, @@ -324,7 +330,10 @@ async def _send_subscribe( await websocket.send(json.dumps(payload)) async def _handle_ws_message( - self, message: Data, websocket: WebSocketClientProtocol + self, + message: Data, + websocket: WebSocketClientProtocol, + expected_type: Optional[GraphQLTransportWSMessageType] = None, ) -> Optional[Dict[str, Any]]: try: message_dict = json.loads(message) @@ -337,6 +346,11 @@ async def _handle_ws_message( if not type_ or type_ not in {t.value for t in GraphQLTransportWSMessageType}: raise GraphQLClientInvalidMessageFormat(message=message) + if expected_type and expected_type != type_: + raise GraphQLClientInvalidMessageFormat( + f"Invalid message received. Expected: {expected_type.value}" + ) + if type_ == GraphQLTransportWSMessageType.NEXT: if "data" not in payload: raise GraphQLClientInvalidMessageFormat(message=message) diff --git a/tests/main/clients/custom_scalars/expected_client/async_base_client.py b/tests/main/clients/custom_scalars/expected_client/async_base_client.py index d3ad17ef..5358ced6 100644 --- a/tests/main/clients/custom_scalars/expected_client/async_base_client.py +++ b/tests/main/clients/custom_scalars/expected_client/async_base_client.py @@ -163,6 +163,12 @@ async def execute_ws( **merged_kwargs, ) as websocket: await self._send_connection_init(websocket) + # wait for connection_ack from server + await self._handle_ws_message( + await websocket.recv(), + websocket, + expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK, + ) await self._send_subscribe( websocket, operation_id=operation_id, @@ -324,7 +330,10 @@ async def _send_subscribe( await websocket.send(json.dumps(payload)) async def _handle_ws_message( - self, message: Data, websocket: WebSocketClientProtocol + self, + message: Data, + websocket: WebSocketClientProtocol, + expected_type: Optional[GraphQLTransportWSMessageType] = None, ) -> Optional[Dict[str, Any]]: try: message_dict = json.loads(message) @@ -337,6 +346,11 @@ async def _handle_ws_message( if not type_ or type_ not in {t.value for t in GraphQLTransportWSMessageType}: raise GraphQLClientInvalidMessageFormat(message=message) + if expected_type and expected_type != type_: + raise GraphQLClientInvalidMessageFormat( + f"Invalid message received. Expected: {expected_type.value}" + ) + if type_ == GraphQLTransportWSMessageType.NEXT: if "data" not in payload: raise GraphQLClientInvalidMessageFormat(message=message) diff --git a/tests/main/clients/example/expected_client/async_base_client.py b/tests/main/clients/example/expected_client/async_base_client.py index d3ad17ef..5358ced6 100644 --- a/tests/main/clients/example/expected_client/async_base_client.py +++ b/tests/main/clients/example/expected_client/async_base_client.py @@ -163,6 +163,12 @@ async def execute_ws( **merged_kwargs, ) as websocket: await self._send_connection_init(websocket) + # wait for connection_ack from server + await self._handle_ws_message( + await websocket.recv(), + websocket, + expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK, + ) await self._send_subscribe( websocket, operation_id=operation_id, @@ -324,7 +330,10 @@ async def _send_subscribe( await websocket.send(json.dumps(payload)) async def _handle_ws_message( - self, message: Data, websocket: WebSocketClientProtocol + self, + message: Data, + websocket: WebSocketClientProtocol, + expected_type: Optional[GraphQLTransportWSMessageType] = None, ) -> Optional[Dict[str, Any]]: try: message_dict = json.loads(message) @@ -337,6 +346,11 @@ async def _handle_ws_message( if not type_ or type_ not in {t.value for t in GraphQLTransportWSMessageType}: raise GraphQLClientInvalidMessageFormat(message=message) + if expected_type and expected_type != type_: + raise GraphQLClientInvalidMessageFormat( + f"Invalid message received. Expected: {expected_type.value}" + ) + if type_ == GraphQLTransportWSMessageType.NEXT: if "data" not in payload: raise GraphQLClientInvalidMessageFormat(message=message) diff --git a/tests/main/clients/example/expected_client/fragments.py b/tests/main/clients/example/expected_client/fragments.py index 4795d454..e1cee688 100644 --- a/tests/main/clients/example/expected_client/fragments.py +++ b/tests/main/clients/example/expected_client/fragments.py @@ -13,3 +13,7 @@ class BasicUser(BaseModel): class UserPersonalData(BaseModel): first_name: Optional[str] = Field(alias="firstName") last_name: Optional[str] = Field(alias="lastName") + + +BasicUser.model_rebuild() +UserPersonalData.model_rebuild() diff --git a/tests/main/clients/extended_models/expected_client/async_base_client.py b/tests/main/clients/extended_models/expected_client/async_base_client.py index d3ad17ef..5358ced6 100644 --- a/tests/main/clients/extended_models/expected_client/async_base_client.py +++ b/tests/main/clients/extended_models/expected_client/async_base_client.py @@ -163,6 +163,12 @@ async def execute_ws( **merged_kwargs, ) as websocket: await self._send_connection_init(websocket) + # wait for connection_ack from server + await self._handle_ws_message( + await websocket.recv(), + websocket, + expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK, + ) await self._send_subscribe( websocket, operation_id=operation_id, @@ -324,7 +330,10 @@ async def _send_subscribe( await websocket.send(json.dumps(payload)) async def _handle_ws_message( - self, message: Data, websocket: WebSocketClientProtocol + self, + message: Data, + websocket: WebSocketClientProtocol, + expected_type: Optional[GraphQLTransportWSMessageType] = None, ) -> Optional[Dict[str, Any]]: try: message_dict = json.loads(message) @@ -337,6 +346,11 @@ async def _handle_ws_message( if not type_ or type_ not in {t.value for t in GraphQLTransportWSMessageType}: raise GraphQLClientInvalidMessageFormat(message=message) + if expected_type and expected_type != type_: + raise GraphQLClientInvalidMessageFormat( + f"Invalid message received. Expected: {expected_type.value}" + ) + if type_ == GraphQLTransportWSMessageType.NEXT: if "data" not in payload: raise GraphQLClientInvalidMessageFormat(message=message) diff --git a/tests/main/clients/extended_models/expected_client/fragments.py b/tests/main/clients/extended_models/expected_client/fragments.py index 30a907d0..77a32fea 100644 --- a/tests/main/clients/extended_models/expected_client/fragments.py +++ b/tests/main/clients/extended_models/expected_client/fragments.py @@ -20,3 +20,8 @@ class GetQueryAFragment(BaseModel): class GetQueryAFragmentQueryA(BaseModel, MixinA, CommonMixin): field_a: int = Field(alias="fieldA") + + +FragmentA.model_rebuild() +FragmentB.model_rebuild() +GetQueryAFragment.model_rebuild() diff --git a/tests/main/clients/fragments_on_abstract_types/expected_client/async_base_client.py b/tests/main/clients/fragments_on_abstract_types/expected_client/async_base_client.py index d3ad17ef..5358ced6 100644 --- a/tests/main/clients/fragments_on_abstract_types/expected_client/async_base_client.py +++ b/tests/main/clients/fragments_on_abstract_types/expected_client/async_base_client.py @@ -163,6 +163,12 @@ async def execute_ws( **merged_kwargs, ) as websocket: await self._send_connection_init(websocket) + # wait for connection_ack from server + await self._handle_ws_message( + await websocket.recv(), + websocket, + expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK, + ) await self._send_subscribe( websocket, operation_id=operation_id, @@ -324,7 +330,10 @@ async def _send_subscribe( await websocket.send(json.dumps(payload)) async def _handle_ws_message( - self, message: Data, websocket: WebSocketClientProtocol + self, + message: Data, + websocket: WebSocketClientProtocol, + expected_type: Optional[GraphQLTransportWSMessageType] = None, ) -> Optional[Dict[str, Any]]: try: message_dict = json.loads(message) @@ -337,6 +346,11 @@ async def _handle_ws_message( if not type_ or type_ not in {t.value for t in GraphQLTransportWSMessageType}: raise GraphQLClientInvalidMessageFormat(message=message) + if expected_type and expected_type != type_: + raise GraphQLClientInvalidMessageFormat( + f"Invalid message received. Expected: {expected_type.value}" + ) + if type_ == GraphQLTransportWSMessageType.NEXT: if "data" not in payload: raise GraphQLClientInvalidMessageFormat(message=message) diff --git a/tests/main/clients/fragments_on_abstract_types/expected_client/fragments.py b/tests/main/clients/fragments_on_abstract_types/expected_client/fragments.py index 1206200f..e20413f0 100644 --- a/tests/main/clients/fragments_on_abstract_types/expected_client/fragments.py +++ b/tests/main/clients/fragments_on_abstract_types/expected_client/fragments.py @@ -11,3 +11,7 @@ class FragmentA(BaseModel): class FragmentB(BaseModel): id: str value_b: str = Field(alias="valueB") + + +FragmentA.model_rebuild() +FragmentB.model_rebuild() diff --git a/tests/main/clients/inline_fragments/expected_client/async_base_client.py b/tests/main/clients/inline_fragments/expected_client/async_base_client.py index d3ad17ef..5358ced6 100644 --- a/tests/main/clients/inline_fragments/expected_client/async_base_client.py +++ b/tests/main/clients/inline_fragments/expected_client/async_base_client.py @@ -163,6 +163,12 @@ async def execute_ws( **merged_kwargs, ) as websocket: await self._send_connection_init(websocket) + # wait for connection_ack from server + await self._handle_ws_message( + await websocket.recv(), + websocket, + expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK, + ) await self._send_subscribe( websocket, operation_id=operation_id, @@ -324,7 +330,10 @@ async def _send_subscribe( await websocket.send(json.dumps(payload)) async def _handle_ws_message( - self, message: Data, websocket: WebSocketClientProtocol + self, + message: Data, + websocket: WebSocketClientProtocol, + expected_type: Optional[GraphQLTransportWSMessageType] = None, ) -> Optional[Dict[str, Any]]: try: message_dict = json.loads(message) @@ -337,6 +346,11 @@ async def _handle_ws_message( if not type_ or type_ not in {t.value for t in GraphQLTransportWSMessageType}: raise GraphQLClientInvalidMessageFormat(message=message) + if expected_type and expected_type != type_: + raise GraphQLClientInvalidMessageFormat( + f"Invalid message received. Expected: {expected_type.value}" + ) + if type_ == GraphQLTransportWSMessageType.NEXT: if "data" not in payload: raise GraphQLClientInvalidMessageFormat(message=message) diff --git a/tests/main/clients/inline_fragments/expected_client/fragments.py b/tests/main/clients/inline_fragments/expected_client/fragments.py index 16fcc46a..4770eb21 100644 --- a/tests/main/clients/inline_fragments/expected_client/fragments.py +++ b/tests/main/clients/inline_fragments/expected_client/fragments.py @@ -57,3 +57,8 @@ class FragmentOnQueryWithUnionQueryUTypeC(BaseModel): class UnusedFragmentOnTypeA(BaseModel): id: str field_a: str = Field(alias="fieldA") + + +FragmentOnQueryWithInterface.model_rebuild() +FragmentOnQueryWithUnion.model_rebuild() +UnusedFragmentOnTypeA.model_rebuild() diff --git a/tests/main/clients/multiple_fragments/expected_client/async_base_client.py b/tests/main/clients/multiple_fragments/expected_client/async_base_client.py index d3ad17ef..5358ced6 100644 --- a/tests/main/clients/multiple_fragments/expected_client/async_base_client.py +++ b/tests/main/clients/multiple_fragments/expected_client/async_base_client.py @@ -163,6 +163,12 @@ async def execute_ws( **merged_kwargs, ) as websocket: await self._send_connection_init(websocket) + # wait for connection_ack from server + await self._handle_ws_message( + await websocket.recv(), + websocket, + expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK, + ) await self._send_subscribe( websocket, operation_id=operation_id, @@ -324,7 +330,10 @@ async def _send_subscribe( await websocket.send(json.dumps(payload)) async def _handle_ws_message( - self, message: Data, websocket: WebSocketClientProtocol + self, + message: Data, + websocket: WebSocketClientProtocol, + expected_type: Optional[GraphQLTransportWSMessageType] = None, ) -> Optional[Dict[str, Any]]: try: message_dict = json.loads(message) @@ -337,6 +346,11 @@ async def _handle_ws_message( if not type_ or type_ not in {t.value for t in GraphQLTransportWSMessageType}: raise GraphQLClientInvalidMessageFormat(message=message) + if expected_type and expected_type != type_: + raise GraphQLClientInvalidMessageFormat( + f"Invalid message received. Expected: {expected_type.value}" + ) + if type_ == GraphQLTransportWSMessageType.NEXT: if "data" not in payload: raise GraphQLClientInvalidMessageFormat(message=message) diff --git a/tests/main/clients/multiple_fragments/expected_client/fragments.py b/tests/main/clients/multiple_fragments/expected_client/fragments.py index c4d0d139..7f29eb44 100644 --- a/tests/main/clients/multiple_fragments/expected_client/fragments.py +++ b/tests/main/clients/multiple_fragments/expected_client/fragments.py @@ -40,3 +40,10 @@ class MinimalA(BaseModel): class MinimalAFieldB(MinimalB): pass + + +CompleteA.model_rebuild() +FullB.model_rebuild() +FullA.model_rebuild() +MinimalB.model_rebuild() +MinimalA.model_rebuild() diff --git a/tests/main/clients/only_used_inputs_and_enums/expected_client/async_base_client.py b/tests/main/clients/only_used_inputs_and_enums/expected_client/async_base_client.py index d3ad17ef..5358ced6 100644 --- a/tests/main/clients/only_used_inputs_and_enums/expected_client/async_base_client.py +++ b/tests/main/clients/only_used_inputs_and_enums/expected_client/async_base_client.py @@ -163,6 +163,12 @@ async def execute_ws( **merged_kwargs, ) as websocket: await self._send_connection_init(websocket) + # wait for connection_ack from server + await self._handle_ws_message( + await websocket.recv(), + websocket, + expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK, + ) await self._send_subscribe( websocket, operation_id=operation_id, @@ -324,7 +330,10 @@ async def _send_subscribe( await websocket.send(json.dumps(payload)) async def _handle_ws_message( - self, message: Data, websocket: WebSocketClientProtocol + self, + message: Data, + websocket: WebSocketClientProtocol, + expected_type: Optional[GraphQLTransportWSMessageType] = None, ) -> Optional[Dict[str, Any]]: try: message_dict = json.loads(message) @@ -337,6 +346,11 @@ async def _handle_ws_message( if not type_ or type_ not in {t.value for t in GraphQLTransportWSMessageType}: raise GraphQLClientInvalidMessageFormat(message=message) + if expected_type and expected_type != type_: + raise GraphQLClientInvalidMessageFormat( + f"Invalid message received. Expected: {expected_type.value}" + ) + if type_ == GraphQLTransportWSMessageType.NEXT: if "data" not in payload: raise GraphQLClientInvalidMessageFormat(message=message) diff --git a/tests/main/clients/only_used_inputs_and_enums/expected_client/fragments.py b/tests/main/clients/only_used_inputs_and_enums/expected_client/fragments.py index 4bad3235..25a2a806 100644 --- a/tests/main/clients/only_used_inputs_and_enums/expected_client/fragments.py +++ b/tests/main/clients/only_used_inputs_and_enums/expected_client/fragments.py @@ -8,3 +8,7 @@ class FragmentG(BaseModel): class FragmentGG(BaseModel): val: EnumGG + + +FragmentG.model_rebuild() +FragmentGG.model_rebuild() diff --git a/tests/main/clients/operations/expected_client/async_base_client.py b/tests/main/clients/operations/expected_client/async_base_client.py index d3ad17ef..5358ced6 100644 --- a/tests/main/clients/operations/expected_client/async_base_client.py +++ b/tests/main/clients/operations/expected_client/async_base_client.py @@ -163,6 +163,12 @@ async def execute_ws( **merged_kwargs, ) as websocket: await self._send_connection_init(websocket) + # wait for connection_ack from server + await self._handle_ws_message( + await websocket.recv(), + websocket, + expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK, + ) await self._send_subscribe( websocket, operation_id=operation_id, @@ -324,7 +330,10 @@ async def _send_subscribe( await websocket.send(json.dumps(payload)) async def _handle_ws_message( - self, message: Data, websocket: WebSocketClientProtocol + self, + message: Data, + websocket: WebSocketClientProtocol, + expected_type: Optional[GraphQLTransportWSMessageType] = None, ) -> Optional[Dict[str, Any]]: try: message_dict = json.loads(message) @@ -337,6 +346,11 @@ async def _handle_ws_message( if not type_ or type_ not in {t.value for t in GraphQLTransportWSMessageType}: raise GraphQLClientInvalidMessageFormat(message=message) + if expected_type and expected_type != type_: + raise GraphQLClientInvalidMessageFormat( + f"Invalid message received. Expected: {expected_type.value}" + ) + if type_ == GraphQLTransportWSMessageType.NEXT: if "data" not in payload: raise GraphQLClientInvalidMessageFormat(message=message) diff --git a/tests/main/clients/operations/expected_client/fragments.py b/tests/main/clients/operations/expected_client/fragments.py index 5adc6550..6836e301 100644 --- a/tests/main/clients/operations/expected_client/fragments.py +++ b/tests/main/clients/operations/expected_client/fragments.py @@ -9,3 +9,7 @@ class FragmentB(BaseModel): class FragmentY(BaseModel): value_y: int = Field(alias="valueY") + + +FragmentB.model_rebuild() +FragmentY.model_rebuild() diff --git a/tests/main/clients/remote_schema/expected_client/async_base_client.py b/tests/main/clients/remote_schema/expected_client/async_base_client.py index d3ad17ef..5358ced6 100644 --- a/tests/main/clients/remote_schema/expected_client/async_base_client.py +++ b/tests/main/clients/remote_schema/expected_client/async_base_client.py @@ -163,6 +163,12 @@ async def execute_ws( **merged_kwargs, ) as websocket: await self._send_connection_init(websocket) + # wait for connection_ack from server + await self._handle_ws_message( + await websocket.recv(), + websocket, + expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK, + ) await self._send_subscribe( websocket, operation_id=operation_id, @@ -324,7 +330,10 @@ async def _send_subscribe( await websocket.send(json.dumps(payload)) async def _handle_ws_message( - self, message: Data, websocket: WebSocketClientProtocol + self, + message: Data, + websocket: WebSocketClientProtocol, + expected_type: Optional[GraphQLTransportWSMessageType] = None, ) -> Optional[Dict[str, Any]]: try: message_dict = json.loads(message) @@ -337,6 +346,11 @@ async def _handle_ws_message( if not type_ or type_ not in {t.value for t in GraphQLTransportWSMessageType}: raise GraphQLClientInvalidMessageFormat(message=message) + if expected_type and expected_type != type_: + raise GraphQLClientInvalidMessageFormat( + f"Invalid message received. Expected: {expected_type.value}" + ) + if type_ == GraphQLTransportWSMessageType.NEXT: if "data" not in payload: raise GraphQLClientInvalidMessageFormat(message=message) diff --git a/tests/main/clients/shorter_results/expected_client/async_base_client.py b/tests/main/clients/shorter_results/expected_client/async_base_client.py index d3ad17ef..5358ced6 100644 --- a/tests/main/clients/shorter_results/expected_client/async_base_client.py +++ b/tests/main/clients/shorter_results/expected_client/async_base_client.py @@ -163,6 +163,12 @@ async def execute_ws( **merged_kwargs, ) as websocket: await self._send_connection_init(websocket) + # wait for connection_ack from server + await self._handle_ws_message( + await websocket.recv(), + websocket, + expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK, + ) await self._send_subscribe( websocket, operation_id=operation_id, @@ -324,7 +330,10 @@ async def _send_subscribe( await websocket.send(json.dumps(payload)) async def _handle_ws_message( - self, message: Data, websocket: WebSocketClientProtocol + self, + message: Data, + websocket: WebSocketClientProtocol, + expected_type: Optional[GraphQLTransportWSMessageType] = None, ) -> Optional[Dict[str, Any]]: try: message_dict = json.loads(message) @@ -337,6 +346,11 @@ async def _handle_ws_message( if not type_ or type_ not in {t.value for t in GraphQLTransportWSMessageType}: raise GraphQLClientInvalidMessageFormat(message=message) + if expected_type and expected_type != type_: + raise GraphQLClientInvalidMessageFormat( + f"Invalid message received. Expected: {expected_type.value}" + ) + if type_ == GraphQLTransportWSMessageType.NEXT: if "data" not in payload: raise GraphQLClientInvalidMessageFormat(message=message) diff --git a/tests/main/clients/shorter_results/expected_client/shorter_results_fragments.py b/tests/main/clients/shorter_results/expected_client/shorter_results_fragments.py index f7e5e395..4b103a44 100644 --- a/tests/main/clients/shorter_results/expected_client/shorter_results_fragments.py +++ b/tests/main/clients/shorter_results/expected_client/shorter_results_fragments.py @@ -22,3 +22,7 @@ class ListAnimalsFragment(BaseModel): class ListAnimalsFragmentListAnimals(BaseModel): typename__: Literal["Animal", "Cat", "Dog"] = Field(alias="__typename") name: str + + +FragmentWithSingleField.model_rebuild() +ListAnimalsFragment.model_rebuild() diff --git a/tests/main/graphql_schemas/example/expected_schema.gql b/tests/main/graphql_schemas/example/expected_schema.gql new file mode 100644 index 00000000..2e19a47b --- /dev/null +++ b/tests/main/graphql_schemas/example/expected_schema.gql @@ -0,0 +1,59 @@ +type Query { + users(country: String): [User!]! +} + +type Mutation { + userCreate(userData: UserCreateInput!): User + userPreferences(data: UserPreferencesInput): Boolean! +} + +input UserCreateInput { + firstName: String + lastName: String + email: String! + favouriteColor: Color + location: LocationInput +} + +input LocationInput { + city: String + country: String +} + +type User { + id: ID! + firstName: String + lastName: String + email: String! + favouriteColor: Color + location: Location +} + +type Location { + city: String + country: String +} + +enum Color { + BLACK + WHITE + RED + GREEN + BLUE + YELLOW +} + +input UserPreferencesInput { + luckyNumber: Int = 7 + favouriteWord: String = "word" + colorOpacity: Float = 1 + excludedTags: [String!] = ["offtop", "tag123"] + notificationsPreferences: NotificationsPreferencesInput! = {receiveMails: true, receivePushNotifications: true, receiveSms: false, title: "Mr"} +} + +input NotificationsPreferencesInput { + receiveMails: Boolean! + receivePushNotifications: Boolean! + receiveSms: Boolean! + title: String! +} \ No newline at end of file diff --git a/tests/main/graphql_schemas/example/expected_schema.graphql b/tests/main/graphql_schemas/example/expected_schema.graphql new file mode 100644 index 00000000..2e19a47b --- /dev/null +++ b/tests/main/graphql_schemas/example/expected_schema.graphql @@ -0,0 +1,59 @@ +type Query { + users(country: String): [User!]! +} + +type Mutation { + userCreate(userData: UserCreateInput!): User + userPreferences(data: UserPreferencesInput): Boolean! +} + +input UserCreateInput { + firstName: String + lastName: String + email: String! + favouriteColor: Color + location: LocationInput +} + +input LocationInput { + city: String + country: String +} + +type User { + id: ID! + firstName: String + lastName: String + email: String! + favouriteColor: Color + location: Location +} + +type Location { + city: String + country: String +} + +enum Color { + BLACK + WHITE + RED + GREEN + BLUE + YELLOW +} + +input UserPreferencesInput { + luckyNumber: Int = 7 + favouriteWord: String = "word" + colorOpacity: Float = 1 + excludedTags: [String!] = ["offtop", "tag123"] + notificationsPreferences: NotificationsPreferencesInput! = {receiveMails: true, receivePushNotifications: true, receiveSms: false, title: "Mr"} +} + +input NotificationsPreferencesInput { + receiveMails: Boolean! + receivePushNotifications: Boolean! + receiveSms: Boolean! + title: String! +} \ No newline at end of file diff --git a/tests/main/graphql_schemas/example/pyproject-schema-gql.toml b/tests/main/graphql_schemas/example/pyproject-schema-gql.toml new file mode 100644 index 00000000..557b014f --- /dev/null +++ b/tests/main/graphql_schemas/example/pyproject-schema-gql.toml @@ -0,0 +1,3 @@ +[tool.ariadne-codegen] +schema_path = "schema.graphql" +target_file_path = "expected_schema.gql" diff --git a/tests/main/graphql_schemas/example/pyproject-schema-graphql.toml b/tests/main/graphql_schemas/example/pyproject-schema-graphql.toml new file mode 100644 index 00000000..c9083813 --- /dev/null +++ b/tests/main/graphql_schemas/example/pyproject-schema-graphql.toml @@ -0,0 +1,3 @@ +[tool.ariadne-codegen] +schema_path = "schema.graphql" +target_file_path = "expected_schema.graphql" diff --git a/tests/main/test_main.py b/tests/main/test_main.py index 569ef416..8dc755ae 100644 --- a/tests/main/test_main.py +++ b/tests/main/test_main.py @@ -315,6 +315,22 @@ def test_main_can_read_config_from_provided_file(tmp_path): "schema.py", GRAPHQL_SCHEMAS_PATH / "all_types" / "expected_schema.py", ), + ( + ( + GRAPHQL_SCHEMAS_PATH / "example" / "pyproject-schema-graphql.toml", + (GRAPHQL_SCHEMAS_PATH / "example" / "schema.graphql",), + ), + "expected_schema.graphql", + GRAPHQL_SCHEMAS_PATH / "example" / "expected_schema.graphql", + ), + ( + ( + GRAPHQL_SCHEMAS_PATH / "example" / "pyproject-schema-gql.toml", + (GRAPHQL_SCHEMAS_PATH / "example" / "schema.graphql",), + ), + "expected_schema.gql", + GRAPHQL_SCHEMAS_PATH / "example" / "expected_schema.gql", + ), ], indirect=["project_dir"], ) diff --git a/tests/main/test_model_rebuild_validation.py b/tests/main/test_model_rebuild_validation.py new file mode 100644 index 00000000..973b3104 --- /dev/null +++ b/tests/main/test_model_rebuild_validation.py @@ -0,0 +1,57 @@ +""" +To ensure all models with nested dependencies are fully rebuilt those tests +create an instance of the query from `multiple_fragments` containing the +`FullA` fragment (used by the `ExampleQuery2ExampleQuery`) which itself includes +a field of type `FullAFieldB` that extends the `FullB` fragment. + +If this model is not rebuilt with `FullA.model_rebuild()` `ExampleQuery2` will +not be fully defined and we will raise a `PydanticUserError`. + +Reference to Pydantic documentation about when and why we need to call +`model_rebuild` on our types: +https://errors.pydantic.dev/2.5/u/class-not-fully-defined +""" + +import pytest +from pydantic_core import ValidationError + +from .clients.multiple_fragments.expected_client.example_query_2 import ( + ExampleQuery2, + ExampleQuery2ExampleQuery, +) +from .clients.multiple_fragments.expected_client.fragments import FullA + + +def test_json_schema_contains_all_properties(): + json_schema = ExampleQuery2.model_json_schema() + assert "ExampleQuery2ExampleQuery" in json_schema["$defs"] + assert "FullAFieldB" in json_schema["$defs"] + + query_props = json_schema["$defs"]["ExampleQuery2ExampleQuery"]["properties"] + assert "id" in query_props + assert "value" in query_props + assert "fieldB" in query_props + assert query_props["fieldB"]["$ref"] == "#/$defs/FullAFieldB" + + +@pytest.fixture +def field_a_data(): + field_b = {"id": "321", "value": 13.37} + field_a = {"id": "123", "value": "A", "field_b": field_b} + + return field_a + + +def test_validate_field_a_on_faulty_model(field_a_data): + with pytest.raises(ValidationError): + ExampleQuery2.model_validate(field_a_data) + + +def test_validate_field_a_on_correct_model(field_a_data): + FullA.model_validate(field_a_data) + ExampleQuery2ExampleQuery.model_validate(field_a_data) + + +def test_validate_field_a_in_example_query(field_a_data): + example_query_2 = {"example_query": field_a_data} + ExampleQuery2.model_validate(example_query_2) diff --git a/tests/test_config.py b/tests/test_config.py index b728af64..2157fe02 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -112,6 +112,16 @@ def test_get_client_settings_returns_client_settings_object(tmp_path): ) } + # Regression test for #256: don't mutate config_dict's scalars + assert config_dict["tool"]["ariadne-codegen"]["scalars"] == { + "ID": { + "type": "str", + "parse": "parse_id", + "serialize": "serialize_id", + "import": ".custom_scalars", + }, + } + def test_get_client_settings_without_section_raises_missing_configuration_exception(): config_dict = {"invalid-section": {"schema_path": "."}} diff --git a/tests/test_settings.py b/tests/test_settings.py index 3a629287..11d03523 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -226,6 +226,45 @@ def test_graphq_schema_settings_without_schema_path_with_remote_schema_url_is_va assert not settings.schema_path +def test_graphql_schema_settings_with_target_file_path_with_py_extension_is_valid(): + settings = GraphQLSchemaSettings( + remote_schema_url="http://testserver/graphq/", + target_file_path="schema_file.py", + ) + + assert settings.target_file_path == "schema_file.py" + assert settings.target_file_format == "py" + + +def test_graphql_schema_settings_with_target_file_with_graphql_extension_is_valid(): + settings = GraphQLSchemaSettings( + remote_schema_url="http://testserver/graphq/", + target_file_path="schema_file.graphql", + ) + + assert settings.target_file_path == "schema_file.graphql" + assert settings.target_file_format == "graphql" + + +def test_graphql_schema_settings_with_target_file_path_with_gql_extension_is_valid(): + settings = GraphQLSchemaSettings( + remote_schema_url="http://testserver/graphq/", + target_file_path="schema_file.gql", + ) + + assert settings.target_file_path == "schema_file.gql" + assert settings.target_file_format == "gql" + + +def test_graphql_schema_settings_target_file_format_is_lowercased(): + settings = GraphQLSchemaSettings( + remote_schema_url="http://testserver/graphq/", + target_file_path="schema_file.GQL", + ) + + assert settings.target_file_format == "gql" + + def test_graphq_schema_settings_without_schema_path_or_remote_schema_url_is_not_valid(): with pytest.raises(InvalidConfiguration): GraphQLSchemaSettings() @@ -236,6 +275,22 @@ def test_graphql_schema_settings_raises_invalid_configuration_for_invalid_schema GraphQLSchemaSettings(schema_path="not_exisitng.graphql") +def test_graphql_schema_settings_with_target_file_missing_extension_raises_exception(): + with pytest.raises(InvalidConfiguration): + GraphQLSchemaSettings( + remote_schema_url="http://testserver/graphq/", + target_file_path="schema_file", + ) + + +def test_graphql_schema_settings_with_target_file_invalid_extension_raises_exception(): + with pytest.raises(InvalidConfiguration): + GraphQLSchemaSettings( + remote_schema_url="http://testserver/graphq/", + target_file_path="schema_file.invalid", + ) + + def test_graphql_schema_settings_with_invalid_schema_variable_name_raises_exception(): with pytest.raises(InvalidConfiguration): GraphQLSchemaSettings(