diff --git a/bin/rspr_heatmap.py b/bin/rspr_heatmap.py index c608fd9..0cffc92 100755 --- a/bin/rspr_heatmap.py +++ b/bin/rspr_heatmap.py @@ -250,6 +250,34 @@ def generate_cluster_network(lst_tree_clusters, refer_tree): update_cluster_probability(refer_tree.root, dict_clstr_map, total_trees, leaf_mapping) +##################################################################### +### FUNCTION REPLACE_CLUSTER_PROBABILITY +### Replace branch support with cluster probability +### node: current node +##################################################################### + +def replace_cluster_probability(node): + if not node: + return + node.comment = node.cluster_probability + if not node.is_terminal(): + for child in node.clades: + replace_cluster_probability(child) + + +##################################################################### +### FUNCTION SAVE_CLUSTER_TREE +### Save cluster network tree +### cluster_file_path: path to store cliuster tree +### refer_tree: reference tree +##################################################################### + +def save_cluster_tree(cluster_file_path, refer_tree): + print("Saving cluster tree") + replace_cluster_probability(refer_tree.root) + Phylo.write(refer_tree, cluster_file_path, "newick") + + ##################################################################### ### FUNCTION GENERATE_CLUSTER_HEATMAP ### Generate cluster heatmap @@ -301,8 +329,6 @@ def read_tree(input_path): formatted = re.sub(r";[^:]+:", ":", tree_string) return Phylo.read(io.StringIO(formatted), "newick") -def write_tree(output_path, data): - Phylo.write(data, output_path, "newick") def get_fig_size(refer_tree): max_fig_size = 100 @@ -350,7 +376,6 @@ def main(args=None): refer_tree = read_tree(refer_tree_path) if refer_tree: generate_cluster_network(lst_tree_clusters, refer_tree) - write_tree(cluster_file_path, refer_tree) plt.rcParams['font.size'] = '12' fig_size = get_fig_size(refer_tree) @@ -361,5 +386,7 @@ def main(args=None): plt.title("Cluster network") plt.savefig(cluster_tree_path, format="png") + save_cluster_tree(cluster_file_path, refer_tree) + if __name__ == "__main__": sys.exit(main())