Skip to content

Commit

Permalink
Move function
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanx749 committed Aug 28, 2024
1 parent 2b00721 commit 7f0edc9
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 14 deletions.
16 changes: 2 additions & 14 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,27 +32,15 @@

from data import Vocab
from model import *
from utils import Peptides, fasta2df, predict
from utils import Peptides, fasta2df, predict, dict2df

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.load(args.model, map_location=device)
vocab = Vocab(max_len=25)
df_in = fasta2df(args.input)
vocab = Vocab(max_len=25)
dataset = Peptides(df_in, vocab)
dataloader = DataLoader(dataset, batch_size=32, collate_fn=dataset.collate_fn)
lst = predict(dataloader, model)


def dict2df(lst: dict) -> pd.DataFrame:
label = dict(zip(range(3), ["A", "E", "M"]))
data = {
"score": lst["score_p"],
"epitope": lst["prediction_p"],
"Ig": [label[e] for e in lst["prediction_ig"]],
}
return pd.DataFrame(data)


df_out = dict2df(lst)
df = pd.concat([df_in, df_out], axis=1)
df.to_csv(args.output)
9 changes: 9 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,12 @@ def predict(data: DataLoader, model: nn.Module):
for k in output.keys():
lst[k].extend(output[k].tolist())
return lst

def dict2df(lst: dict) -> pd.DataFrame:
label = dict(zip(range(3), ["A", "E", "M"]))
data = {
"score": lst["score_p"],
"epitope": lst["prediction_p"],
"Ig": [label[e] for e in lst["prediction_ig"]],
}
return pd.DataFrame(data)

0 comments on commit 7f0edc9

Please sign in to comment.