Skip to content

Commit

Permalink
Add ability to pass explicit localns (and globalns) to class_schema
Browse files Browse the repository at this point in the history
When class_schema is called, it doesn't need the caller's whole stack
frame.  What it really wants is a `localns` to pass to
`typing.get_type_hints` to be used to resolve type references.

Here we add the ability to pass an explicit `localns` parameter to
`class_schema`.  We also add the ability to pass an explicit
`globalns`, because ... might as well — it might come in useful.
(Since we need these only to pass to `get_type_hints`, we might
as well match `get_type_hints` API as closely as possible.)
  • Loading branch information
dairiki committed Jan 19, 2023
1 parent fd04f8c commit 9093446
Showing 1 changed file with 44 additions and 13 deletions.
57 changes: 44 additions & 13 deletions marshmallow_dataclass/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,36 @@ def decorator(clazz: Type[_U], stacklevel: int = stacklevel) -> Type[_U]:
return decorator(_cls, stacklevel=stacklevel + 1)


@overload
def class_schema(
clazz: type,
base_schema: Optional[Type[marshmallow.Schema]] = None,
*,
globalns: Optional[Dict[str, Any]] = None,
localns: Optional[Dict[str, Any]] = None,
) -> Type[marshmallow.Schema]:
...


@overload
def class_schema(
clazz: type,
base_schema: Optional[Type[marshmallow.Schema]] = None,
clazz_frame: Optional[types.FrameType] = None,
*,
globalns: Optional[Dict[str, Any]] = None,
) -> Type[marshmallow.Schema]:
...


def class_schema(
clazz: type,
base_schema: Optional[Type[marshmallow.Schema]] = None,
# FIXME: delete clazz_frame from API?
clazz_frame: Optional[types.FrameType] = None,
*,
globalns: Optional[Dict[str, Any]] = None,
localns: Optional[Dict[str, Any]] = None,
) -> Type[marshmallow.Schema]:
"""
Convert a class to a marshmallow schema
Expand Down Expand Up @@ -398,24 +424,26 @@ def class_schema(
"""
if not dataclasses.is_dataclass(clazz):
clazz = dataclasses.dataclass(clazz)
if not clazz_frame:
clazz_frame = _maybe_get_callers_frame(clazz)

with _SchemaContext(clazz_frame):
if localns is None:
if clazz_frame is None:
clazz_frame = _maybe_get_callers_frame(clazz)
if clazz_frame is not None:
localns = clazz_frame.f_locals
with _SchemaContext(globalns, localns):
return _internal_class_schema(clazz, base_schema)


class _SchemaContext:
"""Global context for an invocation of class_schema."""

def __init__(self, frame: Optional[types.FrameType]):
def __init__(
self,
globalns: Optional[Dict[str, Any]] = None,
localns: Optional[Dict[str, Any]] = None,
):
self.seen_classes: Dict[type, str] = {}
self.frame = frame

def get_type_hints(self, cls: Type) -> Dict[str, Any]:
frame = self.frame
localns = frame.f_locals if frame is not None else None
return get_type_hints(cls, localns=localns)
self.globalns = globalns
self.localns = localns

def __enter__(self) -> "_SchemaContext":
_schema_ctx_stack.push(self)
Expand Down Expand Up @@ -486,7 +514,9 @@ def _internal_class_schema(
}

# Update the schema members to contain marshmallow fields instead of dataclass fields
type_hints = schema_ctx.get_type_hints(clazz)
type_hints = get_type_hints(
clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns
)
attributes.update(
(
field.name,
Expand Down Expand Up @@ -670,6 +700,7 @@ def field_for_schema(
default: Any = marshmallow.missing,
metadata: Optional[Mapping[str, Any]] = None,
base_schema: Optional[Type[marshmallow.Schema]] = None,
# FIXME: delete typ_frame from API?
typ_frame: Optional[types.FrameType] = None,
) -> marshmallow.fields.Field:
"""
Expand All @@ -692,7 +723,7 @@ def field_for_schema(
>>> field_for_schema(str, metadata={"marshmallow_field": marshmallow.fields.Url()}).__class__
<class 'marshmallow.fields.Url'>
"""
with _SchemaContext(typ_frame):
with _SchemaContext(localns=typ_frame.f_locals if typ_frame is not None else None):
return _field_for_schema(typ, default, metadata, base_schema)


Expand Down

0 comments on commit 9093446

Please sign in to comment.