From 2a6ad2223a3ae4cc605a49a92b6e71e0acdafcad Mon Sep 17 00:00:00 2001 From: colganwi Date: Wed, 21 Aug 2024 09:10:28 -0400 Subject: [PATCH 1/3] ensure obst contains copy of digraph --- src/treedata/_core/aligned_mapping.py | 2 +- src/treedata/_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/treedata/_core/aligned_mapping.py b/src/treedata/_core/aligned_mapping.py index a0ed464..837d128 100755 --- a/src/treedata/_core/aligned_mapping.py +++ b/src/treedata/_core/aligned_mapping.py @@ -162,7 +162,7 @@ def __setitem__(self, key: str, value: nx.DiGraph): if not self.parent.is_view: self._update_tree_labels() - self._data[key] = value + self._data[key] = value.copy() def __delitem__(self, key: str): for leaf in self._tree_to_leaf[key]: diff --git a/src/treedata/_utils.py b/src/treedata/_utils.py index 1b06768..bd6797d 100755 --- a/src/treedata/_utils.py +++ b/src/treedata/_utils.py @@ -63,7 +63,7 @@ def dict_to_digraph(graph_dict: dict) -> nx.DiGraph: return G -def make_serializable(data: dict) -> dict: +def make_serializable(data) -> dict: """Make a graph dictionary serializable.""" if isinstance(data, dict): return {k: make_serializable(v) for k, v in data.items()} From 5998540a12a20a6dcebe21cff45d68f22bd2bd70 Mon Sep 17 00:00:00 2001 From: colganwi Date: Wed, 21 Aug 2024 09:36:11 -0400 Subject: [PATCH 2/3] fixed tree label bug --- CHANGELOG.md | 8 ++++++++ src/treedata/_core/aligned_mapping.py | 6 +++--- tests/test_base.py | 4 ++++ 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0024437..b844926 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,14 @@ and this project adheres to [Semantic Versioning][]. ### Added +### Changed + +- `obst` and `vart` create local copy of `nx.DiGraphs` that are added (#26) + +### Fixed + +- Fixed bug which caused key to be listed twice in `tree_label` column after value update in `obst` or `vart` (#26) + ## [0.0.2] - 2024-06-18 ### Changed diff --git a/src/treedata/_core/aligned_mapping.py b/src/treedata/_core/aligned_mapping.py index 837d128..08abb23 100755 --- a/src/treedata/_core/aligned_mapping.py +++ b/src/treedata/_core/aligned_mapping.py @@ -83,7 +83,7 @@ def _update_tree_labels(self): if self.parent.allow_overlap: mapping = {k: ",".join(map(str, v)) for k, v in self._leaf_to_tree.items()} else: - mapping = {k: v[0] for k, v in self._leaf_to_tree.items()} + mapping = {k: next(iter(v)) for k, v in self._leaf_to_tree.items()} getattr(self.parent, self.dim)[self.parent._tree_label] = getattr(self.parent, f"{self.dim}_names").map( mapping ) @@ -144,7 +144,7 @@ def __init__( self._axis = axis self._data = {} self._tree_to_leaf = defaultdict(set) - self._leaf_to_tree = defaultdict(list) + self._leaf_to_tree = defaultdict(set) if vals is not None: self.update(vals) @@ -156,7 +156,7 @@ def __setitem__(self, key: str, value: nx.DiGraph): value, leaves = self._validate_tree(value, key) for leaf in leaves: - self._leaf_to_tree[leaf].append(key) + self._leaf_to_tree[leaf].add(key) self._tree_to_leaf[key] = leaves if not self.parent.is_view: diff --git a/tests/test_base.py b/tests/test_base.py index af11085..5580dfd 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -88,6 +88,10 @@ def test_tree_label(X, tree, dim): df = pd.DataFrame({"tree": ["bad", "bad", "bad"]}) with pytest.warns(UserWarning): tdata = td.TreeData(X, label="tree", obs=df, var=df) + # Test tree label with updata + tdata = td.TreeData(X, obst={"0": tree, "1": tree}, label="tree", vart={"0": tree, "1": tree}, allow_overlap=True) + tdata.obst["0"] = tree + assert getattr(tdata, dim).loc["0", "tree"] == "0,1" def test_tree_overlap(X, tree): From e34931513485f9d29a2197c23e5a818442019332 Mon Sep 17 00:00:00 2001 From: colganwi Date: Wed, 21 Aug 2024 09:44:56 -0400 Subject: [PATCH 3/3] sort tree label --- src/treedata/_core/aligned_mapping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/treedata/_core/aligned_mapping.py b/src/treedata/_core/aligned_mapping.py index 08abb23..af6a05f 100755 --- a/src/treedata/_core/aligned_mapping.py +++ b/src/treedata/_core/aligned_mapping.py @@ -81,7 +81,7 @@ def _validate_tree(self, tree: nx.DiGraph, key: str) -> nx.DiGraph: def _update_tree_labels(self): if self.parent._tree_label is not None: if self.parent.allow_overlap: - mapping = {k: ",".join(map(str, v)) for k, v in self._leaf_to_tree.items()} + mapping = {k: ",".join(map(str, sorted(v))) for k, v in self._leaf_to_tree.items()} else: mapping = {k: next(iter(v)) for k, v in self._leaf_to_tree.items()} getattr(self.parent, self.dim)[self.parent._tree_label] = getattr(self.parent, f"{self.dim}_names").map(