Skip to content

Commit

Permalink
Merge branch 'mirumee:main' into test_plugin_scalar_fix
Browse files Browse the repository at this point in the history
  • Loading branch information
alzex3 authored Jan 28, 2024
2 parents badc2da + 9edc90b commit 6519723
Show file tree
Hide file tree
Showing 42 changed files with 751 additions and 57 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# CHANGELOG

## 0.12.0 (UNRELEASED)

- Fixed `graphql-transport-ws` protocol implementation not waiting for the `connection_ack` message on new connection.
- Fixed `get_client_settings` mutating `config_dict` instance.
- Added support to `graphqlschema` for saving schema as a GraphQL file.
- Restored `model_rebuild` calls for top level fragment models.


## 0.11.0 (2023-12-05)

- Removed `model_rebuild` calls for generated input, fragment and result models.
Expand Down
4 changes: 4 additions & 0 deletions EXAMPLE.md
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,10 @@ class BasicUser(BaseModel):
class UserPersonalData(BaseModel):
first_name: Optional[str] = Field(alias="firstName")
last_name: Optional[str] = Field(alias="lastName")


BasicUser.model_rebuild()
UserPersonalData.model_rebuild()
```

### Init file
Expand Down
36 changes: 29 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ plugins = ["ariadne_codegen.contrib.extract_operations.ExtractOperationsPlugin"]
operations_module_name = "custom_operations_module_name"
```

- [`ariadne_codegen.contrib.extract_operations.NoReimportsPlugin`](ariadne_codegen/contrib/no_reimports.py) - This plugin removes content of generated `__init__.py`. This is useful in scenarios where generated plugins contain so many Pydantic models that client's eager initialization of entire package on first import is very slow.
- [`ariadne_codegen.contrib.no_reimports.NoReimportsPlugin`](ariadne_codegen/contrib/no_reimports.py) - This plugin removes content of generated `__init__.py`. This is useful in scenarios where generated plugins contain so many Pydantic models that client's eager initialization of entire package on first import is very slow.


## Using generated client
Expand Down Expand Up @@ -323,23 +323,45 @@ Example with simple schema and few queries and mutations is available [here](htt

## Generating graphql schema's python representation

Instead of generating client, you can generate file with a copy of GraphQL schema as `GraphQLSchema` declaration. To do this call `ariadne-codegen` with `graphqlschema` argument:
Instead of generating a client, you can generate a file with a copy of a GraphQL schema. To do this call `ariadne-codegen` with `graphqlschema` argument:

```
ariadne-codegen graphqlschema
```

`graphqlschema` mode reads configuration from the same place as [`client`](#configuration) but uses only `schema_path`, `remote_schema_url`, `remote_schema_headers`, `remote_schema_verify_ssl` and `plugins` options with addition to some extra options specific to it:
`graphqlschema` mode reads configuration from the same place as [`client`](#configuration) but uses only `schema_path`, `remote_schema_url`, `remote_schema_headers`, `remote_schema_verify_ssl` options to retrieve the schema and `plugins` option to load plugins.

In addition to the above, `graphqlschema` mode also accepts additional settings specific to it:


### `target_file_path`

- `target_file_path` (defaults to `"schema.py"`) - destination path for generated file
- `schema_variable_name` (defaults to `"schema"`) - name for schema variable, must be valid python identifier
- `type_map_variable_name` (defaults to `"type_map"`) - name for type map variable, must be valid python identifier
A string with destination path for generated file. Must be either a Python (`.py`), or GraphQL (`.graphql` or `.gql`) file.

Generated file contains:
Defaults to `schema.py`.

Generated Python file will contain:

- Necessary imports
- Type map declaration `{type_map_variable_name}: TypeMap = {...}`
- Schema declaration `{schema_variable_name}: GraphQLSchema = GraphQLSchema(...)`

Generated GraphQL file will contain a formatted output of the `print_schema` function from the `graphql-core` package.


### `schema_variable_name`

A string with a name for schema variable, must be valid python identifier.

Defaults to `"schema"`. Used only if target is a Python file.


### `type_map_variable_name`

A string with a name for type map variable, must be valid python identifier.

Defaults to `"type_map"`. Used only if target is a Python file.


## Contributing

Expand Down
1 change: 1 addition & 0 deletions ariadne_codegen/client_generators/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
MODEL_VALIDATE_METHOD = "model_validate"
PLAIN_SERIALIZER = "PlainSerializer"
BEFORE_VALIDATOR = "BeforeValidator"
MODEL_REBUILD_METHOD = "model_rebuild"

ENUM_MODULE = "enum"
ENUM_CLASS = "Enum"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,12 @@ async def execute_ws(
**merged_kwargs,
) as websocket:
await self._send_connection_init(websocket)
# wait for connection_ack from server
await self._handle_ws_message(
await websocket.recv(),
websocket,
expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK,
)
await self._send_subscribe(
websocket,
operation_id=operation_id,
Expand Down Expand Up @@ -324,7 +330,10 @@ async def _send_subscribe(
await websocket.send(json.dumps(payload))

async def _handle_ws_message(
self, message: Data, websocket: WebSocketClientProtocol
self,
message: Data,
websocket: WebSocketClientProtocol,
expected_type: Optional[GraphQLTransportWSMessageType] = None,
) -> Optional[Dict[str, Any]]:
try:
message_dict = json.loads(message)
Expand All @@ -337,6 +346,11 @@ async def _handle_ws_message(
if not type_ or type_ not in {t.value for t in GraphQLTransportWSMessageType}:
raise GraphQLClientInvalidMessageFormat(message=message)

if expected_type and expected_type != type_:
raise GraphQLClientInvalidMessageFormat(
f"Invalid message received. Expected: {expected_type.value}"
)

if type_ == GraphQLTransportWSMessageType.NEXT:
if "data" not in payload:
raise GraphQLClientInvalidMessageFormat(message=message)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,12 @@ async def _execute_ws(
**merged_kwargs,
) as websocket:
await self._send_connection_init(websocket)
# wait for connection_ack from server
await self._handle_ws_message(
await websocket.recv(),
websocket,
expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK,
)
await self._send_subscribe(
websocket,
operation_id=operation_id,
Expand Down Expand Up @@ -414,7 +420,10 @@ async def _send_subscribe(
await websocket.send(json.dumps(payload))

async def _handle_ws_message(
self, message: Data, websocket: WebSocketClientProtocol
self,
message: Data,
websocket: WebSocketClientProtocol,
expected_type: Optional[GraphQLTransportWSMessageType] = None,
) -> Optional[Dict[str, Any]]:
try:
message_dict = json.loads(message)
Expand All @@ -427,6 +436,11 @@ async def _handle_ws_message(
if not type_ or type_ not in {t.value for t in GraphQLTransportWSMessageType}:
raise GraphQLClientInvalidMessageFormat(message=message)

if expected_type and expected_type != type_:
raise GraphQLClientInvalidMessageFormat(
f"Invalid message received. Expected: {expected_type.value}"
)

if type_ == GraphQLTransportWSMessageType.NEXT:
if "data" not in payload:
raise GraphQLClientInvalidMessageFormat(message=message)
Expand Down Expand Up @@ -563,6 +577,13 @@ async def _execute_ws_with_telemetry(
root_span=root_span,
websocket=websocket,
)
# wait for connection_ack from server
await self._handle_ws_message_with_telemetry(
root_span=root_span,
message=await websocket.recv(),
websocket=websocket,
expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK,
)
await self._send_subscribe_with_telemetry(
root_span=root_span,
websocket=websocket,
Expand Down Expand Up @@ -628,7 +649,11 @@ async def _send_subscribe_with_telemetry(
)

async def _handle_ws_message_with_telemetry(
self, root_span: Span, message: Data, websocket: WebSocketClientProtocol
self,
root_span: Span,
message: Data,
websocket: WebSocketClientProtocol,
expected_type: Optional[GraphQLTransportWSMessageType] = None,
) -> Optional[Dict[str, Any]]:
with self.tracer.start_as_current_span( # type: ignore
"received message", context=set_span_in_context(root_span)
Expand All @@ -650,6 +675,11 @@ async def _handle_ws_message_with_telemetry(
}:
raise GraphQLClientInvalidMessageFormat(message=message)

if expected_type and expected_type != type_:
raise GraphQLClientInvalidMessageFormat(
f"Invalid message received. Expected: {expected_type.value}"
)

if type_ == GraphQLTransportWSMessageType.NEXT:
if "data" not in payload:
raise GraphQLClientInvalidMessageFormat(message=message)
Expand Down
32 changes: 28 additions & 4 deletions ariadne_codegen/client_generators/fragments.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

from graphql import FragmentDefinitionNode, GraphQLSchema

from ..codegen import generate_module
from ..codegen import generate_expr, generate_method_call, generate_module
from ..plugins.manager import PluginManager
from .constants import BASE_MODEL_IMPORT
from .constants import BASE_MODEL_IMPORT, MODEL_REBUILD_METHOD
from .result_types import ResultTypesGenerator
from .scalars import ScalarData

Expand Down Expand Up @@ -36,6 +36,7 @@ def __init__(
def generate(self, exclude_names: Optional[Set[str]] = None) -> ast.Module:
class_defs_dict: Dict[str, List[ast.ClassDef]] = {}
imports: List[ast.ImportFrom] = []
top_level_class_names: List[str] = []
dependencies_dict: Dict[str, Set[str]] = {}

names_to_exclude = exclude_names or set()
Expand All @@ -53,7 +54,10 @@ def generate(self, exclude_names: Optional[Set[str]] = None) -> ast.Module:
plugin_manager=self.plugin_manager,
)
imports.extend(generator.get_imports())
class_defs_dict[name] = generator.get_classes()
class_defs = generator.get_classes()
class_defs_dict[name] = class_defs
if class_defs:
top_level_class_names.append(class_defs[0].name)
dependencies_dict[name] = generator.get_fragments_used_as_mixins()
self._generated_public_names.extend(generator.get_generated_public_names())
self._used_enums.extend(generator.get_used_enums())
Expand All @@ -62,7 +66,15 @@ def generate(self, exclude_names: Optional[Set[str]] = None) -> ast.Module:
class_defs_dict=class_defs_dict, dependencies_dict=dependencies_dict
)
module = generate_module(
body=cast(List[ast.stmt], imports) + cast(List[ast.stmt], sorted_class_defs)
body=cast(List[ast.stmt], imports)
+ cast(List[ast.stmt], sorted_class_defs)
+ cast(
List[ast.stmt],
self._get_model_rebuild_calls(
top_level_fragments_names=top_level_class_names,
class_defs=sorted_class_defs,
),
)
)
if self.plugin_manager:
module = self.plugin_manager.generate_fragments_module(
Expand Down Expand Up @@ -108,3 +120,15 @@ def visit(name):
visit(name)

return sorted_names

def _get_model_rebuild_calls(
self, top_level_fragments_names: List[str], class_defs: List[ast.ClassDef]
) -> List[ast.Call]:
class_names = [c.name for c in class_defs]
sorted_fragments_names = sorted(
top_level_fragments_names, key=class_names.index
)
return [
generate_expr(generate_method_call(name, MODEL_REBUILD_METHOD))
for name in sorted_fragments_names
]
2 changes: 1 addition & 1 deletion ariadne_codegen/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_config_dict(config_file_name: Optional[str] = None) -> Dict:

def get_client_settings(config_dict: Dict) -> ClientSettings:
"""Parse configuration dict and return ClientSettings instance."""
section = get_section(config_dict)
section = get_section(config_dict).copy()
settings_fields_names = {f.name for f in fields(ClientSettings)}
try:
section["scalars"] = {
Expand Down
8 changes: 6 additions & 2 deletions ariadne_codegen/graphql_schema_generators/schema.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ast
from pathlib import Path

from graphql import GraphQLSchema
from graphql import GraphQLSchema, print_schema
from graphql.type.schema import TypeMap

from ..codegen import (
Expand All @@ -23,7 +23,11 @@
from .utils import get_optional_named_type


def generate_graphql_schema_file(
def generate_graphql_schema_graphql_file(schema: GraphQLSchema, target_file_path: str):
Path(target_file_path).write_text(print_schema(schema), encoding="UTF-8")


def generate_graphql_schema_python_file(
schema: GraphQLSchema,
target_file_path: str,
type_map_name: str,
Expand Down
23 changes: 16 additions & 7 deletions ariadne_codegen/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@

from .client_generators.package import get_package_generator
from .config import get_client_settings, get_config_dict, get_graphql_schema_settings
from .graphql_schema_generators.schema import generate_graphql_schema_file
from .graphql_schema_generators.schema import (
generate_graphql_schema_graphql_file,
generate_graphql_schema_python_file,
)
from .plugins.explorer import get_plugins_types
from .plugins.manager import PluginManager
from .schema import (
Expand Down Expand Up @@ -99,9 +102,15 @@ def graphql_schema(config_dict):

sys.stdout.write(settings.used_settings_message)

generate_graphql_schema_file(
schema=schema,
target_file_path=settings.target_file_path,
type_map_name=settings.type_map_variable_name,
schema_variable_name=settings.schema_variable_name,
)
if settings.target_file_format == "py":
generate_graphql_schema_python_file(
schema=schema,
target_file_path=settings.target_file_path,
type_map_name=settings.type_map_variable_name,
schema_variable_name=settings.schema_variable_name,
)
else:
generate_graphql_schema_graphql_file(
schema=schema,
target_file_path=settings.target_file_path,
)
Loading

0 comments on commit 6519723

Please sign in to comment.