Skip to content

Commit

Permalink
Re-introduce merging enumeration unions
Browse files Browse the repository at this point in the history
This was supposed to be fixed
  • Loading branch information
tefra committed May 15, 2020
1 parent 99fc56d commit 125fae1
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 78 deletions.
25 changes: 17 additions & 8 deletions tests/fixtures/defxmlschema/chapter08/chapter08.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from enum import Enum
from dataclasses import dataclass, field
from typing import List, Union
from typing import List


class SmlxsizeType(Enum):
Expand All @@ -16,6 +16,21 @@ class SmlxsizeType(Enum):
EXTRA_LARGE = "extra large"


class XsmlxsizeType(Enum):
"""
:cvar SMALL:
:cvar MEDIUM:
:cvar LARGE:
:cvar EXTRA_LARGE:
:cvar EXTRA_SMALL:
"""
SMALL = "small"
MEDIUM = "medium"
LARGE = "large"
EXTRA_LARGE = "extra large"
EXTRA_SMALL = "extra small"


@dataclass
class SizesType:
"""
Expand Down Expand Up @@ -74,7 +89,7 @@ class SizesType:
max_occurs=9223372036854775807
)
)
xsmlx_size: List[Union[SmlxsizeType, "SizesType.Value"]] = field(
xsmlx_size: List[XsmlxsizeType] = field(
default_factory=list,
metadata=dict(
name="xsmlxSize",
Expand All @@ -85,12 +100,6 @@ class SizesType:
)
)

class Value(Enum):
"""
:cvar EXTRA_SMALL:
"""
EXTRA_SMALL = "extra small"


@dataclass
class Sizes(SizesType):
Expand Down
31 changes: 11 additions & 20 deletions tests/fixtures/defxmlschema/chapter08/example0812.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,16 @@
from enum import Enum
from dataclasses import dataclass, field
from typing import Optional, Union
from tests.fixtures.defxmlschema.chapter08.example0809 import (
SmlxsizeType,
)


@dataclass
class XsmlxsizeType:
class XsmlxsizeType(Enum):
"""
:ivar value:
:cvar SMALL:
:cvar MEDIUM:
:cvar LARGE:
:cvar EXTRA_LARGE:
:cvar EXTRA_SMALL:
"""
class Meta:
name = "XSMLXSizeType"

value: Optional[Union[SmlxsizeType, "XsmlxsizeType.Value"]] = field(
default=None,
)

class Value(Enum):
"""
:cvar EXTRA_SMALL:
"""
EXTRA_SMALL = "extra small"
SMALL = "small"
MEDIUM = "medium"
LARGE = "large"
EXTRA_LARGE = "extra large"
EXTRA_SMALL = "extra small"
25 changes: 10 additions & 15 deletions tests/fixtures/defxmlschema/chapter10/example1010.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
from dataclasses import dataclass, field
from typing import List
from tests.fixtures.defxmlschema.chapter08.example0810 import (
SmlsizeType,
)
from enum import Enum


@dataclass
class AvailableSizesType:
class AvailableSizesType(Enum):
"""
:ivar value:
:cvar SMALL:
:cvar MEDIUM:
:cvar LARGE:
:cvar EXTRA_LARGE:
"""
value: List[SmlsizeType] = field(
default_factory=list,
metadata=dict(
min_occurs=0,
max_occurs=9223372036854775807
)
)
SMALL = "small"
MEDIUM = "medium"
LARGE = "large"
EXTRA_LARGE = "extra large"
29 changes: 7 additions & 22 deletions tests/fixtures/defxmlschema/chapter10/example1013.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,12 @@
from enum import Enum
from dataclasses import dataclass, field
from typing import List


@dataclass
class AvailableSizesType:
class AvailableSizesType(Enum):
"""
:ivar value:
:cvar SMALL:
:cvar MEDIUM:
:cvar LARGE:
"""
value: List["AvailableSizesType.Value"] = field(
default_factory=list,
metadata=dict(
min_occurs=0,
max_occurs=9223372036854775807
)
)

class Value(Enum):
"""
:cvar SMALL:
:cvar MEDIUM:
:cvar LARGE:
"""
SMALL = "small"
MEDIUM = "medium"
LARGE = "large"
SMALL = "small"
MEDIUM = "medium"
LARGE = "large"
23 changes: 10 additions & 13 deletions tests/fixtures/defxmlschema/chapter10/example1016.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from enum import Enum
from dataclasses import dataclass, field
from typing import List


class SmlxsizeType(Enum):
class AvailableSizesType(Enum):
"""
:cvar SMALL:
:cvar MEDIUM:
Expand All @@ -16,15 +14,14 @@ class SmlxsizeType(Enum):
EXTRA_LARGE = "extra large"


@dataclass
class AvailableSizesType:
class SmlxsizeType(Enum):
"""
:ivar value:
:cvar SMALL:
:cvar MEDIUM:
:cvar LARGE:
:cvar EXTRA_LARGE:
"""
value: List[SmlxsizeType] = field(
default_factory=list,
metadata=dict(
min_occurs=0,
max_occurs=9223372036854775807
)
)
SMALL = "small"
MEDIUM = "medium"
LARGE = "large"
EXTRA_LARGE = "extra large"
48 changes: 48 additions & 0 deletions tests/test_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,12 +239,14 @@ def test_find_simple_class(self):
@mock.patch.object(ClassAnalyzer, "create_mixed_attribute")
@mock.patch.object(ClassAnalyzer, "add_substitution_attrs")
@mock.patch.object(ClassAnalyzer, "flatten_attribute_types")
@mock.patch.object(ClassAnalyzer, "flatten_enumeration_unions")
@mock.patch.object(ClassAnalyzer, "flatten_extension")
@mock.patch.object(ClassAnalyzer, "expand_attribute_group")
def test_flatten_class(
self,
mock_expand_attribute_group,
mock_flatten_extension,
mock_flatten_enumeration_unions,
mock_flatten_attribute_types,
mock_add_substitution_attrs,
mock_create_mixed_attribute,
Expand All @@ -267,6 +269,10 @@ def test_flatten_class(
]
)

mock_flatten_enumeration_unions.assert_has_calls(
[mock.call(target), mock.call(target.inner[0]),]
)

mock_flatten_attribute_types.assert_has_calls(
[mock.call(target, target.attrs[0]), mock.call(target, target.attrs[1])]
)
Expand Down Expand Up @@ -557,6 +563,48 @@ def test_expand_attribute_group_with_unknown_source(self):

self.assertEqual("Group attribute not found: `{foo}bar`", str(cm.exception))

def test_flatten_enumeration_unions(self):
root_enum = ClassFactory.enumeration(2)
inner_enum = ClassFactory.enumeration(2)
target = ClassFactory.create(
type=Element,
attrs=[
AttrFactory.create(
name="value",
types=[
AttrTypeFactory.create(name=root_enum.name),
AttrTypeFactory.create(name=inner_enum.name, forward_ref=True),
AttrTypeFactory.xs_int(),
],
),
AttrFactory.create(),
],
)
target.inner.append(inner_enum)
self.analyzer.create_class_index([root_enum, target])

# Target has more than one attribute
self.analyzer.flatten_enumeration_unions(target)
self.assertFalse(target.is_enumeration)

# Target has one attribute but is not a simple type
target.attrs.pop()
self.analyzer.flatten_enumeration_unions(target)
self.assertFalse(target.is_enumeration)

# Attribute is not a union of enumerations only
target.type = SimpleType
self.analyzer.flatten_enumeration_unions(target)
self.assertFalse(target.is_enumeration)

# Winner: single attr named with a union of enum types
target.attrs[0].types.pop()
self.analyzer.flatten_enumeration_unions(target)
self.assertTrue(target.is_enumeration)

self.assertEqual(root_enum.attrs + inner_enum.attrs, target.attrs)
self.assertEqual(0, len(target.inner))

def test_flatten_attribute_types_when_type_is_native(self):
xs_bool = AttrTypeFactory.xs_bool()
xs_decimal = AttrTypeFactory.xs_decimal()
Expand Down
24 changes: 24 additions & 0 deletions xsdata/analyzer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import defaultdict
from dataclasses import dataclass
from dataclasses import field
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
Expand Down Expand Up @@ -202,6 +203,8 @@ def flatten_class(self, target: Class):
for extension in reversed(target.extensions):
self.flatten_extension(target, extension)

self.flatten_enumeration_unions(target)

for attr in list(target.attrs):
self.flatten_attribute_types(target, attr)

Expand All @@ -216,6 +219,27 @@ def flatten_class(self, target: Class):
if id(inner) not in self.processed:
self.flatten_class(inner)

def flatten_enumeration_unions(self, target: Class):
"""Convert simple types with a single field which is a union of enums
to a standalone enumeration."""
if len(target.attrs) == 1 and target.is_simple:
enums: List[Any] = list()
attr = target.attrs[0]
for attr_type in attr.types:
if attr_type.forward_ref:
enums.extend(target.inner)
elif not attr_type.native:
enums.append(self.find_attr_type(target, attr_type))
else:
enums.append(None)

merge = all(isinstance(x, Class) and x.is_enumeration for x in enums)
if merge:
target.attrs.clear()
target.inner.clear()
for enum in enums:
target.attrs.extend(enum.attrs)

def flatten_extension(self, target: Class, extension: Extension):
"""
Flatten target class extension based on the extension type.
Expand Down

0 comments on commit 125fae1

Please sign in to comment.