diff --git a/tests/fixtures/defxmlschema/chapter08/chapter08.py b/tests/fixtures/defxmlschema/chapter08/chapter08.py index 484be3320..82d61cef6 100644 --- a/tests/fixtures/defxmlschema/chapter08/chapter08.py +++ b/tests/fixtures/defxmlschema/chapter08/chapter08.py @@ -1,6 +1,6 @@ from enum import Enum from dataclasses import dataclass, field -from typing import List, Union +from typing import List class SmlxsizeType(Enum): @@ -16,6 +16,21 @@ class SmlxsizeType(Enum): EXTRA_LARGE = "extra large" +class XsmlxsizeType(Enum): + """ + :cvar SMALL: + :cvar MEDIUM: + :cvar LARGE: + :cvar EXTRA_LARGE: + :cvar EXTRA_SMALL: + """ + SMALL = "small" + MEDIUM = "medium" + LARGE = "large" + EXTRA_LARGE = "extra large" + EXTRA_SMALL = "extra small" + + @dataclass class SizesType: """ @@ -74,7 +89,7 @@ class SizesType: max_occurs=9223372036854775807 ) ) - xsmlx_size: List[Union[SmlxsizeType, "SizesType.Value"]] = field( + xsmlx_size: List[XsmlxsizeType] = field( default_factory=list, metadata=dict( name="xsmlxSize", @@ -85,12 +100,6 @@ class SizesType: ) ) - class Value(Enum): - """ - :cvar EXTRA_SMALL: - """ - EXTRA_SMALL = "extra small" - @dataclass class Sizes(SizesType): diff --git a/tests/fixtures/defxmlschema/chapter08/example0812.py b/tests/fixtures/defxmlschema/chapter08/example0812.py index 67aa0d18c..34e8c1053 100644 --- a/tests/fixtures/defxmlschema/chapter08/example0812.py +++ b/tests/fixtures/defxmlschema/chapter08/example0812.py @@ -1,25 +1,16 @@ from enum import Enum -from dataclasses import dataclass, field -from typing import Optional, Union -from tests.fixtures.defxmlschema.chapter08.example0809 import ( - SmlxsizeType, -) -@dataclass -class XsmlxsizeType: +class XsmlxsizeType(Enum): """ - :ivar value: + :cvar SMALL: + :cvar MEDIUM: + :cvar LARGE: + :cvar EXTRA_LARGE: + :cvar EXTRA_SMALL: """ - class Meta: - name = "XSMLXSizeType" - - value: Optional[Union[SmlxsizeType, "XsmlxsizeType.Value"]] = field( - default=None, - ) - - class Value(Enum): - """ - :cvar EXTRA_SMALL: - """ - EXTRA_SMALL = "extra small" + SMALL = "small" + MEDIUM = "medium" + LARGE = "large" + EXTRA_LARGE = "extra large" + EXTRA_SMALL = "extra small" diff --git a/tests/fixtures/defxmlschema/chapter10/example1010.py b/tests/fixtures/defxmlschema/chapter10/example1010.py index cfe41de26..f72582dbe 100644 --- a/tests/fixtures/defxmlschema/chapter10/example1010.py +++ b/tests/fixtures/defxmlschema/chapter10/example1010.py @@ -1,19 +1,14 @@ -from dataclasses import dataclass, field -from typing import List -from tests.fixtures.defxmlschema.chapter08.example0810 import ( - SmlsizeType, -) +from enum import Enum -@dataclass -class AvailableSizesType: +class AvailableSizesType(Enum): """ - :ivar value: + :cvar SMALL: + :cvar MEDIUM: + :cvar LARGE: + :cvar EXTRA_LARGE: """ - value: List[SmlsizeType] = field( - default_factory=list, - metadata=dict( - min_occurs=0, - max_occurs=9223372036854775807 - ) - ) + SMALL = "small" + MEDIUM = "medium" + LARGE = "large" + EXTRA_LARGE = "extra large" diff --git a/tests/fixtures/defxmlschema/chapter10/example1013.py b/tests/fixtures/defxmlschema/chapter10/example1013.py index 9c222a053..0899c3fef 100644 --- a/tests/fixtures/defxmlschema/chapter10/example1013.py +++ b/tests/fixtures/defxmlschema/chapter10/example1013.py @@ -1,27 +1,12 @@ from enum import Enum -from dataclasses import dataclass, field -from typing import List -@dataclass -class AvailableSizesType: +class AvailableSizesType(Enum): """ - :ivar value: + :cvar SMALL: + :cvar MEDIUM: + :cvar LARGE: """ - value: List["AvailableSizesType.Value"] = field( - default_factory=list, - metadata=dict( - min_occurs=0, - max_occurs=9223372036854775807 - ) - ) - - class Value(Enum): - """ - :cvar SMALL: - :cvar MEDIUM: - :cvar LARGE: - """ - SMALL = "small" - MEDIUM = "medium" - LARGE = "large" + SMALL = "small" + MEDIUM = "medium" + LARGE = "large" diff --git a/tests/fixtures/defxmlschema/chapter10/example1016.py b/tests/fixtures/defxmlschema/chapter10/example1016.py index c6e51a7ab..f918c1a6e 100644 --- a/tests/fixtures/defxmlschema/chapter10/example1016.py +++ b/tests/fixtures/defxmlschema/chapter10/example1016.py @@ -1,9 +1,7 @@ from enum import Enum -from dataclasses import dataclass, field -from typing import List -class SmlxsizeType(Enum): +class AvailableSizesType(Enum): """ :cvar SMALL: :cvar MEDIUM: @@ -16,15 +14,14 @@ class SmlxsizeType(Enum): EXTRA_LARGE = "extra large" -@dataclass -class AvailableSizesType: +class SmlxsizeType(Enum): """ - :ivar value: + :cvar SMALL: + :cvar MEDIUM: + :cvar LARGE: + :cvar EXTRA_LARGE: """ - value: List[SmlxsizeType] = field( - default_factory=list, - metadata=dict( - min_occurs=0, - max_occurs=9223372036854775807 - ) - ) + SMALL = "small" + MEDIUM = "medium" + LARGE = "large" + EXTRA_LARGE = "extra large" diff --git a/tests/test_analyzer.py b/tests/test_analyzer.py index 173bf633d..0b0683417 100644 --- a/tests/test_analyzer.py +++ b/tests/test_analyzer.py @@ -239,12 +239,14 @@ def test_find_simple_class(self): @mock.patch.object(ClassAnalyzer, "create_mixed_attribute") @mock.patch.object(ClassAnalyzer, "add_substitution_attrs") @mock.patch.object(ClassAnalyzer, "flatten_attribute_types") + @mock.patch.object(ClassAnalyzer, "flatten_enumeration_unions") @mock.patch.object(ClassAnalyzer, "flatten_extension") @mock.patch.object(ClassAnalyzer, "expand_attribute_group") def test_flatten_class( self, mock_expand_attribute_group, mock_flatten_extension, + mock_flatten_enumeration_unions, mock_flatten_attribute_types, mock_add_substitution_attrs, mock_create_mixed_attribute, @@ -267,6 +269,10 @@ def test_flatten_class( ] ) + mock_flatten_enumeration_unions.assert_has_calls( + [mock.call(target), mock.call(target.inner[0]),] + ) + mock_flatten_attribute_types.assert_has_calls( [mock.call(target, target.attrs[0]), mock.call(target, target.attrs[1])] ) @@ -557,6 +563,48 @@ def test_expand_attribute_group_with_unknown_source(self): self.assertEqual("Group attribute not found: `{foo}bar`", str(cm.exception)) + def test_flatten_enumeration_unions(self): + root_enum = ClassFactory.enumeration(2) + inner_enum = ClassFactory.enumeration(2) + target = ClassFactory.create( + type=Element, + attrs=[ + AttrFactory.create( + name="value", + types=[ + AttrTypeFactory.create(name=root_enum.name), + AttrTypeFactory.create(name=inner_enum.name, forward_ref=True), + AttrTypeFactory.xs_int(), + ], + ), + AttrFactory.create(), + ], + ) + target.inner.append(inner_enum) + self.analyzer.create_class_index([root_enum, target]) + + # Target has more than one attribute + self.analyzer.flatten_enumeration_unions(target) + self.assertFalse(target.is_enumeration) + + # Target has one attribute but is not a simple type + target.attrs.pop() + self.analyzer.flatten_enumeration_unions(target) + self.assertFalse(target.is_enumeration) + + # Attribute is not a union of enumerations only + target.type = SimpleType + self.analyzer.flatten_enumeration_unions(target) + self.assertFalse(target.is_enumeration) + + # Winner: single attr named with a union of enum types + target.attrs[0].types.pop() + self.analyzer.flatten_enumeration_unions(target) + self.assertTrue(target.is_enumeration) + + self.assertEqual(root_enum.attrs + inner_enum.attrs, target.attrs) + self.assertEqual(0, len(target.inner)) + def test_flatten_attribute_types_when_type_is_native(self): xs_bool = AttrTypeFactory.xs_bool() xs_decimal = AttrTypeFactory.xs_decimal() diff --git a/xsdata/analyzer.py b/xsdata/analyzer.py index 74434ebd6..5761d5a0f 100644 --- a/xsdata/analyzer.py +++ b/xsdata/analyzer.py @@ -1,6 +1,7 @@ from collections import defaultdict from dataclasses import dataclass from dataclasses import field +from typing import Any from typing import Callable from typing import Dict from typing import List @@ -202,6 +203,8 @@ def flatten_class(self, target: Class): for extension in reversed(target.extensions): self.flatten_extension(target, extension) + self.flatten_enumeration_unions(target) + for attr in list(target.attrs): self.flatten_attribute_types(target, attr) @@ -216,6 +219,27 @@ def flatten_class(self, target: Class): if id(inner) not in self.processed: self.flatten_class(inner) + def flatten_enumeration_unions(self, target: Class): + """Convert simple types with a single field which is a union of enums + to a standalone enumeration.""" + if len(target.attrs) == 1 and target.is_simple: + enums: List[Any] = list() + attr = target.attrs[0] + for attr_type in attr.types: + if attr_type.forward_ref: + enums.extend(target.inner) + elif not attr_type.native: + enums.append(self.find_attr_type(target, attr_type)) + else: + enums.append(None) + + merge = all(isinstance(x, Class) and x.is_enumeration for x in enums) + if merge: + target.attrs.clear() + target.inner.clear() + for enum in enums: + target.attrs.extend(enum.attrs) + def flatten_extension(self, target: Class, extension: Extension): """ Flatten target class extension based on the extension type.