Skip to content

Commit

Permalink
Cleanup & docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
tefra committed May 14, 2020
1 parent d5553d7 commit 267148a
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 45 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ repos:
hooks:
- id: black
- repo: https://gitlab.com/pycqa/flake8
rev: 3.8.0a2
rev: 3.8.1
hooks:
- id: flake8
additional_dependencies: [flake8-bugbear, flake8-annotations]
Expand Down
24 changes: 4 additions & 20 deletions tests/test_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,24 +292,22 @@ def test_flatten_extension_with_native_type(self, mock_flatten_extension_native)
mock_flatten_extension_native.assert_called_once_with(target, extension)

@mock.patch.object(ClassAnalyzer, "create_default_attribute")
def test_flatten_extension_native_and_target_no_enumeration(
self, mock_create_default_attribute
):
def test_flatten_extension_native(self, mock_create_default_attribute):
extension = ExtensionFactory.create()
target = ClassFactory.elements(1)

self.analyzer.flatten_extension_native(target, extension)
mock_create_default_attribute.assert_called_once_with(target, extension)

@mock.patch.object(ClassAnalyzer, "create_default_attribute")
@mock.patch.object(ClassAnalyzer, "copy_extension_type")
def test_flatten_extension_native_and_target_enumeration(
self, mock_create_default_attribute
self, mock_copy_extension_type
):
extension = ExtensionFactory.create()
target = ClassFactory.enumeration(1)

self.analyzer.flatten_extension_native(target, extension)
self.assertEqual(0, mock_create_default_attribute.call_count)
mock_copy_extension_type.assert_called_once_with(target, extension)

@mock.patch.object(ClassAnalyzer, "flatten_extension_simple")
@mock.patch.object(ClassAnalyzer, "find_simple_class")
Expand Down Expand Up @@ -597,20 +595,6 @@ def test_flatten_attribute_types_ignores_forward_types(

mock_flatten_attribute_type.assert_called_once_with(parent, attr, type_a)

def test_flatten_attribute_types_with_enumeration_target(self):
target = ClassFactory.enumeration(
1,
extensions=[
ExtensionFactory.create(type=AttrTypeFactory.xs_bool()),
ExtensionFactory.create(type=AttrTypeFactory.xs_int()),
],
)

self.assertEqual(1, len(target.attrs[0].types))

self.analyzer.flatten_attribute_types(target, target.attrs[0])
self.assertEqual(3, len(target.attrs[0].types))

def test_flatten_attribute_types_filters_duplicate_types(self):
target = ClassFactory.create(
attrs=[
Expand Down
11 changes: 11 additions & 0 deletions tests/utils/test_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,3 +468,14 @@ def test_copy_inner_classes(self):
self.assertTrue(attr.types[0].self_ref)
self.assertFalse(attr.types[1].self_ref)
self.assertFalse(attr.types[2].self_ref)

def test_copy_extension_type(self):
extension = ExtensionFactory.create()
target = ClassFactory.elements(2)
target.extensions.append(extension)

ClassUtils.copy_extension_type(target, extension)

self.assertEqual(extension.type, target.attrs[0].types[1])
self.assertEqual(extension.type, target.attrs[1].types[1])
self.assertEqual(0, len(target.extensions))
56 changes: 44 additions & 12 deletions xsdata/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,10 @@ def remove_invalid_classes(self, classes: List[Class]):

def fetch_classes_for_generation(self) -> List[Class]:
"""
Return the qualified classes to continue for code generation. Return
all of them if there are no classes derived from xs:element or
xs:complexType.
Return the qualified classes for code generation.
Qualifications:
* not an abstract
* type: element | complexType | simpleType with enumerations
Return all if no classes are derived from xs:element or
xs:complexType.
"""
classes = [item for values in self.class_index.values() for item in values]
if any(item.is_complex for item in classes):
Expand All @@ -128,6 +125,7 @@ def create_substitutions_index(self):
self.substitutions_index[qname].append(attr)

def find_attr_type(self, source: Class, attr_type: AttrType) -> Optional[Class]:
"""Find the source class for the given class and attribute type."""
qname = source.source_qname(attr_type.name)
return self.find_class(qname)

Expand All @@ -142,6 +140,8 @@ def attr_type_is_missing(self, source: Class, attr_type: AttrType) -> bool:
def find_attr_simple_type(
self, source: Class, attr_type: AttrType
) -> Optional[Class]:
"""Find the source class for the given class and attribute type,
excluding enumerations, complex types and self references."""
qname = source.source_qname(attr_type.name)
return self.find_class(
qname,
Expand All @@ -151,13 +151,16 @@ def find_attr_simple_type(
)

def find_simple_class(self, qname: QName) -> Optional[Class]:
"""Find an enumeration or simple type source class for the given
qualified name."""
return self.find_class(
qname, condition=lambda x: x.is_enumeration or x.is_simple,
)

def find_class(
self, qname: QName, condition: Optional[Callable] = None
) -> Optional[Class]:
"""Find the flattened source class for the given qualified name."""
candidates = list(filter(condition, self.class_index.get(qname, [])))
if candidates:
candidate = candidates.pop(0)
Expand All @@ -171,6 +174,7 @@ def find_class(
return None

def flatten_classes(self):
"""Flatten the class index objects once."""
for classes in self.class_index.values():
for obj in classes:
if id(obj) not in self.processed:
Expand Down Expand Up @@ -237,9 +241,12 @@ def flatten_extension(self, target: Class, extension: Extension):
logger.warning("Missing extension type: %s", extension.type.name)
target.extensions.remove(extension)

def flatten_extension_native(self, target: Class, ext: Extension) -> None:
if not target.is_enumeration:
return self.create_default_attribute(target, ext)
def flatten_extension_native(self, target: Class, extension: Extension):
"""Native type flatten extension handler, ignore enumerations."""
if target.is_enumeration:
self.copy_extension_type(target, extension)
else:
self.create_default_attribute(target, extension)

def flatten_extension_simple(self, source: Class, target: Class, ext: Extension):
"""
Expand Down Expand Up @@ -333,9 +340,6 @@ def flatten_attribute_types(self, target: Class, attr: Attr):
elif not current_type.forward_ref:
self.flatten_attribute_type(target, attr, current_type)

if target.is_enumeration:
attr.types.extend(ext.type for ext in target.extensions)

attr.types = unique_sequence(attr.types, key="name")

def flatten_attribute_type(self, target: Class, attr: Attr, attr_type: AttrType):
Expand Down Expand Up @@ -375,11 +379,24 @@ def add_substitution_attrs(self, target: Class, attr: Attr):
self.add_substitution_attrs(target, clone)

def sanitize_classes(self):
"""Sanitize the class index objects."""
for classes in self.class_index.values():
for target in classes:
self.sanitize_class(target)

def sanitize_class(self, target: Class):
"""
Sanitize the attributes of the given class. After applying all the
flattening handlers the attributes need to be further sanitized to
squash common issues like duplicate attribute names.
Steps:
1. Sanitize inner classes
2. Sanitize attributes default value
3. Sanitize attributes name
4. Sanitize attributes sequential flag
5. Sanitize duplicate attribute names
"""
for inner in target.inner:
self.sanitize_class(inner)

Expand All @@ -394,6 +411,14 @@ def sanitize_class(self, target: Class):
self.sanitize_duplicate_attribute_names(target.attrs)

def sanitize_attribute_default_value(self, target: Class, attr: Attr):
"""
Sanitize attribute default value.
Cases:
1. List fields can not have a fixed value.
2. Optional fields or xsi:type can not have a default or fixed value.
3. Convert string literal default value for enum fields.
"""
if attr.is_list:
attr.fixed = False

Expand All @@ -405,6 +430,13 @@ def sanitize_attribute_default_value(self, target: Class, attr: Attr):
self.sanitize_attribute_default_enum(target, attr)

def sanitize_attribute_default_enum(self, target: Class, attr: Attr):
"""
Convert string literal default value for enum fields.
Loop through all attributes types and search for enum sources.
If an enum source exist map the default string literal value to
a qualified name. Inner enum references are ignored.
"""
for attr_type in attr.types:
if attr_type.native:
continue
Expand Down
22 changes: 11 additions & 11 deletions xsdata/formats/dataclass/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,29 @@

@functools.lru_cache(maxsize=50)
def class_name(name: str) -> str:
"""Apply python conventions for class names."""
return text.pascal_case(utils.safe_snake(name, "type"))


@functools.lru_cache(maxsize=50)
def attribute_name(name: str) -> str:
"""Apply python conventions for instance variable names."""
return text.snake_case(utils.safe_snake(text.suffix(name)))


@functools.lru_cache(maxsize=50)
def constant_name(name: str) -> str:
"""Apply python conventions for constant names."""
return text.snake_case(utils.safe_snake(name)).upper()


def type_name(attr_type: AttrType) -> str:
"""Return native python type name or apply class name conventions."""
return attr_type.native_name or class_name(text.suffix(attr_type.name))


def attribute_metadata(attr: Attr, parent_namespace: Optional[str]) -> Dict:
"""Return a metadata dictionary for the given attribute."""
metadata = dict(
name=None
if attr.is_nameless or attr.local_name == attribute_name(attr.name)
Expand All @@ -59,6 +64,8 @@ def attribute_metadata(attr: Attr, parent_namespace: Optional[str]) -> Dict:


def format_arguments(data: Dict) -> str:
"""Format given dictionary as keyword arguments."""

def prep(key: str, value: Any) -> str:
if isinstance(value, str):
value = f'''"{value.replace('"', "'")}"'''
Expand All @@ -70,6 +77,7 @@ def prep(key: str, value: Any) -> str:


def class_docstring(obj: Class, enum: bool = False) -> str:
"""Generate docstring for the given class and the constructor arguments."""
lines = []
if obj.help:
lines.append(obj.help)
Expand All @@ -85,6 +93,7 @@ def class_docstring(obj: Class, enum: bool = False) -> str:


def default_imports(output: str) -> str:
"""Generate the default imports for the given package output."""
result = []

if "Decimal" in output:
Expand Down Expand Up @@ -113,7 +122,7 @@ def default_imports(output: str) -> str:


def attribute_default(attr: Attr, ns_map: Optional[Dict] = None) -> Any:
"""Normalize default value/factory by the attribute type."""
"""Generate the field default value/factory for the given attribute."""
if attr.is_list:
return "list"
if attr.is_map:
Expand Down Expand Up @@ -160,16 +169,7 @@ def attribute_default(attr: Attr, ns_map: Optional[Dict] = None) -> Any:


def attribute_type(attr: Attr, parents: List[str]) -> str:
"""
Normalize attribute type.
Steps:
* If type alias is present use class name normalization
* Otherwise use the type name normalization
* Prepend outer class names and quote result for forward references
* Wrap the result with List if the field accepts a list of values
* Wrap the result with Optional if the field default value is None
"""
"""Generate type hints for the given attribute."""

type_names: List[str] = []
for attr_type in attr.types:
Expand Down
Loading

0 comments on commit 267148a

Please sign in to comment.