Skip to content

Commit

Permalink
Merge pull request #26 from YosefLab/copy-tree
Browse files Browse the repository at this point in the history
Copy tree
  • Loading branch information
colganwi authored Aug 21, 2024
2 parents d1c0461 + e349315 commit 4ac765a
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 6 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@ and this project adheres to [Semantic Versioning][].

### Added

### Changed

- `obst` and `vart` create local copy of `nx.DiGraphs` that are added (#26)

### Fixed

- Fixed bug which caused key to be listed twice in `tree_label` column after value update in `obst` or `vart` (#26)

## [0.0.2] - 2024-06-18

### Changed
Expand Down
10 changes: 5 additions & 5 deletions src/treedata/_core/aligned_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ def _validate_tree(self, tree: nx.DiGraph, key: str) -> nx.DiGraph:
def _update_tree_labels(self):
if self.parent._tree_label is not None:
if self.parent.allow_overlap:
mapping = {k: ",".join(map(str, v)) for k, v in self._leaf_to_tree.items()}
mapping = {k: ",".join(map(str, sorted(v))) for k, v in self._leaf_to_tree.items()}
else:
mapping = {k: v[0] for k, v in self._leaf_to_tree.items()}
mapping = {k: next(iter(v)) for k, v in self._leaf_to_tree.items()}
getattr(self.parent, self.dim)[self.parent._tree_label] = getattr(self.parent, f"{self.dim}_names").map(
mapping
)
Expand Down Expand Up @@ -144,7 +144,7 @@ def __init__(
self._axis = axis
self._data = {}
self._tree_to_leaf = defaultdict(set)
self._leaf_to_tree = defaultdict(list)
self._leaf_to_tree = defaultdict(set)
if vals is not None:
self.update(vals)

Expand All @@ -156,13 +156,13 @@ def __setitem__(self, key: str, value: nx.DiGraph):
value, leaves = self._validate_tree(value, key)

for leaf in leaves:
self._leaf_to_tree[leaf].append(key)
self._leaf_to_tree[leaf].add(key)
self._tree_to_leaf[key] = leaves

if not self.parent.is_view:
self._update_tree_labels()

self._data[key] = value
self._data[key] = value.copy()

def __delitem__(self, key: str):
for leaf in self._tree_to_leaf[key]:
Expand Down
2 changes: 1 addition & 1 deletion src/treedata/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def dict_to_digraph(graph_dict: dict) -> nx.DiGraph:
return G


def make_serializable(data: dict) -> dict:
def make_serializable(data) -> dict:
"""Make a graph dictionary serializable."""
if isinstance(data, dict):
return {k: make_serializable(v) for k, v in data.items()}
Expand Down
4 changes: 4 additions & 0 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ def test_tree_label(X, tree, dim):
df = pd.DataFrame({"tree": ["bad", "bad", "bad"]})
with pytest.warns(UserWarning):
tdata = td.TreeData(X, label="tree", obs=df, var=df)
# Test tree label with updata
tdata = td.TreeData(X, obst={"0": tree, "1": tree}, label="tree", vart={"0": tree, "1": tree}, allow_overlap=True)
tdata.obst["0"] = tree
assert getattr(tdata, dim).loc["0", "tree"] == "0,1"


def test_tree_overlap(X, tree):
Expand Down

0 comments on commit 4ac765a

Please sign in to comment.