Skip to content

Commit

Permalink
refactor: Save cluster tree with cluster prob (#190)
Browse files Browse the repository at this point in the history
- Cluster probability instead of branch support

Co-authored-by: KARTIK KAKADIYA <[email protected]>
  • Loading branch information
jvfe and ktkakadiya28 authored May 6, 2024
1 parent cc5c53f commit 7fd1ddc
Showing 1 changed file with 30 additions and 3 deletions.
33 changes: 30 additions & 3 deletions bin/rspr_heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,34 @@ def generate_cluster_network(lst_tree_clusters, refer_tree):
update_cluster_probability(refer_tree.root, dict_clstr_map, total_trees, leaf_mapping)


#####################################################################
### FUNCTION REPLACE_CLUSTER_PROBABILITY
### Replace branch support with cluster probability
### node: current node
#####################################################################

def replace_cluster_probability(node):
if not node:
return
node.comment = node.cluster_probability
if not node.is_terminal():
for child in node.clades:
replace_cluster_probability(child)


#####################################################################
### FUNCTION SAVE_CLUSTER_TREE
### Save cluster network tree
### cluster_file_path: path to store cliuster tree
### refer_tree: reference tree
#####################################################################

def save_cluster_tree(cluster_file_path, refer_tree):
print("Saving cluster tree")
replace_cluster_probability(refer_tree.root)
Phylo.write(refer_tree, cluster_file_path, "newick")


#####################################################################
### FUNCTION GENERATE_CLUSTER_HEATMAP
### Generate cluster heatmap
Expand Down Expand Up @@ -301,8 +329,6 @@ def read_tree(input_path):
formatted = re.sub(r";[^:]+:", ":", tree_string)
return Phylo.read(io.StringIO(formatted), "newick")

def write_tree(output_path, data):
Phylo.write(data, output_path, "newick")

def get_fig_size(refer_tree):
max_fig_size = 100
Expand Down Expand Up @@ -350,7 +376,6 @@ def main(args=None):
refer_tree = read_tree(refer_tree_path)
if refer_tree:
generate_cluster_network(lst_tree_clusters, refer_tree)
write_tree(cluster_file_path, refer_tree)

plt.rcParams['font.size'] = '12'
fig_size = get_fig_size(refer_tree)
Expand All @@ -361,5 +386,7 @@ def main(args=None):
plt.title("Cluster network")
plt.savefig(cluster_tree_path, format="png")

save_cluster_tree(cluster_file_path, refer_tree)

if __name__ == "__main__":
sys.exit(main())

0 comments on commit 7fd1ddc

Please sign in to comment.