Skip to content

Commit

Permalink
support async_client with custom operations
Browse files Browse the repository at this point in the history
  • Loading branch information
DamianCzajkowski committed Jul 26, 2024
1 parent a064d22 commit 48f3cdd
Show file tree
Hide file tree
Showing 17 changed files with 1,521 additions and 42 deletions.
132 changes: 93 additions & 39 deletions ariadne_codegen/client_generators/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,28 @@ def create_build_operation_ast_method(self):
return_type=generate_name("DocumentNode"),
)

def create_execute_custom_operation_method(self):
def create_execute_custom_operation_method(self, async_client: bool):
execute_call = generate_call(
func=generate_attribute(value=generate_name("self"), attr="execute"),
args=[
generate_call(
func=generate_name("print_ast"),
args=[generate_name("operation_ast")],
)
],
keywords=[
generate_keyword(
arg="variables", value=generate_name('combined_variables["values"]')
),
generate_keyword(
arg="operation_name", value=generate_name("operation_name")
),
],
)
response_value = (
generate_await(value=execute_call) if async_client else execute_call
)

method_body = [
generate_assign(
targets=["selections"],
Expand Down Expand Up @@ -549,54 +570,31 @@ def create_execute_custom_operation_method(self):
],
),
),
generate_assign(
targets=["response"],
value=generate_await(
value=generate_call(
func=generate_attribute(
value=generate_name("self"),
attr="execute",
),
args=[
generate_call(
func=generate_name("print_ast"),
args=[generate_name("operation_ast")],
)
],
keywords=[
generate_keyword(
arg="variables",
value=generate_name('combined_variables["values"]'),
),
generate_keyword(
arg="operation_name",
value=generate_name("operation_name"),
),
],
)
),
),
generate_assign(targets=["response"], value=response_value),
generate_return(
value=generate_call(
func=generate_attribute(
value=generate_name("self"),
attr="get_data",
value=generate_name("self"), attr="get_data"
),
args=[generate_name("response")],
)
),
]
return generate_async_method_definition(

method_definition = (
generate_async_method_definition
if async_client
else generate_method_definition
)

return method_definition(
name="execute_custom_operation",
arguments=generate_arguments(
args=[
generate_arg("self"),
generate_arg("*fields", annotation=generate_name("GraphQLField")),
generate_arg(
"operation_type",
annotation=generate_name(
"OperationType",
),
"operation_type", annotation=generate_name("OperationType")
),
generate_arg("operation_name", annotation=generate_name("str")),
]
Expand Down Expand Up @@ -655,7 +653,7 @@ def create_build_selection_set(self):
),
)

def add_execute_custom_operation_method(self):
def add_execute_custom_operation_method(self, async_client: bool):
self._add_import(
generate_import_from(
[
Expand All @@ -679,13 +677,20 @@ def add_execute_custom_operation_method(self):
)
self._add_import(generate_import_from([DICT, TUPLE, LIST, ANY], "typing"))

self._class_def.body.append(self.create_execute_custom_operation_method())
self._class_def.body.append(
self.create_execute_custom_operation_method(async_client)
)
self._class_def.body.append(self.create_combine_variables_method())
self._class_def.body.append(self.create_build_variable_definitions_method())
self._class_def.body.append(self.create_build_operation_ast_method())
self._class_def.body.append(self.create_build_selection_set())

def create_custom_operation_method(self, name, operation_type):
def create_custom_operation_method(
self,
name: str,
operation_type: str,
async_client: bool,
):
self._add_import(
generate_import_from(
[
Expand All @@ -694,6 +699,55 @@ def create_custom_operation_method(self, name, operation_type):
GRAPHQL_MODULE,
)
)
if async_client:
def_query = self._create_async_operation_method(name, operation_type)
else:
def_query = self._create_sync_operation_method(name, operation_type)
self._class_def.body.append(def_query)

def _create_sync_operation_method(self, name: str, operation_type: str):
body_return = generate_return(
value=generate_call(
func=generate_attribute(
value=generate_name("self"),
attr="execute_custom_operation",
),
args=[
generate_name("*fields"),
],
keywords=[
generate_keyword(
arg="operation_type",
value=generate_attribute(
value=generate_name("OperationType"),
attr=operation_type,
),
),
generate_keyword(
arg="operation_name", value=generate_name("operation_name")
),
],
)
)

def_query = generate_method_definition(
name=name,
arguments=generate_arguments(
args=[
generate_arg("self"),
generate_arg("*fields", annotation=generate_name("GraphQLField")),
generate_arg("operation_name", annotation=generate_name("str")),
],
),
body=[body_return],
return_type=generate_subscript(
generate_name(DICT),
generate_tuple([generate_name("str"), generate_name("Any")]),
),
)
return def_query

def _create_async_operation_method(self, name: str, operation_type: str):
body_return = generate_return(
value=generate_await(
value=generate_call(
Expand Down Expand Up @@ -734,7 +788,7 @@ def create_custom_operation_method(self, name, operation_type):
generate_tuple([generate_name("str"), generate_name("Any")]),
),
)
self._class_def.body.append(async_def_query)
return async_def_query

def get_variable_names(self, arguments: ast.arguments) -> Dict[str, str]:
mapped_variable_names = [
Expand Down
6 changes: 3 additions & 3 deletions ariadne_codegen/client_generators/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,16 +156,16 @@ def generate(self) -> List[str]:
if self.enable_custom_operations:
self._generate_custom_fields_typing()
self._generate_custom_fields()
self.client_generator.add_execute_custom_operation_method()
self.client_generator.add_execute_custom_operation_method(self.async_client)
if self.custom_query_generator:
self._generate_custom_queries()
self.client_generator.create_custom_operation_method(
"query", OperationType.QUERY.value.upper()
"query", OperationType.QUERY.value.upper(), self.async_client
)
if self.custom_mutation_generator:
self._generate_custom_mutations()
self.client_generator.create_custom_operation_method(
"mutation", OperationType.MUTATION.value.upper()
"mutation", OperationType.MUTATION.value.upper(), self.async_client
)

self._generate_client()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from .base_client import BaseClient
from .base_model import BaseModel, Upload
from .client import Client
from .enums import MetadataErrorCode
from .exceptions import (
GraphQLClientError,
GraphQLClientGraphQLError,
GraphQLClientGraphQLMultiError,
GraphQLClientHttpError,
GraphQLClientInvalidResponseError,
)

__all__ = [
"BaseClient",
"BaseModel",
"Client",
"GraphQLClientError",
"GraphQLClientGraphQLError",
"GraphQLClientGraphQLMultiError",
"GraphQLClientHttpError",
"GraphQLClientInvalidResponseError",
"MetadataErrorCode",
"Upload",
]
Loading

0 comments on commit 48f3cdd

Please sign in to comment.