From 9dd8d220ddcd43990ee11bb39897ea957a8e2704 Mon Sep 17 00:00:00 2001 From: Jiyang Date: Mon, 4 Dec 2023 13:18:52 -0600 Subject: [PATCH] Add model training code --- python/.gitignore | 31 ++ python/deltr/coditT5/CodeT5.py | 457 +++++++++++++++++++++ python/deltr/coditT5/prediction.py | 88 ++++ python/deltr/coditT5/save_pretrained.py | 75 ++++ python/deltr/coditT5/utils.py | 509 ++++++++++++++++++++++++ 5 files changed, 1160 insertions(+) create mode 100644 python/.gitignore create mode 100644 python/deltr/coditT5/CodeT5.py create mode 100644 python/deltr/coditT5/prediction.py create mode 100644 python/deltr/coditT5/save_pretrained.py create mode 100644 python/deltr/coditT5/utils.py diff --git a/python/.gitignore b/python/.gitignore new file mode 100644 index 0000000..eb0f88f --- /dev/null +++ b/python/.gitignore @@ -0,0 +1,31 @@ +# Temp files + +*~ +\#*\# +.DS_Store +*.class +*.pyc +.pdf + +experiments.log + +target/ + +.idea/ +*.iml + +/raw_data + +# logs +tacc-logs/ + +*.code-workspace + +# downloads dir +_downloads/ +data/ + +repo-data/ + +models/ +.vscode/settings.json diff --git a/python/deltr/coditT5/CodeT5.py b/python/deltr/coditT5/CodeT5.py new file mode 100644 index 0000000..98db972 --- /dev/null +++ b/python/deltr/coditT5/CodeT5.py @@ -0,0 +1,457 @@ +import transformers +from transformers import ( + RobertaTokenizer, + T5ForConditionalGeneration, + T5EncoderModel, +) +from typing import List, Tuple, Dict, Optional, Union, Sequence +from jsonargparse.typing import Path_dc, Path_drw +import os +from pathlib import Path +from seutil import LoggingUtils +import torch +import torch.utils.data +import pytorch_lightning as pl +from pytorch_lightning.utilities.cli import ( + LR_SCHEDULER_REGISTRY, + OPTIMIZER_REGISTRY, + instantiate_class, + SaveConfigCallback, +) +import collections +import numpy as np + +from .utils import ( + DefaultLightningCLI, + ExampleDataset, + PredictDataset, + Prediction, +) +from deltr.Macros import Macros +from deltr.eval.evaluate import compute_bleu_scores +from deltr.collector.diff_utils import EDIT_TOKENS + +from deltr.coditT5.prediction import PredictionWriter + +logger = LoggingUtils.get_logger(__name__, LoggingUtils.INFO) + +MAX_LENGTH = 512 + + +class CodeT5DataModule(pl.LightningDataModule): + def __init__( + self, + dataset: str = "java2cs", + model: str = "CodeT5", + infer_data: str = "test", + batch_size: int = 2, + eval_batch_size: int = 8, + ): + """ + :model_outputs: {model_name: {train: Path, test: Path}} + """ + super().__init__() + + pl.seed_everything(42) + self.data_dir = Macros.data_dir / model / dataset + self.dataset = dataset + self.infer_data = infer_data + self.model = model + self.save_hyperparameters() + logger.info(f"Data Module params: \n{self.hparams}") + + def setup(self, stage: Optional[str] = None): + """Load and encode train/valid/test dataset""" + + self.tokenizer = self.trainer.lightning_module.tokenizer + self.stage = stage + if stage == "fit" or stage is None: + # Process training data + train_source_file = self.data_dir / f"train.{self.dataset}.src" + train_target_file = self.data_dir / f"train.{self.dataset}.tgt" + self.train_dataset = ExampleDataset(train_source_file, train_target_file) + + # Process validatoin data + valid_source_file = self.data_dir / f"valid.{self.dataset}.src" + valid_target_file = self.data_dir / f"valid.{self.dataset}.tgt" + self.valid_dataset = ExampleDataset(valid_source_file, valid_target_file) + + if stage == "predict": + test_source_file = self.data_dir / f"{self.infer_data}.{self.dataset}.src" + test_target_file = self.data_dir / f"{self.infer_data}.{self.dataset}.tgt" + logger.info("Start to process prediction data...") + self.test_dataset = PredictDataset(test_source_file, test_target_file) + + if stage == "validate": + valid_source_file = self.data_dir / f"valid.{self.dataset}.src" + valid_target_file = self.data_dir / f"valid.{self.dataset}.tgt" + self.valid_dataset = ExampleDataset(valid_source_file, valid_target_file) + + def tokenizer_collate_fn( + self, batch_data: List[Tuple[str, str]] + ) -> Sequence[torch.Tensor]: + """Customize collate function""" + source_batch = [self.tokenize_sequence(t[0]) for t in batch_data] + target_batch = [self.tokenize_sequence(t[1]) for t in batch_data] + max_length = MAX_LENGTH + batch_size = len(source_batch) + + batched_input_ids, batched_labels_ids = [], [] + for i in range(batch_size): + batched_input_ids.append( + self.tokenizer.encode( + source_batch[i], + max_length=max_length, + truncation=True, + padding="max_length", + ) + ) + batched_labels_ids.append( + self.tokenizer.encode( + target_batch[i], + max_length=max_length, + truncation=True, + padding="max_length", + ) + ) + + return ( + torch.LongTensor(batched_input_ids), + torch.LongTensor(batched_labels_ids), + ) + + def tokenize_collate_fn_predict(self, batch_data: List[Tuple[str, str, int]]): + + source_batch = [self.tokenize_sequence(t[0]) for t in batch_data] + target_batch = [self.tokenize_sequence(t[1]) for t in batch_data] + index_batch = [t[2] for t in batch_data] + max_length = MAX_LENGTH + batch_size = len(source_batch) + + batched_input_ids, batched_labels_ids, = ( + [], + [], + ) + for i in range(batch_size): + batched_input_ids.append( + self.tokenizer.encode( + source_batch[i], + max_length=max_length, + truncation=True, + padding="longest", + ) + ) + batched_labels_ids.append( + self.tokenizer.encode( + target_batch[i], + max_length=max_length, + truncation=True, + padding="longest", + ) + ) + + return ( + torch.LongTensor(batched_input_ids), + torch.LongTensor(batched_labels_ids), + index_batch, + ) + + def tokenize_sequence(self, seq: str) -> List[str]: + """Given string sequence should be able to be split by space.""" + + space_split_tokens = seq.split() + new_subtokens = [] + for token in space_split_tokens: + new_subtokens += self.tokenizer.tokenize(" " + token) + return new_subtokens + + def train_dataloader(self): + return torch.utils.data.DataLoader( + self.train_dataset, + shuffle=True, + batch_size=self.hparams.batch_size, + num_workers=16, + collate_fn=self.tokenizer_collate_fn, + persistent_workers=True, + ) + + def val_dataloader(self): + return torch.utils.data.DataLoader( + self.valid_dataset, + shuffle=False, + batch_size=self.hparams.batch_size, + num_workers=1, + collate_fn=self.tokenizer_collate_fn, + persistent_workers=True, + ) + + def test_dataloader(self): + return torch.utils.data.DataLoader( + self.test_dataset, + shuffle=False, + batch_size=self.hparams.eval_batch_size, + num_workers=0, + collate_fn=self.tokenizer_collate_fn, + ) + + def predict_dataloader(self): + return torch.utils.data.DataLoader( + self.test_dataset, + shuffle=False, + batch_size=self.hparams.eval_batch_size, + num_workers=0, + collate_fn=self.tokenize_collate_fn_predict, + ) + + +class CodeT5Module(pl.LightningModule): + + # Instantiate the model + def __init__( + self, + pretrained_tokenizer: Union[Path_drw, str], + pretrained_model: Union[Path_drw, str], + optimizer_init: dict, + lr_scheduler_init: dict, + output_dir=None, + skip_special_token_when_generate: bool = True, + beam_size=5, + num_return_sequences=1, + ): + super(CodeT5Module, self).__init__() + + pl.seed_everything(42) + if isinstance(pretrained_tokenizer, Path_drw): + pretrained_tokenizer = os.path.relpath( + Path(pretrained_tokenizer.abs_path), Path.cwd() + ) + if isinstance(pretrained_model, Path_drw): + pretrained_model = os.path.relpath( + Path(pretrained_model.abs_path), Path.cwd() + ) + + self.save_hyperparameters() + self.beam_size = beam_size + self.num_return_sequences = num_return_sequences + + self.tokenizer = RobertaTokenizer.from_pretrained( + self.hparams.pretrained_tokenizer + ) + + self.model = T5ForConditionalGeneration.from_pretrained( + self.hparams.pretrained_model + ) + self.skip_special_token_when_generate = skip_special_token_when_generate + self.model.resize_token_embeddings(len(self.tokenizer)) + logger.info(f"Model Module params: \n{self.hparams}") + + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def configure_optimizers(self): + if "weight_decay" in self.hparams.optimizer_init["init_args"]: + no_decay = ["bias", "LayerNorm.weight"] + parameters = [ + { + "params": [ + p + for n, p in self.named_parameters() + if not any(nd in n for nd in no_decay) + ], + "weight_decay": self.hparams.optimizer_init["init_args"][ + "weight_decay" + ], + }, + { + "params": [ + p + for n, p in self.named_parameters() + if any(nd in n for nd in no_decay) + ], + "weight_decay": 0.0, + }, + ] + else: + parameters = self.parameters() + optimizer = instantiate_class(parameters, self.hparams.optimizer_init) + lr_scheduler = instantiate_class(optimizer, self.hparams.lr_scheduler_init) + return { + "optimizer": optimizer, + "lr_scheduler": lr_scheduler, + } + + def training_step(self, batch: List[torch.Tensor], batch_idx=-1): + inputs, labels = batch + attention_masks = ~(inputs == self.tokenizer.pad_token_id) + outputs = self.model( + inputs, labels=labels, attention_mask=attention_masks, return_dict=True + ) + train_loss = outputs.loss + self.log_dict({"loss/train": train_loss.item()}, on_step=True) + + return train_loss + + def validation_step(self, batch: List[torch.Tensor], batch_idx=-1): + inputs, labels = batch + attention_masks = ~(inputs == self.tokenizer.pad_token_id) + batch_size = inputs.shape[0] + outputs = self.model( + inputs, attention_mask=attention_masks, labels=labels, return_dict=True + ) + val_loss = outputs.loss + output_sequences = self.model.generate( + input_ids=inputs, + attention_mask=attention_masks, + num_beams=5, + num_return_sequences=self.num_return_sequences, + max_length=MAX_LENGTH, + ) + pred_sequences = [] + target_sequences = [] + srcs = [] + for input_ids, output_ids, label in zip(inputs, output_sequences, labels): + pred = self.detokenize(output_ids) + if pred == "": + pred = "" + target = self.detokenize(label) + pred_sequences.append(pred) + target_sequences.append(target) + _, bleu_score_list = compute_bleu_scores(target_sequences, pred_sequences) + if self.trainer.datamodule.stage == "validate": + return pred_sequences + metrics_list = {"bleu/val": bleu_score_list} + metrics_list["loss/val"] = [val_loss.item()] * batch_size + + # log the prediction of model + s = "" + for i in range(batch_size): + s += f"# Example {i}\n\n" + s += f"- gold\n```\n{target_sequences[i]}\n```\n\n" + s += f"- pred\n```\n{pred_sequences[i]}\n```\n\n" + s += f"- metrics\n\n" + for k, v in metrics_list.items(): + s += f"{k}: {v[i]}\n" + s += "\n" + + self.logger.experiment.add_text("examples/val", s, global_step=self.global_step) + # self.logger.log_text( + # key="validation", + # columns=["examples/val"], + # data=[[s]], + # step=self.global_step, + # ) + + return metrics_list + + def predict_step(self, batch: List[torch.Tensor], batch_idx=-1): + inputs, labels, indexs = batch + attention_masks = ~(inputs == self.tokenizer.pad_token_id) + batch_size = inputs.shape[0] + pred_sequences = [] + + output_sequences = self.model.generate( + input_ids=inputs, + attention_mask=attention_masks, + num_beams=self.beam_size, + num_return_sequences=self.num_return_sequences, + max_length=MAX_LENGTH, + ) + + for index, output_ids in zip(indexs, output_sequences): + pred = self.tokenizer.convert_tokens_to_string( + self.post_process_edit_sequences( + self.tokenizer.convert_ids_to_tokens( + output_ids, + skip_special_tokens=self.skip_special_token_when_generate, + ) + ) + ) + pred_sequences.append(Prediction(index, pred)) + + return pred_sequences + + def validation_epoch_end(self, outputs: Union[List[Dict], List[List[str]]]): + dataset_name = self.trainer.datamodule.dataset + if self.trainer.datamodule.stage == "validate": + all_valid_preds = [] + for batch_pred in outputs: + all_valid_preds.extend(batch_pred) + output_file = ( + f"valid.{dataset_name}.hyp" + if self.num_return_sequences == 1 + else f"valid.{dataset_name}.{self.num_return_sequences}.hyp" + ) + with open(f"{self.hparams.output_dir}/{output_file}", "w") as f: + for pred in all_valid_preds: + f.write(f"{pred}\n") + return + metrics_list = collections.defaultdict(list) + for o in outputs: + for k in o: + metrics_list[k] += o[k] + metrics = summarize_metrics(metrics_list) + self.log_dict(metrics) + + def detokenize(self, output_ids: torch.Tensor) -> str: + pred = ( + self.tokenizer.convert_tokens_to_string( + self.post_process_edit_sequences( + self.tokenizer.convert_ids_to_tokens( + output_ids, + skip_special_tokens=self.skip_special_token_when_generate, + ) + ) + ) + .replace("", "") + .replace("", "") + .replace("", "") + ) + return pred + + def save_pretrained(self, save_dir: Union[str, Path, Path_drw, Path_dc]): + if isinstance(save_dir, (Path_drw, Path_dc)): + save_dir = Path(save_dir.abs_path) + self.model.save_pretrained(save_dir) + self.tokenizer.save_pretrained(save_dir) + + def post_process_edit_sequences(self, token_list: List[str]) -> List[str]: + """Post process token list with edit keywords, manually add space.""" + token_list_after_process = [] + for tk in token_list: + if tk in self.tokenizer.additional_special_tokens or tk in EDIT_TOKENS: + token_list_after_process.append(f"Ġ{tk}Ġ") + else: + token_list_after_process.append(tk) + return token_list_after_process + + +def summarize_metrics( + metrics: Dict[str, Union[float, List[float]]], +) -> Dict[str, float]: + metrics_summary = {} + for k, v in metrics.items(): + if isinstance(v, list): + metrics_summary[k] = float(np.mean([float(x) for x in v])) + else: + metrics_summary[k] = float(v) + return metrics_summary + + +if __name__ == "__main__": + LoggingUtils.setup(LoggingUtils.INFO, Macros.log_file) + + OPTIMIZER_REGISTRY.register_classes( + transformers.optimization, torch.optim.Optimizer, override=True + ) + LR_SCHEDULER_REGISTRY.register_classes( + transformers.optimization, torch.optim.lr_scheduler._LRScheduler, override=True + ) + + DefaultLightningCLI( + CodeT5Module, + CodeT5DataModule, + save_config_callback=SaveConfigCallback, + prediction_writer=PredictionWriter, + optimizers=[(None, "optimizer", "model.optimizer_init")], + lr_schedulers=[(None, "lr_scheduler", "model.lr_scheduler_init")], + ) diff --git a/python/deltr/coditT5/prediction.py b/python/deltr/coditT5/prediction.py new file mode 100644 index 0000000..1d1db6a --- /dev/null +++ b/python/deltr/coditT5/prediction.py @@ -0,0 +1,88 @@ +import os +import torch +from pathlib import Path +from jsonargparse.typing import Path_dc, Path_drw +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Type, Union, Any + +import pytorch_lightning as pl +import seutil as su +from pytorch_lightning.callbacks import BasePredictionWriter + +from deltr.eval.evaluate import run_evaluation + +logger = su.LoggingUtils.get_logger(__name__, su.LoggingUtils.DEBUG) + + +class PredictionWriter(BasePredictionWriter): + def __init__( + self, + output_dir: Union[Path, str], + no_compute_metrics: bool = True, + dataset: str = "", + model: str = "", + infer_data: str = "test", + ): + super().__init__(write_interval="epoch") + self.no_compute_metrics = no_compute_metrics + self.output_dir = Path(output_dir) + su.io.mkdir(self.output_dir) + self.temp_dir = self.output_dir / "temp" + su.io.mkdir(self.temp_dir) + self.dataset = dataset + self.model_name = model + self.infer_data = infer_data + + def write_on_epoch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + results: List[List[List[Any]]], + batch_indices: Optional[Sequence[Sequence[Sequence[int]]]], + ): + # Collect preds, and put into a file according to current global rank + + preds: List[str] = [] + for dl_batch_preds in results: + for batch_preds in dl_batch_preds: + if isinstance(batch_preds, list): + for pred in batch_preds: + preds.append(pred) + else: + preds.append(batch_preds) + + su.io.dump( + self.temp_dir / f"{pl_module.global_rank}.pkl", + preds, + ) + + # Wait all processes to finish prediction + trainer.training_type_plugin.barrier("prediction") + + if pl_module.global_rank == 0: + id2pred = {} + for rank in range(trainer.world_size): + for pred in su.io.load(self.temp_dir / f"{rank}.pkl"): + id2pred[pred.id] = pred.data + if sorted(id2pred.keys()) != list(range(len(id2pred))): + logger.warning(f"Prediction ids are not continuous") + preds = [id2pred[i] for i in sorted(id2pred.keys())] + + # Dump predictions + logger.info("Saving predictions") + with open( + self.output_dir / f"{self.infer_data}.{self.dataset}.hyp", "w+" + ) as f: + for pred in preds: + f.write(f"{pred}\n") + + if not self.no_compute_metrics: + # Compute metrics + logger.info("Computing and saving metrics") + + run_evaluation( + dataset=self.dataset, + model=self.model_name, + ) + + # Delete temp directory + su.io.rmdir(self.temp_dir) diff --git a/python/deltr/coditT5/save_pretrained.py b/python/deltr/coditT5/save_pretrained.py new file mode 100644 index 0000000..5091f8d --- /dev/null +++ b/python/deltr/coditT5/save_pretrained.py @@ -0,0 +1,75 @@ +from pathlib import Path +from typing import Optional, Union +from jsonargparse import CLI +from jsonargparse.typing import Path_dc, Path_drw, Path_fr +from seutil import LoggingUtils + +from deltr.Macros import Macros + +logger = LoggingUtils.get_logger(__name__, LoggingUtils.INFO) + + +def locate_ckpt(ckpt_dir: Path) -> Optional[Path]: + ckpt_files = list(ckpt_dir.glob("*.ckpt")) + if len(ckpt_files) == 0: + ckpt_file = None + logger.info(f"No checkpoint found in {ckpt_dir}") + elif len(ckpt_files) == 1: + ckpt_file = ckpt_files[0] + logger.info(f"Found one checkpoint in {ckpt_dir}: {ckpt_file.name}") + else: + ckpt_files = [f for f in ckpt_files if f.name != "last.ckpt"] + ckpt_file = sorted(ckpt_files, key=lambda x: x.stat().st_mtime)[-1] + logger.warning( + f"Multiple checkpoints found in {ckpt_dir}: {[x.name for x in ckpt_files]}; picking the latest modified: {ckpt_file.name}" + ) + return ckpt_file + + +def add_tokens_to_tokenizer(): + + from transformers import RobertaTokenizer + from deltr.collector.diff_utils import EDIT_TOKENS + + lowercase_edit_tokens = [tk.lower() for tk in EDIT_TOKENS] + tokenizer = RobertaTokenizer.from_pretrained("Salesforce/codet5-base") + special_tokens_dict = { + "additional_special_tokens": EDIT_TOKENS + lowercase_edit_tokens + } + tokenizer.add_special_tokens(special_tokens_dict) + print(f"Size of codeT5 tokenizer is {len(tokenizer)}") + tokenizer.save_pretrained(f"{Macros.model_dir}/EditModelTokenizer") + + +def save_pretrained( + model_cls: str, + ckpt_dir: Path_drw, + ckpt_name: str = None, + output_dir: Optional[Union[Path_drw, Path_dc]] = None, +): + ckpt_dir = Path_drw(ckpt_dir) + ckpt_dir = Path(ckpt_dir.abs_path) + if ckpt_name: + ckpt_path = ckpt_dir / ckpt_name + else: + ckpt_path = locate_ckpt(ckpt_dir) + if output_dir is not None: + output_dir = Path(output_dir.abs_path) + else: + output_dir = ckpt_dir + if model_cls == "CodeT5": + from deltr.coditT5.CodeT5 import CodeT5Module + + model = CodeT5Module.load_from_checkpoint(ckpt_path) + model.save_pretrained(output_dir) + elif model_cls == "T5Encoder": + from deltr.coditT5.SeqClassifier import SeqClassifierModule + + model = SeqClassifierModule.load_from_checkpoint(ckpt_path) + model.save_pretrained(output_dir) + else: + raise ValueError(f"Unknown model class: {model_cls}") + + +if __name__ == "__main__": + CLI(add_tokens_to_tokenizer, as_positional=False) diff --git a/python/deltr/coditT5/utils.py b/python/deltr/coditT5/utils.py new file mode 100644 index 0000000..6a71005 --- /dev/null +++ b/python/deltr/coditT5/utils.py @@ -0,0 +1,509 @@ +import os +import datetime +import time +from pathlib import Path +from typing import ( + Any, + Dict, + List, + Optional, + Tuple, + Type, + Union, +) +from tqdm import tqdm +import torch +import numpy as np +from jsonargparse.typing import Path_dc, Path_drw, Path_dw, Path_fc, Path_fr +from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +from pytorch_lightning.utilities.cli import ( + LR_SCHEDULER_REGISTRY, + OPTIMIZER_REGISTRY, + LightningArgumentParser, + LightningCLI, +) +import pytorch_lightning as pl +from recordclass import RecordClass + + +from seutil.LoggingUtils import LoggingUtils +import seutil as su + + +logger = LoggingUtils.get_logger(__name__, LoggingUtils.INFO) + + +class DefaultLightningCLI(LightningCLI): + def __init__( + self, + *args, + optimizers: Optional[ + List[Tuple[Optional[Union[Type, List[Type]]], str, str]] + ] = None, + lr_schedulers: Optional[ + List[Tuple[Optional[Union[Type, List[Type]]], str, str]] + ] = None, + prediction_writer: Optional[Callback] = None, + **kwargs, + ): + self.optimizers = optimizers + self.lr_schedulers = lr_schedulers + self.prediction_writer = prediction_writer + kwargs.setdefault("save_config_overwrite", True) + super().__init__(*args, **kwargs) + + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + super().add_arguments_to_parser(parser) + parser.add_argument( + "--exp_dir", + required=True, + help="Path to experiment directory", + type=Union[Path_drw, Path_dc], + ) + + parser.add_argument( + "--resume", + required=False, + help="When training, what to do if a checkpoint already exists: unset (default) = error; True = resume; False = remove (all existing checkpoints)", + type=bool, + ) + + parser.add_argument( + "--ckpt_name", + required=False, + help="The checkpoint file name to load (under regular ckpt directory); if unset, the latest checkpoint will be loaded", + type=str, + ) + + parser.add_argument( + "--no_compute_metrics", + required=False, + help="When predicting, do not compute metrics and only collect predictions", + type=bool, + default=True, + ) + + parser.add_argument( + "--no_ckpt_ok", + required=False, + help="When predicting, what to do if no checkpoint exists: False (default) = error; True = predict from scratch", + type=bool, + default=False, + ) + + parser.add_argument( + "--output_dir", + required=False, + help="Path to the output directory during prediction", + type=Path_dc, + ) + + parser.add_lightning_class_args(ModelCheckpoint, "ckpt") + parser.set_defaults( + { + "ckpt.save_last": True, + "ckpt.verbose": True, + } + ) + + if self.optimizers is not None: + for types, nested_key, link_to in self.optimizers: + if types is None: + types = OPTIMIZER_REGISTRY.classes + parser.add_optimizer_args(types, nested_key, link_to) + + if self.lr_schedulers is not None: + for types, nested_key, link_to in self.lr_schedulers: + if types is None: + types = LR_SCHEDULER_REGISTRY.classes + parser.add_lr_scheduler_args(types, nested_key, link_to) + + def before_instantiate_classes(self) -> None: + super().before_instantiate_classes() + config = self.config[self.config["subcommand"]] + # In ddp mode, default disable find_unused_parameters + if config["trainer"]["strategy"] == "ddp": + config["trainer"]["strategy"] = pl.plugins.DDPPlugin( + find_unused_parameters=False, + ) + + # # Don't save config in non-fit mode + if self.config["subcommand"] != "fit": + self.save_config_callback = None + + # Set up experiment directory and logger + exp_dir = Path(config["exp_dir"].abs_path).resolve() + + config["trainer"]["default_root_dir"] = os.path.relpath(exp_dir, Path.cwd()) + ckpt_dir = exp_dir / "model" + su.io.mkdir(ckpt_dir) + config["ckpt"]["dirpath"] = os.path.relpath(ckpt_dir, Path.cwd()) + + # locate checkpoint file + if config["ckpt_path"] is None: + if config["ckpt_name"] is not None: + ckpt_file = ckpt_dir / config["ckpt_name"] + else: + ckpt_file = self.locate_ckpt(ckpt_dir, self.config["subcommand"]) + else: + ckpt_file = Path(os.path.abspath(config["ckpt_path"])).resolve() + + if self.config["subcommand"] == "fit": + # If a checkpoint path is specified, assume we want to resume from it + if config["ckpt_path"] is not None or config["ckpt_name"] is not None: + config.setdefault("resume", True) + + # If there is a checkpoint, we must decide what to do with it + if ckpt_file is not None: + if config["resume"] is None: + raise RuntimeError( + f"A checkpoint is present at {ckpt_file}, but I'm not sure what to do with it. Either set `--resume True` to use it or `--resume False` to overwrite it." + ) + elif config["resume"] is True: + logger.info(f"Resuming from checkpoint {ckpt_file}") + config["ckpt_path"] = str(ckpt_file.resolve()) + else: + logger.info(f"Removing checkpoints under {ckpt_dir}") + su.io.mkdir(ckpt_dir, fresh=True) + config["ckpt_path"] = None + + if ( + self.config["subcommand"] == "predict" + or self.config["subcommand"] == "validate" + or self.config["subcommand"] == "test" + ): + if ( + self.config["subcommand"] == "test" + or self.config["subcommand"] == "validate" + ): + config["trainer"]["gpus"] = 1 + config["model"]["output_dir"] = os.path.relpath(exp_dir, Path.cwd()) + if ckpt_file is not None: + config["ckpt_path"] = str(ckpt_file.resolve()) + print("Checkpoint path", config["ckpt_path"]) + else: + if config["no_ckpt_ok"] is False: + raise RuntimeError( + f"No checkpoint found, cannot predict (unless using `--no_ckpt_ok True` to allow predicting from scratch)" + ) + else: + logger.info("No checkpoint found, predicting from scratch") + + if self.prediction_writer is None: + logger.warning( + "No prediction writer specified. " + "Will not write predictions to disk." + ) + elif config["model"]["output_dir"] is None: + logger.warning( + "No output directory specified." + "Will not write predictions to disk." + ) + elif self.config["subcommand"] == "predict": + config["trainer"]["callbacks"].append( + self.prediction_writer( + config["model"]["output_dir"], + config["no_compute_metrics"], + config["data"]["dataset"], + config["data"]["model"], + config["data"]["infer_data"], + ) + ) + + (exp_dir / "logs").mkdir(parents=True, exist_ok=True) + logger_save_dir = exp_dir / "logs" / self.config["subcommand"] + logger_version = datetime.datetime.now().strftime("%y%m%d-%H%M%S") + while (logger_save_dir / logger_version).exists(): + time.sleep(1) + logger_version = datetime.datetime.now().strftime("%y%m%d-%H%M%S") + su.io.mkdir(logger_save_dir) + config["trainer"]["logger"] = { + "class_path": "pytorch_lightning.loggers.tensorboard.TensorBoardLogger", + "init_args": { + "save_dir": os.path.relpath(logger_save_dir, Path.cwd()), + "name": None, + # "project": "delta-translation", + "version": logger_version, + }, + } + + @classmethod + def locate_ckpt(cls, ckpt_dir: Path, mode: str) -> Optional[Path]: + ckpt_files = list(ckpt_dir.glob("*.ckpt")) + if len(ckpt_files) == 0: + ckpt_file = None + logger.info(f"No checkpoint found in {ckpt_dir}") + elif len(ckpt_files) == 1: + ckpt_file = ckpt_files[0] + logger.info(f"Found one checkpoint in {ckpt_dir}: {ckpt_file.name}") + else: + if (ckpt_dir / "last.ckpt").is_file() and mode == "fit": + ckpt_file = ckpt_dir / "last.ckpt" + logger.info( + f"Found the last checkpoint in {ckpt_dir}: {ckpt_file.name}" + ) + else: + for f in ckpt_files: + if f.name == "last.ckpt": + ckpt_files.remove(f) + ckpt_file = sorted(ckpt_files, key=lambda x: x.stat().st_mtime)[-1] + logger.warning( + f"Multiple checkpoints found in {ckpt_dir}: {[x.name for x in ckpt_files]}; picking the latest modified: {ckpt_file.name}" + ) + return ckpt_file + + +class SequenceLabelingDataset(torch.utils.data.Dataset): + "Characterizes a dataset for PyTorch" + + def __init__( + self, + source_file_path: Path, + context_file_path: Path, + label_file_path: Path, + tokenizer: Any, + ): + """Read data from jsonl files.""" + self.source_code = [ + code.strip() + for code in open(source_file_path, "r", encoding="utf-8").readlines() + ] + self.context = [ + ctx.strip() + for ctx in open(context_file_path, "r", encoding="utf-8").readlines() + ] + self.labels = [ + [int(label) for label in lb.strip().split()] + for lb in open(label_file_path, "r", encoding="utf-8").readlines() + ] + self.tokenized_labels = tokenize_and_align_labels( + self.source_code, self.labels, tokenizer + ) + + def __len__(self): + + return len(self.source_code) + + def __getitem__(self, index: int): + + return { + "code": self.source_code[index], + "context": self.context[index], + "labels": self.tokenized_labels[index], + } + + +class SequenceLabelingChunkDataset(torch.utils.data.Dataset): + """Dataset for sequence labeling and chunk the data""" + + def __init__( + self, + source_file_path: Path, + context_file_path: Path, + label_file_path: Path, + tokenizer: Any, + ): + """Read data from jsonl files.""" + + self.JAVA_CHUNK_LEN = 240 + self.CS_CHUNK_LEN = 255 + self.tokenizer = tokenizer + source_code = [ + code.strip() + for code in open(source_file_path, "r", encoding="utf-8").readlines() + ] + context = [ + ctx.strip() + for ctx in open(context_file_path, "r", encoding="utf-8").readlines() + ] + labels = [ + [int(label) for label in lb.strip().split()] + for lb in open(label_file_path, "r", encoding="utf-8").readlines() + ] + tokenized_labels = tokenize_and_align_labels(source_code, labels, tokenizer) + self.__split_data_to_chunks__(source_code, context, tokenized_labels) + + def __len__(self): + + return len(self.tokenized_code_input) + + def __split_data_to_chunks__(self, source_code, context, tokenized_labels): + """Split examples into chunks if too long.""" + + self.tokenized_code_input = [] + self.tokenized_context_input = [] + self.data_index = [] + self.labels = [] + too_long_context = 0 + + for index in tqdm(range(len(source_code)), total=len(source_code)): + tokenized_code = self.tokenizer.tokenize(source_code[index]) + tokenized_context = self.tokenizer.tokenize(context[index]) + tokenized_label = tokenized_labels[index] + assert len(tokenized_code) == len(tokenized_labels[index]) + if ( + len(tokenized_code) + len(tokenized_context) + 1 + > self.tokenizer.model_max_length + ): + # context_length = min(self.MAX_CTX_LEN, len(tokenized_context)) + # if context_length == self.MAX_CTX_LEN: + too_long_context += 1 + # start to cut + code_start_id, code_end_id = 0, 0 + context_start_id, context_end_id = 0, 0 + while code_start_id < len(tokenized_code): + code_end_id = self.CS_CHUNK_LEN + code_start_id + context_end_id = self.JAVA_CHUNK_LEN + context_start_id + self.tokenized_code_input.append( + tokenized_code[code_start_id:code_end_id] + ) + self.tokenized_context_input.append( + tokenized_context[context_start_id:context_end_id] + ) + self.labels.append(tokenized_label[code_start_id:code_end_id]) + self.data_index.append(index) + code_start_id = code_end_id + context_start_id = context_end_id + + else: + self.tokenized_code_input.append(tokenized_code) + self.tokenized_context_input.append(tokenized_context) + self.data_index.append(index) + self.labels.append(tokenized_label) + + + return + + def __getitem__(self, index: int): + + return { + "code": self.tokenized_code_input[index], + "context": self.tokenized_context_input[index], + "labels": self.labels[index], + "index": self.data_index[index], + } + + +def tokenize_and_align_labels( + source_code: List[str], labels: List[int], tokenizer: Any +) -> List[List[int]]: + + tokenized_labels = [] + + for code, label in zip(source_code, labels): + tokenized_inputs = tokenizer( + code.split(), is_split_into_words=True, add_special_tokens=False + ) + word_ids = tokenized_inputs.word_ids() + previous_word_idx = None + label_ids = [] + for word_idx in word_ids: + # Special tokens have a word id that is None. We set the label to -100 so they are automatically + # ignored in the loss function. + if word_idx is None: + label_ids.append(-100) + # We set the label for the first token of each word. + elif word_idx != previous_word_idx: + label_ids.append(label[word_idx]) + # For the other tokens in a word, we set the label to either the current label or -100, depending on + # the label_all_tokens flag. + else: + label_ids.append(-100) + previous_word_idx = word_idx + + tokenized_labels.append(label_ids) + + return tokenized_labels + + +class ExampleDataset(torch.utils.data.Dataset): + def __init__(self, source_file_path: Path, target_file_path: Path): + self.source_file_path = source_file_path + self.target_file_path = target_file_path + self.source_offset = [] + self.target_offset = [] + self.n_data = 0 + + with open(source_file_path, "rb") as fp: + self.source_offset = [0] + while fp.readline(): + self.source_offset.append(fp.tell()) + self.source_offset = self.source_offset[:-1] + + with open(target_file_path, "rb") as fp: + self.target_offset = [0] + while fp.readline(): + self.target_offset.append(fp.tell()) + self.target_offset = self.target_offset[:-1] + + assert len(self.target_offset) == len(self.source_offset) + + self.n_data = len(self.target_offset) + + def __len__(self) -> int: + return self.n_data + + def __getitem__(self, index: int) -> Tuple: + + if index < 0: + index = self.n_data + index + + with open(self.source_file_path, "r", errors="replace") as sf, open( + self.target_file_path, "r", errors="replace" + ) as tf: + sf.seek(self.source_offset[index]) + source_line = sf.readline() + tf.seek(self.target_offset[index]) + target_line = tf.readline() + + return (source_line.strip(), target_line.strip()) + + +class Prediction(RecordClass): + """Prediction at one data""" + + id: int = -1 + data: str = "" + + +class PredictDataset(torch.utils.data.Dataset): + def __init__(self, source_file_path: Path, target_file_path: Path): + self.source_file_path = source_file_path + self.target_file_path = target_file_path + self.source_offset = [] + self.target_offset = [] + self.n_data = 0 + + with open(source_file_path, "rb") as fp: + self.source_offset = [0] + while fp.readline(): + self.source_offset.append(fp.tell()) + self.source_offset = self.source_offset[:-1] + + with open(target_file_path, "rb") as fp: + self.target_offset = [0] + while fp.readline(): + self.target_offset.append(fp.tell()) + self.target_offset = self.target_offset[:-1] + + assert len(self.target_offset) == len(self.source_offset) + + self.n_data = len(self.target_offset) + + def __len__(self) -> int: + return self.n_data + + def __getitem__(self, index: int) -> Tuple: + + if index < 0: + index = self.n_data + index + + with open(self.source_file_path, "r", errors="replace") as sf, open( + self.target_file_path, "r", errors="replace" + ) as tf: + sf.seek(self.source_offset[index]) + source_line = sf.readline() + tf.seek(self.target_offset[index]) + target_line = tf.readline() + return (source_line.strip(), target_line.strip(), index)