Skip to content

Commit

Permalink
feat(rspr): Save cluster tree in newick format (#189)
Browse files Browse the repository at this point in the history
* Save cluster tree file in newick format

* refactor: Change write_tree to use Phylo.write

Signed-off-by: jvfe <[email protected]>

---------

Signed-off-by: jvfe <[email protected]>
Co-authored-by: KARTIK KAKADIYA <[email protected]>
  • Loading branch information
jvfe and ktkakadiya28 authored Apr 16, 2024
1 parent e47e851 commit cc5c53f
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
22 changes: 16 additions & 6 deletions bin/rspr_heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from matplotlib import pyplot as plt
from matplotlib.colors import LogNorm
import seaborn as sns
import os
import os
import json
from Bio import Phylo
import re
Expand Down Expand Up @@ -42,6 +42,12 @@ def parse_args(args=None):
dest="CLUSTER_OUTPUT",
help="Cluster probability output file name",
)
parser.add_argument(
"-cfo",
"--cluster_file_output",
dest="CLUSTER_FILE_OUTPUT",
help="Cluster probability output newick tree file name",
)
parser.add_argument(
"-mnher",
"--min_heatmap_exact_rspr",
Expand Down Expand Up @@ -231,7 +237,7 @@ def generate_cluster_network(lst_tree_clusters, refer_tree):
print("Generating cluster network")
if not refer_tree:
return

lst_leaves = [leave.name for leave in refer_tree.get_terminals()]
leaf_mapping = {leaf: i for i, leaf in enumerate(lst_leaves)}
dict_clstr_map = defaultdict(int)
Expand Down Expand Up @@ -288,13 +294,15 @@ def generate_cluster_heatmap(lst_tree_clusters, cluster_heatmap_path):
plt.ylabel("Leaves")
plt.savefig(cluster_heatmap_path)


def read_tree(input_path):
with open(input_path, "r") as f:
tree_string = f.read()
formatted = re.sub(r";[^:]+:", ":", tree_string)
return Phylo.read(io.StringIO(formatted), "newick")


def write_tree(output_path, data):
Phylo.write(data, output_path, "newick")

def get_fig_size(refer_tree):
max_fig_size = 100
Expand All @@ -306,7 +314,7 @@ def get_fig_size(refer_tree):


#endregion

def main(args=None):
args = parse_args(args)

Expand All @@ -315,7 +323,7 @@ def main(args=None):
# Generate standard heatmap
results["exact_drSPR"] = pd.to_numeric(results["exact_drSPR"])
make_heatmap(
results,
results,
args.OUTPUT,
args.MIN_HEATMAP_RSPR_DISTANCE,
args.MAX_HEATMAP_RSPR_DISTANCE
Expand All @@ -337,10 +345,12 @@ def main(args=None):
lst_tree_clusters.append(json.loads(str_clstr))

cluster_tree_path = args.CLUSTER_OUTPUT
cluster_file_path = args.CLUSTER_FILE_OUTPUT
refer_tree_path = os.path.join("rooted_reference_tree/core_gene_alignment.tre")
refer_tree = read_tree(refer_tree_path)
if refer_tree:
generate_cluster_network(lst_tree_clusters, refer_tree)
write_tree(cluster_file_path, refer_tree)

plt.rcParams['font.size'] = '12'
fig_size = get_fig_size(refer_tree)
Expand Down
3 changes: 3 additions & 0 deletions modules/local/rspr/heatmap.nf
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ process RSPR_HEATMAP {
path "exact_output.png", emit: png
path "exact_group_output.png", emit: exact_group_output
path "cluster_tree_output.png", emit: cluster_tree_output
path "cluster_file_output.nwk", emit: cluster_file_output

when:
task.ext.when == null || task.ext.when
Expand All @@ -30,6 +31,7 @@ process RSPR_HEATMAP {
-o exact_output.png \\
-go exact_group_output.png \\
-co cluster_tree_output.png \\
-cfo cluster_file_output.nwk \\
--min_heatmap_exact_rspr $min_heatmap_exact_rspr \\
--max_heatmap_exact_rspr $max_heatmap_exact_rspr \\
$args
Expand All @@ -39,5 +41,6 @@ process RSPR_HEATMAP {
touch exact_output.png
touch exact_group_output.png
touch cluster_tree_output.png
touch cluster_file_output.nwk
"""
}

0 comments on commit cc5c53f

Please sign in to comment.