From b834d1440612f441904eea4c2077872a2dc51331 Mon Sep 17 00:00:00 2001 From: "shibo.19" Date: Mon, 30 Dec 2024 18:19:50 +0800 Subject: [PATCH] fix msa result dir --- runner/batch_inference.py | 6 +----- runner/msa_search.py | 6 +++++- setup.py | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/runner/batch_inference.py b/runner/batch_inference.py index acebea4..762744e 100644 --- a/runner/batch_inference.py +++ b/runner/batch_inference.py @@ -215,9 +215,7 @@ def inference_jsons( for idx, infer_json in enumerate(tqdm.tqdm(infer_jsons)): try: if use_msa_server: - infer_json = msa_search_update( - infer_json, os.path.join(out_dir, f"msa_res_{idx}") - ) + infer_json = msa_search_update(infer_json, out_dir) elif not contain_msa_res(infer_json): raise RuntimeError(f"can not find msa for {infer_json}") configs["input_json_path"] = infer_json @@ -395,8 +393,6 @@ def msa(input, out_dir) -> Union[str, dict]: :return: """ init_logging() - out_dir = os.path.join(out_dir, uuid.uuid4().hex) - os.makedirs(out_dir, exist_ok=True) logger.info(f"run msa with input={input}, out_dir={out_dir}") if input.endswith(".json"): msa_input_json = msa_search_update(input, out_dir) diff --git a/runner/msa_search.py b/runner/msa_search.py index 4c1008c..c9dfbd2 100644 --- a/runner/msa_search.py +++ b/runner/msa_search.py @@ -79,15 +79,18 @@ def msa_search_update(json_file: str, out_dir: str) -> str: return json_file with open(json_file, "r") as f: input_json_data = json.load(f) + logger.info(f"starting to update msa result for {json_file}") for seq_idx, seq in enumerate(input_json_data): protein_seqs = [] + seq_name = seq.get("name", f"seq_{seq_idx}") for sequence in seq["sequences"]: if "proteinChain" in sequence.keys(): protein_seqs.append(sequence["proteinChain"]["sequence"]) if len(protein_seqs) > 0: protein_seqs = sorted(protein_seqs) msa_res_subdirs = msa_search( - protein_seqs, os.path.join(out_dir, f"msa_seq_{seq_idx}") + protein_seqs, + os.path.join(out_dir, seq_name, "msa_res" f"msa_seq_{seq_idx}"), ) assert len(msa_res_subdirs) == len(msa_res_subdirs), "msa search failed" update_msa_res(seq, dict(zip(protein_seqs, msa_res_subdirs))) @@ -97,4 +100,5 @@ def msa_search_update(json_file: str, out_dir: str) -> str: ) with open(msa_input_json, "w") as f: json.dump(input_json_data, f, indent=4) + logger.info(f"update msa result success and save to {msa_input_json}") return msa_input_json diff --git a/setup.py b/setup.py index 0f1bc51..8353bcd 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ setup( name="protenix", python_requires=">=3.10", - version="0.3.5", + version="0.3.6", description="A trainable PyTorch reproduction of AlphaFold 3.", author="Bytedance Inc.", url="https://github.com/bytedance/Protenix",