From a65c13a4137033e2ed42c8498ece310398be3107 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?B=C3=A1lint=20Aradi?= Date: Mon, 29 May 2023 16:17:36 +0200 Subject: [PATCH] Implement HSD-wrappers to manipulate nested content --- .gitignore | 2 + docs/conf.py | 2 +- docs/introduction.rst | 81 +++++++ src/hsd/__init__.py | 6 +- src/hsd/common.py | 6 - src/hsd/dict.py | 29 ++- src/hsd/io.py | 3 +- src/hsd/wrappers.py | 518 ++++++++++++++++++++++++++++++++++++++++++ test/test_wrappers.py | 190 ++++++++++++++++ 9 files changed, 815 insertions(+), 22 deletions(-) create mode 100644 src/hsd/wrappers.py create mode 100644 test/test_wrappers.py diff --git a/.gitignore b/.gitignore index 15eec4f..ac41eba 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,7 @@ *~ .idea +.env +.vscode *.pyc dist build diff --git a/docs/conf.py b/docs/conf.py index 4705164..89c864f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -56,7 +56,7 @@ # a list of builtin themes. # # html_theme = 'alabaster' -html_theme = 'sphinx_rtd_theme' +html_theme = 'sphinx_book_theme' # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, diff --git a/docs/introduction.rst b/docs/introduction.rst index 38fd346..57260ba 100644 --- a/docs/introduction.rst +++ b/docs/introduction.rst @@ -36,6 +36,9 @@ or into the user space issueing :: Quick tutorial ============== +The basics +---------- + A typical, self-explaining input written in HSD looks like :: driver { @@ -117,3 +120,81 @@ Python :: and then stored again in HSD format :: hsd.dump(hsdinput, "test2.hsd") + + + +Accesing nested data structures via wrappers +-------------------------------------------- + +The hsd module contains lightweight wrappers (``HsdDict``, ``HsdList`` and +``HsdValue``), which offer convenient access to entries in nested data +structures. With the help of these wrappers, nested nodes and values can be +directly accessed using paths. When accessing nested content via wrappers, the +resulting objects will be wrappers themself, wrapping the appropriate parts of +the data structure (and inheriting certain properties of the original wrapper). + +For example, reading and wrapping the example above:: + + import hsd + hsdinp = hsd.wrap(hsd.load("test.hsd")) + +creates an ``HsdDict`` wrapper instance (``hsdinp``), which can be used to query +encapsulated information in the structure:: + + # Reading out the value directly (100) + maxsteps = hsdinp["driver", "conjugate_gradients", "max_steps"].value + + # Storing wrapper (HsdValue) instance and reading out value and the attribute + temp = hsdinp["hamiltonian / dftb / filling / fermi / temperature"] + temp_value = temp.value + temp_unit = temp.attrib + + # Getting a default value, if a given path does not exists: + pot = hsdinp.get_item("hamiltonian / dftb / bias", default=hsd.HsdValue(100, attrib="V")) + + # Setting a value for given path by creating missing parents + hsdinp.set_item("analysis / calculate_forces", True, parents=True) + +As demonstrated above, paths can be specified as tuples or as slash (``/``) joined strings. + +The wrappers also support case-insensitive access. Let's have a look at a +mixed-case example file ``test2.hsd``:: + + Driver { + ConjugateGradients { + MovedAtoms = 1 2 "7:19" + MaxSteps = 100 + } + +We now make copy of the data structure before wrapping it, and make sure that +all keys are converted to lower case, but the original names are saved as +HSD-attributes:: + + hsdinp = hsd.copy(hsd.load("test2.hsd"), lower_names=True, save_names=True) + +This way, paths passed to the Hsd-wrapper are treated in a case-insensitive +way:: + + maxsteps = hsdinp["driver", "CONJUGATEGRADIENTS", "MAXSTEPS"].value + +When adding new items, the access is and remains case in-sensitive, but the +actual form of the name of the new node will be saved. The code snippet:: + + hsdinp["driver", "conjugategradients", "MaxForce"] = hsd.HsdValue(1e-4, attrib="au") + maxforceval = hsdinp["driver", "conjugategradients", "maxforce"] + print(f"{maxforceval.value} {maxforceval.attrib}") + print(hsd.dump_string(hsdinp.value, use_hsd_attribs=True)) + +will result in :: + + 0.0001 au + Driver { + ConjugateGradients { + MovedAtoms = 1 2 "7:19" + MaxSteps = 100 + MaxForce [au] = 0.0001 + } + } + +where the case-convention for ``MaxForce`` is identical to the one used when the +item was created. diff --git a/src/hsd/__init__.py b/src/hsd/__init__.py index 7cad801..f16660b 100644 --- a/src/hsd/__init__.py +++ b/src/hsd/__init__.py @@ -7,12 +7,12 @@ """ Toolbox for reading, writing and manipulating HSD-data. """ -from hsd.common import HSD_ATTRIB_LINE, HSD_ATTRIB_EQUAL, HSD_ATTRIB_SUFFIX,\ - HSD_ATTRIB_NAME, HsdError -from hsd.dict import HsdDictBuilder, HsdDictWalker +from hsd.common import HSD_ATTRIB_LINE, HSD_ATTRIB_EQUAL, HSD_ATTRIB_NAME, HsdError +from hsd.dict import ATTRIB_KEY_SUFFIX, HSD_ATTRIB_KEY_SUFFIX, HsdDictBuilder, HsdDictWalker from hsd.eventhandler import HsdEventHandler, HsdEventPrinter from hsd.formatter import HsdFormatter from hsd.io import load, load_string, dump, dump_string from hsd.parser import HsdParser +from hsd.wrappers import HsdDict, HsdList, HsdValue, copy, wrap __version__ = '0.1' diff --git a/src/hsd/common.py b/src/hsd/common.py index 0f84911..1e0c832 100644 --- a/src/hsd/common.py +++ b/src/hsd/common.py @@ -28,12 +28,6 @@ def unquote(txt): # Name for default attribute (when attribute name is not specified) DEFAULT_ATTRIBUTE = "unit" -# Suffix to mark attribute -ATTRIB_SUFFIX = ".attrib" - -# Suffix to mark hsd processing attributes -HSD_ATTRIB_SUFFIX = ".hsdattrib" - # HSD attribute containing the original tag name HSD_ATTRIB_NAME = "name" diff --git a/src/hsd/dict.py b/src/hsd/dict.py index 847aab0..3e613d8 100644 --- a/src/hsd/dict.py +++ b/src/hsd/dict.py @@ -9,10 +9,17 @@ """ import re from typing import List, Tuple, Union -from hsd.common import HSD_ATTRIB_NAME, np, ATTRIB_SUFFIX, HSD_ATTRIB_SUFFIX, HsdError,\ - QUOTING_CHARS, SPECIAL_CHARS +from hsd.common import HSD_ATTRIB_NAME, np, HsdError, QUOTING_CHARS, SPECIAL_CHARS from hsd.eventhandler import HsdEventHandler, HsdEventPrinter + +# Dictionary key suffix to mark attribute +ATTRIB_KEY_SUFFIX = ".attrib" + +# Dictionary keysuffix to mark hsd processing attributes +HSD_ATTRIB_KEY_SUFFIX = ".hsdattrib" + + _ItemType = Union[float, complex, int, bool, str] _DataType = Union[_ItemType, List[_ItemType]] @@ -130,26 +137,26 @@ def close_tag(self, tagname): parentblock[key] = [{None: prevcont}, self._curblock] if attrib and prevcont is None: - parentblock[key + ATTRIB_SUFFIX] = attrib + parentblock[key + ATTRIB_KEY_SUFFIX] = attrib elif prevcont is not None: - prevattrib = parentblock.get(key + ATTRIB_SUFFIX) + prevattrib = parentblock.get(key + ATTRIB_KEY_SUFFIX) if isinstance(prevattrib, list): prevattrib.append(attrib) else: - parentblock[key + ATTRIB_SUFFIX] = [prevattrib, attrib] + parentblock[key + ATTRIB_KEY_SUFFIX] = [prevattrib, attrib] if self._include_hsd_attribs: if self._lower_tag_names: hsdattrib = {} if hsdattrib is None else hsdattrib hsdattrib[HSD_ATTRIB_NAME] = tagname if prevcont is None: - parentblock[key + HSD_ATTRIB_SUFFIX] = hsdattrib + parentblock[key + HSD_ATTRIB_KEY_SUFFIX] = hsdattrib else: - prevhsdattrib = parentblock.get(key + HSD_ATTRIB_SUFFIX) + prevhsdattrib = parentblock.get(key + HSD_ATTRIB_KEY_SUFFIX) if isinstance(prevhsdattrib, list): prevhsdattrib.append(hsdattrib) else: - parentblock[key + HSD_ATTRIB_SUFFIX] = [prevhsdattrib, hsdattrib] + parentblock[key + HSD_ATTRIB_KEY_SUFFIX] = [prevhsdattrib, hsdattrib] self._curblock = parentblock self._data = None @@ -219,11 +226,11 @@ def walk(self, dictobj): for key, value in dictobj.items(): - if key.endswith(ATTRIB_SUFFIX) or key.endswith(HSD_ATTRIB_SUFFIX): + if key.endswith(ATTRIB_KEY_SUFFIX) or key.endswith(HSD_ATTRIB_KEY_SUFFIX): continue - hsdattrib = dictobj.get(key + HSD_ATTRIB_SUFFIX) - attrib = dictobj.get(key + ATTRIB_SUFFIX) + hsdattrib = dictobj.get(key + HSD_ATTRIB_KEY_SUFFIX) + attrib = dictobj.get(key + ATTRIB_KEY_SUFFIX) if isinstance(value, dict): diff --git a/src/hsd/io.py b/src/hsd/io.py index fe751bc..077bd5c 100644 --- a/src/hsd/io.py +++ b/src/hsd/io.py @@ -6,6 +6,7 @@ """ Provides functionality to dump Python structures to HSD """ +from collections.abc import Mapping import io from typing import Union, TextIO from hsd.dict import HsdDictWalker, HsdDictBuilder @@ -155,7 +156,7 @@ def dump(data: dict, hsdfile: Union[TextIO, str], See :func:`hsd.load_string` for an example. """ - if not isinstance(data, dict): + if not isinstance(data, Mapping): msg = "Invalid object type" raise TypeError(msg) if isinstance(hsdfile, str): diff --git a/src/hsd/wrappers.py b/src/hsd/wrappers.py new file mode 100644 index 0000000..6e6adf6 --- /dev/null +++ b/src/hsd/wrappers.py @@ -0,0 +1,518 @@ +# ------------------------------------------------------------------------------------------------ # +# hsd-python: package for manipulating HSD-formatted data in Python # +# Copyright (C) 2011 - 2023 DFTB+ developers group # +# Licensed under the BSD 2-clause license. # +# ------------------------------------------------------------------------------------------------ # +# +""" +Contains wrappers to make HSD handling comfortable. +""" +from collections.abc import Mapping, MutableMapping, MutableSequence, Sequence +import re +from reprlib import repr +from hsd.dict import ATTRIB_KEY_SUFFIX, HSD_ATTRIB_KEY_SUFFIX + + +_HSD_PATH_SEP_PATTERN = re.compile(r"\s*/\s*") + +_HSD_LIST_INDEX_PATTERN = re.compile(r"^-?\d+$") + +_HSD_PATH_NAME_PATTERN = re.compile(r"^\s*([a-zA-Z]\w*)\s*$") + +_HSD_PATH_SEP = "/" + + +class HsdNode: + """Represents a HSD node with value, attribute and hsd attributes. + + Attributes: + value: Value of the node (read-only). + attrib: Attribute(s) of the node (read-only). + hsdattrib: HSD-attribute(s) of the node (read-only). + """ + + def __init__(self, value, attrib=None, hsdattrib=None): + """Initializes the instance. + + Args: + value: Value to represent (mapping, sequence or leaf node value) + attrib: Attribute(s) associated with the value. + hsdattrib: HSD-attribute(s) associated with the value. + """ + if isinstance(value, self.__class__): + self._value = value._value + else: + self._value = value + self._attrib = attrib + self._hsdattrib = hsdattrib + + @property + def value(self): + return self._value + + @property + def attrib(self): + return self._attrib + + @property + def hsdattrib(self): + return self._hsdattrib + + def __repr__(self): + clsname = self.__class__.__name__ + strrep = ( + f"{clsname}(value={repr(self._value)}, attrib={repr(self._attrib)}, " + f"hsdattrib={repr(self._hsdattrib)})" + ) + return strrep + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return False + return ( + self.attrib == other.attrib + and self.hsdattrib == other.hsdattrib + and self.value == other.value + ) + + +class HsdContainer(HsdNode): + """Represents a container-like HSD-node (mapping or sequence). + + Args: + """ + + def __init__( + self, value, attrib=None, hsdattrib=None, lower_names=False, save_names=False + ): + """Initializes the instance. + + Args: + value: Value to represent (mapping, sequence or leaf node value) + attrib: Attribute(s) associated with the value. + hsdattrib: HSD-attribute(s) associated with the value. + lower_names: Whether names in paths should be lowered when looked up (in order to + support case insensitive searcher). Note, this option only effects queries invoked + directly via this instance. + save_names: If the lower_names setting is active, and an item is set, which name differs + from its lowered form, the original name will be saved as an HSD-attribute, if this + option is turned on. Note, this option only effects queries invoked directly via + this instance. + """ + super().__init__(value, attrib=attrib, hsdattrib=hsdattrib) + self._lower_names = lower_names + self._save_names = save_names + + def __getitem__(self, key): + path = _path_from_key(key) + normpath = _normalized_path(path, self._lower_names) + nodes = _find_path(self._value, normpath) + attrib, hsdattrib = _get_last_attributes(nodes, normpath) + return _hsd_from_node(nodes[-1], attrib, hsdattrib) + + def __setitem__(self, key, value): + self.set_item(key, value, parents=False) + + def __delitem__(self, key): + path = _path_from_key(key) + normpath = _normalized_path(path, self._lower_names) + nodes = _find_path(self._value, normpath[:-1]) + _del_value(nodes, normpath) + + def __len__(self): + return len(self._value) + + def __iter__(self): + return iter(self._value) + + def set_item(self, key, value, parents=False): + """Sets an item at a given key/path. + + Args: + key: Path/key where the item should be stored. + value: Item to store. + parents: Whether missing parents should be created. Creating missing parents is only + possible if the missing path does not contain list indices. + """ + path = _path_from_key(key) + normpath = _normalized_path(path, self._lower_names) + nodes = _find_path(self._value, normpath[:-1], return_partial=parents) + # Note: _find_path() returns root + all nodes in the passed path + if len(nodes) != len(normpath): + for ind in range(len(nodes), len(normpath)): + if isinstance(path[ind - 1], int): + raise KeyError( + "Missing path components must not contain list indices" + ) + newvalue = {} + _set_value( + nodes, + path[:ind], + normpath[:ind], + newvalue, + None, + None, + self._save_names, + ) + nodes.append(newvalue) + value, attrib, hsdattrib = _value_and_attribs(value) + _set_value(nodes, path, normpath, value, attrib, hsdattrib, self._save_names) + + def get_item(self, key, default=None): + """Returns a key or a default value if not found. + + Args: + key: Key / path to look for. + default: Default value to return if no item was found. + + Returns: + Item at given key/path or the default value. + """ + return self.get(key, default=default) + + @classmethod + def copy(cls, source, lower_names=False, save_names=False): + """Creates a new object instance by copying the entries of a suitable object. + + Note: This is a 'semi-deep' copy as all keys are created new, but values are shallow-copied. + + Args: + lower_names: Whether all keys should be converted to lower case during the copy (to + support case insensitive searches). The resulting instance will also be created with + identical lower_names option. + save_names: Whether the original of the converted keys should be stored as + HSD-attributes. The resulting instance will also be created with identical + save_names option. + """ + srcdict, srcattrib, srchsdattrib = _value_and_attribs(source) + value = _hsd_copy(srcdict, lower_names, save_names) + return cls( + value, + attrib=srcattrib, + hsdattrib=srchsdattrib, + lower_names=lower_names, + save_names=save_names, + ) + + +class HsdDict(HsdContainer, MutableMapping): + """HSD wrapper around dictionaries.""" + + @classmethod + def copy(cls, source, lower_names=False, save_names=False): + """Creates a new object instance by copying the entries of a suitable object. + + Note: This is a 'semi-deep' copy as all keys are created new, but values are shallow-copied. + + Args: + lower_names: Whether all keys should be converted to lower case during the copy (to + support case insensitive searches). The resulting instance will also be created with + identical lower_names option. + save_names: Whether the original of the converted keys should be stored as + HSD-attributes. The resulting instance will also be created with identical + save_names option. + """ + if not _is_hsd_dict_compatible(source): + raise TypeError("Source must be of HsdDict compatible type") + return super(cls, cls).copy(source, lower_names=lower_names, save_names=save_names) + + +class HsdList(HsdContainer, MutableSequence): + """HSD wrapper around lists""" + + @classmethod + def copy(cls, source, lower_names=False, save_names=False): + """Creates a new object instance by copying the entries of a suitable object. + + Note: This is a 'semi-deep' copy as all keys are created new, but values are shallow-copied. + + Args: + lower_names: Whether all keys should be converted to lower case during the copy (to + support case insensitive searches). The resulting instance will also be created with + identical lower_names option. + save_names: Whether the original of the converted keys should be stored as + HSD-attributes. The resulting instance will also be created with identical + save_names option. + """ + if not _is_hsd_list_compatible(source): + raise TypeError("Source must be of HsdList compatible type") + return super(cls, cls).copy(source, lower_names=lower_names, save_names=save_names) + + def __setitem__(self, ind, value): + if not _is_hsd_dict_compatible(value): + raise TypeError("HsdList might only contain HsdDict compatible elements") + super().__setitem__(ind, value) + + def insert(self, ind, value): + if not _is_hsd_dict_compatible(value): + raise TypeError("Only HsdDict compatible items can be inserted into HsdLists") + value, attrib, hsdattrib = _value_and_attribs(value) + if self._attrib is None and attrib is not None: + raise ValueError( + "HsdList without attribute can not be extended with a value with attribute" + ) + elif self._attrib is not None: + self._attrib.insert(ind, attrib) + if self._hsdattrib is None and hsdattrib is not None: + raise ValueError( + "HsdList without HSD-attribute can not be extended with a value with HSD-attribute" + ) + elif self._hsdattrib is not None: + self._hsdattrib.insert(ind, hsdattrib) + self._value.insert(ind, value) + + +class HsdValue(HsdNode): + """HSD wrapper around values (leaf nodes)""" + + +def copy(source, lower_names=False, save_names=False): + """Makes a copy of the source and wraps it with an appropriate HSD-container. + + Note: This is a 'semi-deep' copy as all keys are created new, but values are shallow-copied. + + Args: + lower_names: Whether all keys should be converted to lower case during the copy (to + support case insensitive searches). The resulting instance will also be created with + identical lower_names option. + save_names: Whether the original of the converted keys should be stored as + HSD-attributes. The resulting instance will also be created with identical + save_names option. + """ + if _is_hsd_dict_compatible(source): + return HsdDict.copy(source, lower_names=lower_names, save_names=save_names) + elif _is_hsd_list_compatible(source): + return HsdList.copy(source, lower_names=lower_names, save_names=save_names) + raise TypeError("Only HsdDict and HsdList compatible types can be copied") + + +def wrap(source, lower_names=False, save_names=False): + """Wraps an object with an appropriate HSD-container. + + Note: This is a 'semi-deep' copy as all keys are created new, but values are shallow-copied. + + Args: + lower_names: Whether names in paths should be lowered when looked up (in order to + support case insensitive searcher). Note, this option only effects queries invoked + directly via this instance. + save_names: If the lower_names setting is active, and an item is set, which name differs + from its lowered form, the original name will be saved as an HSD-attribute, if this + option is turned on. Note, this option only effects queries invoked directly via + this instance. + """ + if _is_hsd_dict_compatible(source): + return HsdDict(source, lower_names=lower_names, save_names=save_names) + elif _is_hsd_list_compatible(source): + return HsdList(source, lower_names=lower_names, save_names=save_names) + raise TypeError("Only HsdDict and HsdList compatible types can be wrapped") + + +def _path_from_key(key): + """Returns an iterable path from the provided key.""" + if isinstance(key, str): + pathcomps = _HSD_PATH_SEP_PATTERN.split(key) + pathcomps = [ + int(pc) if _HSD_LIST_INDEX_PATTERN.match(pc) else pc for pc in pathcomps + ] + elif isinstance(key, Sequence): + pathcomps = key + if len(pathcomps) == 0: + raise ValueError(f"Key sequence has zero length") + else: + raise TypeError(f"Key '{key}' has invalid type") + + path = [] + for pc in pathcomps: + if isinstance(pc, int): + path.append(pc) + elif match := _HSD_PATH_NAME_PATTERN.match(pc): + path.append(match.group(1)) + else: + raise ValueError(f"Invalid path component {pc!r} in path {key!r}") + return path + + +def _find_path(root, path, return_partial=False): + """Returns list of nodes corresponding to a path relative to (and starting from) root.""" + nodes = [root] + node = root + found = True + key = None + for key in path: + parent = node + try: + node = parent[key] + except (KeyError, IndexError): + found = False + break + nodes.append(node) + if found or return_partial: + return nodes + failedpath = _HSD_PATH_SEP.join([str(p) for p in path[: len(nodes)]]) + # Raise key error, so that the "in" operator can be used with HsdDicts. + raise KeyError(f"Could not find item '{key}' at '{failedpath}'") + + +def _get_last_attributes(nodes, path): + """Returns attribute and hsdattributes associated with the last node in a path.""" + if isinstance(nodes[-2], Sequence): + parent = nodes[-3] + key, ind = path[-2], path[-1] + attrib = parent.get(key + ATTRIB_KEY_SUFFIX) + hsdattrib = parent.get(key + HSD_ATTRIB_KEY_SUFFIX) + attrib = attrib[ind] if attrib is not None else None + hsdattrib = hsdattrib[ind] if hsdattrib is not None else None + else: + parent = nodes[-2] + key = path[-1] + attrib = parent.get(key + ATTRIB_KEY_SUFFIX) + hsdattrib = parent.get(key + HSD_ATTRIB_KEY_SUFFIX) + return attrib, hsdattrib + + +def _hsd_from_node(node, attrib, hsdattrib): + """Returns a HsdNode subclass wrapper depending on the node type.""" + if _is_hsd_dict_compatible(node): + return HsdDict(node, attrib, hsdattrib) + elif _is_hsd_list_compatible(node): + return HsdList(node, attrib, hsdattrib) + else: + return HsdValue(node, attrib, hsdattrib) + + +def _value_and_attribs(obj): + """Returns value, attribute and hsdattributes corresponding to an object.""" + if isinstance(obj, HsdNode): + attrib = obj._attrib + hsdattrib = obj._hsdattrib + value = obj._value + else: + value = obj + attrib = hsdattrib = None + return value, attrib, hsdattrib + + +def _set_value(nodes, path, normpath, value, attrib, hsdattrib, save_names): + """Sets the value of the last node in a path.""" + node = nodes[-1] + node[normpath[-1]] = value + if isinstance(node, Sequence): + ind, normkey = normpath[-1], normpath[-2] + key = path[-2] + parent = nodes[-2] + if attrib is not None: + parent[normkey + ATTRIB_KEY_SUFFIX][ind] = attrib + if save_names: + if normkey != key: + hsdattrib = {} if hsdattrib is None else hsdattrib.copy() + hsdattrib["name"] = key + elif hsdattrib is not None and "name" in hsdattrib: + del hsdattrib["name"] + if hsdattrib is not None: + parent[normkey + HSD_ATTRIB_KEY_SUFFIX][ind] = hsdattrib + elif parent[normkey + HSD_ATTRIB_KEY_SUFFIX] is not None: + parent[normkey + HSD_ATTRIB_KEY_SUFFIX][ind] = None + else: + key = path[-1] + normkey = normpath[-1] + if attrib is not None: + node[normkey + ATTRIB_KEY_SUFFIX] = attrib + if save_names: + if isinstance(value, Sequence): + if normkey != key: + # The container hsdattrib may be None, or it may be a list, which may have None + # elements itself. We handle both cases and return a list of dicts. + tmphsdattrib = [None] * len(value) if hsdattrib is None else hsdattrib + hsdattrib = [ + {} if hsddict is None else hsddict.copy() for hsddict in tmphsdattrib + ] + for hsddict in hsdattrib: + hsddict["name"] = key + elif hsdattrib is not None: + hsdattrib = [ + None if hsddict is None else hsddict.copy() + for hsddict in hsdattrib + ] + for hsddict in hsdattrib: + if hsddict is not None and "name" in hsddict: + del hsddict["name"] + else: + if normkey != key: + hsdattrib = {} if hsdattrib is None else hsdattrib.copy() + hsdattrib["name"] = key + elif hsdattrib is not None and "name" in hsdattrib: + del hsdattrib["name"] + + if hsdattrib is not None: + node[normkey + HSD_ATTRIB_KEY_SUFFIX] = hsdattrib + elif normkey + HSD_ATTRIB_KEY_SUFFIX in node: + del node[normkey + HSD_ATTRIB_KEY_SUFFIX] + + +def _del_value(nodes, path): + """Deletes the last node in a path.""" + node = nodes[-1] + del node[path[-1]] + if isinstance(node, Sequence): + ind, key = path[-1], path[-2] + parent = nodes[-2] + if key + ATTRIB_KEY_SUFFIX in parent: + del parent[key + ATTRIB_KEY_SUFFIX][ind] + if key + HSD_ATTRIB_KEY_SUFFIX in parent: + del parent[key + HSD_ATTRIB_KEY_SUFFIX][ind] + else: + key = path[-1] + if key + ATTRIB_KEY_SUFFIX in node: + del node[key + ATTRIB_KEY_SUFFIX] + if key + HSD_ATTRIB_KEY_SUFFIX in node: + del node[key + HSD_ATTRIB_KEY_SUFFIX] + + +def _hsd_copy(source, lower_names, save_names): + """Copies a HSD-tree recursively (by creating new containers and keys).""" + if _is_hsd_dict_compatible(source): + result = {} + for key, value in source.items(): + if key.endswith(ATTRIB_KEY_SUFFIX) or key.endswith(HSD_ATTRIB_KEY_SUFFIX): + continue + attrib = source.get(key + ATTRIB_KEY_SUFFIX) + hsdattrib = source.get(key + HSD_ATTRIB_KEY_SUFFIX) + newkey = key.lower() if lower_names else key + if save_names and newkey != key: + if _is_hsd_list_compatible(value): + if hsdattrib is None: + hsdattrib = [{} for _ in range(len(value))] + for dd in hsdattrib: + dd["name"] = key + else: + hsdattrib = hsdattrib if hsdattrib is not None else {} + hsdattrib["name"] = key + if attrib is not None: + result[newkey + ATTRIB_KEY_SUFFIX] = attrib + if hsdattrib is not None: + result[newkey + HSD_ATTRIB_KEY_SUFFIX] = hsdattrib + result[newkey] = _hsd_copy( + value, lower_names=lower_names, save_names=save_names + ) + elif _is_hsd_list_compatible(source): + result = [_hsd_copy(item, lower_names, save_names) for item in source] + else: + result = source + return result + + +def _normalized_path(path, lower_names): + """Returns a normalized path.""" + return [ + name.lower() if lower_names and isinstance(name, str) else name for name in path + ] + + +def _is_hsd_dict_compatible(obj): + """Whether an object can be wrapped as HsdDict.""" + return isinstance(obj, Mapping) + + +def _is_hsd_list_compatible(obj): + """Whether an object can be wrapped as HsdList.""" + return isinstance(obj, Sequence) and all([isinstance(item, Mapping) for item in obj]) diff --git a/test/test_wrappers.py b/test/test_wrappers.py new file mode 100644 index 0000000..e84c6c5 --- /dev/null +++ b/test/test_wrappers.py @@ -0,0 +1,190 @@ +#!/bin/env python3 +# ------------------------------------------------------------------------------------------------ # +# hsd-python: package for manipulating HSD-formatted data in Python # +# Copyright (C) 2011 - 2023 DFTB+ developers group # +# Licensed under the BSD 2-clause license. # +# ------------------------------------------------------------------------------------------------ # +# +"""Tests for the hsdwrappers module""" + +import pytest +import numpy as np +import hsd + +_DICT = { + "Ham": { + "Dftb": { + "Scc": True, + "Filling": { + "Fermi": { + "Temp": 100, + "Temp.attrib": "K", + } + }, + "EField": { + "PCharges": [ + {"Coords": np.array([0.0, 1.0, 2.0, 3.0])}, + {"Coords": np.array([0.0, -1.0, 2.0, 3.0])}, + ], + "PCharges.attrib": ["Pointy", "Smeared"], + }, + }, + }, +} + +_HSD_DICT = hsd.HsdDict.copy(_DICT) + +_HSD_DICT_LOW = hsd.HsdDict.copy(_DICT, lower_names=True, save_names=True) + + +def test_tuple_path_access(): + assert _HSD_DICT["Ham", "Dftb", "Scc"].value == True + coords = _HSD_DICT["Ham", "Dftb", "EField", "PCharges", 1, "Coords"].value + assert np.all(np.isclose(coords, np.array([0.0, -1.0, 2.0, 3.0]))) + + +def test_string_path_access(): + assert _HSD_DICT["Ham / Dftb / Scc"].value == True + coords = _HSD_DICT["Ham / Dftb / EField / PCharges / 1 / Coords"].value + assert np.all(np.isclose(coords, np.array([0.0, -1.0, 2.0, 3.0]))) + + +def test_path_failure(): + with pytest.raises(KeyError) as exc: + _HSD_DICT["Ham / dftb / Scc"] + with pytest.raises(KeyError) as exc: + _HSD_DICT["Ham / Dftb / EField / PCharges / 9 / Coords"].value + + +def test_self_equality(): + assert _HSD_DICT == _HSD_DICT + assert _HSD_DICT_LOW == _HSD_DICT_LOW + + +def test_lowered_unequal_original(): + assert _HSD_DICT_LOW != _DICT + + +def test_lowered_access(): + assert _HSD_DICT_LOW["ham", "dftb", "scc"].value == True + assert _HSD_DICT_LOW["Ham", "Dftb", "Scc"].value == True + coords = _HSD_DICT_LOW["ham", "dftb", "efield", "pcharges", 1, "coords"].value + assert np.all(np.isclose(coords, np.array([0.0, -1.0, 2.0, 3.0]))) + coords = _HSD_DICT_LOW["Ham", "Dftb", "EField", "PCharges", 1, "Coords"].value + assert np.all(np.isclose(coords, np.array([0.0, -1.0, 2.0, 3.0]))) + + +def test_attrib(): + assert _HSD_DICT_LOW["Ham", "Dftb", "Filling", "Fermi", "Temp"].attrib == "K" + attribs = _HSD_DICT_LOW["ham / dftb / efield / pcharges"].attrib + assert attribs == ["Pointy", "Smeared"] + assert _HSD_DICT_LOW["ham / dftb / efield / pcharges / 0"].attrib == "Pointy" + assert _HSD_DICT_LOW["ham / dftb / efield / pcharges / 1"].attrib == "Smeared" + + +def test_hsdattrib_name(): + name = _HSD_DICT_LOW["ham"].hsdattrib["name"] + assert name == "Ham" + hattrs = _HSD_DICT_LOW["HAM", "DFTB", "EFIELD", "PCHARGES"].hsdattrib + assert len(hattrs) == 2 + assert hattrs[0]["name"] == "PCharges" + assert hattrs[1]["name"] == "PCharges" + + +def test_setting_value(): + hdict = hsd.HsdDict.copy({"a1": {"b1": 1}}) + hdict["a1 / b1"] = 9 + val = hdict["a1 / b1"] + assert val.value == 9 + assert val.attrib is None + + +def test_setting_hsdvalue(): + hdict = hsd.HsdDict.copy({"a1": {"b1": 1}}) + hdict["a1 / b1"] = hsd.HsdValue(9, "kg") + val = hdict["a1 / b1"] + assert val.value == 9 + assert val.attrib == "kg" + assert hdict.value["a1"]["b1"] == 9 + assert hdict.value["a1"]["b1.attrib"] == "kg" + + +def test_del(): + inp = { + "a1": { + "b1": 1, + "b1.attrib": "K", + "b1.hsdattrib": {"name": "B1"}, + "b2": 2, + }, + } + hdict = hsd.HsdDict.copy(inp) + del hdict["a1 / b1"] + assert hdict == hsd.HsdDict.copy({"a1": {"b2": 2}}) + + del hdict["a1 / b2"] + assert hdict == hsd.HsdDict.copy({"a1": {}}) + + del hdict["a1"] + assert hdict == hsd.HsdDict.copy({}) + + +def test_insert(): + inp = { + "a1": [ + {"b1": 1}, + {"b3": 3}, + ], + "a1.attrib": ["cm", "km"], + "a1.hsdattrib": [{"name": "A1"}, {"name": "A1"}], + } + out = { + "a1": [ + {"b1": 1}, + {"b2": 2}, + {"b3": 3}, + ], + "a1.attrib": ["cm", "pc", "km"], + "a1.hsdattrib": [{"name": "A1"}, {"name": "A1"}, {"name": "A1"}], + } + + hdict = hsd.HsdDict.copy(inp, lower_names=True, save_names=True) + newitem = hsd.HsdDict( + {}, attrib="pc", hsdattrib={"name": "A1"}, lower_names=True, save_names=True + ) + newitem["b2"] = hsd.HsdValue(2) + a1list = hdict["A1"] + a1list.insert(1, newitem) + assert hdict == hsd.HsdDict.copy(out, lower_names=True, save_names=True) + + +def test_list_name_rewriting(): + inp = hsd.HsdDict({}, lower_names=True, save_names=True) + out = { + "a1": [{"b1": 1}, {"b2": 2}], + "a1.hsdattrib": [{}, {}], + } + hsdlist = hsd.HsdList( + [{"b1": 1}, {"b2": 2}], hsdattrib=[{"name": "A1"}, {"name": "A1"}] + ) + inp["a1"] = hsdlist + assert inp == hsd.HsdDict(out) + + +def test_get_item(): + inp = {"a": 1} + hdict = hsd.HsdDict.copy(inp) + assert hdict.get_item("a").value == 1 + assert hdict.get_item("b", default=hsd.HsdValue(23)).value == 23 + assert hdict == hsd.HsdDict(inp) + + +def test_set_item(): + inp = {"a": 1} + hdict = hsd.HsdDict(inp) + hdict.set_item("b", 2) + assert hdict == hsd.HsdDict({"a": 1, "b": 2}) + with pytest.raises(KeyError): + hdict.set_item("c / d", 3) + hdict.set_item("c / d", 3, parents=True) + assert hdict == hsd.HsdDict({"a": 1, "b": 2, "c": {"d": 3}})