From 267148a9064f847fe9f59add5a121dabbaff41e2 Mon Sep 17 00:00:00 2001 From: Chris Date: Thu, 14 May 2020 18:34:12 +0300 Subject: [PATCH] Cleanup & docstrings --- .pre-commit-config.yaml | 2 +- tests/test_analyzer.py | 24 ++-------- tests/utils/test_classes.py | 11 +++++ xsdata/analyzer.py | 56 ++++++++++++++++++----- xsdata/formats/dataclass/filters.py | 22 ++++----- xsdata/utils/classes.py | 69 ++++++++++++++++++++++++++++- 6 files changed, 139 insertions(+), 45 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2ca02cc2e..7c670b214 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,7 +15,7 @@ repos: hooks: - id: black - repo: https://gitlab.com/pycqa/flake8 - rev: 3.8.0a2 + rev: 3.8.1 hooks: - id: flake8 additional_dependencies: [flake8-bugbear, flake8-annotations] diff --git a/tests/test_analyzer.py b/tests/test_analyzer.py index bd2c0d6d2..173bf633d 100644 --- a/tests/test_analyzer.py +++ b/tests/test_analyzer.py @@ -292,24 +292,22 @@ def test_flatten_extension_with_native_type(self, mock_flatten_extension_native) mock_flatten_extension_native.assert_called_once_with(target, extension) @mock.patch.object(ClassAnalyzer, "create_default_attribute") - def test_flatten_extension_native_and_target_no_enumeration( - self, mock_create_default_attribute - ): + def test_flatten_extension_native(self, mock_create_default_attribute): extension = ExtensionFactory.create() target = ClassFactory.elements(1) self.analyzer.flatten_extension_native(target, extension) mock_create_default_attribute.assert_called_once_with(target, extension) - @mock.patch.object(ClassAnalyzer, "create_default_attribute") + @mock.patch.object(ClassAnalyzer, "copy_extension_type") def test_flatten_extension_native_and_target_enumeration( - self, mock_create_default_attribute + self, mock_copy_extension_type ): extension = ExtensionFactory.create() target = ClassFactory.enumeration(1) self.analyzer.flatten_extension_native(target, extension) - self.assertEqual(0, mock_create_default_attribute.call_count) + mock_copy_extension_type.assert_called_once_with(target, extension) @mock.patch.object(ClassAnalyzer, "flatten_extension_simple") @mock.patch.object(ClassAnalyzer, "find_simple_class") @@ -597,20 +595,6 @@ def test_flatten_attribute_types_ignores_forward_types( mock_flatten_attribute_type.assert_called_once_with(parent, attr, type_a) - def test_flatten_attribute_types_with_enumeration_target(self): - target = ClassFactory.enumeration( - 1, - extensions=[ - ExtensionFactory.create(type=AttrTypeFactory.xs_bool()), - ExtensionFactory.create(type=AttrTypeFactory.xs_int()), - ], - ) - - self.assertEqual(1, len(target.attrs[0].types)) - - self.analyzer.flatten_attribute_types(target, target.attrs[0]) - self.assertEqual(3, len(target.attrs[0].types)) - def test_flatten_attribute_types_filters_duplicate_types(self): target = ClassFactory.create( attrs=[ diff --git a/tests/utils/test_classes.py b/tests/utils/test_classes.py index 49b8ed202..64bbb5e39 100644 --- a/tests/utils/test_classes.py +++ b/tests/utils/test_classes.py @@ -468,3 +468,14 @@ def test_copy_inner_classes(self): self.assertTrue(attr.types[0].self_ref) self.assertFalse(attr.types[1].self_ref) self.assertFalse(attr.types[2].self_ref) + + def test_copy_extension_type(self): + extension = ExtensionFactory.create() + target = ClassFactory.elements(2) + target.extensions.append(extension) + + ClassUtils.copy_extension_type(target, extension) + + self.assertEqual(extension.type, target.attrs[0].types[1]) + self.assertEqual(extension.type, target.attrs[1].types[1]) + self.assertEqual(0, len(target.extensions)) diff --git a/xsdata/analyzer.py b/xsdata/analyzer.py index 73e76572a..74434ebd6 100644 --- a/xsdata/analyzer.py +++ b/xsdata/analyzer.py @@ -95,13 +95,10 @@ def remove_invalid_classes(self, classes: List[Class]): def fetch_classes_for_generation(self) -> List[Class]: """ - Return the qualified classes to continue for code generation. Return - all of them if there are no classes derived from xs:element or - xs:complexType. + Return the qualified classes for code generation. - Qualifications: - * not an abstract - * type: element | complexType | simpleType with enumerations + Return all if no classes are derived from xs:element or + xs:complexType. """ classes = [item for values in self.class_index.values() for item in values] if any(item.is_complex for item in classes): @@ -128,6 +125,7 @@ def create_substitutions_index(self): self.substitutions_index[qname].append(attr) def find_attr_type(self, source: Class, attr_type: AttrType) -> Optional[Class]: + """Find the source class for the given class and attribute type.""" qname = source.source_qname(attr_type.name) return self.find_class(qname) @@ -142,6 +140,8 @@ def attr_type_is_missing(self, source: Class, attr_type: AttrType) -> bool: def find_attr_simple_type( self, source: Class, attr_type: AttrType ) -> Optional[Class]: + """Find the source class for the given class and attribute type, + excluding enumerations, complex types and self references.""" qname = source.source_qname(attr_type.name) return self.find_class( qname, @@ -151,6 +151,8 @@ def find_attr_simple_type( ) def find_simple_class(self, qname: QName) -> Optional[Class]: + """Find an enumeration or simple type source class for the given + qualified name.""" return self.find_class( qname, condition=lambda x: x.is_enumeration or x.is_simple, ) @@ -158,6 +160,7 @@ def find_simple_class(self, qname: QName) -> Optional[Class]: def find_class( self, qname: QName, condition: Optional[Callable] = None ) -> Optional[Class]: + """Find the flattened source class for the given qualified name.""" candidates = list(filter(condition, self.class_index.get(qname, []))) if candidates: candidate = candidates.pop(0) @@ -171,6 +174,7 @@ def find_class( return None def flatten_classes(self): + """Flatten the class index objects once.""" for classes in self.class_index.values(): for obj in classes: if id(obj) not in self.processed: @@ -237,9 +241,12 @@ def flatten_extension(self, target: Class, extension: Extension): logger.warning("Missing extension type: %s", extension.type.name) target.extensions.remove(extension) - def flatten_extension_native(self, target: Class, ext: Extension) -> None: - if not target.is_enumeration: - return self.create_default_attribute(target, ext) + def flatten_extension_native(self, target: Class, extension: Extension): + """Native type flatten extension handler, ignore enumerations.""" + if target.is_enumeration: + self.copy_extension_type(target, extension) + else: + self.create_default_attribute(target, extension) def flatten_extension_simple(self, source: Class, target: Class, ext: Extension): """ @@ -333,9 +340,6 @@ def flatten_attribute_types(self, target: Class, attr: Attr): elif not current_type.forward_ref: self.flatten_attribute_type(target, attr, current_type) - if target.is_enumeration: - attr.types.extend(ext.type for ext in target.extensions) - attr.types = unique_sequence(attr.types, key="name") def flatten_attribute_type(self, target: Class, attr: Attr, attr_type: AttrType): @@ -375,11 +379,24 @@ def add_substitution_attrs(self, target: Class, attr: Attr): self.add_substitution_attrs(target, clone) def sanitize_classes(self): + """Sanitize the class index objects.""" for classes in self.class_index.values(): for target in classes: self.sanitize_class(target) def sanitize_class(self, target: Class): + """ + Sanitize the attributes of the given class. After applying all the + flattening handlers the attributes need to be further sanitized to + squash common issues like duplicate attribute names. + + Steps: + 1. Sanitize inner classes + 2. Sanitize attributes default value + 3. Sanitize attributes name + 4. Sanitize attributes sequential flag + 5. Sanitize duplicate attribute names + """ for inner in target.inner: self.sanitize_class(inner) @@ -394,6 +411,14 @@ def sanitize_class(self, target: Class): self.sanitize_duplicate_attribute_names(target.attrs) def sanitize_attribute_default_value(self, target: Class, attr: Attr): + """ + Sanitize attribute default value. + + Cases: + 1. List fields can not have a fixed value. + 2. Optional fields or xsi:type can not have a default or fixed value. + 3. Convert string literal default value for enum fields. + """ if attr.is_list: attr.fixed = False @@ -405,6 +430,13 @@ def sanitize_attribute_default_value(self, target: Class, attr: Attr): self.sanitize_attribute_default_enum(target, attr) def sanitize_attribute_default_enum(self, target: Class, attr: Attr): + """ + Convert string literal default value for enum fields. + + Loop through all attributes types and search for enum sources. + If an enum source exist map the default string literal value to + a qualified name. Inner enum references are ignored. + """ for attr_type in attr.types: if attr_type.native: continue diff --git a/xsdata/formats/dataclass/filters.py b/xsdata/formats/dataclass/filters.py index 02589e7a2..aa46048f0 100644 --- a/xsdata/formats/dataclass/filters.py +++ b/xsdata/formats/dataclass/filters.py @@ -22,24 +22,29 @@ @functools.lru_cache(maxsize=50) def class_name(name: str) -> str: + """Apply python conventions for class names.""" return text.pascal_case(utils.safe_snake(name, "type")) @functools.lru_cache(maxsize=50) def attribute_name(name: str) -> str: + """Apply python conventions for instance variable names.""" return text.snake_case(utils.safe_snake(text.suffix(name))) @functools.lru_cache(maxsize=50) def constant_name(name: str) -> str: + """Apply python conventions for constant names.""" return text.snake_case(utils.safe_snake(name)).upper() def type_name(attr_type: AttrType) -> str: + """Return native python type name or apply class name conventions.""" return attr_type.native_name or class_name(text.suffix(attr_type.name)) def attribute_metadata(attr: Attr, parent_namespace: Optional[str]) -> Dict: + """Return a metadata dictionary for the given attribute.""" metadata = dict( name=None if attr.is_nameless or attr.local_name == attribute_name(attr.name) @@ -59,6 +64,8 @@ def attribute_metadata(attr: Attr, parent_namespace: Optional[str]) -> Dict: def format_arguments(data: Dict) -> str: + """Format given dictionary as keyword arguments.""" + def prep(key: str, value: Any) -> str: if isinstance(value, str): value = f'''"{value.replace('"', "'")}"''' @@ -70,6 +77,7 @@ def prep(key: str, value: Any) -> str: def class_docstring(obj: Class, enum: bool = False) -> str: + """Generate docstring for the given class and the constructor arguments.""" lines = [] if obj.help: lines.append(obj.help) @@ -85,6 +93,7 @@ def class_docstring(obj: Class, enum: bool = False) -> str: def default_imports(output: str) -> str: + """Generate the default imports for the given package output.""" result = [] if "Decimal" in output: @@ -113,7 +122,7 @@ def default_imports(output: str) -> str: def attribute_default(attr: Attr, ns_map: Optional[Dict] = None) -> Any: - """Normalize default value/factory by the attribute type.""" + """Generate the field default value/factory for the given attribute.""" if attr.is_list: return "list" if attr.is_map: @@ -160,16 +169,7 @@ def attribute_default(attr: Attr, ns_map: Optional[Dict] = None) -> Any: def attribute_type(attr: Attr, parents: List[str]) -> str: - """ - Normalize attribute type. - - Steps: - * If type alias is present use class name normalization - * Otherwise use the type name normalization - * Prepend outer class names and quote result for forward references - * Wrap the result with List if the field accepts a list of values - * Wrap the result with Optional if the field default value is None - """ + """Generate type hints for the given attribute.""" type_names: List[str] = [] for attr_type in attr.types: diff --git a/xsdata/utils/classes.py b/xsdata/utils/classes.py index 36c142e40..df0abb264 100644 --- a/xsdata/utils/classes.py +++ b/xsdata/utils/classes.py @@ -27,6 +27,8 @@ class ClassUtils: @classmethod def compare_attributes(cls, source: Class, target: Class) -> int: + """Compare the attributes of the two classes and return whether the + source includes all, some or none of the target attributes.""" if source is target: return cls.INCLUDES_ALL @@ -46,6 +48,8 @@ def compare_attributes(cls, source: Class, target: Class) -> int: @classmethod def sanitize_attribute_restrictions(cls, attr: Attr): + """Sanitize attribute required flag by comparing the min/max + occurrences restrictions.""" restrictions = attr.restrictions min_occurs = restrictions.min_occurs or 0 max_occurs = restrictions.max_occurs or 0 @@ -85,6 +89,17 @@ def sanitize_attribute_sequence(cls, attrs: List[Attr], index: int): @classmethod def sanitize_attribute_name(cls, attr: Attr): + """ + Sanitize attribute name in preparation for duplicate attrbute names + handler. + + Steps: + 1. Remove non alpha numerical values + 2. Handle Enum negative numerical values + 2. Remove namespaces prefixes + 3. Ensure name not empty + 4. Ensure name starts with a letter + """ if attr.is_enumeration: attr.name = attr.default if re.match(r"^-\d*\.?\d+$", attr.name): @@ -101,6 +116,18 @@ def sanitize_attribute_name(cls, attr: Attr): @classmethod def sanitize_duplicate_attribute_names(cls, attrs: List[Attr]) -> None: + """ + Sanitize duplicate attribute names that might exist by applying rename + strategies. + + Steps: + 1. If more than two attributes share the same name or if they are + enumerations append a numerical index to the attribute names. + 2. If one of the two fields has a specific namespace prepend it to the + name. If possible rename the second field. + 3. Append the xml type to the name of one of the two attributes. if + possible rename the second field or the field with xml type `attribute`. + """ grouped: Dict[str, List[Attr]] = dict() for attr in attrs: grouped.setdefault(attr.name.lower(), []).append(attr) @@ -133,7 +160,8 @@ def merge_duplicate_attributes(cls, target: Class): Flatten duplicate attributes. Remove duplicate fields in case of attributes or enumerations - otherwise convert fields to lists. + otherwise convert fields to lists. Two attributes are considered + equal if they have the same name and types and namespace. """ if not target.attrs: @@ -163,6 +191,13 @@ def merge_duplicate_attributes(cls, target: Class): @classmethod def copy_attributes(cls, source: Class, target: Class, extension: Extension): + """ + Copy the attributes from the source class to the target class and + remove the extension that links the two classes together. + + The new attributes are prepended in the list unless if they are + supposed to be last in a sequence. + """ prefix = text.prefix(extension.type.name) target.extensions.remove(extension) target_attr_names = {text.suffix(attr.name) for attr in target.attrs} @@ -185,6 +220,13 @@ def copy_attributes(cls, source: Class, target: Class, extension: Extension): def clone_attribute( cls, attr: Attr, restrictions: Restrictions, prefix: Optional[str] = None ) -> Attr: + """ + Clone the given attribute and merge its restrictions with the given + instance. + + Prepend the given namespace prefix to the attribute name if + available. + """ clone = attr.clone() clone.restrictions.merge(restrictions) if prefix: @@ -198,6 +240,14 @@ def clone_attribute( def merge_attribute_type( cls, source: Class, target: Class, attr: Attr, attr_type: AttrType ): + """ + Replace the given attribute type with the types of the single field + source class. + + If the source class has more than one or no fields a warning + will be logged and the target attribute type will change to + simple string. + """ if len(source.attrs) != 1: logger.warning("Missing implementation: %s", source.type.__name__) cls.reset_attribute_type(attr_type) @@ -234,6 +284,15 @@ def copy_inner_classes(cls, source: Class, target: Class): elif not any(existing.name == inner.name for existing in target.inner): target.inner.append(inner) + @classmethod + def copy_extension_type(cls, target: Class, extension: Extension): + """Add the given extension type to all target attributes types and + remove it from the target class extensions.""" + + for attr in target.attrs: + attr.types.append(extension.type) + target.extensions.remove(extension) + @classmethod def merge_redefined_classes(cls, classes: List[Class]): """Merge original and redefined classes.""" @@ -285,6 +344,8 @@ def update_abstract_classes(cls, classes: List[Class]): @classmethod def create_mixed_attribute(cls, target: Class): + """Add an xs:anyType attribute to the given class if it supports mixed + content and doesn't have a wildcard attribute yet.""" if not target.mixed or target.has_wild_attr: return @@ -300,6 +361,8 @@ def create_mixed_attribute(cls, target: Class): @classmethod def create_default_attribute(cls, item: Class, extension: Extension): + """Add a default value field to the given class based on the extension + type.""" if extension.type.native_code == DataType.ANY_TYPE.code: attr = Attr( name="any_element", @@ -327,6 +390,8 @@ def create_default_attribute(cls, item: Class, extension: Extension): @classmethod def create_reference_attribute(cls, source: Class, qname: QName) -> Attr: + """Create an attribute with type that refers to the given source class + and namespaced qualified name.""" prefix = None if qname.namespace != source.source_namespace: prefix = source.source_prefix @@ -344,6 +409,7 @@ def create_reference_attribute(cls, source: Class, qname: QName) -> Attr: @classmethod def find_attribute(cls, attrs: List[Attr], attr: Attr) -> int: + """Return the position of the given attribute in the list.""" try: return attrs.index(attr) except ValueError: @@ -351,6 +417,7 @@ def find_attribute(cls, attrs: List[Attr], attr: Attr) -> int: @classmethod def reset_attribute_type(cls, attr_type: AttrType): + """Reset the attribute type to native string.""" attr_type.name = DataType.STRING.code attr_type.native = True attr_type.self_ref = False