From 8e44a4cfb00ec1d124c73aab5d1921f3f96510d0 Mon Sep 17 00:00:00 2001 From: Qiuhui Liu Date: Mon, 16 Aug 2021 11:42:28 +0800 Subject: [PATCH] August 2021 update --- README.md | 8 +- adv/predict/doc/para/predict_doc_para.py | 2 - adv/predict/predict_ape.py | 1 - adv/train/doc/para/train_doc_para.py | 50 +- adv/train/mulang/eff/train_m2o.py | 379 ++++++++++++ adv/train/mulang/eff/train_mulang.py | 379 ++++++++++++ adv/train/mulang/eff/train_mulang_robt.py | 410 +++++++++++++ adv/train/train_ape.py | 48 +- adv/train/train_dynb.py | 52 +- adv/train/train_probe.py | 48 +- cnfg/README.md | 4 + cnfg/base.py | 2 + cnfg/hyp.py | 6 + cnfg/ihyp.py | 52 +- cnfg/mulang.py | 8 + loss/base.py | 8 +- loss/mulang.py | 40 ++ modules/TA.py | 42 +- modules/aan.py | 63 ++ modules/act.py | 11 +- modules/attn/rap.py | 19 +- modules/attn/res.py | 19 +- modules/base.py | 576 +++++++++++++++++-- modules/cpp/act/act.cpp | 7 + modules/cpp/act/act_func.cpp | 29 + modules/cpp/act/act_func.h | 12 + modules/cpp/act/base.h | 45 ++ modules/cpp/act/setup.py | 6 + modules/cpp/base/attn/attn.cpp | 6 + modules/cpp/base/attn/attn_func.cpp | 98 ++++ modules/cpp/base/attn/base.h | 25 + modules/cpp/base/attn/common.cpp | 177 ++++++ modules/cpp/base/attn/cross/attn.cpp | 6 + modules/cpp/base/attn/cross/attn_func.cpp | 58 ++ modules/cpp/base/attn/cross/setup.py | 6 + modules/cpp/base/attn/self/attn.cpp | 6 + modules/cpp/base/attn/self/attn_func.cpp | 90 +++ modules/cpp/base/attn/self/setup.py | 6 + modules/cpp/base/attn/setup.py | 6 + modules/cpp/base/ffn/pff.cpp | 6 + modules/cpp/base/ffn/pff_func.cpp | 45 ++ modules/cpp/base/ffn/pff_func.h | 11 + modules/cpp/base/ffn/setup.py | 6 + modules/cpp/base/resattn/attn.cpp | 6 + modules/cpp/base/resattn/attn_func.cpp | 3 + modules/cpp/base/resattn/common.cpp | 205 +++++++ modules/cpp/base/resattn/cross/attn.cpp | 6 + modules/cpp/base/resattn/cross/attn_func.cpp | 3 + modules/cpp/base/resattn/cross/setup.py | 6 + modules/cpp/base/resattn/self/attn.cpp | 6 + modules/cpp/base/resattn/self/attn_func.cpp | 3 + modules/cpp/base/resattn/self/setup.py | 6 + modules/cpp/base/resattn/setup.py | 6 + modules/cpp/group/group.cpp | 6 + modules/cpp/group/group_func.cpp | 55 ++ modules/cpp/group/group_func.h | 10 + modules/cpp/group/setup.py | 6 + modules/{hplstm/cpp => cpp/hplstm}/lgate.cpp | 14 +- modules/{hplstm/cpp => cpp/hplstm}/setup.py | 2 +- modules/group/base.py | 48 ++ modules/hplstm/LGate.py | 12 +- modules/hplstm/base.py | 51 +- modules/hplstm/hfn.py | 43 +- modules/mulang/__init__.py | 1 + modules/mulang/eff/__init__.py | 1 + modules/mulang/eff/base.py | 143 +++++ modules/noise.py | 66 ++- parallel/base.py | 233 +++++--- parallel/parallelMT.py | 14 +- predict.py | 1 - requirements.txt | 2 +- scripts/README.md | 2 + scripts/ape/mktest.sh | 31 +- scripts/bpe/mk.sh | 2 +- scripts/doc/para/mktest.sh | 29 +- scripts/mktest.sh | 29 +- scripts/mulang/mktest.sh | 62 ++ scripts/mulang/mktrain.sh | 56 ++ scripts/spm/clean.sh | 75 +++ scripts/spm/mk.sh | 69 +++ tools/ape/mkiodata.py | 35 +- tools/check/avg_bsize.py | 37 +- tools/check/doc/para/epoch_steps.py | 39 +- tools/check/dynb/report_dynb.py | 51 +- tools/check/epoch_steps.py | 37 +- tools/check/mulang/cnfg | 1 + tools/check/mulang/eff/cnfg | 1 + tools/check/mulang/eff/epoch_steps.py | 35 ++ tools/check/mulang/eff/utils | 1 + tools/check/mulang/fbindexes.py | 42 ++ tools/check/mulang/utils | 1 + tools/check/para.py | 7 +- tools/check/probe/merge_probe.py | 10 +- tools/check/tspeed.py | 1 - tools/clean/sampler/eff_sampler.py | 7 +- tools/clean/sampler/strict_sampler.py | 9 +- tools/clean/token_repeat.py | 4 +- tools/doc/para/mkiodata.py | 43 +- tools/doc/para/mktest.py | 33 +- tools/doc/sort.py | 4 +- tools/h5/compress.py | 8 +- tools/lsort/merge.py | 11 +- tools/lsort/partsort.py | 10 +- tools/mkiodata.py | 33 +- tools/mktest.py | 27 +- tools/mulang/cnfg | 1 + tools/mulang/eff/cnfg | 1 + tools/mulang/eff/mkiodata.py | 51 ++ tools/mulang/eff/mktest.py | 46 ++ tools/mulang/eff/sort.py | 51 ++ tools/mulang/eff/utils | 1 + tools/mulang/share_vocab.py | 39 ++ tools/mulang/utils | 1 + tools/mulang/vocab.py | 26 + tools/prune_model_vocab.py | 2 +- tools/restore.py | 8 +- tools/shuffle.py | 3 +- tools/sort.py | 4 +- tools/spm/decode.py | 50 ++ tools/spm/encode.py | 65 +++ tools/spm/train.py | 9 + train.py | 56 +- transformer/AGG/HierDecoder.py | 3 +- transformer/AGG/HierEncoder.py | 3 +- transformer/APE/Decoder.py | 43 +- transformer/APE/Encoder.py | 17 +- transformer/APE/NMT.py | 2 +- transformer/AvgDecoder.py | 18 +- transformer/Decoder.py | 51 +- transformer/Doc/Para/Base/Decoder.py | 32 +- transformer/Doc/Para/Base/Encoder.py | 3 +- transformer/Doc/Para/Base/NMT.py | 2 +- transformer/Encoder.py | 32 +- transformer/HPLSTM/Decoder.py | 14 +- transformer/HPLSTM/FNDecoder.py | 15 +- transformer/LD/Decoder.py | 23 +- transformer/LD/Encoder.py | 8 +- transformer/LD/NMT.py | 2 +- transformer/MuLang/Eff/Base/Decoder.py | 276 +++++++++ transformer/MuLang/Eff/Base/Encoder.py | 67 +++ transformer/MuLang/Eff/Base/NMT.py | 45 ++ transformer/MuLang/Eff/Base/__init__.py | 1 + transformer/MuLang/Eff/__init__.py | 1 + transformer/MuLang/__init__.py | 1 + transformer/NMT.py | 2 +- transformer/Probe/Decoder.py | 36 +- transformer/Probe/NMT.py | 2 +- transformer/Probe/ReDecoder.py | 22 +- transformer/Probe/ReNMT.py | 2 +- transformer/RealFormer/Decoder.py | 36 +- transformer/RealFormer/Encoder.py | 15 +- transformer/SC/Decoder.py | 15 +- transformer/SC/NMT.py | 2 +- transformer/TA/Encoder.py | 11 +- translator.py | 2 - utils/aan.py | 2 +- utils/base.py | 88 ++- utils/contpara.py | 130 +++++ utils/cpp/base.h | 93 +++ utils/dynbatch.py | 21 +- utils/fmt/base.py | 37 +- utils/fmt/mulang/__init__.py | 1 + utils/fmt/mulang/eff/__init__.py | 1 + utils/fmt/mulang/eff/dual.py | 54 ++ utils/fmt/mulang/eff/single.py | 49 ++ utils/h5serial.py | 40 +- utils/mulang.py | 42 ++ utils/pyctorch.py | 24 + utils/random.py | 37 ++ utils/torch.py | 87 +++ 170 files changed, 5585 insertions(+), 1017 deletions(-) create mode 100644 adv/train/mulang/eff/train_m2o.py create mode 100644 adv/train/mulang/eff/train_mulang.py create mode 100644 adv/train/mulang/eff/train_mulang_robt.py create mode 100644 cnfg/mulang.py create mode 100644 loss/mulang.py create mode 100644 modules/aan.py create mode 100644 modules/cpp/act/act.cpp create mode 100644 modules/cpp/act/act_func.cpp create mode 100644 modules/cpp/act/act_func.h create mode 100644 modules/cpp/act/base.h create mode 100644 modules/cpp/act/setup.py create mode 100644 modules/cpp/base/attn/attn.cpp create mode 100644 modules/cpp/base/attn/attn_func.cpp create mode 100644 modules/cpp/base/attn/base.h create mode 100644 modules/cpp/base/attn/common.cpp create mode 100644 modules/cpp/base/attn/cross/attn.cpp create mode 100644 modules/cpp/base/attn/cross/attn_func.cpp create mode 100644 modules/cpp/base/attn/cross/setup.py create mode 100644 modules/cpp/base/attn/self/attn.cpp create mode 100644 modules/cpp/base/attn/self/attn_func.cpp create mode 100644 modules/cpp/base/attn/self/setup.py create mode 100644 modules/cpp/base/attn/setup.py create mode 100644 modules/cpp/base/ffn/pff.cpp create mode 100644 modules/cpp/base/ffn/pff_func.cpp create mode 100644 modules/cpp/base/ffn/pff_func.h create mode 100644 modules/cpp/base/ffn/setup.py create mode 100644 modules/cpp/base/resattn/attn.cpp create mode 100644 modules/cpp/base/resattn/attn_func.cpp create mode 100644 modules/cpp/base/resattn/common.cpp create mode 100644 modules/cpp/base/resattn/cross/attn.cpp create mode 100644 modules/cpp/base/resattn/cross/attn_func.cpp create mode 100644 modules/cpp/base/resattn/cross/setup.py create mode 100644 modules/cpp/base/resattn/self/attn.cpp create mode 100644 modules/cpp/base/resattn/self/attn_func.cpp create mode 100644 modules/cpp/base/resattn/self/setup.py create mode 100644 modules/cpp/base/resattn/setup.py create mode 100644 modules/cpp/group/group.cpp create mode 100644 modules/cpp/group/group_func.cpp create mode 100644 modules/cpp/group/group_func.h create mode 100644 modules/cpp/group/setup.py rename modules/{hplstm/cpp => cpp/hplstm}/lgate.cpp (89%) rename modules/{hplstm/cpp => cpp/hplstm}/setup.py (60%) create mode 100644 modules/mulang/__init__.py create mode 100644 modules/mulang/eff/__init__.py create mode 100644 modules/mulang/eff/base.py create mode 100644 scripts/mulang/mktest.sh create mode 100644 scripts/mulang/mktrain.sh create mode 100644 scripts/spm/clean.sh create mode 100644 scripts/spm/mk.sh create mode 120000 tools/check/mulang/cnfg create mode 120000 tools/check/mulang/eff/cnfg create mode 100644 tools/check/mulang/eff/epoch_steps.py create mode 120000 tools/check/mulang/eff/utils create mode 100644 tools/check/mulang/fbindexes.py create mode 120000 tools/check/mulang/utils create mode 120000 tools/mulang/cnfg create mode 120000 tools/mulang/eff/cnfg create mode 100644 tools/mulang/eff/mkiodata.py create mode 100644 tools/mulang/eff/mktest.py create mode 100644 tools/mulang/eff/sort.py create mode 120000 tools/mulang/eff/utils create mode 100644 tools/mulang/share_vocab.py create mode 120000 tools/mulang/utils create mode 100644 tools/mulang/vocab.py create mode 100644 tools/spm/decode.py create mode 100644 tools/spm/encode.py create mode 100644 tools/spm/train.py create mode 100644 transformer/MuLang/Eff/Base/Decoder.py create mode 100644 transformer/MuLang/Eff/Base/Encoder.py create mode 100644 transformer/MuLang/Eff/Base/NMT.py create mode 100644 transformer/MuLang/Eff/Base/__init__.py create mode 100644 transformer/MuLang/Eff/__init__.py create mode 100644 transformer/MuLang/__init__.py create mode 100644 utils/contpara.py create mode 100644 utils/cpp/base.h create mode 100644 utils/fmt/mulang/__init__.py create mode 100644 utils/fmt/mulang/eff/__init__.py create mode 100644 utils/fmt/mulang/eff/dual.py create mode 100644 utils/fmt/mulang/eff/single.py create mode 100644 utils/mulang.py create mode 100644 utils/pyctorch.py create mode 100644 utils/random.py diff --git a/README.md b/README.md index 46adc7a..9505fb1 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ We provide scripts to apply Byte-Pair Encoding (BPE) under `scripts/bpe/`. ### convert plain text to tensors for training -Generate training data for `train.py` with `bash scripts/mktrain.sh`, [configure variables](scripts/README.md#mktrainsh) in `scripts/mktrain.sh` for your usage (the other variables shall comply with those in `scripts/mkbpe.sh`). +Generate training data for `train.py` with `bash scripts/mktrain.sh`, [configure variables](scripts/README.md#mktrainsh) in `scripts/mktrain.sh` for your usage (the other variables shall comply with those in `scripts/bpe/mk.sh`). ## Configuration for training and testing @@ -120,9 +120,3 @@ Details of this project can be found [here](https://arxiv.org/abs/1903.07402), a pdf = {https://arxiv.org/pdf/1903.07402} } ``` - -## Contributor(s) - -## Need more? - -Every details are in those codes, just explore them and make commits ;-) diff --git a/adv/predict/doc/para/predict_doc_para.py b/adv/predict/doc/para/predict_doc_para.py index 4e53378..6ddcf56 100644 --- a/adv/predict/doc/para/predict_doc_para.py +++ b/adv/predict/doc/para/predict_doc_para.py @@ -63,9 +63,7 @@ def load_fixing(module): mymodel = DataParallelMT(mymodel, device_ids=cuda_devices, output_device=cuda_device.index, host_replicate=True, gather_output=False) #num_prev_sent = cnfg.num_prev_sent - beam_size = cnfg.beam_size - length_penalty = cnfg.length_penalty ens = "\n".encode("utf-8") diff --git a/adv/predict/predict_ape.py b/adv/predict/predict_ape.py index 1284f3c..70846e4 100644 --- a/adv/predict/predict_ape.py +++ b/adv/predict/predict_ape.py @@ -63,7 +63,6 @@ def load_fixing(module): mymodel = DataParallelMT(mymodel, device_ids=cuda_devices, output_device=cuda_device.index, host_replicate=True, gather_output=False) beam_size = cnfg.beam_size - length_penalty = cnfg.length_penalty ens = "\n".encode("utf-8") diff --git a/adv/train/doc/para/train_doc_para.py b/adv/train/doc/para/train_doc_para.py index 1bfc768..14c8082 100644 --- a/adv/train/doc/para/train_doc_para.py +++ b/adv/train/doc/para/train_doc_para.py @@ -11,6 +11,7 @@ from utils.base import * from utils.init import init_model_params +from utils.contpara import get_model_parameters from utils.h5serial import h5save, h5load from utils.fmt.base import tostr, save_states, load_states, pad_id from utils.fmt.base4torch import parse_cuda, load_emb @@ -22,9 +23,6 @@ from tqdm import tqdm -from os import makedirs -from os.path import exists as p_check - import h5py import cnfg.docpara as cnfg @@ -77,7 +75,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok _done_tokens += wd_add if _done_tokens >= tokens_optm: - optm_step(optm, model=model, scaler=scaler, multi_gpu=multi_gpu, multi_gpu_optimizer=multi_gpu_optimizer) + optm_step(optm, model=model, scaler=scaler, multi_gpu=multi_gpu, multi_gpu_optimizer=multi_gpu_optimizer, zero_grad_none=optm_step_zero_grad_set_none) _done_tokens = 0 if _cur_rstep is not None: if save_checkp_epoch and (save_every is not None) and (_cur_rstep % save_every == 0) and (chkpf is not None) and (_cur_rstep > 0): @@ -90,7 +88,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok else: _chkpf = chkpf _chkpof = chkpof - save_model(model, _chkpf, multi_gpu, logger) + save_model(model, _chkpf, multi_gpu, print_func=logger.info) if chkpof is not None: h5save(optm.state_dict(), _chkpof) if statesf is not None: @@ -124,7 +122,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok else: _chkpf = chkpf _chkpof = chkpof - save_model(model, _chkpf, multi_gpu, logger) + save_model(model, _chkpf, multi_gpu, print_func=logger.info) if chkpof is not None: h5save(optm.state_dict(), _chkpof) if statesf is not None: @@ -183,32 +181,23 @@ def load_fixing(module): module.fix_load() rid = cnfg.run_id - earlystop = cnfg.earlystop - maxrun = cnfg.maxrun - tokens_optm = cnfg.tokens_optm - done_tokens = 0 - batch_report = cnfg.batch_report report_eva = cnfg.report_eva - use_ams = cnfg.use_ams - save_optm_state = cnfg.save_optm_state - +save_auto_clean = cnfg.save_auto_clean +overwrite_eva = cnfg.overwrite_eva save_every = cnfg.save_every start_chkp_save = cnfg.epoch_start_checkpoint_save - epoch_save = cnfg.epoch_save - remain_steps = cnfg.training_steps wkdir = "".join((cnfg.exp_dir, cnfg.data_id, "/", cnfg.group_id, "/", rid, "/")) -if not p_check(wkdir): - makedirs(wkdir) +mkdir(wkdir) chkpf = None chkpof = None @@ -270,7 +259,7 @@ def load_fixing(module): lossf.to(cuda_device) optimizer = Optimizer(filter_para_grad(mymodel.parameters()), lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams) -optimizer.zero_grad(set_to_none=True) +optimizer.zero_grad(set_to_none=optm_step_zero_grad_set_none) use_amp = cnfg.use_amp and use_cuda scaler = (MultiGPUGradScaler() if multi_gpu_optimizer else GradScaler()) if use_amp else None @@ -279,12 +268,11 @@ def load_fixing(module): mymodel = DataParallelMT(mymodel, device_ids=cuda_devices, output_device=cuda_device.index, host_replicate=True, gather_output=False) lossf = DataParallelCriterion(lossf, device_ids=cuda_devices, output_device=cuda_device.index, replicate_once=True) -if multi_gpu_optimizer: - optimizer = mymodel.build_optimizer(Optimizer, lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams) - mymodel.zero_grad(set_to_none=True) +if multi_gpu: + optimizer = mymodel.build_optimizer(Optimizer, lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams, multi_gpu_optimizer=multi_gpu_optimizer, contiguous_parameters=contiguous_parameters) else: - optimizer = Optimizer((mymodel.module if multi_gpu else mymodel).parameters(), lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams) - optimizer.zero_grad(set_to_none=True) + optimizer = Optimizer(get_model_parameters(mymodel, contiguous_parameters=contiguous_parameters), lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams) +optimizer.zero_grad(set_to_none=optm_step_zero_grad_set_none) fine_tune_state = cnfg.fine_tune_state if fine_tune_state is not None: @@ -302,16 +290,16 @@ def load_fixing(module): logger.info("".join(("Init lr: ", ",".join(tostr(getlr(optimizer))), ", Dev Loss/Error: %.3f %.2f" % (minloss, minerr)))) if fine_tune_m is None: - save_model(mymodel, wkdir + "init.h5", multi_gpu, logger) + save_model(mymodel, wkdir + "init.h5", multi_gpu, print_func=logger.info) logger.info("Initial model saved") else: cnt_states = cnfg.train_statesf - if (cnt_states is not None) and p_check(cnt_states): + if cnt_states is not None: logger.info("Continue last epoch") tminerr, done_tokens, cur_checkid, remain_steps, _ = train(td, load_states(cnt_states), vd, vl, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, tokens_optm, batch_report, save_every, chkpf, chkpof, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, False, False, scaler) vloss, vprec = eva(vd, vl, mymodel, lossf, cuda_device, multi_gpu, use_amp) logger.info("Epoch: 0, train loss: %.3f, valid loss/error: %.3f %.2f" % (tminerr, vloss, vprec)) - save_model(mymodel, wkdir + "train_0_%.3f_%.3f_%.2f.h5" % (tminerr, vloss, vprec), multi_gpu, logger) + save_model(mymodel, wkdir + "train_0_%.3f_%.3f_%.2f.h5" % (tminerr, vloss, vprec), multi_gpu, print_func=logger.info, mtyp=("eva" if overwrite_eva else "train") if save_auto_clean else None) if save_optm_state: h5save(optimizer.state_dict(), wkdir + "train_0_%.3f_%.3f_%.2f.optm.h5" % (tminerr, vloss, vprec)) logger.info("New best model saved") @@ -340,7 +328,7 @@ def load_fixing(module): logger.info("Epoch: %d, train loss: %.3f, valid loss/error: %.3f %.2f" % (i, terr, vloss, vprec)) if (vprec <= minerr) or (vloss <= minloss): - save_model(mymodel, wkdir + "eva_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec), multi_gpu, logger) + save_model(mymodel, wkdir + "eva_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec), multi_gpu, print_func=logger.info, mtyp="eva" if save_auto_clean else None) if save_optm_state: h5save(optimizer.state_dict(), wkdir + "eva_%d_%.3f_%.3f_%.2f.optm.h5" % (i, terr, vloss, vprec)) logger.info("New best model saved") @@ -355,11 +343,11 @@ def load_fixing(module): else: if terr < tminerr: tminerr = terr - save_model(mymodel, wkdir + "train_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec), multi_gpu, logger) + save_model(mymodel, wkdir + "train_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec), multi_gpu, print_func=logger.info, mtyp=("eva" if overwrite_eva else "train") if save_auto_clean else None) if save_optm_state: h5save(optimizer.state_dict(), wkdir + "train_%d_%.3f_%.3f_%.2f.optm.h5" % (i, terr, vloss, vprec)) elif epoch_save: - save_model(mymodel, wkdir + "epoch_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec), multi_gpu, logger) + save_model(mymodel, wkdir + "epoch_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec), multi_gpu, print_func=logger.info) namin += 1 if namin >= earlystop: @@ -385,7 +373,7 @@ def load_fixing(module): if done_tokens > 0: optm_step(optimizer, model=mymodel, scaler=scaler, multi_gpu=multi_gpu, multi_gpu_optimizer=multi_gpu_optimizer) -save_model(mymodel, wkdir + "last.h5", multi_gpu, logger) +save_model(mymodel, wkdir + "last.h5", multi_gpu, print_func=logger.info) if save_optm_state: h5save(optimizer.state_dict(), wkdir + "last.optm.h5") logger.info("model saved") diff --git a/adv/train/mulang/eff/train_m2o.py b/adv/train/mulang/eff/train_m2o.py new file mode 100644 index 0000000..5fd4a75 --- /dev/null +++ b/adv/train/mulang/eff/train_m2o.py @@ -0,0 +1,379 @@ +#encoding: utf-8 + +import torch +from torch.cuda.amp import autocast, GradScaler + +from torch.optim import Adam as Optimizer + +from parallel.base import DataParallelCriterion +from parallel.parallelMT import DataParallelMT +from parallel.optm import MultiGPUGradScaler + +from utils.base import * +from utils.init import init_model_params +from utils.contpara import get_model_parameters +from utils.h5serial import h5save, h5load +from utils.fmt.base import tostr, save_states, load_states, pad_id +from utils.fmt.base4torch import parse_cuda, load_emb +from utils.mulang import data_sampler + +from lrsch import GoogleLR as LRScheduler +from loss.base import LabelSmoothingLoss + +from random import shuffle + +from tqdm import tqdm + +import h5py + +import cnfg.mulang as cnfg +from cnfg.ihyp import * + +from transformer.NMT import NMT + +def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tokens, multi_gpu, multi_gpu_optimizer, tokens_optm=32768, nreport=None, save_every=None, chkpf=None, chkpof=None, statesf=None, num_checkpoint=1, cur_checkid=0, report_eva=True, remain_steps=None, save_loss=False, save_checkp_epoch=False, scaler=None): + + sum_loss = part_loss = 0.0 + sum_wd = part_wd = 0 + _done_tokens, _cur_checkid, _cur_rstep, _use_amp = done_tokens, cur_checkid, remain_steps, scaler is not None + model.train() + cur_b, _ls = 1, {} if save_loss else None + for i_d, taskid in tqdm(tl, mininterval=tqdm_mininterval): + task_grp = td[str(taskid)] + seq_batch = torch.from_numpy(task_grp["src"][i_d][:]) + seq_o = torch.from_numpy(task_grp["tgt"][i_d][:]) + lo = seq_o.size(1) - 1 + if mv_device: + seq_batch = seq_batch.to(mv_device) + seq_o = seq_o.to(mv_device) + seq_batch, seq_o = seq_batch.long(), seq_o.long() + + oi = seq_o.narrow(1, 0, lo) + ot = seq_o.narrow(1, 1, lo).contiguous() + with autocast(enabled=_use_amp): + output = model(seq_batch, oi) + loss = lossf(output, ot) + if multi_gpu: + loss = loss.sum() + loss_add = loss.data.item() + + if scaler is None: + loss.backward() + else: + scaler.scale(loss).backward() + + wd_add = ot.ne(pad_id).int().sum().item() + loss = output = oi = ot = seq_batch = seq_o = None + sum_loss += loss_add + if save_loss: + _ls[(i_d, t_d)] = loss_add / wd_add + sum_wd += wd_add + _done_tokens += wd_add + + if _done_tokens >= tokens_optm: + optm_step(optm, model=model, scaler=scaler, multi_gpu=multi_gpu, multi_gpu_optimizer=multi_gpu_optimizer, zero_grad_none=optm_step_zero_grad_set_none) + _done_tokens = 0 + if _cur_rstep is not None: + if save_checkp_epoch and (save_every is not None) and (_cur_rstep % save_every == 0) and (chkpf is not None) and (_cur_rstep > 0): + if num_checkpoint > 1: + _fend = "_%d.h5" % (_cur_checkid) + _chkpf = chkpf[:-3] + _fend + if chkpof is not None: + _chkpof = chkpof[:-3] + _fend + _cur_checkid = (_cur_checkid + 1) % num_checkpoint + else: + _chkpf = chkpf + _chkpof = chkpof + save_model(model, _chkpf, multi_gpu, print_func=logger.info) + if chkpof is not None: + h5save(optm.state_dict(), _chkpof) + if statesf is not None: + save_states(statesf, tl[cur_b - 1:]) + _cur_rstep -= 1 + if _cur_rstep <= 0: + break + lrsch.step() + + if nreport is not None: + part_loss += loss_add + part_wd += wd_add + if cur_b % nreport == 0: + if report_eva: + _leva, _eeva = eva(ed, nd, model, lossf, mv_device, multi_gpu, _use_amp) + logger.info("Average loss over %d tokens: %.3f, valid loss/error: %.3f %.2f" % (part_wd, part_loss / part_wd, _leva, _eeva,)) + free_cache(mv_device) + model.train() + else: + logger.info("Average loss over %d tokens: %.3f" % (part_wd, part_loss / part_wd,)) + part_loss = 0.0 + part_wd = 0 + + if save_checkp_epoch and (_cur_rstep is None) and (save_every is not None) and (cur_b % save_every == 0) and (chkpf is not None) and (cur_b < ndata): + if num_checkpoint > 1: + _fend = "_%d.h5" % (_cur_checkid) + _chkpf = chkpf[:-3] + _fend + if chkpof is not None: + _chkpof = chkpof[:-3] + _fend + _cur_checkid = (_cur_checkid + 1) % num_checkpoint + else: + _chkpf = chkpf + _chkpof = chkpof + save_model(model, _chkpf, multi_gpu, print_func=logger.info) + if chkpof is not None: + h5save(optm.state_dict(), _chkpof) + if statesf is not None: + save_states(statesf, tl[cur_b - 1:]) + cur_b += 1 + if part_wd != 0.0: + logger.info("Average loss over %d tokens: %.3f" % (part_wd, part_loss / part_wd,)) + return sum_loss / sum_wd, _done_tokens, _cur_checkid, _cur_rstep, _ls + +def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): + r = w = 0 + sum_loss = 0.0 + model.eval() + with torch.no_grad(): + for i_d, taskid in tqdm(nd, mininterval=tqdm_mininterval): + task_grp = ed[str(taskid)] + seq_batch = torch.from_numpy(task_grp["src"][i_d][:]) + seq_o = torch.from_numpy(task_grp["tgt"][i_d][:]) + lo = seq_o.size(1) - 1 + if mv_device: + seq_batch = seq_batch.to(mv_device) + seq_o = seq_o.to(mv_device) + seq_batch, seq_o = seq_batch.long(), seq_o.long() + ot = seq_o.narrow(1, 1, lo).contiguous() + with autocast(enabled=use_amp): + output = model(seq_batch, seq_o.narrow(1, 0, lo)) + loss = lossf(output, ot) + if multi_gpu: + loss = loss.sum() + trans = torch.cat([outu.argmax(-1).to(mv_device) for outu in output], 0) + else: + trans = output.argmax(-1) + sum_loss += loss.data.item() + data_mask = ot.ne(pad_id) + correct = (trans.eq(ot) & data_mask).int() + w += data_mask.int().sum().item() + r += correct.sum().item() + correct = data_mask = trans = loss = output = ot = seq_batch = seq_o = None + w = float(w) + return sum_loss / w, (w - r) / w * 100.0 + +def hook_lr_update(optm, flags=None): + + reset_Adam(optm, flags) + +def init_fixing(module): + + if hasattr(module, "fix_init"): + module.fix_init() + +def load_fixing(module): + + if hasattr(module, "fix_load"): + module.fix_load() + +rid = cnfg.run_id +earlystop = cnfg.earlystop +maxrun = cnfg.maxrun +tokens_optm = cnfg.tokens_optm +done_tokens = 0 +batch_report = cnfg.batch_report +report_eva = cnfg.report_eva +use_ams = cnfg.use_ams +save_optm_state = cnfg.save_optm_state +save_auto_clean = cnfg.save_auto_clean +overwrite_eva = cnfg.overwrite_eva +save_every = cnfg.save_every +start_chkp_save = cnfg.epoch_start_checkpoint_save +epoch_save = cnfg.epoch_save +remain_steps = cnfg.training_steps + +wkdir = "".join((cnfg.exp_dir, cnfg.data_id, "/", cnfg.group_id, "/", rid, "/")) +mkdir(wkdir) + +chkpf = None +chkpof = None +statesf = None +if save_every is not None: + chkpf = wkdir + "checkpoint.h5" + if save_optm_state: + chkpof = wkdir + "checkpoint.optm.h5" + if cnfg.save_train_state: + statesf = wkdir + "checkpoint.states" + +logger = get_logger(wkdir + "train.log") + +use_cuda, cuda_device, cuda_devices, multi_gpu = parse_cuda(cnfg.use_cuda, cnfg.gpuid) +multi_gpu_optimizer = multi_gpu and cnfg.multi_gpu_optimizer + +set_random_seed(cnfg.seed, use_cuda) + +td = h5py.File(cnfg.train_data, "r") +vd = h5py.File(cnfg.dev_data, "r") + +ntrain = td["ndata"][:].tolist() +nvalid = vd["ndata"][:].tolist() +nword = td["nword"][:].tolist() +nwordi, ntask, nwordt = nword[0], nword[1], nword[-1] + +logger.info("Design models with seed: %d" % torch.initial_seed()) +mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) + +fine_tune_m = cnfg.fine_tune_m +task_weight, task_weight_T = cnfg.task_weight, cnfg.task_weight_T +if task_weight_T is None or task_weight_T == 1.0: + tl = [(str(i), _task,) for _nd, _task in zip(ntrain, td["taskorder"][:].tolist()) for i in range(_nd)] + train_sampler = None +else: + train_taskorder = td["taskorder"][:].tolist() + _tnd = dict(zip(train_taskorder, ntrain)) + train_taskorder.sort() + ntrain = [_tnd[i] for i in train_taskorder] + _tnd = None + train_sampler = data_sampler(ntrain if task_weight is None else task_weight, task_weight_T, ntrain, train_taskorder, nsample=sum(ntrain)) +nvalid = [(str(i), _task,) for _nd, _task in zip(nvalid, vd["taskorder"][:].tolist()) for i in range(_nd)] + +mymodel = init_model_params(mymodel) +mymodel.apply(init_fixing) +if fine_tune_m is not None: + logger.info("Load pre-trained model from: " + fine_tune_m) + mymodel = load_model_cpu(fine_tune_m, mymodel) + mymodel.apply(load_fixing) + +lossf = LabelSmoothingLoss(nwordt, cnfg.label_smoothing, ignore_index=pad_id, reduction='sum', forbidden_index=cnfg.forbidden_indexes) + +if cnfg.src_emb is not None: + logger.info("Load source embedding from: " + cnfg.src_emb) + load_emb(cnfg.src_emb, mymodel.enc.wemb.weight, nwordi, cnfg.scale_down_emb, cnfg.freeze_srcemb) +if cnfg.tgt_emb is not None: + logger.info("Load target embedding from: " + cnfg.tgt_emb) + load_emb(cnfg.tgt_emb, mymodel.dec.wemb.weight, nwordt, cnfg.scale_down_emb, cnfg.freeze_tgtemb) + +if cuda_device: + mymodel.to(cuda_device) + lossf.to(cuda_device) + +use_amp = cnfg.use_amp and use_cuda +scaler = (MultiGPUGradScaler() if multi_gpu_optimizer else GradScaler()) if use_amp else None + +if multi_gpu: + mymodel = DataParallelMT(mymodel, device_ids=cuda_devices, output_device=cuda_device.index, host_replicate=True, gather_output=False) + lossf = DataParallelCriterion(lossf, device_ids=cuda_devices, output_device=cuda_device.index, replicate_once=True) + +if multi_gpu: + optimizer = mymodel.build_optimizer(Optimizer, lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams, multi_gpu_optimizer=multi_gpu_optimizer, contiguous_parameters=contiguous_parameters) +else: + optimizer = Optimizer(get_model_parameters(mymodel, contiguous_parameters=contiguous_parameters), lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams) +optimizer.zero_grad(set_to_none=optm_step_zero_grad_set_none) + +fine_tune_state = cnfg.fine_tune_state +if fine_tune_state is not None: + logger.info("Load optimizer state from: " + fine_tune_state) + optimizer.load_state_dict(h5load(fine_tune_state)) + +lrsch = LRScheduler(optimizer, cnfg.isize, cnfg.warm_step, scale=cnfg.lr_scale) + +num_checkpoint = cnfg.num_checkpoint +cur_checkid = 0 + +tminerr = inf_default + +minloss, minerr = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp) +logger.info("Init lr: %s, Dev Loss/Error: %.3f %.2f" % (" ".join(tostr(getlr(optimizer))), minloss, minerr,)) + +if fine_tune_m is None: + save_model(mymodel, wkdir + "init.h5", multi_gpu, print_func=logger.info) + logger.info("Initial model saved") +else: + cnt_states = cnfg.train_statesf + if cnt_states is not None: + logger.info("Continue last epoch") + tminerr, done_tokens, cur_checkid, remain_steps, _ = train(td, load_states(cnt_states), vd, nvalid, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, multi_gpu_optimizer, tokens_optm, batch_report, save_every, chkpf, chkpof, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, False, False, scaler) + vloss, vprec = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp) + logger.info("Epoch: 0, train loss: %.3f, valid loss/error: %.3f %.2f" % (tminerr, vloss, vprec,)) + save_model(mymodel, wkdir + "train_0_%.3f_%.3f_%.2f.h5" % (tminerr, vloss, vprec,), multi_gpu, print_func=logger.info, mtyp=("eva" if overwrite_eva else "train") if save_auto_clean else None) + if save_optm_state: + h5save(optimizer.state_dict(), wkdir + "train_0_%.3f_%.3f_%.2f.optm.h5" % (tminerr, vloss, vprec,)) + logger.info("New best model saved") + +if cnfg.dss_ws is not None and cnfg.dss_ws > 0.0 and cnfg.dss_ws < 1.0: + dss_ws = int(cnfg.dss_ws * sum(ntrain)) + _Dws = {} + _prev_Dws = {} + _crit_inc = {} + if cnfg.dss_rm is not None and cnfg.dss_rm > 0.0 and cnfg.dss_rm < 1.0: + dss_rm = int(cnfg.dss_rm * sum(ntrain) * (1.0 - cnfg.dss_ws)) + else: + dss_rm = 0 +else: + dss_ws = 0 + dss_rm = 0 + _Dws = None + +namin = 0 + +for i in range(1, maxrun + 1): + if train_sampler is None: + shuffle(tl) + else: + tl = train_sampler.generate() + free_cache(use_cuda) + terr, done_tokens, cur_checkid, remain_steps, _Dws = train(td, tl, vd, nvalid, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, multi_gpu_optimizer, tokens_optm, batch_report, save_every, chkpf, chkpof, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, dss_ws > 0, i >= start_chkp_save, scaler) + vloss, vprec = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp) + logger.info("Epoch: %d, train loss: %.3f, valid loss/error: %.3f %.2f" % (i, terr, vloss, vprec,)) + + if (vprec <= minerr) or (vloss <= minloss): + save_model(mymodel, wkdir + "eva_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec,), multi_gpu, print_func=logger.info, mtyp="eva" if save_auto_clean else None) + if save_optm_state: + h5save(optimizer.state_dict(), wkdir + "eva_%d_%.3f_%.3f_%.2f.optm.h5" % (i, terr, vloss, vprec,)) + logger.info("New best model saved") + + namin = 0 + + if vprec < minerr: + minerr = vprec + if vloss < minloss: + minloss = vloss + + else: + if terr < tminerr: + tminerr = terr + save_model(mymodel, wkdir + "train_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec,), multi_gpu, print_func=logger.info, mtyp=("eva" if overwrite_eva else "train") if save_auto_clean else None) + if save_optm_state: + h5save(optimizer.state_dict(), wkdir + "train_%d_%.3f_%.3f_%.2f.optm.h5" % (i, terr, vloss, vprec,)) + elif epoch_save: + save_model(mymodel, wkdir + "epoch_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec,), multi_gpu, print_func=logger.info) + + namin += 1 + if namin >= earlystop: + if done_tokens > 0: + optm_step(optimizer, model=mymodel, scaler=scaler, multi_gpu=multi_gpu, multi_gpu_optimizer=multi_gpu_optimizer) + done_tokens = 0 + logger.info("early stop") + break + + if remain_steps is not None and remain_steps <= 0: + logger.info("Last training step reached") + break + + if dss_ws > 0: + if _prev_Dws and (train_sampler is None): + for _key, _value in _Dws.items(): + if _key in _prev_Dws: + _ploss = _prev_Dws[_key] + _crit_inc[_key] = (_ploss - _value) / _ploss + tl = dynamic_sample(_crit_inc, dss_ws, dss_rm) + _prev_Dws = _Dws + +if done_tokens > 0: + optm_step(optimizer, model=mymodel, scaler=scaler, multi_gpu=multi_gpu, multi_gpu_optimizer=multi_gpu_optimizer) + +save_model(mymodel, wkdir + "last.h5", multi_gpu, print_func=logger.info) +if save_optm_state: + h5save(optimizer.state_dict(), wkdir + "last.optm.h5") +logger.info("model saved") + +td.close() +vd.close() diff --git a/adv/train/mulang/eff/train_mulang.py b/adv/train/mulang/eff/train_mulang.py new file mode 100644 index 0000000..ff74a62 --- /dev/null +++ b/adv/train/mulang/eff/train_mulang.py @@ -0,0 +1,379 @@ +#encoding: utf-8 + +import torch +from torch.cuda.amp import autocast, GradScaler + +from torch.optim import Adam as Optimizer + +from parallel.base import DataParallelCriterion +from parallel.parallelMT import DataParallelMT +from parallel.optm import MultiGPUGradScaler + +from utils.base import * +from utils.init import init_model_params +from utils.contpara import get_model_parameters +from utils.h5serial import h5save, h5load +from utils.fmt.base import tostr, save_states, load_states, pad_id +from utils.fmt.base4torch import parse_cuda, load_emb +from utils.mulang import data_sampler + +from lrsch import GoogleLR as LRScheduler +from loss.base import MultiLabelSmoothingLoss as LabelSmoothingLoss + +from random import shuffle + +from tqdm import tqdm + +import h5py + +import cnfg.mulang as cnfg +from cnfg.ihyp import * + +from transformer.MuLang.Eff.Base.NMT import NMT + +def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tokens, multi_gpu, multi_gpu_optimizer, tokens_optm=32768, nreport=None, save_every=None, chkpf=None, chkpof=None, statesf=None, num_checkpoint=1, cur_checkid=0, report_eva=True, remain_steps=None, save_loss=False, save_checkp_epoch=False, scaler=None): + + sum_loss = part_loss = 0.0 + sum_wd = part_wd = 0 + _done_tokens, _cur_checkid, _cur_rstep, _use_amp = done_tokens, cur_checkid, remain_steps, scaler is not None + model.train() + cur_b, _ls = 1, {} if save_loss else None + for i_d, taskid in tqdm(tl, mininterval=tqdm_mininterval): + task_grp = td[str(taskid)] + seq_batch = torch.from_numpy(task_grp["src"][i_d][:]) + seq_o = torch.from_numpy(task_grp["tgt"][i_d][:]) + lo = seq_o.size(1) - 1 + if mv_device: + seq_batch = seq_batch.to(mv_device) + seq_o = seq_o.to(mv_device) + seq_batch, seq_o = seq_batch.long(), seq_o.long() + + oi = seq_o.narrow(1, 0, lo) + ot = seq_o.narrow(1, 1, lo).contiguous() + with autocast(enabled=_use_amp): + output = model(seq_batch, oi, taskid=taskid) + loss = lossf(output, ot, lang_id=taskid) + if multi_gpu: + loss = loss.sum() + loss_add = loss.data.item() + + if scaler is None: + loss.backward() + else: + scaler.scale(loss).backward() + + wd_add = ot.ne(pad_id).int().sum().item() + loss = output = oi = ot = seq_batch = seq_o = None + sum_loss += loss_add + if save_loss: + _ls[(i_d, t_d)] = loss_add / wd_add + sum_wd += wd_add + _done_tokens += wd_add + + if _done_tokens >= tokens_optm: + optm_step(optm, model=model, scaler=scaler, multi_gpu=multi_gpu, multi_gpu_optimizer=multi_gpu_optimizer, zero_grad_none=optm_step_zero_grad_set_none) + _done_tokens = 0 + if _cur_rstep is not None: + if save_checkp_epoch and (save_every is not None) and (_cur_rstep % save_every == 0) and (chkpf is not None) and (_cur_rstep > 0): + if num_checkpoint > 1: + _fend = "_%d.h5" % (_cur_checkid) + _chkpf = chkpf[:-3] + _fend + if chkpof is not None: + _chkpof = chkpof[:-3] + _fend + _cur_checkid = (_cur_checkid + 1) % num_checkpoint + else: + _chkpf = chkpf + _chkpof = chkpof + save_model(model, _chkpf, multi_gpu, print_func=logger.info) + if chkpof is not None: + h5save(optm.state_dict(), _chkpof) + if statesf is not None: + save_states(statesf, tl[cur_b - 1:]) + _cur_rstep -= 1 + if _cur_rstep <= 0: + break + lrsch.step() + + if nreport is not None: + part_loss += loss_add + part_wd += wd_add + if cur_b % nreport == 0: + if report_eva: + _leva, _eeva = eva(ed, nd, model, lossf, mv_device, multi_gpu, _use_amp) + logger.info("Average loss over %d tokens: %.3f, valid loss/error: %.3f %.2f" % (part_wd, part_loss / part_wd, _leva, _eeva,)) + free_cache(mv_device) + model.train() + else: + logger.info("Average loss over %d tokens: %.3f" % (part_wd, part_loss / part_wd,)) + part_loss = 0.0 + part_wd = 0 + + if save_checkp_epoch and (_cur_rstep is None) and (save_every is not None) and (cur_b % save_every == 0) and (chkpf is not None) and (cur_b < ndata): + if num_checkpoint > 1: + _fend = "_%d.h5" % (_cur_checkid) + _chkpf = chkpf[:-3] + _fend + if chkpof is not None: + _chkpof = chkpof[:-3] + _fend + _cur_checkid = (_cur_checkid + 1) % num_checkpoint + else: + _chkpf = chkpf + _chkpof = chkpof + save_model(model, _chkpf, multi_gpu, print_func=logger.info) + if chkpof is not None: + h5save(optm.state_dict(), _chkpof) + if statesf is not None: + save_states(statesf, tl[cur_b - 1:]) + cur_b += 1 + if part_wd != 0.0: + logger.info("Average loss over %d tokens: %.3f" % (part_wd, part_loss / part_wd,)) + return sum_loss / sum_wd, _done_tokens, _cur_checkid, _cur_rstep, _ls + +def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): + r = w = 0 + sum_loss = 0.0 + model.eval() + with torch.no_grad(): + for i_d, taskid in tqdm(nd, mininterval=tqdm_mininterval): + task_grp = ed[str(taskid)] + seq_batch = torch.from_numpy(task_grp["src"][i_d][:]) + seq_o = torch.from_numpy(task_grp["tgt"][i_d][:]) + lo = seq_o.size(1) - 1 + if mv_device: + seq_batch = seq_batch.to(mv_device) + seq_o = seq_o.to(mv_device) + seq_batch, seq_o = seq_batch.long(), seq_o.long() + ot = seq_o.narrow(1, 1, lo).contiguous() + with autocast(enabled=use_amp): + output = model(seq_batch, seq_o.narrow(1, 0, lo), taskid=taskid) + loss = lossf(output, ot, lang_id=taskid) + if multi_gpu: + loss = loss.sum() + trans = torch.cat([outu.argmax(-1).to(mv_device) for outu in output], 0) + else: + trans = output.argmax(-1) + sum_loss += loss.data.item() + data_mask = ot.ne(pad_id) + correct = (trans.eq(ot) & data_mask).int() + w += data_mask.int().sum().item() + r += correct.sum().item() + correct = data_mask = trans = loss = output = ot = seq_batch = seq_o = None + w = float(w) + return sum_loss / w, (w - r) / w * 100.0 + +def hook_lr_update(optm, flags=None): + + reset_Adam(optm, flags) + +def init_fixing(module): + + if hasattr(module, "fix_init"): + module.fix_init() + +def load_fixing(module): + + if hasattr(module, "fix_load"): + module.fix_load() + +rid = cnfg.run_id +earlystop = cnfg.earlystop +maxrun = cnfg.maxrun +tokens_optm = cnfg.tokens_optm +done_tokens = 0 +batch_report = cnfg.batch_report +report_eva = cnfg.report_eva +use_ams = cnfg.use_ams +save_optm_state = cnfg.save_optm_state +save_auto_clean = cnfg.save_auto_clean +overwrite_eva = cnfg.overwrite_eva +save_every = cnfg.save_every +start_chkp_save = cnfg.epoch_start_checkpoint_save +epoch_save = cnfg.epoch_save +remain_steps = cnfg.training_steps + +wkdir = "".join((cnfg.exp_dir, cnfg.data_id, "/", cnfg.group_id, "/", rid, "/")) +mkdir(wkdir) + +chkpf = None +chkpof = None +statesf = None +if save_every is not None: + chkpf = wkdir + "checkpoint.h5" + if save_optm_state: + chkpof = wkdir + "checkpoint.optm.h5" + if cnfg.save_train_state: + statesf = wkdir + "checkpoint.states" + +logger = get_logger(wkdir + "train.log") + +use_cuda, cuda_device, cuda_devices, multi_gpu = parse_cuda(cnfg.use_cuda, cnfg.gpuid) +multi_gpu_optimizer = multi_gpu and cnfg.multi_gpu_optimizer + +set_random_seed(cnfg.seed, use_cuda) + +td = h5py.File(cnfg.train_data, "r") +vd = h5py.File(cnfg.dev_data, "r") + +ntrain = td["ndata"][:].tolist() +nvalid = vd["ndata"][:].tolist() +nword = td["nword"][:].tolist() +nwordi, ntask, nwordt = nword[0], nword[1], nword[-1] + +logger.info("Design models with seed: %d" % torch.initial_seed()) +mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes, ntask=ntask)#, ngroup=cnfg.ngroup + +fine_tune_m = cnfg.fine_tune_m +task_weight, task_weight_T = cnfg.task_weight, cnfg.task_weight_T +if task_weight_T is None or task_weight_T == 1.0: + tl = [(str(i), _task,) for _nd, _task in zip(ntrain, td["taskorder"][:].tolist()) for i in range(_nd)] + train_sampler = None +else: + train_taskorder = td["taskorder"][:].tolist() + _tnd = dict(zip(train_taskorder, ntrain)) + train_taskorder.sort() + ntrain = [_tnd[i] for i in train_taskorder] + _tnd = None + train_sampler = data_sampler(ntrain if task_weight is None else task_weight, task_weight_T, ntrain, train_taskorder, nsample=sum(ntrain)) +nvalid = [(str(i), _task,) for _nd, _task in zip(nvalid, vd["taskorder"][:].tolist()) for i in range(_nd)] + +mymodel = init_model_params(mymodel) +mymodel.apply(init_fixing) +if fine_tune_m is not None: + logger.info("Load pre-trained model from: " + fine_tune_m) + mymodel = load_model_cpu(fine_tune_m, mymodel) + mymodel.apply(load_fixing) + +lossf = LabelSmoothingLoss(nwordt, cnfg.label_smoothing, ignore_index=pad_id, reduction='sum', forbidden_index=cnfg.forbidden_indexes) + +if cnfg.src_emb is not None: + logger.info("Load source embedding from: " + cnfg.src_emb) + load_emb(cnfg.src_emb, mymodel.enc.wemb.weight, nwordi, cnfg.scale_down_emb, cnfg.freeze_srcemb) +if cnfg.tgt_emb is not None: + logger.info("Load target embedding from: " + cnfg.tgt_emb) + load_emb(cnfg.tgt_emb, mymodel.dec.wemb.weight, nwordt, cnfg.scale_down_emb, cnfg.freeze_tgtemb) + +if cuda_device: + mymodel.to(cuda_device) + lossf.to(cuda_device) + +use_amp = cnfg.use_amp and use_cuda +scaler = (MultiGPUGradScaler() if multi_gpu_optimizer else GradScaler()) if use_amp else None + +if multi_gpu: + mymodel = DataParallelMT(mymodel, device_ids=cuda_devices, output_device=cuda_device.index, host_replicate=True, gather_output=False) + lossf = DataParallelCriterion(lossf, device_ids=cuda_devices, output_device=cuda_device.index, replicate_once=True) + +if multi_gpu: + optimizer = mymodel.build_optimizer(Optimizer, lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams, multi_gpu_optimizer=multi_gpu_optimizer, contiguous_parameters=contiguous_parameters) +else: + optimizer = Optimizer(get_model_parameters(mymodel, contiguous_parameters=contiguous_parameters), lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams) +optimizer.zero_grad(set_to_none=optm_step_zero_grad_set_none) + +fine_tune_state = cnfg.fine_tune_state +if fine_tune_state is not None: + logger.info("Load optimizer state from: " + fine_tune_state) + optimizer.load_state_dict(h5load(fine_tune_state)) + +lrsch = LRScheduler(optimizer, cnfg.isize, cnfg.warm_step, scale=cnfg.lr_scale) + +num_checkpoint = cnfg.num_checkpoint +cur_checkid = 0 + +tminerr = inf_default + +minloss, minerr = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp) +logger.info("Init lr: %s, Dev Loss/Error: %.3f %.2f" % (" ".join(tostr(getlr(optimizer))), minloss, minerr,)) + +if fine_tune_m is None: + save_model(mymodel, wkdir + "init.h5", multi_gpu, print_func=logger.info) + logger.info("Initial model saved") +else: + cnt_states = cnfg.train_statesf + if cnt_states is not None: + logger.info("Continue last epoch") + tminerr, done_tokens, cur_checkid, remain_steps, _ = train(td, load_states(cnt_states), vd, nvalid, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, multi_gpu_optimizer, tokens_optm, batch_report, save_every, chkpf, chkpof, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, False, False, scaler) + vloss, vprec = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp) + logger.info("Epoch: 0, train loss: %.3f, valid loss/error: %.3f %.2f" % (tminerr, vloss, vprec,)) + save_model(mymodel, wkdir + "train_0_%.3f_%.3f_%.2f.h5" % (tminerr, vloss, vprec,), multi_gpu, print_func=logger.info, mtyp=("eva" if overwrite_eva else "train") if save_auto_clean else None) + if save_optm_state: + h5save(optimizer.state_dict(), wkdir + "train_0_%.3f_%.3f_%.2f.optm.h5" % (tminerr, vloss, vprec,)) + logger.info("New best model saved") + +if cnfg.dss_ws is not None and cnfg.dss_ws > 0.0 and cnfg.dss_ws < 1.0: + dss_ws = int(cnfg.dss_ws * sum(ntrain)) + _Dws = {} + _prev_Dws = {} + _crit_inc = {} + if cnfg.dss_rm is not None and cnfg.dss_rm > 0.0 and cnfg.dss_rm < 1.0: + dss_rm = int(cnfg.dss_rm * sum(ntrain) * (1.0 - cnfg.dss_ws)) + else: + dss_rm = 0 +else: + dss_ws = 0 + dss_rm = 0 + _Dws = None + +namin = 0 + +for i in range(1, maxrun + 1): + if train_sampler is None: + shuffle(tl) + else: + tl = train_sampler.generate() + free_cache(use_cuda) + terr, done_tokens, cur_checkid, remain_steps, _Dws = train(td, tl, vd, nvalid, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, multi_gpu_optimizer, tokens_optm, batch_report, save_every, chkpf, chkpof, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, dss_ws > 0, i >= start_chkp_save, scaler) + vloss, vprec = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp) + logger.info("Epoch: %d, train loss: %.3f, valid loss/error: %.3f %.2f" % (i, terr, vloss, vprec,)) + + if (vprec <= minerr) or (vloss <= minloss): + save_model(mymodel, wkdir + "eva_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec,), multi_gpu, print_func=logger.info, mtyp="eva" if save_auto_clean else None) + if save_optm_state: + h5save(optimizer.state_dict(), wkdir + "eva_%d_%.3f_%.3f_%.2f.optm.h5" % (i, terr, vloss, vprec,)) + logger.info("New best model saved") + + namin = 0 + + if vprec < minerr: + minerr = vprec + if vloss < minloss: + minloss = vloss + + else: + if terr < tminerr: + tminerr = terr + save_model(mymodel, wkdir + "train_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec,), multi_gpu, print_func=logger.info, mtyp=("eva" if overwrite_eva else "train") if save_auto_clean else None) + if save_optm_state: + h5save(optimizer.state_dict(), wkdir + "train_%d_%.3f_%.3f_%.2f.optm.h5" % (i, terr, vloss, vprec,)) + elif epoch_save: + save_model(mymodel, wkdir + "epoch_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec,), multi_gpu, print_func=logger.info) + + namin += 1 + if namin >= earlystop: + if done_tokens > 0: + optm_step(optimizer, model=mymodel, scaler=scaler, multi_gpu=multi_gpu, multi_gpu_optimizer=multi_gpu_optimizer) + done_tokens = 0 + logger.info("early stop") + break + + if remain_steps is not None and remain_steps <= 0: + logger.info("Last training step reached") + break + + if dss_ws > 0: + if _prev_Dws and (train_sampler is None): + for _key, _value in _Dws.items(): + if _key in _prev_Dws: + _ploss = _prev_Dws[_key] + _crit_inc[_key] = (_ploss - _value) / _ploss + tl = dynamic_sample(_crit_inc, dss_ws, dss_rm) + _prev_Dws = _Dws + +if done_tokens > 0: + optm_step(optimizer, model=mymodel, scaler=scaler, multi_gpu=multi_gpu, multi_gpu_optimizer=multi_gpu_optimizer) + +save_model(mymodel, wkdir + "last.h5", multi_gpu, print_func=logger.info) +if save_optm_state: + h5save(optimizer.state_dict(), wkdir + "last.optm.h5") +logger.info("model saved") + +td.close() +vd.close() diff --git a/adv/train/mulang/eff/train_mulang_robt.py b/adv/train/mulang/eff/train_mulang_robt.py new file mode 100644 index 0000000..6f33077 --- /dev/null +++ b/adv/train/mulang/eff/train_mulang_robt.py @@ -0,0 +1,410 @@ +#encoding: utf-8 + +import torch +from torch.cuda.amp import autocast, GradScaler + +from torch.optim import Adam as Optimizer + +from parallel.base import DataParallelCriterion +from parallel.parallelMT import DataParallelMT +from parallel.optm import MultiGPUGradScaler + +from utils.base import * +from utils.init import init_model_params +from utils.contpara import get_model_parameters +from utils.h5serial import h5save, h5load +from utils.fmt.base import tostr, save_states, load_states, pad_id +from utils.fmt.base4torch import parse_cuda, load_emb +from utils.mulang import data_sampler + +from lrsch import GoogleLR as LRScheduler +from loss.base import MultiLabelSmoothingLoss as LabelSmoothingLoss + +from random import shuffle, randint + +from tqdm import tqdm + +import h5py + +import cnfg.mulang as cnfg +from cnfg.ihyp import * + +from transformer.MuLang.Eff.Base.NMT import NMT + +def back_translate(model, seq_in, taskid, beam_size, multi_gpu, enable_autocast=False, step_bsize=32, step_ntok=640, pivot_bt=True): + + rs = [] + bsize, seql = seq_in.size() + _step_bsize = min(step_bsize, step_ntok // seql) + _max_len = min(255, seql + 16) + if multi_gpu: + _g_out = model.gather_output + model.gather_output = True + sind = 0 + with torch.no_grad(), autocast(enabled=enable_autocast): + while sind < bsize: + num_narrow = min(_step_bsize, bsize - sind) + if pivot_bt and (taskid != 0): + out = model.decode(seq_in.narrow(0, sind, num_narrow), taskid=0, beam_size=beam_size, max_len=_max_len).detach() + rs.append(model.decode(torch.cat((torch.ones(num_narrow, 1, dtype=out.dtype, device=out.device), out,), dim=1), taskid=taskid, beam_size=beam_size, max_len=_max_len).detach()) + out = _tid = None + else: + rs.append(model.decode(seq_in.narrow(0, sind, num_narrow), taskid=taskid, beam_size=beam_size, max_len=_max_len).detach()) + sind += num_narrow + if multi_gpu: + model.gather_output = _g_out + rs = torch.cat(pad_tensors(rs), dim=0) + rs = torch.cat((torch.ones(bsize, 1, dtype=rs.dtype, device=rs.device), rs,), dim=1) + + return rs + +def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tokens, multi_gpu, multi_gpu_optimizer, tokens_optm=32768, nreport=None, save_every=None, chkpf=None, chkpof=None, statesf=None, num_checkpoint=1, cur_checkid=0, report_eva=True, remain_steps=None, save_loss=False, save_checkp_epoch=False, scaler=None): + + sum_loss = part_loss = 0.0 + sum_wd = part_wd = 0 + _done_tokens, _cur_checkid, _cur_rstep, _use_amp = done_tokens, cur_checkid, remain_steps, scaler is not None + model.train() + cur_b, _ls = 1, {} if save_loss else None + global ntask, ro_beam_size + t_sample_max_id = ntask - 2 + for i_d, taskid in tqdm(tl, mininterval=tqdm_mininterval): + seq_o = torch.from_numpy(td[str(taskid)]["tgt"][i_d][:]) + lo = seq_o.size(1) - 1 + if mv_device: + seq_o = seq_o.to(mv_device) + seq_o = seq_o.long() + + _bt_taskid = randint(0, t_sample_max_id) + if _bt_taskid >= taskid: + _bt_taskid += 1 + seq_batch = back_translate(model, seq_o, _bt_taskid, ro_beam_size, multi_gpu, enable_autocast=_use_amp) + oi = seq_o.narrow(1, 0, lo) + ot = seq_o.narrow(1, 1, lo).contiguous() + with autocast(enabled=_use_amp): + output = model(seq_batch, oi, taskid=taskid) + loss = lossf(output, ot, lang_id=taskid) + if multi_gpu: + loss = loss.sum() + loss_add = loss.data.item() + + if scaler is None: + loss.backward() + else: + scaler.scale(loss).backward() + + wd_add = ot.ne(pad_id).int().sum().item() + loss = output = oi = ot = seq_batch = seq_o = None + sum_loss += loss_add + if save_loss: + _ls[(i_d, t_d)] = loss_add / wd_add + sum_wd += wd_add + _done_tokens += wd_add + + if _done_tokens >= tokens_optm: + optm_step(optm, model=model, scaler=scaler, multi_gpu=multi_gpu, multi_gpu_optimizer=multi_gpu_optimizer, zero_grad_none=optm_step_zero_grad_set_none) + _done_tokens = 0 + if _cur_rstep is not None: + if save_checkp_epoch and (save_every is not None) and (_cur_rstep % save_every == 0) and (chkpf is not None) and (_cur_rstep > 0): + if num_checkpoint > 1: + _fend = "_%d.h5" % (_cur_checkid) + _chkpf = chkpf[:-3] + _fend + if chkpof is not None: + _chkpof = chkpof[:-3] + _fend + _cur_checkid = (_cur_checkid + 1) % num_checkpoint + else: + _chkpf = chkpf + _chkpof = chkpof + save_model(model, _chkpf, multi_gpu, print_func=logger.info) + if chkpof is not None: + h5save(optm.state_dict(), _chkpof) + if statesf is not None: + save_states(statesf, tl[cur_b - 1:]) + _cur_rstep -= 1 + if _cur_rstep <= 0: + break + lrsch.step() + + if nreport is not None: + part_loss += loss_add + part_wd += wd_add + if cur_b % nreport == 0: + if report_eva: + _leva, _eeva = eva(ed, nd, model, lossf, mv_device, multi_gpu, _use_amp) + logger.info("Average loss over %d tokens: %.3f, valid loss/error: %.3f %.2f" % (part_wd, part_loss / part_wd, _leva, _eeva,)) + free_cache(mv_device) + model.train() + else: + logger.info("Average loss over %d tokens: %.3f" % (part_wd, part_loss / part_wd,)) + part_loss = 0.0 + part_wd = 0 + + if save_checkp_epoch and (_cur_rstep is None) and (save_every is not None) and (cur_b % save_every == 0) and (chkpf is not None) and (cur_b < ndata): + if num_checkpoint > 1: + _fend = "_%d.h5" % (_cur_checkid) + _chkpf = chkpf[:-3] + _fend + if chkpof is not None: + _chkpof = chkpof[:-3] + _fend + _cur_checkid = (_cur_checkid + 1) % num_checkpoint + else: + _chkpf = chkpf + _chkpof = chkpof + save_model(model, _chkpf, multi_gpu, print_func=logger.info) + if chkpof is not None: + h5save(optm.state_dict(), _chkpof) + if statesf is not None: + save_states(statesf, tl[cur_b - 1:]) + cur_b += 1 + if part_wd != 0.0: + logger.info("Average loss over %d tokens: %.3f" % (part_wd, part_loss / part_wd,)) + return sum_loss / sum_wd, _done_tokens, _cur_checkid, _cur_rstep, _ls + +def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): + r = w = 0 + sum_loss = 0.0 + model.eval() + with torch.no_grad(): + for i_d, taskid in tqdm(nd, mininterval=tqdm_mininterval): + task_grp = ed[str(taskid)] + seq_batch = torch.from_numpy(task_grp["src"][i_d][:]) + seq_o = torch.from_numpy(task_grp["tgt"][i_d][:]) + lo = seq_o.size(1) - 1 + if mv_device: + seq_batch = seq_batch.to(mv_device) + seq_o = seq_o.to(mv_device) + seq_batch, seq_o = seq_batch.long(), seq_o.long() + ot = seq_o.narrow(1, 1, lo).contiguous() + with autocast(enabled=use_amp): + output = model(seq_batch, seq_o.narrow(1, 0, lo), taskid=taskid) + loss = lossf(output, ot, lang_id=taskid) + if multi_gpu: + loss = loss.sum() + trans = torch.cat([outu.argmax(-1).to(mv_device) for outu in output], 0) + else: + trans = output.argmax(-1) + sum_loss += loss.data.item() + data_mask = ot.ne(pad_id) + correct = (trans.eq(ot) & data_mask).int() + w += data_mask.int().sum().item() + r += correct.sum().item() + correct = data_mask = trans = loss = output = ot = seq_batch = seq_o = None + w = float(w) + return sum_loss / w, (w - r) / w * 100.0 + +def hook_lr_update(optm, flags=None): + + reset_Adam(optm, flags) + +def init_fixing(module): + + if hasattr(module, "fix_init"): + module.fix_init() + +def load_fixing(module): + + if hasattr(module, "fix_load"): + module.fix_load() + +rid = cnfg.run_id +earlystop = cnfg.earlystop +maxrun = cnfg.maxrun +tokens_optm = cnfg.tokens_optm +done_tokens = 0 +batch_report = cnfg.batch_report +report_eva = cnfg.report_eva +use_ams = cnfg.use_ams +save_optm_state = cnfg.save_optm_state +save_auto_clean = cnfg.save_auto_clean +overwrite_eva = cnfg.overwrite_eva +save_every = cnfg.save_every +start_chkp_save = cnfg.epoch_start_checkpoint_save +epoch_save = cnfg.epoch_save +remain_steps = cnfg.training_steps +ro_beam_size = cnfg.robt_beam_size + +wkdir = "".join((cnfg.exp_dir, cnfg.data_id, "/", cnfg.group_id, "/", rid, "/")) +mkdir(wkdir) + +chkpf = None +chkpof = None +statesf = None +if save_every is not None: + chkpf = wkdir + "checkpoint.h5" + if save_optm_state: + chkpof = wkdir + "checkpoint.optm.h5" + if cnfg.save_train_state: + statesf = wkdir + "checkpoint.states" + +logger = get_logger(wkdir + "train.log") + +use_cuda, cuda_device, cuda_devices, multi_gpu = parse_cuda(cnfg.use_cuda, cnfg.gpuid) +multi_gpu_optimizer = multi_gpu and cnfg.multi_gpu_optimizer + +set_random_seed(cnfg.seed, use_cuda) + +td = h5py.File(cnfg.train_data, "r") +vd = h5py.File(cnfg.dev_data, "r") + +ntrain = td["ndata"][:].tolist() +nvalid = vd["ndata"][:].tolist() +nword = td["nword"][:].tolist() +nwordi, ntask, nwordt = nword[0], nword[1], nword[-1] + +logger.info("Design models with seed: %d" % torch.initial_seed()) +mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes, ntask=ntask)#, ngroup=cnfg.ngroup + +fine_tune_m = cnfg.fine_tune_m +task_weight, task_weight_T = cnfg.task_weight, cnfg.task_weight_T +if task_weight_T is None or task_weight_T == 1.0: + tl = [(str(i), _task,) for _nd, _task in zip(ntrain, td["taskorder"][:].tolist()) for i in range(_nd)] + train_sampler = None +else: + train_taskorder = td["taskorder"][:].tolist() + _tnd = dict(zip(train_taskorder, ntrain)) + train_taskorder.sort() + ntrain = [_tnd[i] for i in train_taskorder] + _tnd = None + train_sampler = data_sampler(ntrain if task_weight is None else task_weight, task_weight_T, ntrain, train_taskorder, nsample=sum(ntrain)) +nvalid = [(str(i), _task,) for _nd, _task in zip(nvalid, vd["taskorder"][:].tolist()) for i in range(_nd)] + +mymodel = init_model_params(mymodel) +mymodel.apply(init_fixing) +if fine_tune_m is not None: + logger.info("Load pre-trained model from: " + fine_tune_m) + mymodel = load_model_cpu(fine_tune_m, mymodel) + mymodel.apply(load_fixing) + +lossf = LabelSmoothingLoss(nwordt, cnfg.label_smoothing, ignore_index=pad_id, reduction='sum', forbidden_index=cnfg.forbidden_indexes) + +if cnfg.src_emb is not None: + logger.info("Load source embedding from: " + cnfg.src_emb) + load_emb(cnfg.src_emb, mymodel.enc.wemb.weight, nwordi, cnfg.scale_down_emb, cnfg.freeze_srcemb) +if cnfg.tgt_emb is not None: + logger.info("Load target embedding from: " + cnfg.tgt_emb) + load_emb(cnfg.tgt_emb, mymodel.dec.wemb.weight, nwordt, cnfg.scale_down_emb, cnfg.freeze_tgtemb) + +if cuda_device: + mymodel.to(cuda_device) + lossf.to(cuda_device) + +use_amp = cnfg.use_amp and use_cuda +scaler = (MultiGPUGradScaler() if multi_gpu_optimizer else GradScaler()) if use_amp else None + +if multi_gpu: + mymodel = DataParallelMT(mymodel, device_ids=cuda_devices, output_device=cuda_device.index, host_replicate=True, gather_output=False) + lossf = DataParallelCriterion(lossf, device_ids=cuda_devices, output_device=cuda_device.index, replicate_once=True) + +if multi_gpu: + optimizer = mymodel.build_optimizer(Optimizer, lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams, multi_gpu_optimizer=multi_gpu_optimizer, contiguous_parameters=contiguous_parameters) +else: + optimizer = Optimizer(get_model_parameters(mymodel, contiguous_parameters=contiguous_parameters), lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams) +optimizer.zero_grad(set_to_none=optm_step_zero_grad_set_none) + +fine_tune_state = cnfg.fine_tune_state +if fine_tune_state is not None: + logger.info("Load optimizer state from: " + fine_tune_state) + optimizer.load_state_dict(h5load(fine_tune_state)) + +lrsch = LRScheduler(optimizer, cnfg.isize, cnfg.warm_step, scale=cnfg.lr_scale) + +num_checkpoint = cnfg.num_checkpoint +cur_checkid = 0 + +tminerr = inf_default + +minloss, minerr = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp) +logger.info("Init lr: %s, Dev Loss/Error: %.3f %.2f" % (" ".join(tostr(getlr(optimizer))), minloss, minerr,)) + +if fine_tune_m is None: + save_model(mymodel, wkdir + "init.h5", multi_gpu, print_func=logger.info) + logger.info("Initial model saved") +else: + cnt_states = cnfg.train_statesf + if cnt_states is not None: + logger.info("Continue last epoch") + tminerr, done_tokens, cur_checkid, remain_steps, _ = train(td, load_states(cnt_states), vd, nvalid, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, multi_gpu_optimizer, tokens_optm, batch_report, save_every, chkpf, chkpof, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, False, False, scaler) + vloss, vprec = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp) + logger.info("Epoch: 0, train loss: %.3f, valid loss/error: %.3f %.2f" % (tminerr, vloss, vprec,)) + save_model(mymodel, wkdir + "train_0_%.3f_%.3f_%.2f.h5" % (tminerr, vloss, vprec,), multi_gpu, print_func=logger.info, mtyp=("eva" if overwrite_eva else "train") if save_auto_clean else None) + if save_optm_state: + h5save(optimizer.state_dict(), wkdir + "train_0_%.3f_%.3f_%.2f.optm.h5" % (tminerr, vloss, vprec,)) + logger.info("New best model saved") + +if cnfg.dss_ws is not None and cnfg.dss_ws > 0.0 and cnfg.dss_ws < 1.0: + dss_ws = int(cnfg.dss_ws * sum(ntrain)) + _Dws = {} + _prev_Dws = {} + _crit_inc = {} + if cnfg.dss_rm is not None and cnfg.dss_rm > 0.0 and cnfg.dss_rm < 1.0: + dss_rm = int(cnfg.dss_rm * sum(ntrain) * (1.0 - cnfg.dss_ws)) + else: + dss_rm = 0 +else: + dss_ws = 0 + dss_rm = 0 + _Dws = None + +namin = 0 + +for i in range(1, maxrun + 1): + if train_sampler is None: + shuffle(tl) + else: + tl = train_sampler.generate() + free_cache(use_cuda) + terr, done_tokens, cur_checkid, remain_steps, _Dws = train(td, tl, vd, nvalid, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, multi_gpu_optimizer, tokens_optm, batch_report, save_every, chkpf, chkpof, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, dss_ws > 0, i >= start_chkp_save, scaler) + vloss, vprec = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp) + logger.info("Epoch: %d, train loss: %.3f, valid loss/error: %.3f %.2f" % (i, terr, vloss, vprec,)) + + if (vprec <= minerr) or (vloss <= minloss): + save_model(mymodel, wkdir + "eva_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec,), multi_gpu, print_func=logger.info, mtyp="eva" if save_auto_clean else None) + if save_optm_state: + h5save(optimizer.state_dict(), wkdir + "eva_%d_%.3f_%.3f_%.2f.optm.h5" % (i, terr, vloss, vprec,)) + logger.info("New best model saved") + + namin = 0 + + if vprec < minerr: + minerr = vprec + if vloss < minloss: + minloss = vloss + + else: + if terr < tminerr: + tminerr = terr + save_model(mymodel, wkdir + "train_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec,), multi_gpu, print_func=logger.info, mtyp=("eva" if overwrite_eva else "train") if save_auto_clean else None) + if save_optm_state: + h5save(optimizer.state_dict(), wkdir + "train_%d_%.3f_%.3f_%.2f.optm.h5" % (i, terr, vloss, vprec,)) + elif epoch_save: + save_model(mymodel, wkdir + "epoch_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec,), multi_gpu, print_func=logger.info) + + namin += 1 + if namin >= earlystop: + if done_tokens > 0: + optm_step(optimizer, model=mymodel, scaler=scaler, multi_gpu=multi_gpu, multi_gpu_optimizer=multi_gpu_optimizer) + done_tokens = 0 + logger.info("early stop") + break + + if remain_steps is not None and remain_steps <= 0: + logger.info("Last training step reached") + break + + if dss_ws > 0: + if _prev_Dws and (train_sampler is None): + for _key, _value in _Dws.items(): + if _key in _prev_Dws: + _ploss = _prev_Dws[_key] + _crit_inc[_key] = (_ploss - _value) / _ploss + tl = dynamic_sample(_crit_inc, dss_ws, dss_rm) + _prev_Dws = _Dws + +if done_tokens > 0: + optm_step(optimizer, model=mymodel, scaler=scaler, multi_gpu=multi_gpu, multi_gpu_optimizer=multi_gpu_optimizer) + +save_model(mymodel, wkdir + "last.h5", multi_gpu, print_func=logger.info) +if save_optm_state: + h5save(optimizer.state_dict(), wkdir + "last.optm.h5") +logger.info("model saved") + +td.close() +vd.close() diff --git a/adv/train/train_ape.py b/adv/train/train_ape.py index 3fb1042..2e6de86 100644 --- a/adv/train/train_ape.py +++ b/adv/train/train_ape.py @@ -11,6 +11,7 @@ from utils.base import * from utils.init import init_model_params +from utils.contpara import get_model_parameters from utils.h5serial import h5save, h5load from utils.fmt.base import tostr, save_states, load_states, pad_id from utils.fmt.base4torch import parse_cuda, load_emb @@ -22,9 +23,6 @@ from tqdm import tqdm -from os import makedirs -from os.path import exists as p_check - import h5py import cnfg.base as cnfg @@ -74,7 +72,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok _done_tokens += wd_add if _done_tokens >= tokens_optm: - optm_step(optm, model=model, scaler=scaler, multi_gpu=multi_gpu, multi_gpu_optimizer=multi_gpu_optimizer) + optm_step(optm, model=model, scaler=scaler, multi_gpu=multi_gpu, multi_gpu_optimizer=multi_gpu_optimizer, zero_grad_none=optm_step_zero_grad_set_none) _done_tokens = 0 if _cur_rstep is not None: if save_checkp_epoch and (save_every is not None) and (_cur_rstep % save_every == 0) and (chkpf is not None) and (_cur_rstep > 0): @@ -87,7 +85,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok else: _chkpf = chkpf _chkpof = chkpof - save_model(model, _chkpf, multi_gpu, logger) + save_model(model, _chkpf, multi_gpu, print_func=logger.info) if chkpof is not None: h5save(optm.state_dict(), _chkpof) if statesf is not None: @@ -121,7 +119,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok else: _chkpf = chkpf _chkpof = chkpof - save_model(model, _chkpf, multi_gpu, logger) + save_model(model, _chkpf, multi_gpu, print_func=logger.info) if chkpof is not None: h5save(optm.state_dict(), _chkpof) if statesf is not None: @@ -181,32 +179,23 @@ def load_fixing(module): module.fix_load() rid = cnfg.run_id - earlystop = cnfg.earlystop - maxrun = cnfg.maxrun - tokens_optm = cnfg.tokens_optm - done_tokens = 0 - batch_report = cnfg.batch_report report_eva = cnfg.report_eva - use_ams = cnfg.use_ams - save_optm_state = cnfg.save_optm_state - +save_auto_clean = cnfg.save_auto_clean +overwrite_eva = cnfg.overwrite_eva save_every = cnfg.save_every start_chkp_save = cnfg.epoch_start_checkpoint_save - epoch_save = cnfg.epoch_save - remain_steps = cnfg.training_steps wkdir = "".join((cnfg.exp_dir, cnfg.data_id, "/", cnfg.group_id, "/", rid, "/")) -if not p_check(wkdir): - makedirs(wkdir) +mkdir(wkdir) chkpf = None chkpof = None @@ -267,12 +256,11 @@ def load_fixing(module): mymodel = DataParallelMT(mymodel, device_ids=cuda_devices, output_device=cuda_device.index, host_replicate=True, gather_output=False) lossf = DataParallelCriterion(lossf, device_ids=cuda_devices, output_device=cuda_device.index, replicate_once=True) -if multi_gpu_optimizer: - optimizer = mymodel.build_optimizer(Optimizer, lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams) - mymodel.zero_grad(set_to_none=True) +if multi_gpu: + optimizer = mymodel.build_optimizer(Optimizer, lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams, multi_gpu_optimizer=multi_gpu_optimizer, contiguous_parameters=contiguous_parameters) else: - optimizer = Optimizer((mymodel.module if multi_gpu else mymodel).parameters(), lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams) - optimizer.zero_grad(set_to_none=True) + optimizer = Optimizer(get_model_parameters(mymodel, contiguous_parameters=contiguous_parameters), lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams) +optimizer.zero_grad(set_to_none=optm_step_zero_grad_set_none) fine_tune_state = cnfg.fine_tune_state if fine_tune_state is not None: @@ -290,16 +278,16 @@ def load_fixing(module): logger.info("".join(("Init lr: ", ",".join(tostr(getlr(optimizer))), ", Dev Loss/Error: %.3f %.2f" % (minloss, minerr)))) if fine_tune_m is None: - save_model(mymodel, wkdir + "init.h5", multi_gpu, logger) + save_model(mymodel, wkdir + "init.h5", multi_gpu, print_func=logger.info) logger.info("Initial model saved") else: cnt_states = cnfg.train_statesf - if (cnt_states is not None) and p_check(cnt_states): + if cnt_states is not None: logger.info("Continue last epoch") tminerr, done_tokens, cur_checkid, remain_steps, _ = train(td, load_states(cnt_states), vd, nvalid, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, multi_gpu_optimizer, tokens_optm, batch_report, save_every, chkpf, chkpof, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, False, False, scaler) vloss, vprec = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp) logger.info("Epoch: 0, train loss: %.3f, valid loss/error: %.3f %.2f" % (tminerr, vloss, vprec)) - save_model(mymodel, wkdir + "train_0_%.3f_%.3f_%.2f.h5" % (tminerr, vloss, vprec), multi_gpu, logger) + save_model(mymodel, wkdir + "train_0_%.3f_%.3f_%.2f.h5" % (tminerr, vloss, vprec), multi_gpu, print_func=logger.info, mtyp=("eva" if overwrite_eva else "train") if save_auto_clean else None) if save_optm_state: h5save(optimizer.state_dict(), wkdir + "train_0_%.3f_%.3f_%.2f.optm.h5" % (tminerr, vloss, vprec)) logger.info("New best model saved") @@ -328,7 +316,7 @@ def load_fixing(module): logger.info("Epoch: %d, train loss: %.3f, valid loss/error: %.3f %.2f" % (i, terr, vloss, vprec)) if (vprec <= minerr) or (vloss <= minloss): - save_model(mymodel, wkdir + "eva_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec), multi_gpu, logger) + save_model(mymodel, wkdir + "eva_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec), multi_gpu, print_func=logger.info, mtyp="eva" if save_auto_clean else None) if save_optm_state: h5save(optimizer.state_dict(), wkdir + "eva_%d_%.3f_%.3f_%.2f.optm.h5" % (i, terr, vloss, vprec)) logger.info("New best model saved") @@ -343,11 +331,11 @@ def load_fixing(module): else: if terr < tminerr: tminerr = terr - save_model(mymodel, wkdir + "train_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec), multi_gpu, logger) + save_model(mymodel, wkdir + "train_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec), multi_gpu, print_func=logger.info, mtyp=("eva" if overwrite_eva else "train") if save_auto_clean else None) if save_optm_state: h5save(optimizer.state_dict(), wkdir + "train_%d_%.3f_%.3f_%.2f.optm.h5" % (i, terr, vloss, vprec)) elif epoch_save: - save_model(mymodel, wkdir + "epoch_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec), multi_gpu, logger) + save_model(mymodel, wkdir + "epoch_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec), multi_gpu, print_func=logger.info) namin += 1 if namin >= earlystop: @@ -373,7 +361,7 @@ def load_fixing(module): if done_tokens > 0: optm_step(optimizer, model=mymodel, scaler=scaler, multi_gpu=multi_gpu, multi_gpu_optimizer=multi_gpu_optimizer) -save_model(mymodel, wkdir + "last.h5", multi_gpu, logger) +save_model(mymodel, wkdir + "last.h5", multi_gpu, print_func=logger.info) if save_optm_state: h5save(optimizer.state_dict(), wkdir + "last.optm.h5") logger.info("model saved") diff --git a/adv/train/train_dynb.py b/adv/train/train_dynb.py index 1cb927b..bc68e2b 100644 --- a/adv/train/train_dynb.py +++ b/adv/train/train_dynb.py @@ -11,6 +11,7 @@ from utils.base import * from utils.init import init_model_params +from utils.contpara import get_model_parameters from utils.dynbatch import GradientMonitor from utils.h5serial import h5save, h5load from utils.fmt.base import tostr, save_states, load_states, pad_id, parse_double_value_tuple @@ -24,9 +25,6 @@ from tqdm import tqdm -from os import makedirs -from os.path import exists as p_check - import h5py import cnfg.dynb as cnfg @@ -96,8 +94,8 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok if _do_optm_step: if multi_gpu: model.collect_gradients() - optm_step(optm, scaler) - optm.zero_grad(set_to_none=True) + optm_step(optm, scaler, zero_grad_none=optm_step_zero_grad_set_none) + optm.zero_grad(set_to_none=optm_step_zero_grad_set_none) if multi_gpu: model.update_replicas() lrsch.step() @@ -105,7 +103,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok if multi_gpu: model.reset_grad() else: - optm.zero_grad(set_to_none=True) + optm.zero_grad(set_to_none=optm_step_zero_grad_set_none) _done_tokens = 0 if _cur_rstep is not None: if save_checkp_epoch and (save_every is not None) and (_cur_rstep % save_every == 0) and (chkpf is not None) and (_cur_rstep > 0): @@ -118,7 +116,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok else: _chkpf = chkpf _chkpof = chkpof - save_model(model, _chkpf, multi_gpu, logger) + save_model(model, _chkpf, multi_gpu, print_func=logger.info) if chkpof is not None: h5save(optm.state_dict(), _chkpof) if statesf is not None: @@ -152,7 +150,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok else: _chkpf = chkpf _chkpof = chkpof - save_model(model, _chkpf, multi_gpu, logger) + save_model(model, _chkpf, multi_gpu, print_func=logger.info) if chkpof is not None: h5save(optm.state_dict(), _chkpof) if statesf is not None: @@ -207,32 +205,23 @@ def load_fixing(module): module.fix_load() rid = cnfg.run_id - earlystop = cnfg.earlystop - maxrun = cnfg.maxrun - tokens_optm = cnfg.tokens_optm - done_tokens = 0 - batch_report = cnfg.batch_report report_eva = cnfg.report_eva - use_ams = cnfg.use_ams - save_optm_state = cnfg.save_optm_state - +save_auto_clean = cnfg.save_auto_clean +overwrite_eva = cnfg.overwrite_eva save_every = cnfg.save_every start_chkp_save = cnfg.epoch_start_checkpoint_save - epoch_save = cnfg.epoch_save - remain_steps = cnfg.training_steps wkdir = "".join((cnfg.exp_dir, cnfg.data_id, "/", cnfg.group_id, "/", rid, "/")) -if not p_check(wkdir): - makedirs(wkdir) +mkdir(wkdir) chkpf = None chkpof = None @@ -293,12 +282,11 @@ def load_fixing(module): mymodel = DataParallelMT(mymodel, device_ids=cuda_devices, output_device=cuda_device.index, host_replicate=True, gather_output=False) lossf = DataParallelCriterion(lossf, device_ids=cuda_devices, output_device=cuda_device.index, replicate_once=True) -if multi_gpu_optimizer: - optimizer = mymodel.build_optimizer(Optimizer, lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams) - mymodel.zero_grad(set_to_none=True) +if multi_gpu: + optimizer = mymodel.build_optimizer(Optimizer, lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams, multi_gpu_optimizer=multi_gpu_optimizer, contiguous_parameters=contiguous_parameters) else: - optimizer = Optimizer((mymodel.module if multi_gpu else mymodel).parameters(), lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams) - optimizer.zero_grad(set_to_none=True) + optimizer = Optimizer(get_model_parameters(mymodel, contiguous_parameters=contiguous_parameters), lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams) +optimizer.zero_grad(set_to_none=optm_step_zero_grad_set_none) fine_tune_state = cnfg.fine_tune_state if fine_tune_state is not None: @@ -316,16 +304,16 @@ def load_fixing(module): logger.info("".join(("Init lr: ", ",".join(tostr(getlr(optimizer))), ", Dev Loss/Error: %.3f %.2f" % (minloss, minerr)))) if fine_tune_m is None: - save_model(mymodel, wkdir + "init.h5", multi_gpu, logger) + save_model(mymodel, wkdir + "init.h5", multi_gpu, print_func=logger.info) logger.info("Initial model saved") else: cnt_states = cnfg.train_statesf - if (cnt_states is not None) and p_check(cnt_states): + if cnt_states is not None: logger.info("Continue last epoch") tminerr, done_tokens, cur_checkid, remain_steps, _ = train(td, load_states(cnt_states), vd, nvalid, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, multi_gpu_optimizer, tokens_optm, batch_report, save_every, chkpf, chkpof, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, False, False, scaler) vloss, vprec = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp) logger.info("Epoch: 0, train loss: %.3f, valid loss/error: %.3f %.2f" % (tminerr, vloss, vprec)) - save_model(mymodel, wkdir + "train_0_%.3f_%.3f_%.2f.h5" % (tminerr, vloss, vprec), multi_gpu, logger) + save_model(mymodel, wkdir + "train_0_%.3f_%.3f_%.2f.h5" % (tminerr, vloss, vprec), multi_gpu, print_func=logger.info, mtyp=("eva" if overwrite_eva else "train") if save_auto_clean else None) if save_optm_state: h5save(optimizer.state_dict(), wkdir + "train_0_%.3f_%.3f_%.2f.optm.h5" % (tminerr, vloss, vprec)) logger.info("New best model saved") @@ -354,7 +342,7 @@ def load_fixing(module): logger.info("Epoch: %d, train loss: %.3f, valid loss/error: %.3f %.2f" % (i, terr, vloss, vprec)) if (vprec <= minerr) or (vloss <= minloss): - save_model(mymodel, wkdir + "eva_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec), multi_gpu, logger) + save_model(mymodel, wkdir + "eva_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec), multi_gpu, print_func=logger.info, mtyp="eva" if save_auto_clean else None) if save_optm_state: h5save(optimizer.state_dict(), wkdir + "eva_%d_%.3f_%.3f_%.2f.optm.h5" % (i, terr, vloss, vprec)) logger.info("New best model saved") @@ -369,11 +357,11 @@ def load_fixing(module): else: if terr < tminerr: tminerr = terr - save_model(mymodel, wkdir + "train_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec), multi_gpu, logger) + save_model(mymodel, wkdir + "train_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec), multi_gpu, print_func=logger.info, mtyp=("eva" if overwrite_eva else "train") if save_auto_clean else None) if save_optm_state: h5save(optimizer.state_dict(), wkdir + "train_%d_%.3f_%.3f_%.2f.optm.h5" % (i, terr, vloss, vprec)) elif epoch_save: - save_model(mymodel, wkdir + "epoch_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec), multi_gpu, logger) + save_model(mymodel, wkdir + "epoch_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec), multi_gpu, print_func=logger.info) namin += 1 if namin >= earlystop: @@ -399,7 +387,7 @@ def load_fixing(module): if done_tokens > 0: optm_step(optimizer, model=mymodel, scaler=scaler, multi_gpu=multi_gpu, multi_gpu_optimizer=multi_gpu_optimizer) -save_model(mymodel, wkdir + "last.h5", multi_gpu, logger) +save_model(mymodel, wkdir + "last.h5", multi_gpu, print_func=logger.info) if save_optm_state: h5save(optimizer.state_dict(), wkdir + "last.optm.h5") logger.info("model saved") diff --git a/adv/train/train_probe.py b/adv/train/train_probe.py index e3be185..ef94388 100644 --- a/adv/train/train_probe.py +++ b/adv/train/train_probe.py @@ -11,6 +11,7 @@ from utils.base import * from utils.init import init_model_params +from utils.contpara import get_model_parameters from utils.h5serial import h5save, h5load from utils.fmt.base import tostr, save_states, load_states, pad_id from utils.fmt.base4torch import parse_cuda, load_emb @@ -22,9 +23,6 @@ from tqdm import tqdm -from os import makedirs -from os.path import exists as p_check - import h5py import cnfg.probe as cnfg @@ -78,7 +76,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok _done_tokens += wd_add if _done_tokens >= tokens_optm: - optm_step(optm, model=model, scaler=scaler, multi_gpu=multi_gpu, multi_gpu_optimizer=multi_gpu_optimizer) + optm_step(optm, model=model, scaler=scaler, multi_gpu=multi_gpu, multi_gpu_optimizer=multi_gpu_optimizer, zero_grad_none=optm_step_zero_grad_set_none) _done_tokens = 0 if _cur_rstep is not None: if save_checkp_epoch and (save_every is not None) and (_cur_rstep % save_every == 0) and (chkpf is not None) and (_cur_rstep > 0): @@ -91,7 +89,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok else: _chkpf = chkpf _chkpof = chkpof - save_model(model.module.dec if multi_gpu else model.dec, _chkpf, False, logger) + save_model(model.module.dec if multi_gpu else model.dec, _chkpf, False, print_func=logger.info) if chkpof is not None: h5save(optm.state_dict(), _chkpof) if statesf is not None: @@ -125,7 +123,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok else: _chkpf = chkpf _chkpof = chkpof - save_model(model.module.dec if multi_gpu else model.dec, _chkpf, False, logger) + save_model(model.module.dec if multi_gpu else model.dec, _chkpf, False, print_func=logger.info) if chkpof is not None: h5save(optm.state_dict(), _chkpof) if statesf is not None: @@ -187,32 +185,23 @@ def load_fixing(module): module.fix_load() rid = cnfg.run_id - earlystop = cnfg.earlystop - maxrun = cnfg.maxrun - tokens_optm = cnfg.tokens_optm - done_tokens = 0 - batch_report = cnfg.batch_report report_eva = cnfg.report_eva - use_ams = cnfg.use_ams - save_optm_state = cnfg.save_optm_state - +save_auto_clean = cnfg.save_auto_clean +overwrite_eva = cnfg.overwrite_eva save_every = cnfg.save_every start_chkp_save = cnfg.epoch_start_checkpoint_save - epoch_save = cnfg.epoch_save - remain_steps = cnfg.training_steps wkdir = "".join((cnfg.exp_dir, cnfg.data_id, "/", cnfg.group_id, "/", rid, "/")) -if not p_check(wkdir): - makedirs(wkdir) +mkdir(wkdir) chkpf = None chkpof = None @@ -283,12 +272,11 @@ def load_fixing(module): mymodel = DataParallelMT(mymodel, device_ids=cuda_devices, output_device=cuda_device.index, host_replicate=True, gather_output=False) lossf = DataParallelCriterion(lossf, device_ids=cuda_devices, output_device=cuda_device.index, replicate_once=True) -if multi_gpu_optimizer: - optimizer = mymodel.build_optimizer(Optimizer, lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams) - mymodel.zero_grad(set_to_none=True) +if multi_gpu: + optimizer = mymodel.build_optimizer(Optimizer, lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams, multi_gpu_optimizer=multi_gpu_optimizer, contiguous_parameters=contiguous_parameters) else: - optimizer = Optimizer((mymodel.module if multi_gpu else mymodel).parameters(), lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams) - optimizer.zero_grad(set_to_none=True) + optimizer = Optimizer(get_model_parameters(mymodel, contiguous_parameters=contiguous_parameters), lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams) +optimizer.zero_grad(set_to_none=optm_step_zero_grad_set_none) fine_tune_state = cnfg.fine_tune_state if fine_tune_state is not None: @@ -306,16 +294,16 @@ def load_fixing(module): logger.info("".join(("Init lr: ", ",".join(tostr(getlr(optimizer))), ", Dev Loss/Error: %.3f %.2f" % (minloss, minerr)))) if fine_tune_m is None: - save_model(mymodel.module.dec if multi_gpu else mymodel.dec, wkdir + "init.h5", False, logger) + save_model(mymodel.module.dec if multi_gpu else mymodel.dec, wkdir + "init.h5", False, print_func=logger.info) logger.info("Initial model saved") else: cnt_states = cnfg.train_statesf - if (cnt_states is not None) and p_check(cnt_states): + if cnt_states is not None: logger.info("Continue last epoch") tminerr, done_tokens, cur_checkid, remain_steps, _ = train(td, load_states(cnt_states), vd, nvalid, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, multi_gpu_optimizer, tokens_optm, batch_report, save_every, chkpf, chkpof, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, False, False, scaler) vloss, vprec = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp) logger.info("Epoch: 0, train loss: %.3f, valid loss/error: %.3f %.2f" % (tminerr, vloss, vprec)) - save_model(mymodel.module.dec if multi_gpu else mymodel.dec, wkdir + "train_0_%.3f_%.3f_%.2f.h5" % (tminerr, vloss, vprec), False, logger) + save_model(mymodel.module.dec if multi_gpu else mymodel.dec, wkdir + "train_0_%.3f_%.3f_%.2f.h5" % (tminerr, vloss, vprec), False, print_func=logger.info, mtyp=("eva" if overwrite_eva else "train") if save_auto_clean else None) if save_optm_state: h5save(optimizer.state_dict(), wkdir + "train_0_%.3f_%.3f_%.2f.optm.h5" % (tminerr, vloss, vprec)) logger.info("New best model saved") @@ -344,7 +332,7 @@ def load_fixing(module): logger.info("Epoch: %d, train loss: %.3f, valid loss/error: %.3f %.2f" % (i, terr, vloss, vprec)) if (vprec <= minerr) or (vloss <= minloss): - save_model(mymodel.module.dec if multi_gpu else mymodel.dec, wkdir + "eva_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec), False, logger) + save_model(mymodel.module.dec if multi_gpu else mymodel.dec, wkdir + "eva_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec), False, print_func=logger.info, mtyp="eva" if save_auto_clean else None) if save_optm_state: h5save(optimizer.state_dict(), wkdir + "eva_%d_%.3f_%.3f_%.2f.optm.h5" % (i, terr, vloss, vprec)) logger.info("New best model saved") @@ -359,11 +347,11 @@ def load_fixing(module): else: if terr < tminerr: tminerr = terr - save_model(mymodel.module.dec if multi_gpu else mymodel.dec, wkdir + "train_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec), False, logger) + save_model(mymodel.module.dec if multi_gpu else mymodel.dec, wkdir + "train_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec), False, print_func=logger.info, mtyp=("eva" if overwrite_eva else "train") if save_auto_clean else None) if save_optm_state: h5save(optimizer.state_dict(), wkdir + "train_%d_%.3f_%.3f_%.2f.optm.h5" % (i, terr, vloss, vprec)) elif epoch_save: - save_model(mymodel.module.dec if multi_gpu else mymodel.dec, wkdir + "epoch_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec), False, logger) + save_model(mymodel.module.dec if multi_gpu else mymodel.dec, wkdir + "epoch_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec), False, print_func=logger.info) namin += 1 if namin >= earlystop: @@ -389,7 +377,7 @@ def load_fixing(module): if done_tokens > 0: optm_step(optimizer, model=mymodel, scaler=scaler, multi_gpu=multi_gpu, multi_gpu_optimizer=multi_gpu_optimizer) -save_model(mymodel.module.dec if multi_gpu else mymodel.dec, wkdir + "last.h5", False, logger) +save_model(mymodel.module.dec if multi_gpu else mymodel.dec, wkdir + "last.h5", False, print_func=logger.info) if save_optm_state: h5save(optimizer.state_dict(), wkdir + "last.optm.h5") logger.info("model saved") diff --git a/cnfg/README.md b/cnfg/README.md index f2c0f80..5fa3778 100644 --- a/cnfg/README.md +++ b/cnfg/README.md @@ -27,6 +27,10 @@ fine_tune_m = None # add 3 to forbidden_indexes if there are tokens in data forbidden_indexes = [0, 1] +# automatically remove the previous best train/validation model when saving the new best. +save_auto_clean = True +# allow the best performing model on the training set to overwrite the best performing model on the development set. +overwrite_eva=False # after how much step save a checkpoint which you can fine tune with. save_every = 1500 # maximum number of checkpoint models saved, useful for average or ensemble. diff --git a/cnfg/base.py b/cnfg/base.py index e99a17e..1e9a65c 100644 --- a/cnfg/base.py +++ b/cnfg/base.py @@ -20,6 +20,8 @@ # add 3 to forbidden_indexes if there are tokens in data forbidden_indexes = [0, 1] +save_auto_clean = True +overwrite_eva = False save_every = 1500 num_checkpoint = 4 epoch_start_checkpoint_save = 3 diff --git a/cnfg/hyp.py b/cnfg/hyp.py index ea07a8d..2a7882b 100644 --- a/cnfg/hyp.py +++ b/cnfg/hyp.py @@ -39,3 +39,9 @@ # prune with length penalty in each beam decoding step clip_beam_with_lp = True + +# use C backend. Disabling it leads to better performance. +use_c_backend = False + +# accelerate optimizer by using contigous parameters and gradients. Disabling it leads to better performance. +contiguous_parameters = False diff --git a/cnfg/ihyp.py b/cnfg/ihyp.py index 09726ec..8fadc14 100644 --- a/cnfg/ihyp.py +++ b/cnfg/ihyp.py @@ -8,29 +8,39 @@ from utils.fmt.base import parse_none, parse_double_value_tuple +# C backend +if use_c_backend is None: + use_c_backend_attn = use_c_backend_selfattn = use_c_backend_crossattn = use_c_backend_pff = use_c_backend_group = True + use_c_backend_act_func = False +else: + use_c_backend_attn = use_c_backend_selfattn = use_c_backend_crossattn = use_c_backend_group = use_c_backend_pff = use_c_backend_act_func = use_c_backend +use_c_backend_mhattn = use_c_backend_attn or use_c_backend_selfattn or use_c_backend_crossattn +bind_c_forward = use_c_backend + +# biases enable_prev_ln_bias_default = enable_proj_bias_default = not ease_optimization -enable_ln_parameters = True +# computation order +norm_residual_default = not (computation_order.lower() == "v2") -use_adv_act_default = custom_act_Sigmoid = custom_act_Swish = custom_act_Mish = use_norm_Swish = False -if advance_activation_function is not None: - use_adv_act_default = True - _adv_act = advance_activation_function.lower() - use_norm_Swish = (_adv_act == "normswish") - if _adv_act == "sigmoid": - custom_act_Sigmoid = True - elif _adv_act == "swish": - custom_act_Swish = True - elif _adv_act == "mish": - custom_act_Mish = True +# Layer Norm +enable_ln_parameters = True -inplace_after_Custom_Act = use_adv_act_default and (not custom_act_Sigmoid) +# activation fucntion +use_adv_act_default = advance_activation_function is not None +adv_act = advance_activation_function.lower() if use_adv_act_default else None +inplace_after_Custom_Act = use_adv_act_default and (adv_act not in set(["sigmoid"])) -norm_residual_default = not (computation_order.lower() == "v2") +# relative position encoding +use_k_relative_position_encoder, use_k_relative_position_decoder = parse_double_value_tuple(use_k_relative_position) +rel_pos_enabled = (max(use_k_relative_position_encoder, use_k_relative_position_decoder) > 0) +disable_std_pemb_encoder, disable_std_pemb_decoder = parse_double_value_tuple(disable_std_pemb) +relpos_reduction_with_zeros = True -# override by the GoogleLR in most case +# learning rate, override by the GoogleLR in most case init_lr = 1e-4 +# hyper-parameters inf_default = inf ieps_default = 1e-9 @@ -39,18 +49,20 @@ ieps_ln_default = parse_none(ieps_ln_default, ieps_default) ieps_adam_default = parse_none(ieps_adam_default, ieps_default) ieps_noise_default = ieps_ln_default +ieps_upper_bound_default = ieps_default +ieps_dropout_multinomial_default = ieps_default adam_betas_default = (0.9, 0.98,) -use_k_relative_position_encoder, use_k_relative_position_decoder = parse_double_value_tuple(use_k_relative_position) -rel_pos_enabled = (max(use_k_relative_position_encoder, use_k_relative_position_decoder) > 0) -disable_std_pemb_encoder, disable_std_pemb_decoder = parse_double_value_tuple(disable_std_pemb) -relpos_reduction_with_zeros = True - +# HDF5 serialization h5datawargs = {} if hdf5_data_compression is None else {"compression": hdf5_data_compression, "compression_opts": hdf5_data_compression_level, "shuffle":True} h5modelwargs = {} if hdf5_model_compression is None else {"compression": hdf5_model_compression, "compression_opts": hdf5_model_compression_level, "shuffle":True} h5zipargs = {"compression": "gzip", "compression_opts": 9, "shuffle":True} list_key_func = str +# tqdm tqdm_mininterval = 1.0 + +# optimizer step zero_grad +optm_step_zero_grad_set_none = not contiguous_parameters diff --git a/cnfg/mulang.py b/cnfg/mulang.py new file mode 100644 index 0000000..1d7f1e1 --- /dev/null +++ b/cnfg/mulang.py @@ -0,0 +1,8 @@ +#encoding: utf-8 + +from cnfg.base import * + +task_weight_T = None +task_weight = None + +robt_beam_size = 1 diff --git a/loss/base.py b/loss/base.py index 460c89d..8407ece 100644 --- a/loss/base.py +++ b/loss/base.py @@ -1,8 +1,7 @@ #encoding: utf-8 import torch -from torch.nn.modules.loss import _Loss -from torch.nn.modules.loss import NLLLoss as NLLLossBase +from torch.nn.modules.loss import _Loss, NLLLoss as NLLLossBase from torch.nn.functional import kl_div, nll_loss @@ -27,10 +26,11 @@ def forward(self, input, target): smooth_loss = -input.sum(dim=-1, keepdim=True) if isinstance(self.ignore_index, (list, tuple)): pad_mask = eq_indexes(_target, self.ignore_index) - nll_loss.masked_fill_(pad_mask, 0.0) - smooth_loss.masked_fill_(pad_mask, 0.0) elif self.ignore_index >= 0: pad_mask = (_target == self.ignore_index) + else: + pad_mask = None + if pad_mask is not None: nll_loss.masked_fill_(pad_mask, 0.0) smooth_loss.masked_fill_(pad_mask, 0.0) if self.reduction != "none": diff --git a/loss/mulang.py b/loss/mulang.py new file mode 100644 index 0000000..ba3c72a --- /dev/null +++ b/loss/mulang.py @@ -0,0 +1,40 @@ +#encoding: utf-8 + +import torch +from torch.nn.functional import kl_div + +from utils.base import eq_indexes + +from loss.base import MultiLabelSmoothingLoss as MultiLabelSmoothingLossBase + +class MultiLabelSmoothingLoss(MultiLabelSmoothingLossBase): + + def __init__(self, *inputs, **kwargs): + + super(MultiLabelSmoothingLoss, self).__init__(*inputs, **kwargs) + self.register_buffer("weight", self.weight.squeeze(1)) + + def forward(self, input, target, tinput): + + _rsize = list(input.size()) + _nclass = _rsize[-1] + _mpvsize = [1 for i in range(len(_rsize))] + _mpvsize[0] = _rsize[0] + _mpvsize[-1] = _nclass + _rsize[0] = 1 + _rsize[-1] = 1 + + _input = input.view(-1, _nclass) if input.dim() > 2 else input + _target = target.view(-1, 1) + + model_prob = self.weight.index_select(0, tinput).view(_mpvsize).repeat(*_rsize).view(-1, _nclass) + model_prob.scatter_(1, _target, self.conf) + + if isinstance(self.ignore_index, (list, tuple)): + model_prob.masked_fill_(eq_indexes(_target, self.ignore_index), 0.0) + elif self.ignore_index >= 0: + model_prob.masked_fill_(_target == self.ignore_index, 0.0) + + rs = kl_div(_input, model_prob, reduction=self.reduction) + + return rs.view(input.size()) if self.reduction == 'none' and target.dim() > 1 else rs diff --git a/modules/TA.py b/modules/TA.py index 7860ff8..667df66 100644 --- a/modules/TA.py +++ b/modules/TA.py @@ -1,9 +1,49 @@ #encoding: utf-8 -from modules.base import PositionwiseFF as PositionwiseFFBase +from modules.base import ResSelfAttn as ResSelfAttnBase, ResCrossAttn as ResCrossAttnBase, PositionwiseFF as PositionwiseFFBase from cnfg.ihyp import * +class ResSelfAttn(ResSelfAttnBase): + + def forward(self, iQ, *inputs, **kwargs): + + outs = self.net(iQ, *inputs, **kwargs) + + if isinstance(outs, tuple): + _out = outs[0] + + if self.drop is not None: + _out = self.drop(_out) + + return self.normer(_out + iQ), *outs[1:] + + else: + if self.drop is not None: + outs = self.drop(outs) + + return self.normer(outs + iQ) + +class ResCrossAttn(ResCrossAttnBase): + + def forward(self, iQ, iK, *inputs, **kwargs): + + outs = self.net(iQ, iK, *inputs, **kwargs) + + if isinstance(outs, tuple): + _out = outs[0] + + if self.drop is not None: + _out = self.drop(_out) + + return self.normer(_out + iQ), *outs[1:] + + else: + if self.drop is not None: + outs = self.drop(outs) + + return self.normer(outs + iQ) + class PositionwiseFF(PositionwiseFFBase): # isize: input dimension diff --git a/modules/aan.py b/modules/aan.py new file mode 100644 index 0000000..b7f572b --- /dev/null +++ b/modules/aan.py @@ -0,0 +1,63 @@ +#encoding: utf-8 + +import torch +from torch import nn +from modules.base import Linear, Dropout, Custom_Act + +from cnfg.ihyp import * + +# Average Attention is proposed in Accelerating Neural Transformer via an Average Attention Network (https://www.aclweb.org/anthology/P18-1166/) +class AverageAttn(nn.Module): + + # isize: input size of Feed-forward NN + # hsize: hidden size of Feed-forward NN + # dropout: dropout rate for Feed-forward NN + # enable_ffn: using FFN to process the average bag-of-words representation + # num_pos: maximum length of sentence cached, extended length will be generated while needed and droped immediately after that + + def __init__(self, isize, hsize=None, dropout=0.0, enable_ffn=False, num_pos=cache_len_default, custom_act=use_adv_act_default, enable_bias=enable_prev_ln_bias_default, enable_proj_bias=enable_proj_bias_default): + + super(AverageAttn, self).__init__() + + _hsize = isize if hsize is None else hsize + + self.num_pos = num_pos + self.register_buffer('w', torch.Tensor(num_pos, 1)) + + if enable_ffn: + self.ffn = nn.Sequential(Linear(isize, _hsize, bias=enable_bias), nn.LayerNorm(_hsize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters), Custom_Act() if custom_act else nn.ReLU(inplace=True), Dropout(dropout, inplace=inplace_after_Custom_Act), Linear(_hsize, isize, bias=enable_proj_bias), Dropout(dropout, inplace=True)) if dropout > 0.0 else nn.Sequential(Linear(isize, _hsize, bias=enable_bias), nn.LayerNorm(_hsize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters), Custom_Act() if custom_act else nn.ReLU(inplace=True), Linear(_hsize, isize, bias=enable_proj_bias)) + else: + self.ffn = None + + self.gw = Linear(isize * 2, isize * 2) + + self.reset_parameters() + + # iQ: keys (bsize, seql, vsize) for training, (bsize, 1, vsize) for decoding + # iV: values (bsize, seql, vsize) + # decoding: training state or decoding state + + def forward(self, iQ, iV, decoding=False): + + if decoding: + avg = iV + else: + seql = iV.size(1) + + # avg: (bsize, seql, vsize) + avg = iV.cumsum(dim=1) * (self.get_ext(seql) if seql > self.num_pos else self.w.narrow(0, 0, seql)) + + if self.ffn is not None: + avg = self.ffn(avg) + + igate, fgate = self.gw(torch.cat((iQ, avg), -1)).sigmoid().chunk(2, -1) + + return igate * iQ + fgate * avg + + def reset_parameters(self): + + self.w = self.get_ext(self.num_pos) + + def get_ext(self, npos): + + return (torch.arange(1, npos + 1, dtype=self.w.dtype, device=self.w.device).reciprocal_()).unsqueeze(-1) diff --git a/modules/act.py b/modules/act.py index bf34992..8f803b5 100644 --- a/modules/act.py +++ b/modules/act.py @@ -46,7 +46,7 @@ def forward(self, x): # GELU is nonmonotonic function that has a shape similar to Swish with beta = 1.4 (https://arxiv.org/abs/1710.05941). class Swish(nn.Module): - def __init__(self, beta=1.0, freeze_beta=True, isize=None, dim=-1 if use_norm_Swish else None, eps=ieps_default): + def __init__(self, beta=1.0, freeze_beta=True, isize=None, dim=-1 if adv_act == "normswish" else None, eps=ieps_default): super(Swish, self).__init__() @@ -80,14 +80,7 @@ def forward(self, x): return x * nnFunc.softplus(x).tanh() -if custom_act_Swish: - Custom_Act = Swish -elif custom_act_Sigmoid: - Custom_Act = nn.Sigmoid -elif custom_act_Mish: - Custom_Act = Mish -else: - Custom_Act = GELU +Custom_Act = {"swish": Swish, "normswish": Swish, "sigmoid": nn.Sigmoid, "mish": Mish}.get(adv_act, GELU) # SparseMax (https://arxiv.org/pdf/1602.02068) borrowed form OpenNMT-py( https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/modules/sparse_activations.py) class SparsemaxFunction(Function): diff --git a/modules/attn/rap.py b/modules/attn/rap.py index 54075b7..fa20e2b 100644 --- a/modules/attn/rap.py +++ b/modules/attn/rap.py @@ -6,8 +6,7 @@ from torch import nn from torch.autograd import Function -from modules.base import CrossAttn as CrossAttnBase -from modules.base import SelfAttn as SelfAttnBase +from modules.base import CrossAttn as CrossAttnBase, SelfAttn as SelfAttnBase, ResSelfAttn as ResSelfAttnBase, ResCrossAttn as ResCrossAttnBase from cnfg.ihyp import * @@ -119,3 +118,19 @@ def load_base(self, base_module): self.normer = base_module.normer self.drop = base_module.drop + +class ResSelfAttn(ResSelfAttnBase): + + def __init__(self, isize, hsize, num_head=8, dropout=0.0, norm_residual=norm_residual_default, **kwargs): + + super(ResSelfAttn, self).__init__(isize, hsize, num_head=num_head, dropout=dropout, norm_residual=norm_residual, **kwargs) + + self.net = SelfAttn(isize, hsize, isize, num_head=num_head, dropout=dropout, **kwargs) + +class ResCrossAttn(ResCrossAttnBase): + + def __init__(self, isize, hsize, num_head=8, dropout=0.0, norm_residual=norm_residual_default, **kwargs): + + super(ResCrossAttn, self).__init__(isize, hsize, num_head=num_head, dropout=dropout, norm_residual=norm_residual, **kwargs) + + self.net = CrossAttn(isize, hsize, isize, num_head=num_head, dropout=dropout, **kwargs) diff --git a/modules/attn/res.py b/modules/attn/res.py index cb61ebd..c2be122 100644 --- a/modules/attn/res.py +++ b/modules/attn/res.py @@ -4,8 +4,7 @@ from torch.nn import functional as nnFunc from math import sqrt -from modules.base import SelfAttn as SelfAttnBase -from modules.base import CrossAttn as CrossAttnBase +from modules.base import SelfAttn as SelfAttnBase, CrossAttn as CrossAttnBase, ResSelfAttn as ResSelfAttnBase, ResCrossAttn as ResCrossAttnBase from cnfg.ihyp import * @@ -94,3 +93,19 @@ def forward(self, iQ, iK, mask=None, resin=None): scores = self.drop(scores) return self.outer(scores.matmul(real_iV).transpose(1, 2).contiguous().view(bsize, nquery, self.hsize)), resout + +class ResSelfAttn(ResSelfAttnBase): + + def __init__(self, isize, hsize, num_head=8, dropout=0.0, norm_residual=norm_residual_default, **kwargs): + + super(ResSelfAttn, self).__init__(isize, hsize, num_head=num_head, dropout=dropout, norm_residual=norm_residual, **kwargs) + + self.net = SelfAttn(isize, hsize, isize, num_head=num_head, dropout=dropout, **kwargs) + +class ResCrossAttn(ResCrossAttnBase): + + def __init__(self, isize, hsize, num_head=8, dropout=0.0, norm_residual=norm_residual_default, **kwargs): + + super(ResCrossAttn, self).__init__(isize, hsize, num_head=num_head, dropout=dropout, norm_residual=norm_residual, **kwargs) + + self.net = CrossAttn(isize, hsize, isize, num_head=num_head, dropout=dropout, **kwargs) diff --git a/modules/base.py b/modules/base.py index bc0067d..c9ac75f 100644 --- a/modules/base.py +++ b/modules/base.py @@ -5,12 +5,13 @@ from torch import nn from torch.nn import functional as nnFunc from torch.autograd import Function +from torch.utils.cpp_extension import load from utils.base import reduce_model_list, repeat_bsize_for_beam_tensor -from modules.act import Custom_Act -from modules.act import reduce_model as reduce_model_act -from modules.dropout import Dropout -from modules.dropout import reduce_model as reduce_model_drop +from modules.act import Custom_Act, reduce_model as reduce_model_act +from modules.dropout import Dropout, reduce_model as reduce_model_drop + +from utils.pyctorch import transfer_CNone_tuple from cnfg.ihyp import * @@ -33,6 +34,9 @@ def __init__(self, isize, hsize=None, dropout=0.0, norm_residual=norm_residual_d self.norm_residual = norm_residual + if self.c_available(): + self.c_init() + def forward(self, x): _out = self.normer(x) @@ -43,6 +47,56 @@ def forward(self, x): return out + def c_available(self): + + return use_c_backend_pff and (type(self) == PositionwiseFF) + + def c_init(self, bind=bind_c_forward): + + try: + import pff_cpp + except Exception as e: + pff_cpp = load(name="pff_cpp", sources=['modules/cpp/base/ffn/pff.cpp', 'modules/cpp/base/ffn/pff_func.cpp', 'modules/cpp/act/act_func.cpp']) + try: + import act_cpp + except Exception as e: + act_cpp = load(name="act_cpp", sources=['modules/cpp/act/act.cpp', 'modules/cpp/act/act_func.cpp']) + self.c_forward_func = pff_cpp.forward + self.c_act_func = act_cpp.get_func(adv_act if use_adv_act_default else "relu") + self.c_build_cache() + if bind: + PositionwiseFF.forward = PositionwiseFF.c_forward + + def c_forward(self, x): + + return self.c_forward_func(*self.c_build_inputs(x)) + + def c_build_cache(self): + + self.bargs = {"net.1.inplace": (self.net[1].inplace if hasattr(self.net[1], "inplace") else False), "norm_residual": self.norm_residual} + dargs = {"normer.eps": self.normer.eps} + if len(self.net) > 3: + self.bargs["net.2.inplace"] = self.net[2].inplace + self.bargs["net.4.inplace"] = self.net[4].inplace + dargs["net.2.p"] = self.net[2].p + else: + dargs["net.2.p"] = 0.0 + self.aargs = (self.c_act_func, dargs, self.normer.normalized_shape) + self.targs = dict(self.named_parameters()) + + def c_build_inputs(self, x): + + i_d = self.targs.copy() + i_d["x"] = x + if len(self.net) > 3: + bargs = self.bargs.copy() + bargs["net.2.training"] = self.net[2].training + bargs["net.4.training"] = self.net[4].training + else: + bargs = self.bargs + + return i_d, bargs, *self.aargs + class PositionalEmb(nn.Module): # num_dim: dimension of embedding @@ -177,6 +231,9 @@ def __init__(self, isize, hsize, osize, num_head=8, dropout=0.0, k_isize=None, v self.register_buffer('iK', None) self.register_buffer('iV', None) + if self.c_available(): + self.c_init() + # iQ: query (bsize, num_query, vsize) # iK: keys (bsize, seql, vsize) # iV: values (bsize, seql, vsize) @@ -220,7 +277,7 @@ def forward(self, iQ, iK, iV, mask=None, states=None): if self.rel_pemb is not None: self.rel_pos_cache = self.get_rel_pos(seql).narrow(0, seql - nquery, nquery).contiguous() if self.ref_rel_posm is None else self.ref_rel_posm.rel_pos_cache - scores += real_iQ.permute(2, 0, 1, 3).contiguous().view(nquery, bsize * nheads, adim).bmm(self.rel_pemb(self.get_rel_pos(seql).narrow(0, seql - nquery, nquery)).transpose(1, 2)).view(nquery, bsize, nheads, seql).permute(1, 2, 0, 3) + scores += real_iQ.permute(2, 0, 1, 3).contiguous().view(nquery, bsize * nheads, adim).bmm(self.rel_pemb(self.rel_pos_cache).transpose(1, 2)).view(nquery, bsize, nheads, seql).permute(1, 2, 0, 3) scores = scores / sqrt(adim) @@ -244,12 +301,72 @@ def forward(self, iQ, iK, iV, mask=None, states=None): def train(self, mode=True): super(MultiHeadAttn, self).train(mode) - if mode: self.reset_buffer() return self + def c_available(self): + + return use_c_backend_mhattn and (type(self) == MultiHeadAttn) and (type(self.normer) == nn.Softmax) + + def c_init(self, bind=bind_c_forward): + + try: + import attn_cpp + except Exception as e: + attn_cpp = load(name="attn_cpp", sources=["modules/cpp/base/attn/attn.cpp"]) + self.c_forward_func = attn_cpp.forward + self.c_build_cache() + if bind: + MultiHeadAttn.forward = MultiHeadAttn.c_forward + + def c_forward(self, iQ, iK, iV, mask=None, states=None): + + return self.c_process_output(self.c_forward_func(*self.c_build_inputs(iQ, iK, iV, mask=mask, states=states)), iK, iV, states=states) + + def c_build_cache(self): + + iargs = {"num_head": self.num_head, "attn_dim": self.attn_dim} + if self.rel_pemb is not None: + iargs.update({"rel_pemb.padding_idx": self.rel_pemb.padding_idx, "clamp_min": self.clamp_min, "clamp_max": self.clamp_max, "rel_shift": self.rel_shift}) + self.aargs = (iargs, 0.0 if self.drop is None else self.drop.p, inf_default,) + self.targs = dict(self.named_parameters()) + + def c_build_inputs(self, iQ, iK, iV, mask=None, states=None): + + i_d = self.targs.copy() + i_d.update({"iQ": iQ, "iK":iK, "iV": iV}) + if mask is not None: + i_d["mask"] = mask + _buf_d = dict(self.named_buffers()) + if "iK" in _buf_d: + _buf_d["buf_iK"] = _buf_d.pop("iK") + if "iV" in _buf_d: + _buf_d["buf_iV"] = _buf_d.pop("iV") + i_d.update(_buf_d) + if self.rel_pemb is not None: + if self.ref_rel_posm is not None: + i_d["rel_pos_cache"] = self.ref_rel_posm.rel_pos_cache + + return i_d, [] if states is None else transfer_CNone_tuple(states), *self.aargs, {"drop.inplace": False if self.drop is None else self.drop.inplace, "training": self.training, "drop.training": self.training if self.drop is None else self.drop.training} + + def c_process_output(self, rs, iK, iV, states=None): + + if self.rel_pemb is not None: + self.rel_pos_cache = rs["rel_pos_cache"] + + evaluation = not self.training + if (states is not None) or evaluation: + real_iK, real_iV = rs["real_iK"], rs["real_iV"] + if evaluation: + self.iK, self.real_iK, self.iV, self.real_iV = iK, real_iK, iV, real_iV + + if states is None: + return rs["out"] + else: + return rs["out"], (real_iK, real_iV,) + def get_rel_pos(self, length): if length <= self.xseql: @@ -276,62 +393,6 @@ def index_buffer(self, indices, dim=0): if self.real_iV is not None: self.real_iV = self.real_iV.index_select(dim, indices) -# Average Attention is proposed in Accelerating Neural Transformer via an Average Attention Network (https://www.aclweb.org/anthology/P18-1166/) -class AverageAttn(nn.Module): - - # isize: input size of Feed-forward NN - # hsize: hidden size of Feed-forward NN - # dropout: dropout rate for Feed-forward NN - # enable_ffn: using FFN to process the average bag-of-words representation - # num_pos: maximum length of sentence cached, extended length will be generated while needed and droped immediately after that - - def __init__(self, isize, hsize=None, dropout=0.0, enable_ffn=False, num_pos=cache_len_default, custom_act=use_adv_act_default): - - super(AverageAttn, self).__init__() - - _hsize = isize if hsize is None else hsize - - self.num_pos = num_pos - self.register_buffer('w', torch.Tensor(num_pos, 1)) - - if enable_ffn: - self.ffn = nn.Sequential(Linear(isize, _hsize), Custom_Act() if custom_act else nn.ReLU(inplace=True), Dropout(dropout, inplace=inplace_after_Custom_Act), Linear(_hsize, isize), Dropout(dropout, inplace=True)) if dropout > 0.0 else nn.Sequential(Linear(isize, _hsize), Custom_Act() if custom_act else nn.ReLU(inplace=True), Linear(_hsize, isize)) - else: - self.ffn = None - - self.gw = Linear(isize * 2, isize * 2) - - self.reset_parameters() - - # iQ: keys (bsize, seql, vsize) for training, (bsize, 1, vsize) for decoding - # iV: values (bsize, seql, vsize) - # decoding: training state or decoding state - - def forward(self, iQ, iV, decoding=False): - - if decoding: - avg = iV - else: - seql = iV.size(1) - - # avg: (bsize, seql, vsize) - avg = iV.cumsum(dim=1) * (self.get_ext(seql) if seql > self.num_pos else self.w.narrow(0, 0, seql)) - - if self.ffn is not None: - avg = self.ffn(avg) - - igate, fgate = self.gw(torch.cat((iQ, avg), -1)).sigmoid().chunk(2, -1) - - return igate * iQ + fgate * avg - - def reset_parameters(self): - - self.w = self.get_ext(self.num_pos) - - def get_ext(self, npos): - - return (torch.arange(1, npos + 1, dtype=self.w.dtype, device=self.w.device).reciprocal_()).unsqueeze(-1) - # Accelerated MultiHeadAttn for self attention, use when Q == K == V class SelfAttn(nn.Module): @@ -383,6 +444,9 @@ def __init__(self, isize, hsize, osize, num_head=8, dropout=0.0, enable_bias=ena else: self.rel_pemb = None + if self.c_available(): + self.c_init() + def forward(self, iQ, mask=None, states=None): bsize, nquery = iQ.size()[:2] @@ -427,6 +491,56 @@ def forward(self, iQ, mask=None, states=None): else: return out, (real_iK, real_iV,) + def c_available(self): + + return use_c_backend_selfattn and (type(self) == SelfAttn) and (type(self.normer) == nn.Softmax) + + def c_init(self, bind=bind_c_forward): + + try: + import self_attn_cpp + except Exception as e: + self_attn_cpp = load(name="self_attn_cpp", sources=["modules/cpp/base/attn/self/attn.cpp"]) + self.c_forward_func = self_attn_cpp.forward + self.c_build_cache() + if bind: + SelfAttn.forward = SelfAttn.c_forward + + def c_forward(self, iQ, mask=None, states=None): + + return self.c_process_output(self.c_forward_func(*self.c_build_inputs(iQ, mask=mask, states=states)), states=states) + + def c_build_cache(self): + + iargs = {"num_head": self.num_head, "attn_dim": self.attn_dim} + if self.rel_pemb is not None: + iargs.update({"rel_pemb.padding_idx": self.rel_pemb.padding_idx, "clamp_min": self.clamp_min, "clamp_max": self.clamp_max, "rel_shift": self.rel_shift}) + self.aargs = (iargs, 0.0 if self.drop is None else self.drop.p, inf_default,) + self.targs = dict(self.named_parameters()) + self.targs.update(self.named_buffers()) + + def c_build_inputs(self, iQ, mask=None, states=None): + + i_d = self.targs.copy() + i_d["iQ"] = iQ + if mask is not None: + i_d["mask"] = mask + if self.rel_pemb is not None: + if self.ref_rel_posm is not None: + i_d["rel_pos_cache"] = self.ref_rel_posm.rel_pos_cache + + return i_d, [] if states is None else transfer_CNone_tuple(states), *self.aargs, {"drop.inplace": False if self.drop is None else self.drop.inplace, "drop.training": self.training if self.drop is None else self.drop.training} + + def c_process_output(self, rs, states=None): + + if self.rel_pemb is not None: + self.rel_pos_cache = rs["rel_pos_cache"] + + if states is None: + return rs["out"] + else: + return rs["out"], (rs["real_iK"], rs["real_iV"],) + def get_rel_pos(self, length): if length <= self.xseql: @@ -465,6 +579,9 @@ def __init__(self, isize, hsize, osize, num_head=8, dropout=0.0, k_isize=None, e self.register_buffer('real_iV', None) self.register_buffer('iK', None) + if self.c_available(): + self.c_init() + def forward(self, iQ, iK, mask=None): bsize, nquery = iQ.size()[:2] @@ -493,6 +610,52 @@ def forward(self, iQ, iK, mask=None): return self.outer(scores.matmul(real_iV).transpose(1, 2).contiguous().view(bsize, nquery, self.hsize)) + def c_available(self): + + return use_c_backend_crossattn and (type(self) == CrossAttn) and (type(self.normer) == nn.Softmax) + + def c_init(self, bind=bind_c_forward): + + try: + import cross_attn_cpp + except Exception as e: + cross_attn_cpp = load(name="cross_attn_cpp", sources=["modules/cpp/base/attn/cross/attn.cpp"]) + self.c_forward_func = cross_attn_cpp.forward + self.c_build_cache() + if bind: + CrossAttn.forward = CrossAttn.c_forward + + def c_forward(self, iQ, iK, mask=None): + + return self.c_process_output(self.c_forward_func(*self.c_build_inputs(iQ, iK, mask=mask)), iK) + + def c_build_cache(self): + + self.aargs = ({"num_head": self.num_head, "attn_dim": self.attn_dim}, 0.0 if self.drop is None else self.drop.p, inf_default,) + self.targs = dict(self.named_parameters()) + + def c_build_inputs(self, iQ, iK, mask=None): + + i_d = self.targs.copy() + i_d["iQ"] = iQ + i_d["iK"] = iK + if mask is not None: + i_d["mask"] = mask + i_d.update(self.named_parameters()) + _buf_d = dict(self.named_buffers()) + if "iK" in _buf_d: + _buf_d["buf_iK"] = _buf_d.pop("iK") + i_d.update(_buf_d) + + return i_d, *self.aargs, {"drop.inplace": False if self.drop is None else self.drop.inplace, "training": self.training, "drop.training": self.training if self.drop is None else self.drop.training} + + def c_process_output(self, rs, iK): + + if not self.training: + self.iK, self.real_iK, self.real_iV = iK, rs["real_iK"], rs["real_iV"] + + return rs["out"] + def train(self, mode=True): super(CrossAttn, self).train(mode) @@ -516,6 +679,291 @@ def index_buffer(self, indices, dim=0): if self.real_iK is not None: self.real_iK, self.real_iV = self.real_iK.index_select(dim, indices), self.real_iV.index_select(dim, indices) +class ResMHAttn(nn.Module): + + def __init__(self, isize, hsize, num_head=8, dropout=0.0, norm_residual=norm_residual_default, **kwargs): + + super(ResMHAttn, self).__init__() + + self.net = MultiHeadAttn(isize, hsize, isize, num_head=num_head, dropout=dropout, **kwargs) + self.normer = nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) + self.drop = Dropout(dropout, inplace=True) if dropout > 0.0 else None + self.norm_residual = norm_residual + + if self.c_available(): + self.c_init() + + def forward(self, iQ, iK, iV, *inputs, **kwargs): + + _iQ = self.normer(iQ) + _iK = _iQ if iK.is_set_to(iQ) else iK + _iV = _iK if iV.is_set_to(iK) else iV + + outs = self.net(_iQ, _iK, _iV, *inputs, **kwargs) + + if isinstance(outs, tuple): + _out = outs[0] + + if self.drop is not None: + _out = self.drop(_out) + + return _out + (_iQ if self.norm_residual else iQ), *outs[1:] + + else: + if self.drop is not None: + outs = self.drop(outs) + + return outs + (_iQ if self.norm_residual else iQ) + + def load_base(self, base_module): + + self.normer, self.drop, self.norm_residual = base_module.normer, base_module.drop, base_module.norm_residual + if hasattr(self.net, "load_base"): + self.net.load_base(base_module.net) + else: + self.net = base_module.net + + def c_available(self): + + return use_c_backend_mhattn and (type(self) == ResMHAttn) and self.net.c_available() + + def c_init(self, bind=bind_c_forward): + + try: + import res_attn_cpp + except Exception as e: + res_attn_cpp = load(name="res_attn_cpp", sources=["modules/cpp/base/resattn/attn.cpp"]) + self.c_forward_func = res_attn_cpp.forward + self.c_build_cache() + if bind: + ResMHAttn.forward = ResMHAttn.c_forward + + def c_forward(self, iQ, iK, iV, mask=None, states=None): + + return self.c_process_output(self.c_forward_func(*self.c_build_inputs(iQ, iK, iV, mask=mask, states=states)), iK, iV, states=states) + + def c_build_cache(self): + + iargs = {"net.num_head": self.net.num_head, "net.attn_dim": self.net.attn_dim} + if self.net.rel_pemb is not None: + iargs.update({"net.rel_pemb.padding_idx": self.net.rel_pemb.padding_idx, "net.clamp_min": self.net.clamp_min, "net.clamp_max": self.net.clamp_max, "net.rel_shift": self.net.rel_shift}) + self.aargs = (iargs, {"normer.eps": self.normer.eps, "inf_value": inf_default, "net.drop.p": 0.0 if self.net.drop is None else self.net.drop.p, "drop.p": 0.0 if self.drop is None else self.drop.p}, self.normer.normalized_shape,) + self.bargs = {"net.drop.inplace": False if self.net.drop is None else self.net.drop.inplace, "drop.inplace": False if self.drop is None else self.drop.inplace, "norm_residual": self.norm_residual} + self.targs = dict(self.named_parameters()) + + def c_build_inputs(self, iQ, iK, iV, mask=None, states=None): + + i_d = self.targs.copy() + i_d.update({"iQ": iQ, "iK":iK, "iV": iV}) + if mask is not None: + i_d["mask"] = mask + i_d.update(self.named_buffers()) + if self.net.rel_pemb is not None: + if self.net.ref_rel_posm is not None: + i_d["net.rel_pos_cache"] = self.net.ref_rel_posm.rel_pos_cache + bargs = self.bargs.copy() + bargs.update({"net.training": self.net.training, "net.drop.training": self.net.training if self.net.drop is None else self.net.drop.training, "drop.training": self.training if self.drop is None else self.drop.training}) + + return i_d, [] if states is None else transfer_CNone_tuple(states), *self.aargs, bargs + + def c_process_output(self, rs, iK, iV, states=None): + + if self.net.rel_pemb is not None: + self.net.rel_pos_cache = rs["net.rel_pos_cache"] + + evaluation = not self.net.training + if (states is not None) or evaluation: + real_iK, real_iV = rs["net.real_iK"], rs["net.real_iV"] + if evaluation: + self.net.iK, self.net.real_iK, self.net.iV, self.net.real_iV = iK, real_iK, iV, real_iV + + if states is None: + return rs["out"] + else: + return rs["out"], (real_iK, real_iV,) + +class ResSelfAttn(nn.Module): + + def __init__(self, isize, hsize, num_head=8, dropout=0.0, norm_residual=norm_residual_default, **kwargs): + + super(ResSelfAttn, self).__init__() + + self.net = SelfAttn(isize, hsize, isize, num_head=num_head, dropout=dropout, **kwargs) + self.normer = nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) + self.drop = Dropout(dropout, inplace=True) if dropout > 0.0 else None + self.norm_residual = norm_residual + + if self.c_available(): + self.c_init() + + def forward(self, iQ, *inputs, **kwargs): + + _iQ = self.normer(iQ) + + outs = self.net(_iQ, *inputs, **kwargs) + + if isinstance(outs, tuple): + _out = outs[0] + + if self.drop is not None: + _out = self.drop(_out) + + return _out + (_iQ if self.norm_residual else iQ), *outs[1:] + + else: + if self.drop is not None: + outs = self.drop(outs) + + return outs + (_iQ if self.norm_residual else iQ) + + def load_base(self, base_module): + + self.normer, self.drop, self.norm_residual = base_module.normer, base_module.drop, base_module.norm_residual + if hasattr(self.net, "load_base"): + self.net.load_base(base_module.net) + else: + self.net = base_module.net + + def c_available(self): + + return use_c_backend_selfattn and (type(self) == ResSelfAttn) and self.net.c_available() + + def c_init(self, bind=bind_c_forward): + + try: + import res_self_attn_cpp + except Exception as e: + res_self_attn_cpp = load(name="res_self_attn_cpp", sources=["modules/cpp/base/resattn/self/attn.cpp"]) + self.c_forward_func = res_self_attn_cpp.forward + self.c_build_cache() + if bind: + ResSelfAttn.forward = ResSelfAttn.c_forward + + def c_forward(self, iQ, mask=None, states=None): + + return self.c_process_output(self.c_forward_func(*self.c_build_inputs(iQ, mask=mask, states=states)), states=states) + + def c_build_cache(self): + + iargs = {"net.num_head": self.net.num_head, "net.attn_dim": self.net.attn_dim} + if self.net.rel_pemb is not None: + iargs.update({"net.rel_pemb.padding_idx": self.net.rel_pemb.padding_idx, "net.clamp_min": self.net.clamp_min, "net.clamp_max": self.net.clamp_max, "net.rel_shift": self.net.rel_shift}) + self.aargs = (iargs, {"normer.eps": self.normer.eps, "inf_value": inf_default, "net.drop.p": 0.0 if self.net.drop is None else self.net.drop.p, "drop.p": 0.0 if self.drop is None else self.drop.p}, self.normer.normalized_shape,) + self.bargs = {"net.drop.inplace": False if self.net.drop is None else self.net.drop.inplace, "drop.inplace": False if self.drop is None else self.drop.inplace, "norm_residual": self.norm_residual} + self.targs = dict(self.named_parameters()) + + def c_build_inputs(self, iQ, mask=None, states=None): + + i_d = self.targs.copy() + i_d["iQ"] = iQ + if mask is not None: + i_d["mask"] = mask + if self.net.rel_pemb is not None: + if self.net.ref_rel_posm is not None: + i_d["net.rel_pos_cache"] = self.net.ref_rel_posm.rel_pos_cache + bargs = self.bargs.copy() + bargs.update({"net.training": self.net.training, "net.drop.training": self.net.training if self.net.drop is None else self.net.drop.training, "drop.training": self.training if self.drop is None else self.drop.training}) + + return i_d, [] if states is None else transfer_CNone_tuple(states), *self.aargs, bargs + + def c_process_output(self, rs, states=None): + + if self.net.rel_pemb is not None: + self.net.rel_pos_cache = rs["net.rel_pos_cache"] + + if states is None: + return rs["out"] + else: + return rs["out"], (rs["net.real_iK"], rs["net.real_iV"],) + +class ResCrossAttn(nn.Module): + + def __init__(self, isize, hsize, num_head=8, dropout=0.0, norm_residual=norm_residual_default, **kwargs): + + super(ResCrossAttn, self).__init__() + + self.net = CrossAttn(isize, hsize, isize, num_head=num_head, dropout=dropout, **kwargs) + self.normer = nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) + self.drop = Dropout(dropout, inplace=True) if dropout > 0.0 else None + self.norm_residual = norm_residual + + if self.c_available(): + self.c_init() + + def forward(self, iQ, iK, *inputs, **kwargs): + + _iQ = self.normer(iQ) + + outs = self.net(_iQ, iK, *inputs, **kwargs) + + if isinstance(outs, tuple): + _out = outs[0] + + if self.drop is not None: + _out = self.drop(_out) + + return _out + (_iQ if self.norm_residual else iQ), *outs[1:] + + else: + if self.drop is not None: + outs = self.drop(outs) + + return outs + (_iQ if self.norm_residual else iQ) + + def load_base(self, base_module): + + self.normer, self.drop, self.norm_residual = base_module.normer, base_module.drop, base_module.norm_residual + if hasattr(self.net, "load_base"): + self.net.load_base(base_module.net) + else: + self.net = base_module.net + + def c_available(self): + + return use_c_backend_crossattn and (type(self) == ResCrossAttn) and self.net.c_available() + + def c_init(self, bind=bind_c_forward): + + try: + import res_cross_attn_cpp + except Exception as e: + res_cross_attn_cpp = load(name="res_cross_attn_cpp", sources=["modules/cpp/base/resattn/cross/attn.cpp"]) + self.c_forward_func = res_cross_attn_cpp.forward + self.c_build_cache() + if bind: + ResCrossAttn.forward = ResCrossAttn.c_forward + + def c_forward(self, iQ, iK, mask=None): + + return self.c_process_output(self.c_forward_func(*self.c_build_inputs(iQ, iK, mask=mask)), iK) + + def c_build_cache(self): + + iargs = {"net.num_head": self.net.num_head, "net.attn_dim": self.net.attn_dim} + self.aargs = (iargs, {"normer.eps": self.normer.eps, "inf_value": inf_default, "net.drop.p": 0.0 if self.net.drop is None else self.net.drop.p, "drop.p": 0.0 if self.drop is None else self.drop.p}, self.normer.normalized_shape,) + self.bargs = {"net.drop.inplace": False if self.net.drop is None else self.net.drop.inplace, "drop.inplace": False if self.drop is None else self.drop.inplace, "norm_residual": self.norm_residual} + self.targs = dict(self.named_parameters()) + + def c_build_inputs(self, iQ, iK, mask=None): + + i_d = self.targs.copy() + i_d["iQ"] = iQ + i_d["iK"] = iK + if mask is not None: + i_d["mask"] = mask + i_d.update(self.named_buffers()) + bargs = self.bargs.copy() + bargs.update({"net.training": self.net.training, "net.drop.training": self.net.training if self.net.drop is None else self.net.drop.training, "drop.training": self.training if self.drop is None else self.drop.training}) + + return i_d, *self.aargs, bargs + + def c_process_output(self, rs, iK): + + if not self.net.training: + self.net.iK, self.net.real_iK, self.net.real_iV = iK, rs["net.real_iK"], rs["net.real_iV"] + + return rs["out"] + # Aggregation from: Exploiting Deep Representations for Neural Machine Translation class ResidueCombiner(nn.Module): @@ -597,6 +1045,8 @@ def backward(ctx, grad_outputs): else: return None, None +GradientReversalFunc = GradientReversalFunction.apply + class GradientReversalLayer(nn.Module): def __init__(self, adv_weight=1.0): @@ -607,7 +1057,7 @@ def __init__(self, adv_weight=1.0): def forward(self, *inputs): - return (GradientReversalFunction.apply(inputu, self.adv_weight) for inputu in inputs) if len(inputs) > 1 else GradientReversalFunction.apply(inputs[0], self.adv_weight) + return tuple(GradientReversalFunc(inputu, self.adv_weight) for inputu in inputs) if len(inputs) > 1 else GradientReversalFunc(inputs[0], self.adv_weight) class ACTLossFunction(Function): diff --git a/modules/cpp/act/act.cpp b/modules/cpp/act/act.cpp new file mode 100644 index 0000000..0a7e1bd --- /dev/null +++ b/modules/cpp/act/act.cpp @@ -0,0 +1,7 @@ +#include +#include "act_func.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("get_func", &get_func, "Get activation function"); + m.def("forward", &act_forward, "Activation function forward"); +} diff --git a/modules/cpp/act/act_func.cpp b/modules/cpp/act/act_func.cpp new file mode 100644 index 0000000..fa18aed --- /dev/null +++ b/modules/cpp/act/act_func.cpp @@ -0,0 +1,29 @@ +#include +#include +#include "base.h" + +inline Act_Func get_func_core(std::string func_name) { + if (func_name == "gelu") { + return gelu_forward; + } + else if (func_name == "swish") { + return swish_forward; + } + else if (func_name == "sigmoid") { + return sigmoid_forward; + } + else if (func_name == "mish") { + return mish_forward; + } + else { + return relu_forward; + } +} + +const void* get_func(std::string func_name) { + return (void*)get_func_core(func_name); +} + +at::Tensor act_forward(void* act, torch::Tensor input, bool inplace=false) { + return (*(Act_Func)act)(input, inplace=inplace); +} diff --git a/modules/cpp/act/act_func.h b/modules/cpp/act/act_func.h new file mode 100644 index 0000000..da4a1d6 --- /dev/null +++ b/modules/cpp/act/act_func.h @@ -0,0 +1,12 @@ +#ifndef _NEUTRON_MODULES_CPP_ACT_ACT_FUNC +#define _NEUTRON_MODULES_CPP_ACT_FUNC + +#include +#include +#include "base.h" + +const void* get_func(std::string func_name); + +at::Tensor act_forward(void* act, torch::Tensor input, bool inplace=false); + +#endif diff --git a/modules/cpp/act/base.h b/modules/cpp/act/base.h new file mode 100644 index 0000000..c96bf1c --- /dev/null +++ b/modules/cpp/act/base.h @@ -0,0 +1,45 @@ +#ifndef _NEUTRON_MODULES_CPP_ACT_BASE +#define _NEUTRON_MODULES_CPP_ACT_BASE + +#include +#define _USE_MATH_DEFINES +#include + +typedef at::Tensor (*Act_Func) (torch::Tensor, bool inplace); + +inline at::Tensor relu_forward(torch::Tensor x, bool inplace=false) { + + return torch::nn::functional::relu(x, torch::nn::functional::ReLUFuncOptions().inplace(inplace)); +} + +inline at::Tensor gelu_gpt_forward(torch::Tensor x, bool inplace=false) { + + return x * 0.5 * (1.0 + (sqrt(2.0 / M_PI) * (x + 0.044715 * x.pow(3.0))).tanh()); +} + +inline at::Tensor gelu_bert_forward(torch::Tensor x, bool inplace=false) { + + return x * 0.5 * (1.0 + (x / sqrt(2.0)).erf()); +} + +inline at::Tensor gelu_forward(torch::Tensor x, bool inplace=false) { + + return torch::nn::functional::gelu(x); +} + +inline at::Tensor sigmoid_forward(torch::Tensor x, bool inplace=false) { + + return x.sigmoid(); +} + +inline at::Tensor swish_forward(torch::Tensor x, bool inplace=false) { + + return x.sigmoid() * x; +} + +inline at::Tensor mish_forward(torch::Tensor x, bool inplace=false) { + + return x * torch::nn::functional::softplus(x).tanh(); +} + +#endif diff --git a/modules/cpp/act/setup.py b/modules/cpp/act/setup.py new file mode 100644 index 0000000..d633573 --- /dev/null +++ b/modules/cpp/act/setup.py @@ -0,0 +1,6 @@ +#encoding: utf-8 + +from setuptools import setup, Extension +from torch.utils import cpp_extension + +setup(name='act_cpp', ext_modules=[cpp_extension.CppExtension('act_cpp', ['modules/cpp/act/act.cpp', 'modules/cpp/act/act_func.cpp'])], cmdclass={'build_ext': cpp_extension.BuildExtension}) diff --git a/modules/cpp/base/attn/attn.cpp b/modules/cpp/base/attn/attn.cpp new file mode 100644 index 0000000..02a8c7c --- /dev/null +++ b/modules/cpp/base/attn/attn.cpp @@ -0,0 +1,6 @@ +#include +#include "attn_func.cpp" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &_NEUTRON_MODULES_BASE_ATTN_FUNC_NAME, "Multi-head attention forward"); +} diff --git a/modules/cpp/base/attn/attn_func.cpp b/modules/cpp/base/attn/attn_func.cpp new file mode 100644 index 0000000..f8eab25 --- /dev/null +++ b/modules/cpp/base/attn/attn_func.cpp @@ -0,0 +1,98 @@ +#define _NEUTRON_MODULES_BASE_ATTN_BUILD_MHATTN +#include "common.cpp" +#undef _NEUTRON_MODULES_BASE_ATTN_BUILD_MHATTN + +/*#include +#include +#include +#include +#include +#include "../../../../utils/cpp/base.h" +#include "base.h" + +std::map attn_forward(std::map tensors, std::vector states, std::map iargs, double p, double inf_value, std::map bargs) { + + auto iQ = tensors["iQ"]; + auto bsize = iQ.size(0); + auto nquery = iQ.size(1); + auto iK = tensors["iK"]; + auto seql = iK.size(1); + auto iV = tensors["iV"]; + auto nheads = iargs["num_head"]; + auto adim = iargs["attn_dim"]; + torch::Tensor rel_pos_cache; + + auto real_iQ = torch::nn::functional::linear(iQ, tensors["query_adaptor.weight"], map_get(tensors, "query_adaptor.bias")).view({bsize, nquery, nheads, adim}).transpose(1, 2); + at::Tensor real_iK, real_iV; + + auto evaluation = not bargs["training"]; + auto buf_real_iK = map_get(tensors, "real_iK"); + auto buf_iK = map_get(tensors, "buf_iK"); + if (ct_is_not_none(buf_real_iK) and iK.is_set_to(buf_iK) and evaluation) { + real_iK = buf_real_iK; + } + else { + real_iK = torch::nn::functional::linear(iK, tensors["key_adaptor.weight"], map_get(tensors, "key_adaptor.bias")).view({bsize, seql, nheads, adim}).permute({0, 2, 3, 1}); + } + + auto buf_real_iV = map_get(tensors, "real_iV"); + auto buf_iV = map_get(tensors, "buf_iV"); + if (ct_is_not_none(buf_real_iV) and iV.is_set_to(buf_iV) and evaluation) { + real_iV = buf_real_iV; + } + else { + real_iV = torch::nn::functional::linear(iV, tensors["value_adaptor.weight"], map_get(tensors, "value_adaptor.bias")).view({bsize, seql, nheads, adim}).transpose(1, 2); + } + + bool not_non_states = states.size() > 0; + if (not_non_states) { + auto _h_real_iK = states[0]; + auto _h_real_iV = states[1]; + if (pyt_is_not_none(_h_real_iK)) { + seql += _h_real_iK.size(-1); + real_iK = at::cat({_h_real_iK, real_iK}, -1); + real_iV = at::cat({_h_real_iV, real_iV}, 2); + } + } + + auto scores = real_iQ.matmul(real_iK); + + auto rel_pemb_weight = map_get(tensors, "rel_pemb.weight"); + bool not_none_rel = ct_is_not_none(rel_pemb_weight); + if (not_none_rel) { + auto emb_option = torch::nn::functional::EmbeddingFuncOptions(); + auto padding_idx = map_get(iargs, "rel_pemb.padding_idx", -1); + if (padding_idx >= 0) { + emb_option = emb_option.padding_idx(padding_idx); + } + rel_pos_cache = map_get(tensors, "rel_pos_cache"); + if (ct_is_none(rel_pos_cache)) { + rel_pos_cache = get_rel_pos(seql, iargs, tensors["rel_pos"]).narrow(0, seql - nquery, nquery).contiguous(); + } + scores += real_iQ.permute({2, 0, 1, 3}).contiguous().view({nquery, bsize * nheads, adim}).bmm(torch::nn::functional::embedding(rel_pos_cache, rel_pemb_weight, emb_option).transpose(1, 2)).view({nquery, bsize, nheads, seql}).permute({1, 2, 0, 3}); + } + + scores = scores / sqrt(adim); + + auto mask = map_get(tensors, "mask"); + if (ct_is_not_none(mask)) { + scores.masked_fill_(mask.unsqueeze(1), -inf_value); + } + scores = scores.softmax(-1); + if (p > 0.0) { + scores = torch::nn::functional::dropout(scores, torch::nn::functional::DropoutFuncOptions().p(p).inplace(bargs["drop.inplace"]).training(bargs["drop.training"])); + } + + auto out = torch::nn::functional::linear(scores.matmul(real_iV).transpose(1, 2).contiguous().view({bsize, nquery, nheads * adim}), tensors["outer.weight"], map_get(tensors, "outer.bias")); + + std::map rs; + rs["out"] = out; + if ((not_non_states) or evaluation) { + rs["real_iK"] = real_iK; + rs["real_iV"] = real_iV; + } + if (not_none_rel) { + rs["rel_pos_cache"] = rel_pos_cache; + } + return rs; +}*/ diff --git a/modules/cpp/base/attn/base.h b/modules/cpp/base/attn/base.h new file mode 100644 index 0000000..797dfaa --- /dev/null +++ b/modules/cpp/base/attn/base.h @@ -0,0 +1,25 @@ +#ifndef _NEUTRON_MODULES_CPP_ATTN_BASE +#define _NEUTRON_MODULES_CPP_ATTN_BASE + +#include +#include +#include +#include + +inline torch::Tensor get_rel_pos(int64_t length, std::map iargs, torch::Tensor rel_pos) { + + auto xseql = rel_pos.size(0); + if (length <= xseql) { + return rel_pos.narrow(0, 0, length).narrow(1, 0, length); + } + else { + auto _rpm = torch::arange(-length + 1, 1, rel_pos.options()).unsqueeze(0); + #ifdef _NEUTRON_MODULES_BASE_ATTN_BUILD_RES + return ((_rpm - _rpm.t()).clamp(iargs["net.clamp_min"], iargs["net.clamp_max"]) + iargs["net.rel_shift"]); + #else + return ((_rpm - _rpm.t()).clamp(iargs["clamp_min"], iargs["clamp_max"]) + iargs["rel_shift"]); + #endif + } +} + +#endif diff --git a/modules/cpp/base/attn/common.cpp b/modules/cpp/base/attn/common.cpp new file mode 100644 index 0000000..1162746 --- /dev/null +++ b/modules/cpp/base/attn/common.cpp @@ -0,0 +1,177 @@ +#include +#include +#include +#include +#include +#include "../../../../utils/cpp/base.h" + +#if !(defined(_NEUTRON_MODULES_BASE_ATTN_BUILD_SATTN)||defined(_NEUTRON_MODULES_BASE_ATTN_BUILD_CATTN)||defined(_NEUTRON_MODULES_BASE_ATTN_BUILD_MHATTN)) +#define _NEUTRON_MODULES_BASE_ATTN_BUILD_MHATTN +#endif + +#ifdef _NEUTRON_MODULES_BASE_ATTN_BUILD_SATTN +#define _NEUTRON_MODULES_BASE_ATTN_FUNC_NAME self_attn_forward +#elif defined(_NEUTRON_MODULES_BASE_ATTN_BUILD_SATTN) +#define _NEUTRON_MODULES_BASE_ATTN_FUNC_NAME cross_attn_forward +#else +#define _NEUTRON_MODULES_BASE_ATTN_FUNC_NAME multi_head_attn_forward +#endif + +#ifndef _NEUTRON_MODULES_BASE_ATTN_BUILD_CATTN +#include "base.h" +std::map _NEUTRON_MODULES_BASE_ATTN_FUNC_NAME(std::map tensors, std::vector states, std::map iargs, double p, double inf_value, std::map bargs) { +#else +std::map _NEUTRON_MODULES_BASE_ATTN_FUNC_NAME(std::map tensors, std::map iargs, double p, double inf_value, std::map bargs) { +#endif + + auto iQ = tensors["iQ"]; + auto bsize = iQ.size(0); + auto nquery = iQ.size(1); + #ifndef _NEUTRON_MODULES_BASE_ATTN_BUILD_SATTN + auto iK = tensors["iK"]; + auto seql = iK.size(1); + #endif + #ifdef _NEUTRON_MODULES_BASE_ATTN_BUILD_MHATTN + auto iV = tensors["iV"]; + #endif + auto nheads = iargs["num_head"]; + auto adim = iargs["attn_dim"]; + #ifdef _NEUTRON_MODULES_BASE_ATTN_BUILD_SATTN + int64_t seql; + #endif + #ifndef _NEUTRON_MODULES_BASE_ATTN_BUILD_CATTN + torch::Tensor rel_pos_cache; + #endif + + #ifndef _NEUTRON_MODULES_BASE_ATTN_BUILD_SATTN + auto real_iQ = torch::nn::functional::linear(iQ, tensors["query_adaptor.weight"], map_get(tensors, "query_adaptor.bias")).view({bsize, nquery, nheads, adim}).transpose(1, 2); + at::Tensor real_iK, real_iV; + + auto evaluation = not bargs["training"]; + auto buf_real_iK = map_get(tensors, "real_iK"); + auto buf_iK = map_get(tensors, "buf_iK"); + if (ct_is_not_none(buf_real_iK) and iK.is_set_to(buf_iK) and evaluation) { + real_iK = buf_real_iK; + #ifdef _NEUTRON_MODULES_BASE_ATTN_BUILD_CATTN + real_iV = tensors["real_iV"]; + #endif + } + else { + #ifdef _NEUTRON_MODULES_BASE_ATTN_BUILD_CATTN + auto _reals = torch::nn::functional::linear(iK, tensors["kv_adaptor.weight"], map_get(tensors, "kv_adaptor.bias")).view({bsize, seql, 2, nheads, adim}).unbind(2); + real_iK = _reals[0].permute({0, 2, 3, 1}); + real_iV = _reals[1].transpose(1, 2); + #else + real_iK = torch::nn::functional::linear(iK, tensors["key_adaptor.weight"], map_get(tensors, "key_adaptor.bias")).view({bsize, seql, nheads, adim}).permute({0, 2, 3, 1}); + #endif + } + #else + auto _reals = torch::nn::functional::linear(iQ, tensors["adaptor.weight"], map_get(tensors, "adaptor.bias")).view({bsize, nquery, 3, nheads, adim}).unbind(2); + auto real_iQ = _reals[0].transpose(1, 2); + auto real_iK = _reals[1].permute({0, 2, 3, 1}); + auto real_iV = _reals[2].transpose(1, 2); + #endif + + #ifdef _NEUTRON_MODULES_BASE_ATTN_BUILD_MHATTN + auto buf_real_iV = map_get(tensors, "real_iV"); + auto buf_iV = map_get(tensors, "buf_iV"); + if (ct_is_not_none(buf_real_iV) and iV.is_set_to(buf_iV) and evaluation) { + real_iV = buf_real_iV; + } + else { + real_iV = torch::nn::functional::linear(iV, tensors["value_adaptor.weight"], map_get(tensors, "value_adaptor.bias")).view({bsize, seql, nheads, adim}).transpose(1, 2); + } + #endif + + #ifndef _NEUTRON_MODULES_BASE_ATTN_BUILD_CATTN + bool not_non_states = states.size() > 0; + if (not_non_states) { + auto _h_real_iK = states[0]; + auto _h_real_iV = states[1]; + if (pyt_is_not_none(_h_real_iK)) { + #ifdef _NEUTRON_MODULES_BASE_ATTN_BUILD_SATTN + seql = nquery + _h_real_iK.size(-1); + #else + seql += _h_real_iK.size(-1); + #endif + real_iK = at::cat({_h_real_iK, real_iK}, -1); + real_iV = at::cat({_h_real_iV, real_iV}, 2); + } + #ifdef _NEUTRON_MODULES_BASE_ATTN_BUILD_SATTN + else { + seql = nquery; + } + #endif + } + #endif + + #ifdef _NEUTRON_MODULES_BASE_ATTN_BUILD_CATTN + auto scores = real_iQ.matmul(real_iK) / sqrt(adim); + #else + auto scores = real_iQ.matmul(real_iK); + #endif + + #ifndef _NEUTRON_MODULES_BASE_ATTN_BUILD_CATTN + auto rel_pemb_weight = map_get(tensors, "rel_pemb.weight"); + bool not_none_rel = ct_is_not_none(rel_pemb_weight); + if (not_none_rel) { + auto emb_option = torch::nn::functional::EmbeddingFuncOptions(); + auto padding_idx = map_get(iargs, "rel_pemb.padding_idx", -1); + if (padding_idx >= 0) { + emb_option = emb_option.padding_idx(padding_idx); + } + rel_pos_cache = map_get(tensors, "rel_pos_cache"); + #ifdef _NEUTRON_MODULES_BASE_ATTN_BUILD_SATTN + if (not_non_states) { + if (ct_is_none(rel_pos_cache)) { + rel_pos_cache = get_rel_pos(seql, iargs, tensors["rel_pos"]).narrow(0, seql - nquery, nquery).contiguous(); + } + scores += real_iQ.permute({2, 0, 1, 3}).contiguous().view({nquery, bsize * nheads, adim}).bmm(torch::nn::functional::embedding(rel_pos_cache, rel_pemb_weight, emb_option).transpose(1, 2)).view({nquery, bsize, nheads, seql}).permute({1, 2, 0, 3}); + } + else { + if (ct_is_none(rel_pos_cache)) { + rel_pos_cache = get_rel_pos(nquery, iargs, tensors["rel_pos"]).contiguous(); + } + scores += real_iQ.permute({2, 0, 1, 3}).contiguous().view({nquery, bsize * nheads, adim}).bmm(torch::nn::functional::embedding(rel_pos_cache, rel_pemb_weight, emb_option).transpose(1, 2)).view({nquery, bsize, nheads, nquery}).permute({1, 2, 0, 3}); + } + #else + if (ct_is_none(rel_pos_cache)) { + rel_pos_cache = get_rel_pos(seql, iargs, tensors["rel_pos"]).narrow(0, seql - nquery, nquery).contiguous(); + } + scores += real_iQ.permute({2, 0, 1, 3}).contiguous().view({nquery, bsize * nheads, adim}).bmm(torch::nn::functional::embedding(rel_pos_cache, rel_pemb_weight, emb_option).transpose(1, 2)).view({nquery, bsize, nheads, seql}).permute({1, 2, 0, 3}); + #endif + } + + scores = scores / sqrt(adim); + #endif + + auto mask = map_get(tensors, "mask"); + if (ct_is_not_none(mask)) { + scores.masked_fill_(mask.unsqueeze(1), -inf_value); + } + scores = scores.softmax(-1); + if (p > 0.0) { + scores = torch::nn::functional::dropout(scores, torch::nn::functional::DropoutFuncOptions().p(p).inplace(bargs["drop.inplace"]).training(bargs["drop.training"])); + } + + auto out = torch::nn::functional::linear(scores.matmul(real_iV).transpose(1, 2).contiguous().view({bsize, nquery, nheads * adim}), tensors["outer.weight"], map_get(tensors, "outer.bias")); + + std::map rs; + rs["out"] = out; + #ifdef _NEUTRON_MODULES_BASE_ATTN_BUILD_SATTN + if (not_non_states) { + #elif defined(_NEUTRON_MODULES_BASE_ATTN_BUILD_CATTN) + if (evaluation) { + #else + if ((not_non_states) or evaluation) { + #endif + rs["real_iK"] = real_iK; + rs["real_iV"] = real_iV; + } + #ifndef _NEUTRON_MODULES_BASE_ATTN_BUILD_CATTN + if (not_none_rel) { + rs["rel_pos_cache"] = rel_pos_cache; + } + #endif + return rs; +} diff --git a/modules/cpp/base/attn/cross/attn.cpp b/modules/cpp/base/attn/cross/attn.cpp new file mode 100644 index 0000000..bbb1360 --- /dev/null +++ b/modules/cpp/base/attn/cross/attn.cpp @@ -0,0 +1,6 @@ +#include +#include "attn_func.cpp" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &_NEUTRON_MODULES_BASE_ATTN_FUNC_NAME, "Cross attention forward"); +} diff --git a/modules/cpp/base/attn/cross/attn_func.cpp b/modules/cpp/base/attn/cross/attn_func.cpp new file mode 100644 index 0000000..16003a8 --- /dev/null +++ b/modules/cpp/base/attn/cross/attn_func.cpp @@ -0,0 +1,58 @@ +#define _NEUTRON_MODULES_BASE_ATTN_BUILD_CATTN +#include "../common.cpp" +#undef _NEUTRON_MODULES_BASE_ATTN_BUILD_CATTN + +/*#include +#include +#include +#include +#include +#include "../../../../../utils/cpp/base.h" + +std::map attn_forward(std::map tensors, std::map iargs, double p, double inf_value, std::map bargs) { + + auto iQ = tensors["iQ"]; + auto bsize = iQ.size(0); + auto nquery = iQ.size(1); + auto iK = tensors["iK"]; + auto seql = iK.size(1); + auto nheads = iargs["num_head"]; + auto adim = iargs["attn_dim"]; + + auto real_iQ = torch::nn::functional::linear(iQ, tensors["query_adaptor.weight"], map_get(tensors, "query_adaptor.bias")).view({bsize, nquery, nheads, adim}).transpose(1, 2); + at::Tensor real_iK, real_iV; + + auto evaluation = not bargs["training"]; + auto buf_real_iK = map_get(tensors, "real_iK"); + auto buf_iK = map_get(tensors, "buf_iK"); + if (ct_is_not_none(buf_real_iK) and iK.is_set_to(buf_iK) and evaluation) { + real_iK = buf_real_iK; + real_iV = tensors["real_iV"]; + } + else { + auto _reals = torch::nn::functional::linear(iK, tensors["kv_adaptor.weight"], map_get(tensors, "kv_adaptor.bias")).view({bsize, seql, 2, nheads, adim}).unbind(2); + real_iK = _reals[0].permute({0, 2, 3, 1}); + real_iV = _reals[1].transpose(1, 2); + } + + auto scores = real_iQ.matmul(real_iK) / sqrt(adim); + + auto mask = map_get(tensors, "mask"); + if (ct_is_not_none(mask)) { + scores.masked_fill_(mask.unsqueeze(1), -inf_value); + } + scores = scores.softmax(-1); + if (p > 0.0) { + scores = torch::nn::functional::dropout(scores, torch::nn::functional::DropoutFuncOptions().p(p).inplace(bargs["drop.inplace"]).training(bargs["drop.training"])); + } + + auto out = torch::nn::functional::linear(scores.matmul(real_iV).transpose(1, 2).contiguous().view({bsize, nquery, nheads * adim}), tensors["outer.weight"], map_get(tensors, "outer.bias")); + + std::map rs; + rs["out"] = out; + if (evaluation) { + rs["real_iK"] = real_iK; + rs["real_iV"] = real_iV; + } + return rs; +}*/ diff --git a/modules/cpp/base/attn/cross/setup.py b/modules/cpp/base/attn/cross/setup.py new file mode 100644 index 0000000..8f885e5 --- /dev/null +++ b/modules/cpp/base/attn/cross/setup.py @@ -0,0 +1,6 @@ +#encoding: utf-8 + +from setuptools import setup, Extension +from torch.utils import cpp_extension + +setup(name='cross_attn_cpp', ext_modules=[cpp_extension.CppExtension('cross_attn_cpp', ['modules/cpp/base/attn/cross/attn.cpp'])], cmdclass={'build_ext': cpp_extension.BuildExtension}) diff --git a/modules/cpp/base/attn/self/attn.cpp b/modules/cpp/base/attn/self/attn.cpp new file mode 100644 index 0000000..8b7efe9 --- /dev/null +++ b/modules/cpp/base/attn/self/attn.cpp @@ -0,0 +1,6 @@ +#include +#include "attn_func.cpp" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &_NEUTRON_MODULES_BASE_ATTN_FUNC_NAME, "Self attention forward"); +} diff --git a/modules/cpp/base/attn/self/attn_func.cpp b/modules/cpp/base/attn/self/attn_func.cpp new file mode 100644 index 0000000..7d3ff9f --- /dev/null +++ b/modules/cpp/base/attn/self/attn_func.cpp @@ -0,0 +1,90 @@ +#define _NEUTRON_MODULES_BASE_ATTN_BUILD_SATTN +#include "../common.cpp" +#undef _NEUTRON_MODULES_BASE_ATTN_BUILD_SATTN + +/*#include +#include +#include +#include +#include +#include "../../../../../utils/cpp/base.h" +#include "../base.h" + +std::map attn_forward(std::map tensors, std::vector states, std::map iargs, double p, double inf_value, std::map bargs) { + + auto iQ = tensors["iQ"]; + auto bsize = iQ.size(0); + auto nquery = iQ.size(1); + auto nheads = iargs["num_head"]; + auto adim = iargs["attn_dim"]; + int64_t seql; + torch::Tensor rel_pos_cache; + + auto _reals = torch::nn::functional::linear(iQ, tensors["adaptor.weight"], map_get(tensors, "adaptor.bias")).view({bsize, nquery, 3, nheads, adim}).unbind(2); + auto real_iQ = _reals[0].transpose(1, 2); + auto real_iK = _reals[1].permute({0, 2, 3, 1}); + auto real_iV = _reals[2].transpose(1, 2); + + bool not_non_states = states.size() > 0; + if (not_non_states) { + auto _h_real_iK = states[0]; + auto _h_real_iV = states[1]; + if (pyt_is_none(_h_real_iK)) { + seql = nquery; + } + else { + seql = nquery + _h_real_iK.size(-1); + real_iK = at::cat({_h_real_iK, real_iK}, -1); + real_iV = at::cat({_h_real_iV, real_iV}, 2); + } + } + + auto scores = real_iQ.matmul(real_iK); + + auto rel_pemb_weight = map_get(tensors, "rel_pemb.weight"); + bool not_none_rel = ct_is_not_none(rel_pemb_weight); + if (not_none_rel) { + auto emb_option = torch::nn::functional::EmbeddingFuncOptions(); + auto padding_idx = map_get(iargs, "rel_pemb.padding_idx", -1); + if (padding_idx >= 0) { + emb_option = emb_option.padding_idx(padding_idx); + } + rel_pos_cache = map_get(tensors, "rel_pos_cache"); + if (not_non_states) { + if (ct_is_none(rel_pos_cache)) { + rel_pos_cache = get_rel_pos(seql, iargs, tensors["rel_pos"]).narrow(0, seql - nquery, nquery).contiguous(); + } + scores += real_iQ.permute({2, 0, 1, 3}).contiguous().view({nquery, bsize * nheads, adim}).bmm(torch::nn::functional::embedding(rel_pos_cache, rel_pemb_weight, emb_option).transpose(1, 2)).view({nquery, bsize, nheads, seql}).permute({1, 2, 0, 3}); + } + else { + if (ct_is_none(rel_pos_cache)) { + rel_pos_cache = get_rel_pos(nquery, iargs, tensors["rel_pos"]).contiguous(); + } + scores += real_iQ.permute({2, 0, 1, 3}).contiguous().view({nquery, bsize * nheads, adim}).bmm(torch::nn::functional::embedding(rel_pos_cache, rel_pemb_weight, emb_option).transpose(1, 2)).view({nquery, bsize, nheads, nquery}).permute({1, 2, 0, 3}); + } + } + + scores = scores / sqrt(adim); + + auto mask = map_get(tensors, "mask"); + if (ct_is_not_none(mask)) { + scores.masked_fill_(mask.unsqueeze(1), -inf_value); + } + scores = scores.softmax(-1); + if (p > 0.0) { + scores = torch::nn::functional::dropout(scores, torch::nn::functional::DropoutFuncOptions().p(p).inplace(bargs["drop.inplace"]).training(bargs["drop.training"])); + } + + auto out = torch::nn::functional::linear(scores.matmul(real_iV).transpose(1, 2).contiguous().view({bsize, nquery, nheads * adim}), tensors["outer.weight"], map_get(tensors, "outer.bias")); + + std::map rs; + rs["out"] = out; + if (not_non_states) { + rs["real_iK"] = real_iK; + rs["real_iV"] = real_iV; + } + if (not_none_rel) { + rs["rel_pos_cache"] = rel_pos_cache; + } + return rs; +}*/ diff --git a/modules/cpp/base/attn/self/setup.py b/modules/cpp/base/attn/self/setup.py new file mode 100644 index 0000000..489c6e7 --- /dev/null +++ b/modules/cpp/base/attn/self/setup.py @@ -0,0 +1,6 @@ +#encoding: utf-8 + +from setuptools import setup, Extension +from torch.utils import cpp_extension + +setup(name='self_attn_cpp', ext_modules=[cpp_extension.CppExtension('self_attn_cpp', ['modules/cpp/base/attn/self/attn.cpp'])], cmdclass={'build_ext': cpp_extension.BuildExtension}) diff --git a/modules/cpp/base/attn/setup.py b/modules/cpp/base/attn/setup.py new file mode 100644 index 0000000..0e749b6 --- /dev/null +++ b/modules/cpp/base/attn/setup.py @@ -0,0 +1,6 @@ +#encoding: utf-8 + +from setuptools import setup, Extension +from torch.utils import cpp_extension + +setup(name='attn_cpp', ext_modules=[cpp_extension.CppExtension('attn_cpp', ['modules/cpp/base/attn/attn.cpp'])], cmdclass={'build_ext': cpp_extension.BuildExtension}) diff --git a/modules/cpp/base/ffn/pff.cpp b/modules/cpp/base/ffn/pff.cpp new file mode 100644 index 0000000..b5bf31a --- /dev/null +++ b/modules/cpp/base/ffn/pff.cpp @@ -0,0 +1,6 @@ +#include +#include "pff_func.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &positionwise_ff_forward, "Positionwise FF forward"); +} diff --git a/modules/cpp/base/ffn/pff_func.cpp b/modules/cpp/base/ffn/pff_func.cpp new file mode 100644 index 0000000..5f84e22 --- /dev/null +++ b/modules/cpp/base/ffn/pff_func.cpp @@ -0,0 +1,45 @@ +#include +#include +#include +#include +#include "../../act/act_func.h" +#include "../../../../utils/cpp/base.h" + +at::Tensor positionwise_ff_forward(std::map tensors, std::map bargs, void* act, std::map dargs, std::vector normalized_shape) { + + auto x = tensors["x"]; + + auto ln_opts = torch::nn::functional::LayerNormFuncOptions(normalized_shape).eps(dargs["normer.eps"]); + auto ln_weight = map_get(tensors, "normer.weight"); + if (ct_is_not_none(ln_weight)) { + ln_opts = ln_opts.weight(ln_weight); + } + auto ln_bias = map_get(tensors, "normer.bias"); + if (ct_is_not_none(ln_bias)) { + ln_opts = ln_opts.bias(ln_bias); + } + + auto _out = torch::nn::functional::layer_norm(x, ln_opts); + + auto p = dargs["net.2.p"]; + + auto out = torch::nn::functional::linear(_out, tensors["net.0.weight"], map_get(tensors, "net.0.bias")); + out = act_forward(act, out, bargs["net.1.inplace"]); + if (p > 0.0) { + out = torch::nn::functional::dropout(out, torch::nn::functional::DropoutFuncOptions().p(p).training(bargs["net.2.training"]).inplace(bargs["net.2.inplace"])); + out = torch::nn::functional::linear(out, tensors["net.3.weight"], map_get(tensors, "net.3.bias")); + out = torch::nn::functional::dropout(out, torch::nn::functional::DropoutFuncOptions().p(p).training(bargs["net.4.training"]).inplace(bargs["net.4.inplace"])); + } + else { + out = torch::nn::functional::linear(out, tensors["net.2.weight"], map_get(tensors, "net.2.bias")); + } + + if (bargs["norm_residual"]) { + out = out + _out; + } + else { + out = out + x; + } + + return out; +} diff --git a/modules/cpp/base/ffn/pff_func.h b/modules/cpp/base/ffn/pff_func.h new file mode 100644 index 0000000..4d8042d --- /dev/null +++ b/modules/cpp/base/ffn/pff_func.h @@ -0,0 +1,11 @@ +#ifndef _NEUTRON_MODULES_CPP_BASE_FFN_PFF_FUNC +#define _NEUTRON_MODULES_CPP_BASE_FFN_PFF_FUNC + +#include +#include +#include +#include + +at::Tensor positionwise_ff_forward(std::map tensors, std::map bargs, void* act, std::map dargs, std::vector normalized_shape); + +#endif diff --git a/modules/cpp/base/ffn/setup.py b/modules/cpp/base/ffn/setup.py new file mode 100644 index 0000000..2bd090d --- /dev/null +++ b/modules/cpp/base/ffn/setup.py @@ -0,0 +1,6 @@ +#encoding: utf-8 + +from setuptools import setup, Extension +from torch.utils import cpp_extension + +setup(name='pff_cpp', ext_modules=[cpp_extension.CppExtension('pff_cpp', ['modules/cpp/base/ffn/pff.cpp', 'modules/cpp/base/ffn/pff_func.cpp', 'modules/cpp/act/act_func.cpp'])], cmdclass={'build_ext': cpp_extension.BuildExtension}) diff --git a/modules/cpp/base/resattn/attn.cpp b/modules/cpp/base/resattn/attn.cpp new file mode 100644 index 0000000..93064f0 --- /dev/null +++ b/modules/cpp/base/resattn/attn.cpp @@ -0,0 +1,6 @@ +#include +#include "attn_func.cpp" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &_NEUTRON_MODULES_BASE_ATTN_FUNC_NAME, "Residual Multi-head attention forward"); +} diff --git a/modules/cpp/base/resattn/attn_func.cpp b/modules/cpp/base/resattn/attn_func.cpp new file mode 100644 index 0000000..8b58afc --- /dev/null +++ b/modules/cpp/base/resattn/attn_func.cpp @@ -0,0 +1,3 @@ +#define _NEUTRON_MODULES_BASE_ATTN_BUILD_MHATTN +#include "common.cpp" +#undef _NEUTRON_MODULES_BASE_ATTN_BUILD_MHATTN diff --git a/modules/cpp/base/resattn/common.cpp b/modules/cpp/base/resattn/common.cpp new file mode 100644 index 0000000..589aece --- /dev/null +++ b/modules/cpp/base/resattn/common.cpp @@ -0,0 +1,205 @@ +#include +#include +#include +#include +#include +#include "../../../../utils/cpp/base.h" + +#define _NEUTRON_MODULES_BASE_ATTN_BUILD_RES + +#if !(defined(_NEUTRON_MODULES_BASE_ATTN_BUILD_SATTN)||defined(_NEUTRON_MODULES_BASE_ATTN_BUILD_CATTN)||defined(_NEUTRON_MODULES_BASE_ATTN_BUILD_MHATTN)) +#define _NEUTRON_MODULES_BASE_ATTN_BUILD_MHATTN +#endif + +#ifdef _NEUTRON_MODULES_BASE_ATTN_BUILD_SATTN +#define _NEUTRON_MODULES_BASE_ATTN_FUNC_NAME res_self_attn_forward +#elif defined(_NEUTRON_MODULES_BASE_ATTN_BUILD_SATTN) +#define _NEUTRON_MODULES_BASE_ATTN_FUNC_NAME res_cross_attn_forward +#else +#define _NEUTRON_MODULES_BASE_ATTN_FUNC_NAME res_multi_head_attn_forward +#endif + +#ifndef _NEUTRON_MODULES_BASE_ATTN_BUILD_CATTN +#include "../attn/base.h" +std::map _NEUTRON_MODULES_BASE_ATTN_FUNC_NAME(std::map tensors, std::vector states, std::map iargs, std::map dargs, std::vector normalized_shape, std::map bargs) { +#else +std::map _NEUTRON_MODULES_BASE_ATTN_FUNC_NAME(std::map tensors, std::map iargs, std::map dargs, std::vector normalized_shape, std::map bargs) { +#endif + + auto iQ = tensors["iQ"]; + auto bsize = iQ.size(0); + auto nquery = iQ.size(1); + + auto ln_opts = torch::nn::functional::LayerNormFuncOptions(normalized_shape).eps(dargs["normer.eps"]); + auto ln_weight = map_get(tensors, "normer.weight"); + if (ct_is_not_none(ln_weight)) { + ln_opts = ln_opts.weight(ln_weight); + } + auto ln_bias = map_get(tensors, "normer.bias"); + if (ct_is_not_none(ln_bias)) { + ln_opts = ln_opts.bias(ln_bias); + } + + auto _iQ = torch::nn::functional::layer_norm(iQ, ln_opts); + + #ifndef _NEUTRON_MODULES_BASE_ATTN_BUILD_SATTN + auto iK = tensors["iK"]; + auto seql = iK.size(1); + #endif + #ifdef _NEUTRON_MODULES_BASE_ATTN_BUILD_MHATTN + auto iV = tensors["iV"]; + #endif + auto nheads = iargs["net.num_head"]; + auto adim = iargs["net.attn_dim"]; + #ifdef _NEUTRON_MODULES_BASE_ATTN_BUILD_SATTN + int64_t seql; + #endif + #ifndef _NEUTRON_MODULES_BASE_ATTN_BUILD_CATTN + torch::Tensor rel_pos_cache; + #endif + + #ifndef _NEUTRON_MODULES_BASE_ATTN_BUILD_SATTN + auto real_iQ = torch::nn::functional::linear(_iQ, tensors["net.query_adaptor.weight"], map_get(tensors, "net.query_adaptor.bias")).view({bsize, nquery, nheads, adim}).transpose(1, 2); + at::Tensor real_iK, real_iV; + + auto evaluation = not bargs["net.training"]; + auto buf_real_iK = map_get(tensors, "net.real_iK"); + auto buf_iK = map_get(tensors, "net.iK"); + if (ct_is_not_none(buf_real_iK) and iK.is_set_to(buf_iK) and evaluation) { + real_iK = buf_real_iK; + #ifdef _NEUTRON_MODULES_BASE_ATTN_BUILD_CATTN + real_iV = tensors["net.real_iV"]; + #endif + } + else { + #ifdef _NEUTRON_MODULES_BASE_ATTN_BUILD_CATTN + auto _reals = torch::nn::functional::linear(iK, tensors["net.kv_adaptor.weight"], map_get(tensors, "net.kv_adaptor.bias")).view({bsize, seql, 2, nheads, adim}).unbind(2); + real_iK = _reals[0].permute({0, 2, 3, 1}); + real_iV = _reals[1].transpose(1, 2); + #else + real_iK = torch::nn::functional::linear(iK, tensors["net.key_adaptor.weight"], map_get(tensors, "net.key_adaptor.bias")).view({bsize, seql, nheads, adim}).permute({0, 2, 3, 1}); + #endif + } + #else + auto _reals = torch::nn::functional::linear(_iQ, tensors["net.adaptor.weight"], map_get(tensors, "net.adaptor.bias")).view({bsize, nquery, 3, nheads, adim}).unbind(2); + auto real_iQ = _reals[0].transpose(1, 2); + auto real_iK = _reals[1].permute({0, 2, 3, 1}); + auto real_iV = _reals[2].transpose(1, 2); + #endif + + #ifdef _NEUTRON_MODULES_BASE_ATTN_BUILD_MHATTN + auto buf_real_iV = map_get(tensors, "net.real_iV"); + auto buf_iV = map_get(tensors, "net.iV"); + if (ct_is_not_none(buf_real_iV) and iV.is_set_to(buf_iV) and evaluation) { + real_iV = buf_real_iV; + } + else { + real_iV = torch::nn::functional::linear(iV, tensors["net.value_adaptor.weight"], map_get(tensors, "net.value_adaptor.bias")).view({bsize, seql, nheads, adim}).transpose(1, 2); + } + #endif + + #ifndef _NEUTRON_MODULES_BASE_ATTN_BUILD_CATTN + bool not_non_states = states.size() > 0; + if (not_non_states) { + auto _h_real_iK = states[0]; + auto _h_real_iV = states[1]; + if (pyt_is_not_none(_h_real_iK)) { + #ifdef _NEUTRON_MODULES_BASE_ATTN_BUILD_SATTN + seql = nquery + _h_real_iK.size(-1); + #else + seql += _h_real_iK.size(-1); + #endif + real_iK = at::cat({_h_real_iK, real_iK}, -1); + real_iV = at::cat({_h_real_iV, real_iV}, 2); + } + #ifdef _NEUTRON_MODULES_BASE_ATTN_BUILD_SATTN + else { + seql = nquery; + } + #endif + } + #endif + + #ifdef _NEUTRON_MODULES_BASE_ATTN_BUILD_CATTN + auto scores = real_iQ.matmul(real_iK) / sqrt(adim); + #else + auto scores = real_iQ.matmul(real_iK); + #endif + + #ifndef _NEUTRON_MODULES_BASE_ATTN_BUILD_CATTN + auto rel_pemb_weight = map_get(tensors, "net.rel_pemb.weight"); + bool not_none_rel = ct_is_not_none(rel_pemb_weight); + if (not_none_rel) { + auto emb_option = torch::nn::functional::EmbeddingFuncOptions(); + auto padding_idx = map_get(iargs, "net.rel_pemb.padding_idx", -1); + if (padding_idx >= 0) { + emb_option = emb_option.padding_idx(padding_idx); + } + rel_pos_cache = map_get(tensors, "net.rel_pos_cache"); + #ifdef _NEUTRON_MODULES_BASE_ATTN_BUILD_SATTN + if (not_non_states) { + if (ct_is_none(rel_pos_cache)) { + rel_pos_cache = get_rel_pos(seql, iargs, tensors["net.rel_pos"]).narrow(0, seql - nquery, nquery).contiguous(); + } + scores += real_iQ.permute({2, 0, 1, 3}).contiguous().view({nquery, bsize * nheads, adim}).bmm(torch::nn::functional::embedding(rel_pos_cache, rel_pemb_weight, emb_option).transpose(1, 2)).view({nquery, bsize, nheads, seql}).permute({1, 2, 0, 3}); + } + else { + if (ct_is_none(rel_pos_cache)) { + rel_pos_cache = get_rel_pos(nquery, iargs, tensors["net.rel_pos"]).contiguous(); + } + scores += real_iQ.permute({2, 0, 1, 3}).contiguous().view({nquery, bsize * nheads, adim}).bmm(torch::nn::functional::embedding(rel_pos_cache, rel_pemb_weight, emb_option).transpose(1, 2)).view({nquery, bsize, nheads, nquery}).permute({1, 2, 0, 3}); + } + #else + if (ct_is_none(rel_pos_cache)) { + rel_pos_cache = get_rel_pos(seql, iargs, tensors["net.rel_pos"]).narrow(0, seql - nquery, nquery).contiguous(); + } + scores += real_iQ.permute({2, 0, 1, 3}).contiguous().view({nquery, bsize * nheads, adim}).bmm(torch::nn::functional::embedding(rel_pos_cache, rel_pemb_weight, emb_option).transpose(1, 2)).view({nquery, bsize, nheads, seql}).permute({1, 2, 0, 3}); + #endif + } + + scores = scores / sqrt(adim); + #endif + + auto mask = map_get(tensors, "mask"); + if (ct_is_not_none(mask)) { + scores.masked_fill_(mask.unsqueeze(1), -dargs["inf_value"]); + } + scores = scores.softmax(-1); + auto attn_p = dargs["net.drop.p"]; + if (attn_p > 0.0) { + scores = torch::nn::functional::dropout(scores, torch::nn::functional::DropoutFuncOptions().p(attn_p).inplace(bargs["net.drop.inplace"]).training(bargs["net.drop.training"])); + } + + auto out = torch::nn::functional::linear(scores.matmul(real_iV).transpose(1, 2).contiguous().view({bsize, nquery, nheads * adim}), tensors["net.outer.weight"], map_get(tensors, "net.outer.bias")); + + auto out_p = dargs["drop.p"]; + if (out_p > 0.0) { + out = torch::nn::functional::dropout(out, torch::nn::functional::DropoutFuncOptions().p(out_p).inplace(bargs["drop.inplace"]).training(bargs["drop.training"])); + } + + if (bargs["norm_residual"]) { + out = out + _iQ; + } + else { + out = out + iQ; + } + + std::map rs; + rs["out"] = out; + #ifdef _NEUTRON_MODULES_BASE_ATTN_BUILD_SATTN + if (not_non_states) { + #elif defined(_NEUTRON_MODULES_BASE_ATTN_BUILD_CATTN) + if (evaluation) { + #else + if ((not_non_states) or evaluation) { + #endif + rs["net.real_iK"] = real_iK; + rs["net.real_iV"] = real_iV; + } + #ifndef _NEUTRON_MODULES_BASE_ATTN_BUILD_CATTN + if (not_none_rel) { + rs["net.rel_pos_cache"] = rel_pos_cache; + } + #endif + return rs; +} diff --git a/modules/cpp/base/resattn/cross/attn.cpp b/modules/cpp/base/resattn/cross/attn.cpp new file mode 100644 index 0000000..b29df93 --- /dev/null +++ b/modules/cpp/base/resattn/cross/attn.cpp @@ -0,0 +1,6 @@ +#include +#include "attn_func.cpp" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &_NEUTRON_MODULES_BASE_ATTN_FUNC_NAME, "Residual Cross attention forward"); +} diff --git a/modules/cpp/base/resattn/cross/attn_func.cpp b/modules/cpp/base/resattn/cross/attn_func.cpp new file mode 100644 index 0000000..321c35b --- /dev/null +++ b/modules/cpp/base/resattn/cross/attn_func.cpp @@ -0,0 +1,3 @@ +#define _NEUTRON_MODULES_BASE_ATTN_BUILD_CATTN +#include "../common.cpp" +#undef _NEUTRON_MODULES_BASE_ATTN_BUILD_CATTN diff --git a/modules/cpp/base/resattn/cross/setup.py b/modules/cpp/base/resattn/cross/setup.py new file mode 100644 index 0000000..bb2d00f --- /dev/null +++ b/modules/cpp/base/resattn/cross/setup.py @@ -0,0 +1,6 @@ +#encoding: utf-8 + +from setuptools import setup, Extension +from torch.utils import cpp_extension + +setup(name='res_cross_attn_cpp', ext_modules=[cpp_extension.CppExtension('res_cross_attn_cpp', ['modules/cpp/base/resattn/cross/attn.cpp'])], cmdclass={'build_ext': cpp_extension.BuildExtension}) diff --git a/modules/cpp/base/resattn/self/attn.cpp b/modules/cpp/base/resattn/self/attn.cpp new file mode 100644 index 0000000..24c7a0f --- /dev/null +++ b/modules/cpp/base/resattn/self/attn.cpp @@ -0,0 +1,6 @@ +#include +#include "attn_func.cpp" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &_NEUTRON_MODULES_BASE_ATTN_FUNC_NAME, "Residual Self attention forward"); +} diff --git a/modules/cpp/base/resattn/self/attn_func.cpp b/modules/cpp/base/resattn/self/attn_func.cpp new file mode 100644 index 0000000..f5decac --- /dev/null +++ b/modules/cpp/base/resattn/self/attn_func.cpp @@ -0,0 +1,3 @@ +#define _NEUTRON_MODULES_BASE_ATTN_BUILD_SATTN +#include "../common.cpp" +#undef _NEUTRON_MODULES_BASE_ATTN_BUILD_SATTN diff --git a/modules/cpp/base/resattn/self/setup.py b/modules/cpp/base/resattn/self/setup.py new file mode 100644 index 0000000..dbf7b23 --- /dev/null +++ b/modules/cpp/base/resattn/self/setup.py @@ -0,0 +1,6 @@ +#encoding: utf-8 + +from setuptools import setup, Extension +from torch.utils import cpp_extension + +setup(name='res_self_attn_cpp', ext_modules=[cpp_extension.CppExtension('res_self_attn_cpp', ['modules/cpp/base/resattn/self/attn.cpp'])], cmdclass={'build_ext': cpp_extension.BuildExtension}) diff --git a/modules/cpp/base/resattn/setup.py b/modules/cpp/base/resattn/setup.py new file mode 100644 index 0000000..cbf0ad1 --- /dev/null +++ b/modules/cpp/base/resattn/setup.py @@ -0,0 +1,6 @@ +#encoding: utf-8 + +from setuptools import setup, Extension +from torch.utils import cpp_extension + +setup(name='res_attn_cpp', ext_modules=[cpp_extension.CppExtension('res_attn_cpp', ['modules/cpp/base/resattn/attn.cpp'])], cmdclass={'build_ext': cpp_extension.BuildExtension}) diff --git a/modules/cpp/group/group.cpp b/modules/cpp/group/group.cpp new file mode 100644 index 0000000..6dd9492 --- /dev/null +++ b/modules/cpp/group/group.cpp @@ -0,0 +1,6 @@ +#include +#include "group_func.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &group_linear_forward, "Group linear forward"); +} diff --git a/modules/cpp/group/group_func.cpp b/modules/cpp/group/group_func.cpp new file mode 100644 index 0000000..18e095c --- /dev/null +++ b/modules/cpp/group/group_func.cpp @@ -0,0 +1,55 @@ +#include +#include +#include +#include +#include "../../../utils/cpp/base.h" + +at::Tensor group_linear_forward(std::map tensors, std::map iargs, std::map bargs) { + + auto inputu = tensors["inputu"]; + auto _sizes = inputu.sizes().vec(); + auto _s_last_dim_ind = _sizes.size() - 1; + at::Tensor _id, out; + auto trans_input = bargs["trans_input"]; + auto idm = inputu.dim(); + auto ngroup = iargs["ngroup"]; + if ((idm != 3) or trans_input) { + int64_t _ldsize; + if (trans_input) { + _ldsize = iargs["isize"]; + } + else { + _ldsize = _sizes[_s_last_dim_ind]; + } + _id = inputu.view({-1, ngroup, _ldsize}); + } + else { + _id = inputu; + } + _id = _id.transpose(0, 1); + + auto weight = tensors["weight"]; + auto bias = map_get(tensors, "bias"); + if (ct_is_none(bias)) { + out = _id.bmm(weight); + } + else { + out = bias.baddbmm(_id, weight); + } + if (bargs["shuffle"]) { + out = out.permute({1, 2, 0}); + } + else { + out = out.transpose(0, 1); + } + + _sizes[_s_last_dim_ind] = -1; + if (bargs["i_gdim"]) { + _sizes.insert(_sizes.end() - 1, ngroup); + } + else if (bargs["del_gdim"]) { + _sizes.erase(_sizes.end() - 2); + } + + return out.contiguous().view(_sizes); +} diff --git a/modules/cpp/group/group_func.h b/modules/cpp/group/group_func.h new file mode 100644 index 0000000..f20f420 --- /dev/null +++ b/modules/cpp/group/group_func.h @@ -0,0 +1,10 @@ +#ifndef _NEUTRON_MODULES_CPP_GROUP_FUNC +#define _NEUTRON_MODULES_CPP_GROUP_FUNC + +#include +#include +#include + +at::Tensor group_linear_forward(std::map tensors, std::map iargs, std::map bargs); + +#endif diff --git a/modules/cpp/group/setup.py b/modules/cpp/group/setup.py new file mode 100644 index 0000000..f4850d4 --- /dev/null +++ b/modules/cpp/group/setup.py @@ -0,0 +1,6 @@ +#encoding: utf-8 + +from setuptools import setup, Extension +from torch.utils import cpp_extension + +setup(name='group_cpp', ext_modules=[cpp_extension.CppExtension('group_cpp', ['modules/cpp/group/group.cpp', 'modules/cpp/group/group_func.cpp'])], cmdclass={'build_ext': cpp_extension.BuildExtension}) diff --git a/modules/hplstm/cpp/lgate.cpp b/modules/cpp/hplstm/lgate.cpp similarity index 89% rename from modules/hplstm/cpp/lgate.cpp rename to modules/cpp/hplstm/lgate.cpp index dcc2f82..01e0b86 100644 --- a/modules/hplstm/cpp/lgate.cpp +++ b/modules/cpp/hplstm/lgate.cpp @@ -1,35 +1,35 @@ #include #include -at::Tensor lgate_forward(torch::Tensor fgate, torch::Tensor igh, torch::Tensor init_cell, int64_t dim, bool inplace=false){ +at::Tensor lgate_forward(torch::Tensor fgate, torch::Tensor igh, torch::Tensor init_cell, int64_t dim, bool inplace=false) { torch::Tensor cell; if (inplace) { cell = igh; } - else{ + else { cell = igh.clone(); } auto seqlen = cell.size(dim); int64_t i; cell.select(dim, 0).addcmul_(init_cell, fgate.select(dim, 0)); - for (i = 1; i < seqlen; i++){ + for (i = 1; i < seqlen; i++) { cell.select(dim, i).addcmul_(cell.select(dim, i - 1), fgate.select(dim, i)); } return cell; } -std::vector lgate_backward(torch::Tensor grad_cell, torch::Tensor cell, torch::Tensor fgate, torch::Tensor init_cell, int64_t dim){ +std::vector lgate_backward(torch::Tensor grad_cell, torch::Tensor cell, torch::Tensor fgate, torch::Tensor init_cell, int64_t dim) { auto grad_fgate = grad_cell.clone(); auto grad_igh = grad_cell.clone(); auto last_index = grad_cell.size(dim) - 1; auto grad_prev_cell = grad_cell.select(dim, last_index) * fgate.select(dim, last_index); - if (last_index > 0){ + if (last_index > 0) { grad_fgate.select(dim, last_index).mul_(cell.select(dim, last_index - 1)); int64_t i; - for (i = last_index - 1; i > 0; i--){ + for (i = last_index - 1; i > 0; i--) { auto acc_grad_cell = grad_fgate.select(dim, i).add_(grad_prev_cell);// grad_fgate is initialized as a copy of grad_cell, performing the accumulation directly on grad_fgate is more efficient. grad_igh.select(dim, i).add_(grad_prev_cell); grad_prev_cell = acc_grad_cell * fgate.select(dim, i); @@ -40,7 +40,7 @@ std::vector lgate_backward(torch::Tensor grad_cell, torch::Tensor grad_prev_cell = acc_grad_cell * fgate.select(dim, 0); acc_grad_cell.mul_(init_cell); } - else{ + else { grad_fgate.select(dim, last_index).mul_(init_cell); } diff --git a/modules/hplstm/cpp/setup.py b/modules/cpp/hplstm/setup.py similarity index 60% rename from modules/hplstm/cpp/setup.py rename to modules/cpp/hplstm/setup.py index d56397d..5037e35 100644 --- a/modules/hplstm/cpp/setup.py +++ b/modules/cpp/hplstm/setup.py @@ -3,4 +3,4 @@ from setuptools import setup, Extension from torch.utils import cpp_extension -setup(name='lgate_cpp', ext_modules=[cpp_extension.CppExtension('lgate_cpp', ['lgate.cpp'])], cmdclass={'build_ext': cpp_extension.BuildExtension}) +setup(name='lgate_cpp', ext_modules=[cpp_extension.CppExtension('lgate_cpp', ['modules/cpp/hplstm/lgate.cpp'])], cmdclass={'build_ext': cpp_extension.BuildExtension}) diff --git a/modules/group/base.py b/modules/group/base.py index 5469f0c..f7171a6 100644 --- a/modules/group/base.py +++ b/modules/group/base.py @@ -4,6 +4,8 @@ import torch from torch import nn +from cnfg.ihyp import use_c_backend_group, bind_c_forward + class GroupLinear(nn.Module): # isize: input dimension (dimension per group * ngroup) @@ -30,6 +32,9 @@ def __init__(self, isize, osize, ngroup, bias=True, trans_input=True, shuffle=Fa else: self.bias = None + if self.c_available(): + self.c_init() + # inputu: (..., isize) def forward(self, inputu, weight=None, bias=None): @@ -54,6 +59,49 @@ def forward(self, inputu, weight=None, bias=None): return out.contiguous().view(_size) + def c_available(self): + + return use_c_backend_group and (type(self) == GroupLinear) + + def c_init(self, bind=bind_c_forward): + + try: + try: + import group_cpp + except Exception as e: + from torch.utils.cpp_extension import load + group_cpp = load(name="group_cpp", sources=["modules/cpp/group/group.cpp", "modules/cpp/group/group_func.cpp"]) + self.c_forward_func = group_cpp.forward + except: + self.c_forward_func = None + if self.c_forward_func is not None: + self.c_build_cache() + if bind: + GroupLinear.forward = GroupLinear.c_forward + + def c_forward(self, inputu, weight=None, bias=None): + + return self.c_forward_func(*self.c_build_inputs(inputu, weight=weight, bias=bias)) + + def c_build_cache(self): + + self.aargs = ({"isize": self.isize, "ngroup": self.ngroup}, {"trans_input": self.trans_input, "shuffle": self.shuffle, "i_gdim": self.i_gdim, "del_gdim": self.del_gdim},) + self.targs = dict(self.named_parameters()) + + def c_build_inputs(self, inputu, weight=None, bias=None): + + i_d = self.targs.copy() + i_d["inputu"] = inputu + if weight is not None: + i_d["weight"] = weight + if bias is not None: + i_d["bias"] = bias + + return i_d, *self.aargs + + def extra_repr(self): + return 'groups={}, in_features={}, out_features={}, bias={}'.format(self.ngroup, self.ngroup * self.isize, self.ngroup * self.osize, self.bias is not None) + def fix_init(self): with torch.no_grad(): diff --git a/modules/hplstm/LGate.py b/modules/hplstm/LGate.py index 6c12626..e5b269e 100644 --- a/modules/hplstm/LGate.py +++ b/modules/hplstm/LGate.py @@ -6,7 +6,7 @@ import lgate_cpp except Exception as e: from torch.utils.cpp_extension import load - lgate_cpp = load(name="lgate_cpp", sources=["modules/hplstm/cpp/lgate.cpp"]) + lgate_cpp = load(name="lgate_cpp", sources=["modules/cpp/hplstm/lgate.cpp"]) class LGateFunction(Function): @@ -22,8 +22,12 @@ def forward(ctx, fgate, igh, init_cell, dim=None, inplace=False): @staticmethod def backward(ctx, grad_cell): - cell, fgate, init_cell = ctx.saved_variables - grad_fgate, grad_igh, grad_init_cell = lgate_cpp.backward(grad_cell, cell, fgate, init_cell, ctx.dim) - return grad_fgate, grad_igh, grad_init_cell, None, None + needs_grad_fgate, needs_grad_igh, needs_grad_init_cell = ctx.needs_input_grad[0:3] + if needs_grad_fgate or needs_grad_igh or needs_grad_init_cell: + cell, fgate, init_cell = ctx.saved_variables + grad_fgate, grad_igh, grad_init_cell = lgate_cpp.backward(grad_cell, cell, fgate, init_cell, ctx.dim) + return grad_fgate if needs_grad_fgate else None, grad_igh if needs_grad_igh else None, grad_init_cell if needs_grad_init_cell else None, None, None + else: + return None, None, None, None, None LGateFunc = LGateFunction.apply diff --git a/modules/hplstm/base.py b/modules/hplstm/base.py index 599330d..b124e52 100644 --- a/modules/hplstm/base.py +++ b/modules/hplstm/base.py @@ -6,6 +6,7 @@ from modules.group.base import GroupLinear from modules.act import Custom_Act from modules.hplstm.LGate import LGateFunc +from utils.base import float2odd from cnfg.ihyp import * class MHPLSTMCore(nn.Module): @@ -16,18 +17,20 @@ def __init__(self, isize, num_head=8, osize=None, dropout=0.0, custom_act=use_ad _osize = isize if osize is None else osize - head_dim = isize // num_head - hsize = head_dim * num_head + i_head_dim = float2odd(float(isize) / num_head) + i_hsize = i_head_dim * num_head + o_head_dim = float2odd(float(_osize) / num_head) + o_hsize = o_head_dim * num_head - self.trans_hid = GroupLinear(hsize + hsize, hsize * 3, num_head, bias=enable_bias, shuffle=False, trans_input=False, flatten_output=False) - self.trans_og = nn.Sequential(GroupLinear(hsize + hsize, hsize, num_head, bias=enable_bias, shuffle=False, trans_input=False, flatten_output=False), nn.LayerNorm((num_head, head_dim), eps=ieps_ln_default, elementwise_affine=enable_ln_parameters)) + self.trans_hid = GroupLinear(i_hsize + i_hsize, o_hsize * 3, num_head, bias=enable_bias, shuffle=False, trans_input=False, flatten_output=False) + self.trans_og = nn.Sequential(GroupLinear(i_hsize + o_hsize, o_hsize, num_head, bias=enable_bias, shuffle=False, trans_input=False, flatten_output=False), nn.LayerNorm((num_head, o_head_dim), eps=ieps_ln_default, elementwise_affine=enable_ln_parameters)) - self.normer_csum = nn.LayerNorm((num_head, head_dim), eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) - self.normer_hid = nn.LayerNorm((num_head, 3, head_dim), eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) + self.normer_csum = nn.LayerNorm((num_head, i_head_dim), eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) + self.normer_hid = nn.LayerNorm((num_head, 3, o_head_dim), eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) self.act = Custom_Act() if custom_act else nn.ReLU()#Tanh() self.drop = Dropout(dropout, inplace=inplace_after_Custom_Act) if dropout > 0.0 else None - self.init_cx = nn.Parameter(torch.zeros(1, num_head, head_dim)) + self.init_cx = nn.Parameter(torch.zeros(1, num_head, o_head_dim)) # heads_input: (bsize, seql, nheads, adim) # states: ((bsize, 1, num_head, head_dim), (bsize, 1, num_head, head_dim),) @@ -47,7 +50,7 @@ def forward(self, heads_input, states=None, head_mask=None): _csum_state = states[0] csum = self.normer_csum(_csum_state) csum_state_return = _csum_state + heads_input - igate, fgate, hidden = self.normer_hid(self.trans_hid(torch.cat((heads_input, csum,), dim=-1)).view(bsize, seql, nheads, 3, adim)).unbind(-2) + igate, fgate, hidden = self.normer_hid(self.trans_hid(torch.cat((heads_input, csum,), dim=-1)).view(bsize, seql, nheads, 3, -1)).unbind(-2) fgate = fgate.sigmoid() hidden = self.act(hidden) @@ -78,14 +81,15 @@ def __init__(self, isize, num_head=8, osize=None, dropout=0.0, enable_proj_bias= super(HPLSTM, self).__init__() _osize = isize if osize is None else osize + o_hsize = float2odd(float(_osize) / num_head) * num_head - self.head_dim = isize // num_head - self.hsize = self.head_dim * num_head + self.head_dim = float2odd(float(isize) / num_head) + i_hsize = self.head_dim * num_head self.num_head = num_head - self.trans_input = Linear(isize, self.hsize, bias=enable_proj_bias) - self.net = MHPLSTMCore(isize, num_head=self.num_head, osize=_osize, dropout=dropout) - self.trans_output = Linear(self.hsize, _osize, bias=enable_proj_bias) + self.trans_input = Linear(isize, i_hsize, bias=enable_proj_bias) + self.net = MHPLSTMCore(i_hsize, num_head=self.num_head, osize=o_hsize, dropout=dropout) + self.trans_output = Linear(o_hsize, _osize, bias=enable_proj_bias) def forward(self, inpute, states=None, head_mask=None): @@ -97,7 +101,7 @@ def forward(self, inpute, states=None, head_mask=None): else: out, states_return = self.net(heads_input, states=states, head_mask=head_mask) - out = self.trans_output(out.view(bsize, seql, self.hsize)) + out = self.trans_output(out.view(bsize, seql, -1)) if states is None: return out @@ -111,26 +115,27 @@ def __init__(self, isize, num_head=8, osize=None, dropout=0.0, enable_proj_bias= super(BiHPLSTM, self).__init__() _osize = isize if osize is None else osize + o_hsize = float2odd(float(_osize) / num_head) * num_head - self.head_dim = isize // num_head - self.hsize = self.head_dim * num_head + self.head_dim = float2odd(float(isize) / num_head) + i_hsize = self.head_dim * num_head self.num_head = num_head - self.trans_input = Linear(isize, self.hsize + self.hsize, bias=enable_proj_bias) - self.net = MHPLSTMCore(isize + isize, num_head=self.num_head + self.num_head, osize=_osize + _osize, dropout=dropout) - self.trans_output = Linear(self.hsize + self.hsize, _osize, bias=enable_proj_bias) + self.trans_input = Linear(isize, i_hsize + i_hsize, bias=enable_proj_bias) + self.net = MHPLSTMCore(i_hsize + i_hsize, num_head=self.num_head + self.num_head, osize=o_hsize + o_hsize, dropout=dropout) + self.trans_output = Linear(o_hsize + o_hsize, _osize, bias=enable_proj_bias) # inpute: (bsize, seql, isize) - # reversed_mask: (bsize, seql, 1, 1) input.eq(0).view(bsize, seql, 1, 1).flip(1) + # reversed_mask: (bsize, seql, 1, 1), generated by input.eq(0).view(bsize, seql, 1, 1).flip(1) def forward(self, inpute, reversed_mask=None): bsize, seql = inpute.size()[:2] - nheads, adim = self.num_head, self.head_dim - heads_input_fwd, heads_input_bwd = self.trans_input(inpute).view(bsize, seql, 2, nheads, adim).unbind(2) + nheads = self.num_head + heads_input_fwd, heads_input_bwd = self.trans_input(inpute).view(bsize, seql, 2, nheads, self.head_dim).unbind(2) heads_input_bwd_rvs = heads_input_bwd.flip(1) _r_mask = None if reversed_mask is None else torch.cat((reversed_mask.new_zeros(bsize, seql, nheads, 1), reversed_mask.expand(bsize, seql, nheads, 1)), dim=2) o_fwd, o_bwd_rvs = self.net(torch.cat((heads_input_fwd, heads_input_bwd_rvs,), dim=2), head_mask=_r_mask).chunk(2, dim=-2) o_bwd = o_bwd_rvs.flip(1) - return self.trans_output(torch.cat((o_fwd.view(bsize, seql, self.hsize), o_bwd.view(bsize, seql, self.hsize),), dim=-1)) + return self.trans_output(torch.cat((o_fwd.view(bsize, seql, -1), o_bwd.view(bsize, seql, -1),), dim=-1)) diff --git a/modules/hplstm/hfn.py b/modules/hplstm/hfn.py index 2735b83..e79a3cf 100644 --- a/modules/hplstm/hfn.py +++ b/modules/hplstm/hfn.py @@ -6,9 +6,9 @@ from modules.group.base import GroupLinear from modules.act import Custom_Act from modules.hplstm.LGate import LGateFunc +from utils.base import float2odd -from modules.hplstm.base import HPLSTM as HPLSTMBase -from modules.hplstm.base import BiHPLSTM as BiHPLSTMBase +from modules.hplstm.base import HPLSTM as HPLSTMBase, BiHPLSTM as BiHPLSTMBase from cnfg.ihyp import * @@ -19,21 +19,22 @@ def __init__(self, isize, num_head=8, osize=None, fhsize=None, dropout=0.0, cust super(MHPLSTMCore, self).__init__() _osize = isize if osize is None else osize - _fhsize = _osize * 4 if fhsize is None else fhsize - _head_fhsize = _fhsize // num_head - _fhsize = _head_fhsize * num_head - head_dim = isize // num_head - hsize = head_dim * num_head + i_head_dim = float2odd(float(isize) / num_head) + i_hsize = i_head_dim * num_head + o_head_dim = float2odd(float(_osize) / num_head) + o_hsize = o_head_dim * num_head + _head_fhsize = float2odd(float(o_hsize * 4 if fhsize is None else fhsize) / num_head) + _fhsize = _head_fhsize * num_head - self.trans_hid = nn.Sequential(GroupLinear(hsize + hsize, _fhsize, num_head, bias=enable_bias, shuffle=False, trans_input=False, flatten_output=False), nn.LayerNorm((num_head, _head_fhsize), eps=ieps_ln_default, elementwise_affine=enable_ln_parameters), Custom_Act() if custom_act else nn.ReLU(inplace=True), Dropout(dropout, inplace=inplace_after_Custom_Act), GroupLinear(_fhsize, hsize, num_head, bias=enable_proj_bias, shuffle=False, trans_input=False, flatten_output=False), Dropout(dropout, inplace=True)) if dropout > 0.0 else nn.Sequential(GroupLinear(hsize + hsize, _fhsize, num_head, bias=enable_bias, shuffle=False, trans_input=False, flatten_output=False), nn.LayerNorm((num_head, _head_fhsize), eps=ieps_ln_default, elementwise_affine=enable_ln_parameters), Custom_Act() if custom_act else nn.ReLU(inplace=True), GroupLinear(_fhsize, hsize, num_head, bias=enable_proj_bias, shuffle=False, trans_input=False, flatten_output=False)) - self.trans_ifg = GroupLinear(hsize + hsize, hsize + hsize, num_head, bias=enable_bias, shuffle=False, trans_input=False, flatten_output=False) - self.trans_og = nn.Sequential(GroupLinear(hsize + hsize, hsize, num_head, bias=enable_bias, shuffle=False, trans_input=False, flatten_output=False), nn.LayerNorm((num_head, head_dim), eps=ieps_ln_default, elementwise_affine=enable_ln_parameters)) + self.trans_hid = nn.Sequential(GroupLinear(i_hsize + i_hsize, _fhsize, num_head, bias=enable_bias, shuffle=False, trans_input=False, flatten_output=False), nn.LayerNorm((num_head, _head_fhsize), eps=ieps_ln_default, elementwise_affine=enable_ln_parameters), Custom_Act() if custom_act else nn.ReLU(inplace=True), Dropout(dropout, inplace=inplace_after_Custom_Act), GroupLinear(_fhsize, o_hsize, num_head, bias=enable_proj_bias, shuffle=False, trans_input=False, flatten_output=False), Dropout(dropout, inplace=True)) if dropout > 0.0 else nn.Sequential(GroupLinear(i_hsize + i_hsize, _fhsize, num_head, bias=enable_bias, shuffle=False, trans_input=False, flatten_output=False), nn.LayerNorm((num_head, _head_fhsize), eps=ieps_ln_default, elementwise_affine=enable_ln_parameters), Custom_Act() if custom_act else nn.ReLU(inplace=True), GroupLinear(_fhsize, o_hsize, num_head, bias=enable_proj_bias, shuffle=False, trans_input=False, flatten_output=False)) + self.trans_ifg = GroupLinear(i_hsize + i_hsize, o_hsize + o_hsize, num_head, bias=enable_bias, shuffle=False, trans_input=False, flatten_output=False) + self.trans_og = nn.Sequential(GroupLinear(i_hsize + o_hsize, o_hsize, num_head, bias=enable_bias, shuffle=False, trans_input=False, flatten_output=False), nn.LayerNorm((num_head, o_head_dim), eps=ieps_ln_default, elementwise_affine=enable_ln_parameters)) - self.normer_csum = nn.LayerNorm((num_head, head_dim), eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) - self.normer_ifg = nn.LayerNorm((num_head, 2, head_dim), eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) + self.normer_csum = nn.LayerNorm((num_head, i_head_dim), eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) + self.normer_ifg = nn.LayerNorm((num_head, 2, o_head_dim), eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) - self.init_cx = nn.Parameter(torch.zeros(1, num_head, head_dim)) + self.init_cx = nn.Parameter(torch.zeros(1, num_head, o_head_dim)) def forward(self, heads_input, states=None, head_mask=None): @@ -50,7 +51,7 @@ def forward(self, heads_input, states=None, head_mask=None): csum = self.normer_csum(_csum_state) csum_state_return = _csum_state + heads_input gh_input = torch.cat((heads_input, csum,), dim=-1) - (igate, fgate,), hidden = self.normer_ifg(self.trans_ifg(gh_input).view(bsize, seql, nheads, 2, adim)).unbind(-2), self.trans_hid(gh_input) + (igate, fgate,), hidden = self.normer_ifg(self.trans_ifg(gh_input).view(bsize, seql, nheads, 2, -1)).unbind(-2), self.trans_hid(gh_input) fgate = fgate.sigmoid() igh = igate.sigmoid() * hidden if head_mask is not None: @@ -75,19 +76,25 @@ class HPLSTM(HPLSTMBase): def __init__(self, isize, num_head=8, osize=None, fhsize=None, dropout=0.0, **kwargs): _osize = isize if osize is None else osize - _fhsize = _osize * 4 if fhsize is None else fhsize super(HPLSTM, self).__init__(isize, num_head=num_head, osize=_osize, dropout=dropout, **kwargs) - self.net = MHPLSTMCore(isize, num_head=self.num_head, osize=_osize, fhsize=_fhsize, dropout=dropout) + i_hsize = float2odd(float(isize) / num_head) * num_head + o_hsize = float2odd(float(_osize) / num_head) * num_head + _fhsize = float2odd(float(o_hsize * 4 if fhsize is None else fhsize) / num_head) * num_head + + self.net = MHPLSTMCore(i_hsize, num_head=self.num_head, osize=o_hsize, fhsize=_fhsize, dropout=dropout) class BiHPLSTM(BiHPLSTMBase): def __init__(self, isize, num_head=8, osize=None, fhsize=None, dropout=0.0, **kwargs): _osize = isize if osize is None else osize - _fhsize = _osize * 4 if fhsize is None else fhsize super(BiHPLSTM, self).__init__(isize, num_head=num_head, osize=_osize, dropout=dropout, **kwargs) - self.net = MHPLSTMCore(isize + isize, num_head=self.num_head + self.num_head, osize=_osize + _osize, fhsize=_fhsize + _fhsize, dropout=dropout) + i_hsize = float2odd(float(isize) / num_head) * num_head + o_hsize = float2odd(float(_osize) / num_head) * num_head + _fhsize = float2odd(float(o_hsize * 4 if fhsize is None else fhsize) / num_head) * num_head + + self.net = MHPLSTMCore(i_hsize + i_hsize, num_head=self.num_head + self.num_head, osize=o_hsize + o_hsize, fhsize=_fhsize + _fhsize, dropout=dropout) diff --git a/modules/mulang/__init__.py b/modules/mulang/__init__.py new file mode 100644 index 0000000..8fb0d7c --- /dev/null +++ b/modules/mulang/__init__.py @@ -0,0 +1 @@ +#encoding: utf-8 diff --git a/modules/mulang/eff/__init__.py b/modules/mulang/eff/__init__.py new file mode 100644 index 0000000..8fb0d7c --- /dev/null +++ b/modules/mulang/eff/__init__.py @@ -0,0 +1 @@ +#encoding: utf-8 diff --git a/modules/mulang/eff/base.py b/modules/mulang/eff/base.py new file mode 100644 index 0000000..9874e79 --- /dev/null +++ b/modules/mulang/eff/base.py @@ -0,0 +1,143 @@ +#encoding: utf-8 + +import torch +from torch import nn +from torch.nn import functional as nnFunc + +from modules.base import ResSelfAttn as ResSelfAttnBase, ResCrossAttn as ResCrossAttnBase, PositionwiseFF as PositionwiseFFBase + +from math import sqrt +from numbers import Integral + +from cnfg.ihyp import * + +class MBLinear(nn.Linear): + + def __init__(self, in_features, out_features, nbias, bias=True): + + super(MBLinear, self).__init__(in_features, out_features, bias=False) + + if bias: + self.bias = nn.Parameter(torch.zeros(nbias, out_features)) + + def forward(self, x, taskid): + + return nnFunc.linear(x, self.weight, None if self.bias is None else self.bias[taskid]) + + def fix_init(self): + + if self.bias is not None: + with torch.no_grad(): + self.bias.zero_() + +class MWLinear(MBLinear): + + def __init__(self, in_features, out_features, nbias, bias=True): + + super(MWLinear, self).__init__(in_features, out_features, nbias, bias=False) + + self.weight = nn.Parameter(torch.Tensor(nbias, out_features, in_features).uniform_(- sqrt(1.0 / in_features), sqrt(1.0 / in_features))) + + def forward(self, x, taskid): + + return nnFunc.linear(x, self.weight[taskid], None if self.bias is None else self.bias[taskid]) + + def fix_init(self): + + _isize = self.weight.size(-1) + with torch.no_grad(): + self.weight.data.uniform_(- sqrt(1.0 / _isize), sqrt(1.0 / _isize)) + super(MWLinear, self).fix_init() + +class LayerNorm(nn.LayerNorm): + + def __init__(self, normalized_shape, ntask=None, eps=1e-5, elementwise_affine=True, **kwargs): + + if isinstance(normalized_shape, Integral): + normalized_shape = (ntask, normalized_shape,) + else: + normalized_shape = tuple([ntask, *normalized_shape]) + + super(LayerNorm, self).__init__(normalized_shape, eps=eps, elementwise_affine=elementwise_affine, **kwargs) + + self.normalized_shape = self.normalized_shape[1:] + + def forward(self, input, taskid=None): + + return nnFunc.layer_norm(input, self.normalized_shape, None if self.weight is None else self.weight[taskid], None if self.bias is None else self.bias[taskid], self.eps) + +class ResSelfAttn(ResSelfAttnBase): + + def __init__(self, isize, hsize, num_head=8, dropout=0.0, norm_residual=norm_residual_default, ntask=None, **kwargs): + + super(ResSelfAttn, self).__init__(isize, hsize, num_head=num_head, dropout=dropout, norm_residual=norm_residual, **kwargs) + + self.normer = LayerNorm(isize, ntask=ntask, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) + + def forward(self, iQ, *inputs, taskid=None, **kwargs): + + _iQ = self.normer(iQ, taskid=taskid) + + outs = self.net(_iQ, *inputs, **kwargs) + + if isinstance(outs, tuple): + _out = outs[0] + + if self.drop is not None: + _out = self.drop(_out) + + return _out + (_iQ if self.norm_residual else iQ), *outs[1:] + + else: + if self.drop is not None: + outs = self.drop(outs) + + return outs + (_iQ if self.norm_residual else iQ) + +class ResCrossAttn(ResCrossAttnBase): + + def __init__(self, isize, hsize, num_head=8, dropout=0.0, norm_residual=norm_residual_default, ntask=None, **kwargs): + + super(ResCrossAttn, self).__init__(isize, hsize, num_head=num_head, dropout=dropout, norm_residual=norm_residual, **kwargs) + + self.normer = LayerNorm(isize, ntask=ntask, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) + + def forward(self, iQ, iK, *inputs, taskid=None, **kwargs): + + _iQ = self.normer(iQ, taskid=taskid) + + outs = self.net(_iQ, iK, *inputs, **kwargs) + + if isinstance(outs, tuple): + _out = outs[0] + + if self.drop is not None: + _out = self.drop(_out) + + return _out + (_iQ if self.norm_residual else iQ), *outs[1:] + + else: + if self.drop is not None: + outs = self.drop(outs) + + return outs + (_iQ if self.norm_residual else iQ) + +class PositionwiseFF(PositionwiseFFBase): + + def __init__(self, isize, hsize=None, dropout=0.0, norm_residual=norm_residual_default, ntask=None, **kwargs): + + _hsize = isize * 4 if hsize is None else hsize + + super(PositionwiseFF, self).__init__(isize, hsize=_hsize, dropout=dropout, norm_residual=norm_residual, **kwargs) + + self.normer = LayerNorm(isize, ntask=ntask, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) + + def forward(self, x, taskid=None): + + _out = self.normer(x, taskid=taskid) + + out = self.net(_out) + + out = out + (_out if self.norm_residual else x) + + return out diff --git a/modules/noise.py b/modules/noise.py index 8e5ccde..7618618 100644 --- a/modules/noise.py +++ b/modules/noise.py @@ -3,7 +3,7 @@ import torch from torch import nn -from modules.base import PositionwiseFF as PositionwiseFFBase +from modules.base import ResSelfAttn as ResSelfAttnBase, ResCrossAttn as ResCrossAttnBase, PositionwiseFF as PositionwiseFFBase from cnfg.ihyp import * @@ -69,6 +69,70 @@ def get_noise(self, inpute, mask=None): Noiser = UniNoiserVec +class ResSelfAttn(ResSelfAttnBase): + + def __init__(self, isize, hsize, num_head=8, dropout=0.0, norm_residual=norm_residual_default, power=None, custom_noiser=None, **kwargs): + + super(ResSelfAttn, self).__init__(isize, hsize, num_head=num_head, dropout=dropout, norm_residual=norm_residual, **kwargs) + + _noiser = Noiser if custom_noiser is None else custom_noiser + self.noiser = None if power is None else _noiser(power, inplace=True) + + def forward(self, iQ, *inputs, noise_mask=None, **kwargs): + + _iQ = self.normer(iQ) + + if self.noiser is not None: + _iQ = self.noiser(_iQ, noise_mask) + + outs = self.net(_iQ, *inputs, **kwargs) + + if isinstance(outs, tuple): + _out = outs[0] + + if self.drop is not None: + _out = self.drop(_out) + + return _out + (_iQ if self.norm_residual else iQ), *outs[1:] + + else: + if self.drop is not None: + outs = self.drop(outs) + + return outs + (_iQ if self.norm_residual else iQ) + +class ResCrossAttn(ResCrossAttnBase): + + def __init__(self, isize, hsize, num_head=8, dropout=0.0, norm_residual=norm_residual_default, power=None, custom_noiser=None, **kwargs): + + super(ResCrossAttn, self).__init__(isize, hsize, num_head=num_head, dropout=dropout, norm_residual=norm_residual, **kwargs) + + _noiser = Noiser if custom_noiser is None else custom_noiser + self.noiser = None if power is None else _noiser(power, inplace=True) + + def forward(self, iQ, iK, *inputs, noise_mask=None, **kwargs): + + _iQ = self.normer(iQ) + + if self.noiser is not None: + _iQ = self.noiser(_iQ, noise_mask) + + outs = self.net(_iQ, iK, *inputs, **kwargs) + + if isinstance(outs, tuple): + _out = outs[0] + + if self.drop is not None: + _out = self.drop(_out) + + return _out + (_iQ if self.norm_residual else iQ), *outs[1:] + + else: + if self.drop is not None: + outs = self.drop(outs) + + return outs + (_iQ if self.norm_residual else iQ) + class PositionwiseFF(PositionwiseFFBase): def __init__(self, isize, power=None, custom_noiser=None, **kwargs): diff --git a/parallel/base.py b/parallel/base.py index 327d5a9..d0c1c3d 100644 --- a/parallel/base.py +++ b/parallel/base.py @@ -1,9 +1,11 @@ #encoding: utf-8 import torch +from torch import nn import torch.cuda.comm as comm from torch.cuda.amp import autocast from utils.comm import secure_broadcast_coalesced +from utils.contpara import get_contiguous_parameters_m, get_all_contiguous_parameters_m, get_contiguous_parameters_p from torch.jit import ScriptModule from torch._C import ScriptMethod @@ -18,7 +20,7 @@ from parallel.optm import MultiGPUOptimizer -""" Example:: +""" Example: >>> net = DataParallelModel(model, device_ids=[0, 1, 2]) >>> criterion = DataParallelCriterion(criterion, device_ids=[0, 1, 2]) @@ -26,6 +28,11 @@ >>> loss = criterion(y, target) """ +def replicate_fixing(module): + + if hasattr(module, "c_available") and hasattr(module, "c_build_cache") and module.c_available(): + module.c_build_cache() + class DataParallelModel(DataParallel): # host replicates should improve a little bit performance if there are additional calls to update_replicas and collect_gradients in the training scripts. @@ -33,13 +40,14 @@ def __init__(self, module, device_ids=None, output_device=None, dim=0, host_repl super(DataParallelModel, self).__init__(module, device_ids=device_ids, output_device=output_device, dim=dim) + self.is_contiguous_parameters = False if host_replicate and self.device_ids and (len(self.device_ids) > 1): self.make_replicas() else: self.nets = None - - self.optm_splt = None self.gather_output = gather_output + self.lock = Lock() + self.optm_splt = None self.ngradev = 0 def forward(self, *inputs, **kwargs): @@ -60,54 +68,94 @@ def forward(self, *inputs, **kwargs): else: devices = self.device_ids[:ngpu] replicas = self.replicate(self.module, devices) if self.nets is None else self.nets[:ngpu] - outputs = parallel_apply(replicas, inputs, devices, kwargs) + outputs = parallel_apply(replicas, inputs, devices, kwargs, lock=self.lock) if self.gather_output: return self.gather(outputs, self.output_device) else: return tuple(zip(*outputs)) if isinstance(outputs[0], tuple) else outputs - def train(self, mode=True): + def zero_grad(self, set_to_none=True): + + if self.is_contiguous_parameters: + with torch.no_grad(): + for para in get_all_contiguous_parameters_m(self.module): + para.grad.zero_() + if self.nets is not None and self.ngradev > 1: + for net in self.nets[1:self.ngradev]: + for para in get_all_contiguous_parameters_m(net): + para.grad.zero_() + else: + for para in filter_para_grad(self.module.parameters()): + para.grad = None + if self.nets is not None and self.ngradev > 1: + for net in self.nets[1:self.ngradev]: + for para in filter_para_grad(net.parameters()): + para.grad = None + self.ngradev = 0 + + # below 2 functions support direct access to the wrapped module parameters/modules, but exclude direct access to copies (self.nets) + def named_parameters(self, prefix='', recurse=True): - super(DataParallelModel, self).train(mode) + return self.module.named_parameters(prefix=prefix, recurse=recurse) - if self.nets is not None: - for net in self.nets[1:]: - net.train(mode) + def named_modules(self, memo=None, prefix='', remove_duplicate=True): - return self + return self.module.named_modules(memo=memo, prefix=prefix, remove_duplicate=remove_duplicate) def make_replicas(self): - self.nets = replicate(self.module, self.device_ids, True) + self.nets = nn.ModuleList(replicate(self.module, self.device_ids, True)) + for net in self.nets[1:]: + net.apply(replicate_fixing) self.ngradev = 0 def collect_gradients(self): if self.optm_splt is not None: - grads = [[p.grad for p in filter_para_grad(net.parameters())] for net in self.nets[:self.ngradev]] - for i, (net, device, (lind, rind,),) in enumerate(zip(self.nets, self.device_ids, self.optm_splt)): - _dev_grads = [gradu[lind:rind] for gradu in grads] - if i > 0: - _dev_grads.insert(0, _dev_grads.pop(i) if i < self.ngradev else [_pg.new_zeros(_pg.size(), device=device) for _pg in _dev_grads[0]]) - _dev_grads = comm.reduce_add_coalesced(_dev_grads, device) - for mp, grad in zip(range_parameter_iter(net, lind, rind, func=filter_para_grad_iter), _dev_grads): - mp.grad = grad + if self.is_contiguous_parameters: + for i, (net, device,) in enumerate(zip(self.nets, self.device_ids)): + _dev_grads = [[para.grad for para in get_contiguous_parameters_m(_net, index=i)] for _net in self.nets[:self.ngradev]] + if i > 0: + _dev_grads.insert(0, _dev_grads.pop(i) if i < self.ngradev else [para.grad for para in get_contiguous_parameters_m(net, index=i)]) + _dev_grads = comm.reduce_add_coalesced(_dev_grads, device) + for mp, grad in zip(get_contiguous_parameters_m(net, index=i), _dev_grads): + mp.grad.copy_(grad) + else: + grads = [[para.grad for para in filter_para_grad(net.parameters())] for net in self.nets[:self.ngradev]] + for i, (net, device, (lind, rind,),) in enumerate(zip(self.nets, self.device_ids, self.optm_splt)): + _dev_grads = [gradu[lind:rind] for gradu in grads] + if i > 0: + _dev_grads.insert(0, _dev_grads.pop(i) if i < self.ngradev else [_pg.new_zeros(_pg.size(), device=device) for _pg in _dev_grads[0]]) + _dev_grads = comm.reduce_add_coalesced(_dev_grads, device) + for mp, grad in zip(range_parameter_iter(net, lind, rind, func=filter_para_grad_iter), _dev_grads): + mp.grad = grad elif self.ngradev > 1: - # in case some parameters might not be used during the forward propagation on some GPUs: p.data.new_zeros(p.data.size()) if p.grad is None else p.grad instead of p.grad, but in most cases, this can warn you in case you miss the use of some parameters in the forward computation. - grads = comm.reduce_add_coalesced([[p.grad for p in filter_para_grad(net.parameters())] for net in self.nets[:self.ngradev]], self.output_device)# if self.ngradev > 1 else [p.grad for p in filter_para_grad(self.nets[0].parameters())] - for mp, grad in zip(filter_para_grad(self.module.parameters()), grads): - mp.grad = grad + if self.is_contiguous_parameters: + grads = comm.reduce_add_coalesced([[para.grad for para in get_all_contiguous_parameters_m(net)] for net in self.nets[:self.ngradev]], self.output_device) + for mp, grad in zip(get_all_contiguous_parameters_m(self.module), grads): + mp.grad.copy_(grad) + else: + # in case some parameters might not be used during the forward propagation on some GPUs: p.data.new_zeros(p.data.size()) if p.grad is None else p.grad instead of p.grad, but in most cases, this can warn you in case you miss the use of some parameters in the forward computation. + grads = comm.reduce_add_coalesced([[para.grad for para in filter_para_grad(net.parameters())] for net in self.nets[:self.ngradev]], self.output_device)# if self.ngradev > 1 else [p.grad for p in filter_para_grad(self.nets[0].parameters())] + for mp, grad in zip(filter_para_grad(self.module.parameters()), grads): + mp.grad = grad # the parallelization of the update of parameters can be supported, but not adviced, since the cost of multi threads is much higher and thus slower than the loop unless you are running on lots of GPUs. # Note that gradients will be cleared every time this function was called def update_replicas(self): if self.optm_splt is None: - params = [para.data for para in filter_para_grad(self.module.parameters())] - - if len(params) > 0: + if self.is_contiguous_parameters: + params = [para.data for para in get_all_contiguous_parameters_m(self.module)] + param_copies = comm.broadcast_coalesced(params, self.device_ids) + with torch.no_grad(): + for module, param_copy in zip(self.nets[1:], param_copies[1:]): + for mp, para in zip(get_all_contiguous_parameters_m(module), param_copy): + mp.data.copy_(para) + mp.grad.zero_() + else: + params = [para.data for para in filter_para_grad(self.module.parameters())] param_copies = comm.broadcast_coalesced(params, self.device_ids) - # currently, pytorch broadcast binds parameters between self.nets[0] and self.module, so the following line ensures correctness but less efficient #for module, param_copy in zip(self.nets, param_copies): for module, param_copy in zip(self.nets[1:], param_copies[1:]): @@ -115,7 +163,7 @@ def update_replicas(self): mp.data, mp.grad = para, None else: for i, (net, (lind, rind,),) in enumerate(zip(self.nets, self.optm_splt)): - _dev_params = [para.data for para in range_parameter_iter(net, lind, rind, func=filter_para_grad_iter)] + _dev_params = [para.data for para in get_contiguous_parameters_m(net, index=i)] if self.is_contiguous_parameters else [para.data for para in range_parameter_iter(net, lind, rind, func=filter_para_grad_iter)] if i > 0: _devices = self.device_ids[:] _devices.insert(0, _devices.pop(i)) @@ -128,26 +176,38 @@ def update_replicas(self): pc.extend(_dpc) else: param_copies = _dev_param_copies - for module, param_copy in zip(self.nets, param_copies): - for mp, para in zip(filter_para_grad(module.parameters()), param_copy): - mp.data, mp.grad = para, None + if self.is_contiguous_parameters: + with torch.no_grad(): + for module, param_copy in zip(self.nets, param_copies): + for mp, para in zip(get_all_contiguous_parameters_m(module), param_copy): + mp.data.copy_(para) + mp.grad.zero_() + else: + for module, param_copy in zip(self.nets, param_copies): + for mp, para in zip(filter_para_grad(module.parameters()), param_copy): + mp.data, mp.grad = para, None self.ngradev = 0 def update_replicas_para(self): if self.optm_splt is None: - params = [para.data for para in filter_para_grad(self.module.parameters())] - - if len(params) > 0: + if self.is_contiguous_parameters: + params = [para.data for para in get_all_contiguous_parameters_m(self.module)] + param_copies = comm.broadcast_coalesced(params, self.device_ids) + with torch.no_grad(): + for module, param_copy in zip(self.nets[1:], param_copies[1:]): + for mp, para in zip(get_all_contiguous_parameters_m(module), param_copy): + mp.data.copy_(para) + else: + params = [para.data for para in filter_para_grad(self.module.parameters())] param_copies = comm.broadcast_coalesced(params, self.device_ids) - for module, param_copy in zip(self.nets[1:], param_copies[1:]): for mp, para in zip(filter_para_grad(module.parameters()), param_copy): mp.data = para else: for i, (net, (lind, rind,),) in enumerate(zip(self.nets, self.optm_splt)): - _dev_params = [para.data for para in range_parameter_iter(net, lind, rind, func=filter_para_grad_iter)] + _dev_params = [para.data for para in get_contiguous_parameters_m(net, index=i)] if self.is_contiguous_parameters else [para.data for para in range_parameter_iter(net, lind, rind, func=filter_para_grad_iter)] if i > 0: _devices = self.device_ids[:] _devices.insert(0, _devices.pop(i)) @@ -160,59 +220,72 @@ def update_replicas_para(self): pc.extend(_dpc) else: param_copies = _dev_param_copies - for module, param_copy in zip(self.nets, param_copies): - for mp, para in zip(filter_para_grad(module.parameters()), param_copy): - mp.data = para - - self.ngradev = 0 - - def zero_grad(self, set_to_none=True): + if self.is_contiguous_parameters: + with torch.no_grad(): + for module, param_copy in zip(self.nets, param_copies): + for mp, para in zip(get_all_contiguous_parameters_m(module), param_copy): + mp.data.copy_(para) + else: + for module, param_copy in zip(self.nets, param_copies): + for mp, para in zip(filter_para_grad(module.parameters()), param_copy): + mp.data = para - self.module.zero_grad(set_to_none=set_to_none) - if self.nets is not None and self.ngradev > 1: - # currently, pytorch broadcast binds parameters between self.nets[0] and self.module, so the following line ensures correctness but less efficient - #for net in self.nets: - for net in self.nets[1:]: - net.zero_grad(set_to_none=set_to_none) self.ngradev = 0 def collect_gradients_func(self, func): if self.ngradev > 1: grads = comm.reduce_add_coalesced([[p.grad for p in filter_para_grad(func(net).parameters())] for net in self.nets[:self.ngradev]], self.output_device) - for mp, grad in zip(filter_para_grad(func(self.module).parameters()), grads): - mp.grad = grad + if self.is_contiguous_parameters: + for mp, grad in zip(filter_para_grad(func(self.module).parameters()), grads): + mp.grad.copy_(grad) + else: + for mp, grad in zip(filter_para_grad(func(self.module).parameters()), grads): + mp.grad = grad def zero_replicas_grad(self, func=None): if self.nets is not None and self.ngradev > 1: if func is None: - for net in self.nets[1:self.ngradev]: - for para in filter_para_grad(net.parameters()): - para.grad = None + if self.is_contiguous_parameters: + for net in self.nets[1:self.ngradev]: + for para in get_all_contiguous_parameters_m(net): + para.grad.zero_() + else: + for net in self.nets[1:self.ngradev]: + for para in filter_para_grad(net.parameters()): + para.grad = None else: - for net in self.nets[1:self.ngradev]: - for para in filter_para_grad(func(net).parameters()): - para.grad = None - - def reset_grad(self): - - for para in filter_para_grad(self.module.parameters()): - para.grad = None - if self.nets is not None and self.ngradev > 1: - for net in self.nets[1:self.ngradev]: - for para in filter_para_grad(net.parameters()): - para.grad = None - self.ngradev = 0 + if self.is_contiguous_parameters: + for net in self.nets[1:self.ngradev]: + for para in filter_para_grad(func(net).parameters()): + para.grad.zero_() + else: + for net in self.nets[1:self.ngradev]: + for para in filter_para_grad(func(net).parameters()): + para.grad = None - def build_optimizer(self, optm_func, *optm_args, **optm_kwargs): + def build_optimizer(self, optm_func, *optm_args, multi_gpu_optimizer=False, contiguous_parameters=False, **optm_kwargs): + self.is_contiguous_parameters = contiguous_parameters paras = filter_para_grad(self.module.parameters()) - if self.nets is None or (len(paras) < 2): - return optm_func(self.module.parameters(), *optm_args, **optm_kwargs) + if (not multi_gpu_optimizer) or self.nets is None or (len(paras) < 2): + if contiguous_parameters: + if self.nets is not None: + for net in self.nets[1:]: + get_contiguous_parameters_m(net) + _mp = get_contiguous_parameters_m(self.module) + else: + _mp = self.module.parameters() + return optm_func(_mp, *optm_args, **optm_kwargs) else: self.optm_splt, _np = divide_para_ind(paras, len(self.device_ids), return_np=True) - optml = [optm_func(range_parameter_iter(net, lind, rind, func=filter_para_grad_iter), *optm_args, **optm_kwargs) for net, (lind, rind,) in zip(self.nets, self.optm_splt)] + if contiguous_parameters: + for net in self.nets: + get_contiguous_parameters_p([list(range_parameter_iter(net, lind, rind, func=filter_para_grad_iter)) for lind, rind in self.optm_splt], model=net) + optml = [optm_func(get_contiguous_parameters_m(net, index=i), *optm_args, **optm_kwargs) for i, net in enumerate(self.nets)] + else: + optml = [optm_func(range_parameter_iter(net, lind, rind, func=filter_para_grad_iter), *optm_args, **optm_kwargs) for net, (lind, rind,) in zip(self.nets, self.optm_splt)] # sort the optimizers with slightly more parameters ahead to start their optimization steps earlier optml, _device_ids = reorder_by_sort(_np, optml, self.device_ids[:len(optml)], reverse=True) return MultiGPUOptimizer(optml, device_ids=_device_ids) @@ -225,9 +298,10 @@ def __init__(self, module, device_ids=None, output_device=None, dim=0, replicate super(DataParallelCriterion, self).__init__(module, device_ids=device_ids, output_device=output_device, dim=dim) if replicate_once and self.device_ids and (len(self.device_ids) > 1): - self.nets = replicate(self.module, self.device_ids, True) + self.nets = nn.ModuleList(replicate(self.module, self.device_ids, True)) else: self.nets = None + self.lock = Lock() def forward(self, inputs, *targets, **kwargs): # input should be already scatterd @@ -241,7 +315,7 @@ def forward(self, inputs, *targets, **kwargs): return self.module(inputs[0], *targets[0], **kwargs[0]) devices = self.device_ids[:ngpu] replicas = self.replicate(self.module, devices) if self.nets is None else self.nets[:ngpu] - outputs = criterion_parallel_apply(replicas, inputs, targets, devices, kwargs) + outputs = criterion_parallel_apply(replicas, inputs, targets, devices, kwargs, lock=self.lock) return self.gather(outputs, self.output_device) @@ -283,8 +357,7 @@ def clear_gradient(para): module_indices[module] = i for j in range(num_replicas): if isinstance(module, ScriptModule): - # we have to initialize ScriptModule properly so that - # it works with pybind11 + # we have to initialize ScriptModule properly so that it works with pybind11 replica = module._replicate_for_data_parallel() replica._former_parameters = OrderedDict() '''replica = ScriptModule() @@ -351,12 +424,12 @@ def clear_gradient(para): # update these two functions with the update of parallel_apply(https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/parallel_apply.py) -def parallel_apply(modules, inputs, devices, kwargs_tup=None): +def parallel_apply(modules, inputs, devices, kwargs_tup=None, lock=None): if kwargs_tup is None: kwargs_tup = ({},) * len(modules) - lock = Lock() + lock = Lock() if lock is None else lock results = {} grad_enabled, autocast_enabled = torch.is_grad_enabled(), torch.is_autocast_enabled() @@ -381,14 +454,15 @@ def _worker(i, module, input, kwargs, device=None): for i in range(len(inputs)): output = results[i] outputs.append(output) + return outputs -def criterion_parallel_apply(modules, inputs, targets, devices, kwargs_tup=None): +def criterion_parallel_apply(modules, inputs, targets, devices, kwargs_tup=None, lock=None): if kwargs_tup is None: kwargs_tup = ({},) * len(modules) - lock = Lock() + lock = Lock() if lock is None else lock results = {} grad_enabled, autocast_enabled = torch.is_grad_enabled(), torch.is_autocast_enabled() @@ -414,4 +488,5 @@ def _worker(i, module, input, target, kwargs, device): for i in range(len(inputs)): output = results[i] outputs.append(output) + return outputs diff --git a/parallel/parallelMT.py b/parallel/parallelMT.py index 402ec94..31e0c26 100644 --- a/parallel/parallelMT.py +++ b/parallel/parallelMT.py @@ -26,7 +26,7 @@ def decode(self, *inputs, **kwargs): replicas = self.replicate(self.module, devices) else: replicas = self.nets[:ngpu] - outputs = parallel_apply_decode(replicas, inputs, devices, kwargs) + outputs = parallel_apply_decode(replicas, inputs, devices, kwargs, lock=self.lock) return self.gather(pad_tensors(outputs), self.output_device) if self.gather_output else outputs def train_decode(self, *inputs, **kwargs): @@ -43,17 +43,17 @@ def train_decode(self, *inputs, **kwargs): replicas = self.replicate(self.module, devices) else: replicas = self.nets[:ngpu] - outputs = parallel_apply_train_decode(replicas, inputs, devices, kwargs) + outputs = parallel_apply_train_decode(replicas, inputs, devices, kwargs, lock=self.lock) return self.gather(pad_tensors(outputs), self.output_device) if self.gather_output else outputs # update these two functions with the update of parallel_apply(https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/parallel_apply.py) -def parallel_apply_decode(modules, inputs, devices, kwargs_tup=None): +def parallel_apply_decode(modules, inputs, devices, kwargs_tup=None, lock=None): if kwargs_tup is None: kwargs_tup = ({},) * len(modules) - lock = Lock() + lock = Lock() if lock is None else lock results = {} grad_enabled, autocast_enabled = torch.is_grad_enabled(), torch.is_autocast_enabled() @@ -77,14 +77,15 @@ def _worker(i, module, input, kwargs, device=None): for i in range(len(inputs)): output = results[i] outputs.append(output) + return outputs -def parallel_apply_train_decode(modules, inputs, devices, kwargs_tup=None): +def parallel_apply_train_decode(modules, inputs, devices, kwargs_tup=None, lock=None): if kwargs_tup is None: kwargs_tup = ({},) * len(modules) - lock = Lock() + lock = Lock() if lock is None else lock results = {} grad_enabled, autocast_enabled = torch.is_grad_enabled(), torch.is_autocast_enabled() @@ -108,4 +109,5 @@ def _worker(i, module, input, kwargs, device=None): for i in range(len(inputs)): output = results[i] outputs.append(output) + return outputs diff --git a/predict.py b/predict.py index 538a2a4..5297f17 100644 --- a/predict.py +++ b/predict.py @@ -63,7 +63,6 @@ def load_fixing(module): mymodel = DataParallelMT(mymodel, device_ids=cuda_devices, output_device=cuda_device.index, host_replicate=True, gather_output=False) beam_size = cnfg.beam_size - length_penalty = cnfg.length_penalty ens = "\n".encode("utf-8") diff --git a/requirements.txt b/requirements.txt index 0d6034f..4aa3753 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -tqdm>=4.59.0 +tqdm>=4.62.0 torch>=1.9.0 h5py>=3.2.1 diff --git a/scripts/README.md b/scripts/README.md index 3796acf..8236547 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -68,6 +68,8 @@ export dataid=w14ende # number of GPU(s) plan to use for decoding. export ngpu=1 +# gather sentences of similar lengths together for decoding. +export sort_decode=true # merge sub-words export debpe=true ``` diff --git a/scripts/ape/mktest.sh b/scripts/ape/mktest.sh index efe5bd4..62ec523 100644 --- a/scripts/ape/mktest.sh +++ b/scripts/ape/mktest.sh @@ -16,7 +16,9 @@ export dataid=w19ape export ngpu=1 +export sort_decode=true export debpe=true +export spm_bpe=false export tgtd=$cachedir/$dataid @@ -32,14 +34,33 @@ fi mkdir -p $rsd -python tools/sort.py $srcd/$srctf $srcd/$srcmf $tgtd/$srctf.srt $tgtd/$srcmf.srt 1048576 -python tools/mkiodata.py $tgtd/$srctf.srt $tgtd/$srcmf.srt $src_vcb $tgt_vcb $tgtd/test.h5 $ngpu +if $sort_decode; then + export srt_input_f=$tgtd/$srctf.srt + export srt_input_fm=$tgtd/$srcmf.srt + python tools/sort.py $srcd/$srctf $srcd/$srcmf $srt_input_f $srt_input_fm 1048576 +else + export srt_input_f=$srcd/$srctf + export srt_input_fm=$srcd/$srcmf +fi + +python tools/mkiodata.py $srt_input_f $srt_input_fm $src_vcb $tgt_vcb $tgtd/test.h5 $ngpu python predict_ape.py $tgtd/$bpef.srt $tgt_vcb $modelf -python tools/restore.py $srcd/$srctf $srcd/$srcmf $tgtd/$srctf.srt $tgtd/$srcmf.srt $tgtd/$bpef.srt $tgtd/$bpef + +if $sort_decode; then + python tools/restore.py $srcd/$srctf $srcd/$srcmf $srt_input_f $srt_input_fm $tgtd/$bpef.srt $tgtd/$bpef + rm $srt_input_f $srt_input_fm $tgtd/$bpef.srt +else + mv $tgtd/$bpef.srt $tgtd/$bpef +fi + if $debpe; then - sed -r 's/(@@ )|(@@ ?$)//g' < $tgtd/$bpef > $rsf + if $spm_bpe; then + python tools/spm/decode.py --model $tgtd/bpe.model --input_format piece --input $tgtd/$bpef > $rsf + + else + sed -r 's/(@@ )|(@@ ?$)//g' < $tgtd/$bpef > $rsf + fi rm $tgtd/$bpef else mv $tgtd/$bpef $rsf fi -rm $tgtd/$srctf.srt $tgtd/$srcmf.srt $tgtd/$bpef.srt diff --git a/scripts/bpe/mk.sh b/scripts/bpe/mk.sh index b462179..6e2a0bd 100644 --- a/scripts/bpe/mk.sh +++ b/scripts/bpe/mk.sh @@ -13,7 +13,7 @@ export srcvf=dev.tc.en export tgtvf=dev.tc.de export vratio=0.2 -export rratio=0.4 +export rratio=0.6 export maxtokens=256 export bpeops=32000 diff --git a/scripts/doc/para/mktest.sh b/scripts/doc/para/mktest.sh index 67d7cdd..8b4a49a 100644 --- a/scripts/doc/para/mktest.sh +++ b/scripts/doc/para/mktest.sh @@ -15,7 +15,9 @@ export dataid=w19edoc export ngpu=1 +export sort_decode=true export debpe=true +export spm_bpe=false export tgtd=$cachedir/$dataid @@ -31,14 +33,31 @@ fi mkdir -p $rsd -python tools/doc/sort.py $srcd/$srctf $tgtd/$srctf.srt 1048576 -python tools/doc/para/mktest.py $tgtd/$srctf.srt $src_vcb $tgtd/test.h5 $ngpu +if $sort_decode; then + export srt_input_f=$tgtd/$srctf.srt + python tools/doc/sort.py $srcd/$srctf $srt_input_f 1048576 +else + export srt_input_f=$srcd/$srctf +fi + +python tools/doc/para/mktest.py $srt_input_f $src_vcb $tgtd/test.h5 $ngpu python predict_doc_para.py $tgtd/$bpef.srt $tgt_vcb $modelf -python tools/doc/para/restore.py $srcd/$srctf w19ed/test.en.w19ed w19edtrs/base_avg.tbrs $tgtd/$srctf.srt $tgtd/$bpef.srt $tgtd/$bpef + +if $sort_decode; then + python tools/doc/para/restore.py $srcd/$srctf w19ed/test.en.w19ed w19edtrs/base_avg.tbrs $srt_input_f $tgtd/$bpef.srt $tgtd/$bpef + rm $srt_input_f $tgtd/$bpef.srt +else + mv $tgtd/$bpef.srt $tgtd/$bpef +fi + if $debpe; then - sed -r 's/(@@ )|(@@ ?$)//g' < $tgtd/$bpef > $rsf + if $spm_bpe; then + python tools/spm/decode.py --model $tgtd/bpe.model --input_format piece --input $tgtd/$bpef > $rsf + + else + sed -r 's/(@@ )|(@@ ?$)//g' < $tgtd/$bpef > $rsf + fi rm $tgtd/$bpef else mv $tgtd/$bpef $rsf fi -rm $tgtd/$srctf.srt $tgtd/$bpef.srt diff --git a/scripts/mktest.sh b/scripts/mktest.sh index 142054e..c84db38 100644 --- a/scripts/mktest.sh +++ b/scripts/mktest.sh @@ -15,7 +15,9 @@ export dataid=w14ed32 export ngpu=1 +export sort_decode=true export debpe=true +export spm_bpe=false export tgtd=$cachedir/$dataid @@ -31,14 +33,31 @@ fi mkdir -p $rsd -python tools/sort.py $srcd/$srctf $tgtd/$srctf.srt 1048576 -python tools/mktest.py $tgtd/$srctf.srt $src_vcb $tgtd/test.h5 $ngpu +if $sort_decode; then + export srt_input_f=$tgtd/$srctf.srt + python tools/sort.py $srcd/$srctf $srt_input_f 1048576 +else + export srt_input_f=$srcd/$srctf +fi + +python tools/mktest.py $srt_input_f $src_vcb $tgtd/test.h5 $ngpu python predict.py $tgtd/$bpef.srt $tgt_vcb $modelf -python tools/restore.py $srcd/$srctf $tgtd/$srctf.srt $tgtd/$bpef.srt $tgtd/$bpef + +if $sort_decode; then + python tools/restore.py $srcd/$srctf $srt_input_f $tgtd/$bpef.srt $tgtd/$bpef + rm $srt_input_f $tgtd/$bpef.srt +else + mv $tgtd/$bpef.srt $tgtd/$bpef +fi + if $debpe; then - sed -r 's/(@@ )|(@@ ?$)//g' < $tgtd/$bpef > $rsf + if $spm_bpe; then + python tools/spm/decode.py --model $tgtd/bpe.model --input_format piece --input $tgtd/$bpef > $rsf + + else + sed -r 's/(@@ )|(@@ ?$)//g' < $tgtd/$bpef > $rsf + fi rm $tgtd/$bpef else mv $tgtd/$bpef $rsf fi -rm $tgtd/$srctf.srt $tgtd/$bpef.srt diff --git a/scripts/mulang/mktest.sh b/scripts/mulang/mktest.sh new file mode 100644 index 0000000..ded6712 --- /dev/null +++ b/scripts/mulang/mktest.sh @@ -0,0 +1,62 @@ +#!/bin/bash + +set -e -o pipefail -x + +export srcd=opus +export srctf=$1 +export modelf="expm/opus/std/base/checkpoint.h5" +export rsd=opurs +export rsf=$rsd/trans.txt + +export share_vcb=true + +export cachedir=cache +export dataid=opus + +export ngpu=1 + +export sort_decode=true +export debpe=true +export spm_bpe=false + +export tgtd=$cachedir/$dataid + +export bpef=out.bpe + +if $share_vcb; then + export src_vcb=$tgtd/common.vcb + export tgt_vcb=$src_vcb +else + export src_vcb=$tgtd/src.vcb + export tgt_vcb=$tgtd/tgt.vcb +fi + +mkdir -p $rsd + +if $sort_decode; then + export srt_input_f=$tgtd/$srctf.srt + python tools/mulang/eff/sort.py $srcd/$srctf $srt_input_f 1048576 +else + export srt_input_f=$srcd/$srctf +fi + +python tools/mulang/eff/mktest.py $srt_input_f $src_vcb $tgtd/lang.vcb $tgtd/test.h5 $ngpu +python predict_mulang.py $tgtd/$bpef.srt $tgt_vcb $modelf + +if $sort_decode; then + python tools/restore.py $srcd/$srctf $srt_input_f $tgtd/$bpef.srt $tgtd/$bpef + rm $srt_input_f $tgtd/$bpef.srt $tgtd/$bpef +else + mv $tgtd/$bpef.srt $tgtd/$bpef +fi + +if $debpe; then + if $spm_bpe; then + python tools/spm/decode.py --model $tgtd/bpe.model --input_format piece --input $tgtd/$bpef > $rsf + + else + sed -r 's/(@@ )|(@@ ?$)//g' < $tgtd/$bpef > $rsf + fi +else + mv $tgtd/$bpef $rsf +fi diff --git a/scripts/mulang/mktrain.sh b/scripts/mulang/mktrain.sh new file mode 100644 index 0000000..ffa89ad --- /dev/null +++ b/scripts/mulang/mktrain.sh @@ -0,0 +1,56 @@ +#!/bin/bash + +set -e -o pipefail -x + +# take the processed data from scripts/mkbpe.sh and convert to tensor representation. + +export cachedir=cache +export dataid=opus + +export srcd=$cachedir/$dataid +export srctf=src.train.bpe +export tgttf=tgt.train.bpe +export srcvf=src.dev.bpe +export tgtvf=tgt.dev.bpe + +export rsf_train=train.h5 +export rsf_dev=dev.h5 + +export share_vcb=true +export vsize=65536 + +export maxtokens=257 + +export ngpu=1 + +export do_sort=true +export build_vocab=true + +export wkd=$cachedir/$dataid + +mkdir -p $wkd + +if $do_sort; then + python tools/mulang/eff/sort.py $srcd/$srctf $srcd/$tgttf $wkd/src.train.srt $wkd/tgt.train.srt $maxtokens + python tools/mulang/eff/sort.py $srcd/$srcvf $srcd/$tgtvf $wkd/src.dev.srt $wkd/tgt.dev.srt 1048576 +fi + +if $share_vcb; then + export src_vcb=$wkd/common.vcb + export tgt_vcb=$src_vcb + if $build_vocab; then + python tools/mulang/share_vocab.py $wkd/src.train.srt --target $wkd/tgt.train.srt $src_vcb $wkd/lang.vcb $vsize + python tools/check/mulang/fbindexes.py $tgt_vcb $wkd/src.train.srt $wkd/tgt.train.srt $wkd/src.dev.srt $wkd/tgt.dev.srt $wkd/lang.vcb $wkd/fbind.py + fi +else + export src_vcb=$wkd/src.vcb + export tgt_vcb=$wkd/tgt.vcb + if $build_vocab; then + python tools/mulang/vocab.py $wkd/src.train.srt $src_vcb $wkd/lang.vcb $vsize + python tools/vocab.py $wkd/tgt.train.srt $tgt_vcb $vsize + python tools/check/mulang/fbindexes.py $tgt_vcb $wkd/src.train.srt $wkd/tgt.train.srt $wkd/src.dev.srt $wkd/tgt.dev.srt $wkd/lang.vcb $wkd/fbind.py + fi +fi + +python tools/mulang/eff/mkiodata.py $wkd/src.train.srt $wkd/tgt.train.srt $src_vcb $tgt_vcb $wkd/lang.vcb $wkd/$rsf_train $ngpu +python tools/mulang/eff/mkiodata.py $wkd/src.dev.srt $wkd/tgt.dev.srt $src_vcb $tgt_vcb $wkd/lang.vcb $wkd/$rsf_dev $ngpu diff --git a/scripts/spm/clean.sh b/scripts/spm/clean.sh new file mode 100644 index 0000000..06afd58 --- /dev/null +++ b/scripts/spm/clean.sh @@ -0,0 +1,75 @@ +#!/bin/bash + +set -e -o pipefail -x + +export cachedir=cache + +export dataid=w14ed32 + +export srcd=w14ende +export srcvf=dev.tc.en +export tgtvf=dev.tc.de + +export maxtokens=256 + +export bpeops=32000 +export minfreq=8 +# 0.9995 +export charcov=1.0 +# unigram, bpe, char, or word +export mtype="unigram" +export share_bpe=true + +export tgtd=$cachedir/$dataid + +# options for cleaning the data processed by bpe, +# advised values except numrules can be calculated by: +# python tools/check/charatio.py $tgtd/src.dev.bpe $tgtd/tgt.dev.bpe, and +# python tools/check/biratio.py $tgtd/src.dev.bpe $tgtd/tgt.dev.bpe +# with development set. +# As for numrules, choose from [1, 6], fewer data will be droped with larger value, none data would be droped if it was set to 6, details are described in: +# tools/check/chars.py +export charatio=0.973 +export bperatio=36.01 +export seperatio=1.01 +export bibperatio=7.51 +export bioratio=7.51 +export numrules=1 + +# cleaning bpe results and bpe again +python tools/clean/chars.py $tgtd/src.train.bpe $tgtd/tgt.train.bpe $tgtd/src.clean.tmp $tgtd/tgt.clean.tmp $charatio $bperatio $seperatio $bibperatio $bioratio $numrules + +if $share_bpe; then + export src_cdsf=$tgtd/bpe + export tgt_cdsf=$tgtd/bpe +else + export src_cdsf=$tgtd/src + export tgt_cdsf=$tgtd/tgt +fi + +spm_decode --model=$src_cdsf.model --input_format=piece < $tgtd/src.clean.tmp > $tgtd/src.train.tok.clean +spm_decode --model=$tgt_cdsf.model --input_format=piece < $tgtd/tgt.clean.tmp > $tgtd/tgt.train.tok.clean +rm -fr $tgtd/src.clean.tmp $tgtd/tgt.clean.tmp + +if $share_bpe; then +# to learn joint bpe + cat $tgtd/src.train.tok.clean $tgtd/tgt.train.tok.clean | shuf > $tgtd/bpe.train.txt + spm_train --input=$tgtd/bpe.train.txt --model_prefix=$src_cdsf --vocab_size=$bpeops --character_coverage=$charcov --model_type=$mtype --minloglevel=1 + spm_encode --model=$src_cdsf.model --generate_vocabulary < $tgtd/src.train.tok.clean > $tgtd/src.vcb.bpe + spm_encode --model=$tgt_cdsf.model --generate_vocabulary < $tgtd/tgt.train.tok.clean > $tgtd/tgt.vcb.bpe + rm $tgtd/bpe.train.txt +else +# to learn independent bpe: + spm_train --input=$tgtd/src.train.tok.clean --model_prefix=$src_cdsf --vocab_size=$bpeops --character_coverage=$charcov --model_type=$mtype --minloglevel=1 + spm_train --input=$tgtd/tgt.train.tok.clean --model_prefix=$tgt_cdsf --vocab_size=$bpeops --character_coverage=$charcov --model_type=$mtype --minloglevel=1 + spm_encode --model=$src_cdsf.model --generate_vocabulary < $tgtd/src.train.tok.clean > $tgtd/src.vcb.bpe + spm_encode --model=$tgt_cdsf.model --generate_vocabulary < $tgtd/tgt.train.tok.clean > $tgtd/tgt.vcb.bpe +fi + +spm_encode --model=$src_cdsf.model --vocabulary=$tgtd/src.vcb.bpe --vocabulary_threshold=$minfreq < $tgtd/src.train.tok.clean > $tgtd/src.train.bpe +spm_encode --model=$tgt_cdsf.model --vocabulary=$tgtd/tgt.vcb.bpe --vocabulary_threshold=$minfreq < $tgtd/tgt.train.tok.clean > $tgtd/tgt.train.bpe + +spm_encode --model=$src_cdsf.model --vocabulary=$tgtd/src.vcb.bpe --vocabulary_threshold=$minfreq < $srcd/$srcvf > $tgtd/src.dev.bpe +spm_encode --model=$tgt_cdsf.model --vocabulary=$tgtd/tgt.vcb.bpe --vocabulary_threshold=$minfreq < $srcd/$tgtvf > $tgtd/tgt.dev.bpe + +# then execute scripts/mktrain.sh to generate training and development data. diff --git a/scripts/spm/mk.sh b/scripts/spm/mk.sh new file mode 100644 index 0000000..8d82b6a --- /dev/null +++ b/scripts/spm/mk.sh @@ -0,0 +1,69 @@ +#!/bin/bash + +set -e -o pipefail -x + +export cachedir=cache + +export dataid=w14ed32 + +export srcd=w14ende +export srctf=train.tc.en +export tgttf=train.tc.de +export srcvf=dev.tc.en +export tgtvf=dev.tc.de + +export vratio=0.2 +export rratio=0.6 +export maxtokens=256 + +export bpeops=32000 +export minfreq=8 +# 0.9995 +export charcov=1.0 +# unigram, bpe, char, or word +export mtype="unigram" +export share_bpe=true + +export tgtd=$cachedir/$dataid + +mkdir -p $tgtd + +# clean the data first by removing different translations with lower frequency of same sentences +python tools/clean/maxkeeper.py $srcd/$srctf $srcd/$tgttf $tgtd/src.clean.tmp $tgtd/tgt.clean.tmp $maxtokens +python tools/clean/token_repeat.py $tgtd/src.clean.tmp $tgtd/tgt.clean.tmp $tgtd/src.clean.rtmp $tgtd/tgt.clean.rtmp $rratio +mv $tgtd/src.clean.rtmp $tgtd/src.clean.tmp +mv $tgtd/tgt.clean.rtmp $tgtd/tgt.clean.tmp + +python tools/vocab.py $tgtd/src.clean.tmp $tgtd/src.full.vcb 1048576 +python tools/vocab.py $tgtd/tgt.clean.tmp $tgtd/tgt.full.vcb 1048576 +python tools/clean/vocab/ratio.py $tgtd/src.clean.tmp $tgtd/tgt.clean.tmp $tgtd/src.train.tok.clean $tgtd/tgt.train.tok.clean $tgtd/src.full.vcb $tgtd/tgt.full.vcb $vratio +rm -fr $tgtd/src.full.vcb $tgtd/tgt.full.vcb $tgtd/src.clean.tmp $tgtd/tgt.clean.tmp + +if $share_bpe; then +# to learn joint bpe + export src_cdsf=$tgtd/bpe + export tgt_cdsf=$tgtd/bpe + cat $tgtd/src.train.tok.clean $tgtd/tgt.train.tok.clean | shuf > $tgtd/bpe.train.txt + spm_train --input=$tgtd/bpe.train.txt --model_prefix=$src_cdsf --vocab_size=$bpeops --character_coverage=$charcov --model_type=$mtype --minloglevel=1 + spm_encode --model=$src_cdsf.model --generate_vocabulary < $tgtd/src.train.tok.clean > $tgtd/src.vcb.bpe + spm_encode --model=$tgt_cdsf.model --generate_vocabulary < $tgtd/tgt.train.tok.clean > $tgtd/tgt.vcb.bpe + rm $tgtd/bpe.train.txt +else +# to learn independent bpe: + export src_cdsf=$tgtd/src + export tgt_cdsf=$tgtd/tgt + spm_train --input=$tgtd/src.train.tok.clean --model_prefix=$src_cdsf --vocab_size=$bpeops --character_coverage=$charcov --model_type=$mtype --minloglevel=1 + spm_train --input=$tgtd/tgt.train.tok.clean --model_prefix=$tgt_cdsf --vocab_size=$bpeops --character_coverage=$charcov --model_type=$mtype --minloglevel=1 + spm_encode --model=$src_cdsf.model --generate_vocabulary < $tgtd/src.train.tok.clean > $tgtd/src.vcb.bpe + spm_encode --model=$tgt_cdsf.model --generate_vocabulary < $tgtd/tgt.train.tok.clean > $tgtd/tgt.vcb.bpe +fi + +spm_encode --model=$src_cdsf.model --vocabulary=$tgtd/src.vcb.bpe --vocabulary_threshold=$minfreq < $tgtd/src.train.tok.clean > $tgtd/src.train.bpe +spm_encode --model=$tgt_cdsf.model --vocabulary=$tgtd/tgt.vcb.bpe --vocabulary_threshold=$minfreq < $tgtd/tgt.train.tok.clean > $tgtd/tgt.train.bpe + +spm_encode --model=$src_cdsf.model --vocabulary=$tgtd/src.vcb.bpe --vocabulary_threshold=$minfreq < $srcd/$srcvf > $tgtd/src.dev.bpe +spm_encode --model=$tgt_cdsf.model --vocabulary=$tgtd/tgt.vcb.bpe --vocabulary_threshold=$minfreq < $srcd/$tgtvf > $tgtd/tgt.dev.bpe + +# report devlopment set features for cleaning +python tools/check/charatio.py $tgtd/src.dev.bpe $tgtd/tgt.dev.bpe +python tools/check/biratio.py $tgtd/src.dev.bpe $tgtd/tgt.dev.bpe diff --git a/tools/ape/mkiodata.py b/tools/ape/mkiodata.py index b26c236..c34eb10 100644 --- a/tools/ape/mkiodata.py +++ b/tools/ape/mkiodata.py @@ -3,8 +3,8 @@ import sys import numpy -import h5py +from utils.h5serial import h5File from utils.fmt.base import ldvocab from utils.fmt.ape.triple import batch_padder @@ -19,23 +19,22 @@ def handle(finput, fmt, ftarget, fvocab_i, fvocab_t, frs, minbsize=1, expand_for else: _bsize = bsize _maxtoken = maxtoken - rsf = h5py.File(frs, 'w') - src_grp = rsf.create_group("src") - mt_grp = rsf.create_group("mt") - tgt_grp = rsf.create_group("tgt") - curd = 0 - for i_d, md, td in batch_padder(finput, fmt, ftarget, vcbi, vcbt, _bsize, maxpad, maxpart, _maxtoken, minbsize): - rid = numpy.array(i_d, dtype=numpy.int32) - rmd = numpy.array(md, dtype=numpy.int32) - rtd = numpy.array(td, dtype=numpy.int32) - wid = str(curd) - src_grp.create_dataset(wid, data=rid, **h5datawargs) - mt_grp.create_dataset(wid, data=rmd, **h5datawargs) - tgt_grp.create_dataset(wid, data=rtd, **h5datawargs) - curd += 1 - rsf["ndata"] = numpy.array([curd], dtype=numpy.int32) - rsf["nword"] = numpy.array([nwordi, nwordt], dtype=numpy.int32) - rsf.close() + with h5File(frs, 'w') as rsf: + src_grp = rsf.create_group("src") + mt_grp = rsf.create_group("mt") + tgt_grp = rsf.create_group("tgt") + curd = 0 + for i_d, md, td in batch_padder(finput, fmt, ftarget, vcbi, vcbt, _bsize, maxpad, maxpart, _maxtoken, minbsize): + rid = numpy.array(i_d, dtype=numpy.int32) + rmd = numpy.array(md, dtype=numpy.int32) + rtd = numpy.array(td, dtype=numpy.int32) + wid = str(curd) + src_grp.create_dataset(wid, data=rid, **h5datawargs) + mt_grp.create_dataset(wid, data=rmd, **h5datawargs) + tgt_grp.create_dataset(wid, data=rtd, **h5datawargs) + curd += 1 + rsf["ndata"] = numpy.array([curd], dtype=numpy.int32) + rsf["nword"] = numpy.array([nwordi, nwordt], dtype=numpy.int32) print("Number of batches: %d\nSource Vocabulary Size: %d\nTarget Vocabulary Size: %d" % (curd, nwordi, nwordt,)) if __name__ == "__main__": diff --git a/tools/check/avg_bsize.py b/tools/check/avg_bsize.py index f328c3d..ff14e96 100644 --- a/tools/check/avg_bsize.py +++ b/tools/check/avg_bsize.py @@ -3,32 +3,31 @@ import sys import torch -import h5py +from utils.h5serial import h5File from tqdm import tqdm -from random import shuffle -from random import seed as rpyseed +from random import shuffle, seed as rpyseed -def handle(h5f, bsize, shuf=True): +from cnfg.ihyp import tqdm_mininterval - td = h5py.File(h5f, "r") - ntest = td["ndata"][:].item() - tl = list(range(ntest)) - if shuf: - shuffle(tl) +def handle(h5f, bsize, shuf=True): - tgt_grp = td["tgt"] ntoken = 0 rsl = [] - for tid in tqdm(tl, mininterval=tqdm_mininterval): - seq_batch = torch.from_numpy(tgt_grp[str(tid)][:]) - ot = seq_batch.narrow(-1, 1, seq_batch.size(-1) - 1) - ntoken += ot.ne(0).int().sum().item() - if ntoken >= bsize: - rsl.append(ntoken) - ntoken = 0 - - td.close() + with h5File(h5f, "r") as td: + ntest = td["ndata"][:].item() + tl = list(range(ntest)) + if shuf: + shuffle(tl) + + tgt_grp = td["tgt"] + for tid in tqdm(tl, mininterval=tqdm_mininterval): + seq_batch = torch.from_numpy(tgt_grp[str(tid)][:]) + ot = seq_batch.narrow(-1, 1, seq_batch.size(-1) - 1) + ntoken += ot.ne(0).int().sum().item() + if ntoken >= bsize: + rsl.append(ntoken) + ntoken = 0 return sum(rsl)/float(len(rsl)) diff --git a/tools/check/doc/para/epoch_steps.py b/tools/check/doc/para/epoch_steps.py index 3da2be4..a649b35 100644 --- a/tools/check/doc/para/epoch_steps.py +++ b/tools/check/doc/para/epoch_steps.py @@ -3,31 +3,30 @@ import sys import torch -import h5py +from utils.h5serial import h5File from tqdm import tqdm -from random import shuffle -from random import seed as rpyseed +from random import shuffle, seed as rpyseed + +from cnfg.ihyp import tqdm_mininterval def handle(h5f, bsize, shuf=True): - td = h5py.File(h5f, "r") - tl = [(str(nsent), str(_curd),) for nsent, ndata in zip(td["nsent"][:].tolist(), td["ndata"][:].tolist()) for _curd in range(ndata)] - if shuf: - shuffle(tl) - - tgt_grp = td["tgt"] - ntoken = 0 - nstep = 0 - for nsent, i_d in tqdm(tl, mininterval=tqdm_mininterval): - seq_batch = torch.from_numpy(tgt_grp[nsent][i_d][:]) - ot = seq_batch.narrow(-1, 1, seq_batch.size(-1) - 1) - ntoken += ot.ne(0).int().sum().item() - if ntoken >= bsize: - nstep += 1 - ntoken = 0 - - td.close() + with h5File(h5f, "r") as td: + tl = [(str(nsent), str(_curd),) for nsent, ndata in zip(td["nsent"][:].tolist(), td["ndata"][:].tolist()) for _curd in range(ndata)] + if shuf: + shuffle(tl) + + tgt_grp = td["tgt"] + ntoken = 0 + nstep = 0 + for nsent, i_d in tqdm(tl, mininterval=tqdm_mininterval): + seq_batch = torch.from_numpy(tgt_grp[nsent][i_d][:]) + ot = seq_batch.narrow(-1, 1, seq_batch.size(-1) - 1) + ntoken += ot.ne(0).int().sum().item() + if ntoken >= bsize: + nstep += 1 + ntoken = 0 return nstep diff --git a/tools/check/dynb/report_dynb.py b/tools/check/dynb/report_dynb.py index b7aa45c..834b74d 100644 --- a/tools/check/dynb/report_dynb.py +++ b/tools/check/dynb/report_dynb.py @@ -13,6 +13,7 @@ from utils.base import * from utils.init import init_model_params +from utils.contpara import get_model_parameters from utils.dynbatch import GradientMonitor from utils.h5serial import h5save, h5load from utils.fmt.base import tostr, save_states, load_states, pad_id, parse_double_value_tuple @@ -25,9 +26,6 @@ from tqdm import tqdm -from os import makedirs -from os.path import exists as p_check - import h5py import cnfg.dynb as cnfg @@ -111,8 +109,8 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok _log_f_dynbatch.write(("%d\n" % (_done_tokens,)).encode("utf-8")) if multi_gpu: model.collect_gradients() - optm_step(optm, scaler) - optm.zero_grad(set_to_none=True) + optm_step(optm, scaler, zero_grad_none=optm_step_zero_grad_set_none) + optm.zero_grad(set_to_none=optm_step_zero_grad_set_none) if multi_gpu: model.update_replicas() lrsch.step() @@ -122,7 +120,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok if multi_gpu: model.reset_grad() else: - optm.zero_grad(set_to_none=True) + optm.zero_grad(set_to_none=optm_step_zero_grad_set_none) log_dynb = random() <= log_dyn_p _done_tokens = 0 if _cur_rstep is not None: @@ -136,7 +134,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok else: _chkpf = chkpf _chkpof = chkpof - save_model(model, _chkpf, multi_gpu, logger) + save_model(model, _chkpf, multi_gpu, print_func=logger.info) if chkpof is not None: h5save(optm.state_dict(), _chkpof) if statesf is not None: @@ -171,7 +169,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok else: _chkpf = chkpf _chkpof = chkpof - save_model(model, _chkpf, multi_gpu, logger) + save_model(model, _chkpf, multi_gpu, print_func=logger.info) if chkpof is not None: h5save(optm.state_dict(), _chkpof) if statesf is not None: @@ -233,30 +231,22 @@ def load_fixing(module): rid = sys.argv[1] earlystop = cnfg.earlystop - maxrun = cnfg.maxrun - tokens_optm = cnfg.tokens_optm - done_tokens = 0 - batch_report = cnfg.batch_report report_eva = cnfg.report_eva - use_ams = cnfg.use_ams - save_optm_state = cnfg.save_optm_state - +save_auto_clean = cnfg.save_auto_clean +overwrite_eva = cnfg.overwrite_eva save_every = cnfg.save_every start_chkp_save = cnfg.epoch_start_checkpoint_save - epoch_save = cnfg.epoch_save - remain_steps = cnfg.training_steps wkdir = "".join(("expm/", cnfg.data_id, "/", cnfg.group_id, "/", rid, "/")) -if not p_check(wkdir): - makedirs(wkdir) +mkdir(wkdir) chkpf = None chkpof = None @@ -317,12 +307,11 @@ def load_fixing(module): mymodel = DataParallelMT(mymodel, device_ids=cuda_devices, output_device=cuda_device.index, host_replicate=True, gather_output=False) lossf = DataParallelCriterion(lossf, device_ids=cuda_devices, output_device=cuda_device.index, replicate_once=True) -if multi_gpu_optimizer: - optimizer = mymodel.build_optimizer(Optimizer, lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams) - mymodel.zero_grad(set_to_none=True) +if multi_gpu: + optimizer = mymodel.build_optimizer(Optimizer, lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams, multi_gpu_optimizer=multi_gpu_optimizer, contiguous_parameters=contiguous_parameters) else: - optimizer = Optimizer((mymodel.module if multi_gpu else mymodel).parameters(), lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams) - optimizer.zero_grad(set_to_none=True) + optimizer = Optimizer(get_model_parameters(mymodel, contiguous_parameters=contiguous_parameters), lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams) +optimizer.zero_grad(set_to_none=optm_step_zero_grad_set_none) fine_tune_state = cnfg.fine_tune_state if fine_tune_state is not None: @@ -340,16 +329,16 @@ def load_fixing(module): logger.info("".join(("Init lr: ", ",".join(tostr(getlr(optimizer))), ", Dev Loss/Error: %.3f %.2f" % (minloss, minerr)))) if fine_tune_m is None: - save_model(mymodel, wkdir + "init.h5", multi_gpu, logger) + save_model(mymodel, wkdir + "init.h5", multi_gpu, print_func=logger.info) logger.info("Initial model saved") else: cnt_states = cnfg.train_statesf - if (cnt_states is not None) and p_check(cnt_states): + if cnt_states is not None: logger.info("Continue last epoch") tminerr, done_tokens, cur_checkid, remain_steps, _ = train(td, load_states(cnt_states), vd, nvalid, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, multi_gpu_optimizer, tokens_optm, batch_report, save_every, chkpf, chkpof, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, False, False, scaler) vloss, vprec = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp) logger.info("Epoch: 0, train loss: %.3f, valid loss/error: %.3f %.2f" % (tminerr, vloss, vprec)) - save_model(mymodel, wkdir + "train_0_%.3f_%.3f_%.2f.h5" % (tminerr, vloss, vprec), multi_gpu, logger) + save_model(mymodel, wkdir + "train_0_%.3f_%.3f_%.2f.h5" % (tminerr, vloss, vprec), multi_gpu, print_func=logger.info, mtyp=("eva" if overwrite_eva else "train") if save_auto_clean else None) if save_optm_state: h5save(optimizer.state_dict(), wkdir + "train_0_%.3f_%.3f_%.2f.optm.h5" % (tminerr, vloss, vprec)) logger.info("New best model saved") @@ -378,7 +367,7 @@ def load_fixing(module): logger.info("Epoch: %d, train loss: %.3f, valid loss/error: %.3f %.2f" % (i, terr, vloss, vprec)) if (vprec <= minerr) or (vloss <= minloss): - save_model(mymodel, wkdir + "eva_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec), multi_gpu, logger) + save_model(mymodel, wkdir + "eva_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec), multi_gpu, print_func=logger.info, mtyp="eva" if save_auto_clean else None) if save_optm_state: h5save(optimizer.state_dict(), wkdir + "eva_%d_%.3f_%.3f_%.2f.optm.h5" % (i, terr, vloss, vprec)) logger.info("New best model saved") @@ -393,11 +382,11 @@ def load_fixing(module): else: if terr < tminerr: tminerr = terr - save_model(mymodel, wkdir + "train_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec), multi_gpu, logger) + save_model(mymodel, wkdir + "train_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec), multi_gpu, print_func=logger.info, mtyp=("eva" if overwrite_eva else "train") if save_auto_clean else None) if save_optm_state: h5save(optimizer.state_dict(), wkdir + "train_%d_%.3f_%.3f_%.2f.optm.h5" % (i, terr, vloss, vprec)) elif epoch_save: - save_model(mymodel, wkdir + "epoch_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec), multi_gpu, logger) + save_model(mymodel, wkdir + "epoch_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec), multi_gpu, print_func=logger.info) namin += 1 if namin >= earlystop: @@ -423,7 +412,7 @@ def load_fixing(module): if done_tokens > 0: optm_step(optimizer, model=mymodel, scaler=scaler, multi_gpu=multi_gpu, multi_gpu_optimizer=multi_gpu_optimizer) -save_model(mymodel, wkdir + "last.h5", multi_gpu, logger) +save_model(mymodel, wkdir + "last.h5", multi_gpu, print_func=logger.info) if save_optm_state: h5save(optimizer.state_dict(), wkdir + "last.optm.h5") logger.info("model saved") diff --git a/tools/check/epoch_steps.py b/tools/check/epoch_steps.py index 5304451..c5f6187 100644 --- a/tools/check/epoch_steps.py +++ b/tools/check/epoch_steps.py @@ -3,32 +3,31 @@ import sys import torch -import h5py +from utils.h5serial import h5File from tqdm import tqdm -from random import shuffle -from random import seed as rpyseed +from random import shuffle, seed as rpyseed -def handle(h5f, bsize, shuf=True): +from cnfg.ihyp import tqdm_mininterval - td = h5py.File(h5f, "r") - ntest = td["ndata"][:].item() - tl = list(range(ntest)) - if shuf: - shuffle(tl) +def handle(h5f, bsize, shuf=True): - tgt_grp = td["tgt"] ntoken = 0 nstep = 0 - for tid in tqdm(tl, mininterval=tqdm_mininterval): - seq_batch = torch.from_numpy(tgt_grp[str(tid)][:]) - ot = seq_batch.narrow(-1, 1, seq_batch.size(-1) - 1) - ntoken += ot.ne(0).int().sum().item() - if ntoken >= bsize: - nstep += 1 - ntoken = 0 - - td.close() + with h5File(h5f, "r") as td: + ntest = td["ndata"][:].item() + tl = list(range(ntest)) + if shuf: + shuffle(tl) + + tgt_grp = td["tgt"] + for tid in tqdm(tl, mininterval=tqdm_mininterval): + seq_batch = torch.from_numpy(tgt_grp[str(tid)][:]) + ot = seq_batch.narrow(-1, 1, seq_batch.size(-1) - 1) + ntoken += ot.ne(0).int().sum().item() + if ntoken >= bsize: + nstep += 1 + ntoken = 0 return nstep diff --git a/tools/check/mulang/cnfg b/tools/check/mulang/cnfg new file mode 120000 index 0000000..a958f86 --- /dev/null +++ b/tools/check/mulang/cnfg @@ -0,0 +1 @@ +../../../cnfg/ \ No newline at end of file diff --git a/tools/check/mulang/eff/cnfg b/tools/check/mulang/eff/cnfg new file mode 120000 index 0000000..2f54778 --- /dev/null +++ b/tools/check/mulang/eff/cnfg @@ -0,0 +1 @@ +../../../../cnfg/ \ No newline at end of file diff --git a/tools/check/mulang/eff/epoch_steps.py b/tools/check/mulang/eff/epoch_steps.py new file mode 100644 index 0000000..2ec4af7 --- /dev/null +++ b/tools/check/mulang/eff/epoch_steps.py @@ -0,0 +1,35 @@ +#encoding: utf-8 + +import sys + +import torch +from utils.h5serial import h5File + +from tqdm import tqdm +from random import shuffle, seed as rpyseed + +from cnfg.ihyp import tqdm_mininterval + +def handle(h5f, bsize, shuf=True): + + with h5File(h5f, "r") as td: + ntest = td["ndata"][:].tolist() + tl = [(i, str(_task),) for _nd, _task in zip(ntest, td["taskorder"][:].tolist()) for i in range(_nd)] + if shuf: + shuffle(tl) + + ntoken = 0 + nstep = 0 + for tid, taskid in tqdm(tl, mininterval=tqdm_mininterval): + seq_batch = torch.from_numpy(td[taskid]["tgt"][str(tid)][:]) + ot = seq_batch.narrow(-1, 1, seq_batch.size(-1) - 1) + ntoken += ot.ne(0).int().sum().item() + if ntoken >= bsize: + nstep += 1 + ntoken = 0 + + return nstep + +if __name__ == "__main__": + rpyseed(666666) + print(handle(sys.argv[1], int(sys.argv[2]))) diff --git a/tools/check/mulang/eff/utils b/tools/check/mulang/eff/utils new file mode 120000 index 0000000..c2519a9 --- /dev/null +++ b/tools/check/mulang/eff/utils @@ -0,0 +1 @@ +../../../../utils/ \ No newline at end of file diff --git a/tools/check/mulang/fbindexes.py b/tools/check/mulang/fbindexes.py new file mode 100644 index 0000000..c1ed9c9 --- /dev/null +++ b/tools/check/mulang/fbindexes.py @@ -0,0 +1,42 @@ +#encoding: utf-8 + +import sys + +from utils.fmt.base import ldvocab, init_vocab + +def handle(vcbf, srcfl, fvocab_task, rsf, minfreq=False, vsize=False): + + vcb, nwords = ldvocab(vcbf, minf=minfreq, omit_vsize=vsize, vanilla=False) + vcbtask, nwordtask = ldvocab(fvocab_task, minf=False, omit_vsize=False, vanilla=True) + + fvcb = {} + + for srcf, tgtf in zip(srcfl[0::2], srcfl[1::2]): + with open(srcf, "rb") as fsrc, open(tgtf, "rb") as ftgt: + for lsrc, ltgt in zip(fsrc, ftgt): + tsrc, ttgt = lsrc.strip(), ltgt.strip() + if tsrc and ttgt: + task = vcbtask[tsrc.decode("utf-8").split()[0]] + if task not in fvcb: + fvcb[task] = set(init_vocab.keys()) + wset = fvcb[task] + for token in ttgt.decode("utf-8").split(): + if token and (token not in wset): + wset.add(token) + + rsl = [] + for i in range(nwordtask): + wset = fvcb[i] + tmp = [] + for wd, ind in vcb.items(): + if wd not in wset: + tmp.append(ind) + rsl.append(tmp) + + with open(rsf, "wb") as f: + f.write("#encoding: utf-8\n\nfbl = ".encode("utf-8")) + f.write(repr(rsl).encode("utf-8")) + f.write("\n".encode("utf-8")) + +if __name__ == "__main__": + handle(sys.argv[1], sys.argv[2:-2], sys.argv[-2], sys.argv[-1]) diff --git a/tools/check/mulang/utils b/tools/check/mulang/utils new file mode 120000 index 0000000..256f914 --- /dev/null +++ b/tools/check/mulang/utils @@ -0,0 +1 @@ +../../../utils/ \ No newline at end of file diff --git a/tools/check/para.py b/tools/check/para.py index d6af351..100e9e7 100644 --- a/tools/check/para.py +++ b/tools/check/para.py @@ -6,7 +6,7 @@ import sys -import h5py +from utils.h5serial import h5File def handle_group(srcg): @@ -21,9 +21,8 @@ def handle_group(srcg): def handle(srcf): - sfg = h5py.File(srcf, "r") - rs = handle_group(sfg) - sfg.close() + with h5File(srcf, "r") as sfg: + rs = handle_group(sfg) print(rs) if __name__ == "__main__": diff --git a/tools/check/probe/merge_probe.py b/tools/check/probe/merge_probe.py index cc8ad33..7b5ae4c 100644 --- a/tools/check/probe/merge_probe.py +++ b/tools/check/probe/merge_probe.py @@ -2,10 +2,9 @@ import sys -import h5py - from utils.base import * from utils.init import init_model_params +from utils.h5serial import h5File import cnfg.probe as cnfg from cnfg.ihyp import * @@ -15,9 +14,8 @@ def handle(cnfg, srcmtf, decf, rsf): - tdf = h5py.File(cnfg.dev_data, "r") - nwordi, nwordt = tdf["nword"][:].tolist() - tdf.close() + with h5File(cnfg.dev_data, "r") as tdf: + nwordi, nwordt = tdf["nword"][:].tolist() mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes, cnfg.num_layer_fwd) init_model_params(mymodel) @@ -32,7 +30,7 @@ def handle(cnfg, srcmtf, decf, rsf): mymodel.dec.classifier.weight = mymodel.dec.wemb.weight _tmpm = None - save_model(mymodel, rsf, sub_module=False, logger=None, h5args=h5zipargs) + save_model(mymodel, rsf, sub_module=False, h5args=h5zipargs) if __name__ == "__main__": handle(cnfg, sys.argv[1], sys.argv[2], sys.argv[3]) diff --git a/tools/check/tspeed.py b/tools/check/tspeed.py index 3192476..684cb6f 100644 --- a/tools/check/tspeed.py +++ b/tools/check/tspeed.py @@ -77,7 +77,6 @@ def load_fixing(module): mymodel = DataParallelMT(mymodel, device_ids=cuda_devices, output_device=cuda_device.index, host_replicate=True, gather_output=False) beam_size = cnfg.beam_size - length_penalty = cnfg.length_penalty src_grp = td["src"] diff --git a/tools/clean/sampler/eff_sampler.py b/tools/clean/sampler/eff_sampler.py index eee5adf..f54adc8 100644 --- a/tools/clean/sampler/eff_sampler.py +++ b/tools/clean/sampler/eff_sampler.py @@ -4,8 +4,7 @@ # python tools/clean/sampler/eff_sampler.py srcf1 ... srcfn tgtf1 ... tgtfn keep_ratio import sys -from random import random -from random import seed as rpyseed +from random import random, seed as rpyseed from utils.fmt.base import FileList @@ -17,9 +16,9 @@ def handle(srcfl, tgtfl, ratio): with FileList(srcfl, "rb") as sfl, FileList(tgtfl, "rb") as tfl: for srcl in zip(*sfl): if random() <=ratio: - tmp = [tl.strip().decode("utf-8") for tl in srcl] + tmp = [tl.strip().decode("utf-8").encode("utf-8") for tl in srcl] for line, wrtf in zip(tmp, tfl): - wrtf.write(line.encode("utf-8")) + wrtf.write(line) wrtf.write(ens) nkeep += 1 ntotal += 1 diff --git a/tools/clean/sampler/strict_sampler.py b/tools/clean/sampler/strict_sampler.py index 4d82ab3..52d97ee 100644 --- a/tools/clean/sampler/strict_sampler.py +++ b/tools/clean/sampler/strict_sampler.py @@ -4,8 +4,7 @@ # python tools/clean/sampler/strict_sampler.py srcf1 ... srcfn tgtf1 ... tgtfn keep_ratio import sys -from random import shuffle -from random import seed as rpyseed +from random import shuffle, seed as rpyseed from utils.fmt.base import FileList @@ -14,7 +13,7 @@ def handle(srcfl, tgtfl, ratio): rs = [] with FileList(srcfl, "rb") as fl: for srcl in zip(*fl): - tmp = [tl.strip().decode("utf-8") for tl in srcl] + tmp = [tl.strip().decode("utf-8").encode("utf-8") for tl in srcl] rs.append(tmp) shuffle(rs) @@ -27,10 +26,10 @@ def handle(srcfl, tgtfl, ratio): with open(tgtf, "wb") as f: # following 3 lines for memory #for line in data: - #f.write(line.encode("utf-8")) + #f.write(line) #f.write(ens) # use following lines for efficiency - f.write("\n".join(data).encode("utf-8")) + f.write(ens.join(data)) f.write(ens) print("%d in %d data keeped with ratio %.2f" % (nkeep, ntotal, float(nkeep) / float(ntotal) * 100.0 if ntotal > 0 else 0.0)) diff --git a/tools/clean/token_repeat.py b/tools/clean/token_repeat.py index ed6d8ac..37f3d08 100644 --- a/tools/clean/token_repeat.py +++ b/tools/clean/token_repeat.py @@ -2,7 +2,7 @@ import sys -from utils.fmt.base import clean_list, all_true, all_gt, FileList +from utils.fmt.base import clean_list, all_gt, FileList def handle(srcfl, tgtfl, r=0.4): @@ -10,7 +10,7 @@ def handle(srcfl, tgtfl, r=0.4): with FileList(srcfl, "rb") as rfl, FileList(tgtfl, "wb") as wfl: for lines in zip(*rfl): lines = [line.strip() for line in lines] - if all_true(lines): + if all(lines): lines = [clean_list(line.decode("utf-8").split()) for line in lines] ratios = [float(len(set(line))) / float(len(line)) for line in lines] if all_gt(ratios, r): diff --git a/tools/doc/para/mkiodata.py b/tools/doc/para/mkiodata.py index a34727c..89a5cf7 100644 --- a/tools/doc/para/mkiodata.py +++ b/tools/doc/para/mkiodata.py @@ -3,8 +3,8 @@ import sys import numpy -import h5py +from utils.h5serial import h5File from utils.fmt.base import ldvocab, dict2pairs from utils.fmt.doc.para.dual import batch_padder @@ -19,27 +19,26 @@ def handle(finput, ftarget, fvocab_i, fvocab_t, frs, minbsize=1, expand_for_mulg else: _bsize = bsize _maxtoken = maxtoken - rsf = h5py.File(frs, 'w') - src_grp = rsf.create_group("src") - tgt_grp = rsf.create_group("tgt") - curd = {} - for i_d, td, nsent in batch_padder(finput, ftarget, vcbi, vcbt, _bsize, maxpad, maxpart, _maxtoken, minbsize): - rid = numpy.array(i_d, dtype=numpy.int32) - rtd = numpy.array(td, dtype=numpy.int32) - _nsentgid = str(nsent) - _curd = curd.get(nsent, 0) - if _curd == 0: - src_grp.create_group(_nsentgid) - tgt_grp.create_group(_nsentgid) - _curid = str(_curd) - src_grp[_nsentgid].create_dataset(_curid, data=rid, **h5datawargs) - tgt_grp[_nsentgid].create_dataset(_curid, data=rtd, **h5datawargs) - curd[nsent] = _curd + 1 - sents, ndl = dict2pairs(curd) - rsf["nsent"] = numpy.array(sents, dtype=numpy.int32) - rsf["ndata"] = numpy.array(ndl, dtype=numpy.int32) - rsf["nword"] = numpy.array([nwordi, nwordt], dtype=numpy.int32) - rsf.close() + with h5File(frs, 'w') as rsf: + src_grp = rsf.create_group("src") + tgt_grp = rsf.create_group("tgt") + curd = {} + for i_d, td, nsent in batch_padder(finput, ftarget, vcbi, vcbt, _bsize, maxpad, maxpart, _maxtoken, minbsize): + rid = numpy.array(i_d, dtype=numpy.int32) + rtd = numpy.array(td, dtype=numpy.int32) + _nsentgid = str(nsent) + _curd = curd.get(nsent, 0) + if _curd == 0: + src_grp.create_group(_nsentgid) + tgt_grp.create_group(_nsentgid) + _curid = str(_curd) + src_grp[_nsentgid].create_dataset(_curid, data=rid, **h5datawargs) + tgt_grp[_nsentgid].create_dataset(_curid, data=rtd, **h5datawargs) + curd[nsent] = _curd + 1 + sents, ndl = dict2pairs(curd) + rsf["nsent"] = numpy.array(sents, dtype=numpy.int32) + rsf["ndata"] = numpy.array(ndl, dtype=numpy.int32) + rsf["nword"] = numpy.array([nwordi, nwordt], dtype=numpy.int32) print("Number of batches: %d\nSource Vocabulary Size: %d\nTarget Vocabulary Size: %d" % (sum(ndl), nwordi, nwordt,)) if __name__ == "__main__": diff --git a/tools/doc/para/mktest.py b/tools/doc/para/mktest.py index c0548b2..77ff7d9 100644 --- a/tools/doc/para/mktest.py +++ b/tools/doc/para/mktest.py @@ -3,8 +3,8 @@ import sys import numpy -import h5py +from utils.h5serial import h5File from utils.fmt.base import ldvocab, dict2pairs from utils.fmt.doc.para.single import batch_padder @@ -18,22 +18,21 @@ def handle(finput, fvocab_i, frs, minbsize=1, expand_for_mulgpu=True, bsize=max_ else: _bsize = bsize _maxtoken = maxtoken - rsf = h5py.File(frs, 'w') - src_grp = rsf.create_group("src") - curd = {} - for i_d, nsent in batch_padder(finput, vcbi, _bsize, maxpad, maxpart, _maxtoken, minbsize): - rid = numpy.array(i_d, dtype=numpy.int32) - _nsentgid = str(nsent) - _curd = curd.get(nsent, 0) - if _curd == 0: - src_grp.create_group(_nsentgid) - src_grp[_nsentgid].create_dataset(str(_curd), data=rid, **h5datawargs) - curd[nsent] = _curd + 1 - sents, ndl = dict2pairs(curd) - rsf["nsent"] = numpy.array(sents, dtype=numpy.int32) - rsf["ndata"] = numpy.array(ndl, dtype=numpy.int32) - rsf["nword"] = numpy.array([nwordi], dtype=numpy.int32) - rsf.close() + with h5File(frs, 'w') as rsf: + src_grp = rsf.create_group("src") + curd = {} + for i_d, nsent in batch_padder(finput, vcbi, _bsize, maxpad, maxpart, _maxtoken, minbsize): + rid = numpy.array(i_d, dtype=numpy.int32) + _nsentgid = str(nsent) + _curd = curd.get(nsent, 0) + if _curd == 0: + src_grp.create_group(_nsentgid) + src_grp[_nsentgid].create_dataset(str(_curd), data=rid, **h5datawargs) + curd[nsent] = _curd + 1 + sents, ndl = dict2pairs(curd) + rsf["nsent"] = numpy.array(sents, dtype=numpy.int32) + rsf["ndata"] = numpy.array(ndl, dtype=numpy.int32) + rsf["nword"] = numpy.array([nwordi], dtype=numpy.int32) print("Number of batches: %d\nSource Vocabulary Size: %d" % (sum(ndl), nwordi,)) if __name__ == "__main__": diff --git a/tools/doc/sort.py b/tools/doc/sort.py index 9deb54e..d39b6bd 100644 --- a/tools/doc/sort.py +++ b/tools/doc/sort.py @@ -3,7 +3,7 @@ import sys from random import seed as rpyseed -from utils.fmt.base import clean_liststr_lentok, all_true, all_le, maxfreq_filter, shuffle_pair, iter_dict_sort, dict_insert_list, dict_insert_set, FileList +from utils.fmt.base import clean_liststr_lentok, all_le, maxfreq_filter, shuffle_pair, iter_dict_sort, dict_insert_list, dict_insert_set, FileList # remove_same: reduce same data in the corpus # shuf: shuffle the data of same source/target length @@ -20,7 +20,7 @@ def handle(srcfl, tgtfl, max_len=256, remove_same=False, shuf=True, max_remove=F with FileList(srcfl, "rb") as fl: for lines in zip(*fl): lines = [line.strip() for line in lines] - if all_true(lines): + if all(lines): lines, lens = zip(*[clean_liststr_lentok(line.decode("utf-8").split()) for line in lines]) if all_le(lens, max_len): lgth = sum(lens) diff --git a/tools/h5/compress.py b/tools/h5/compress.py index 3dd86df..ae773dc 100644 --- a/tools/h5/compress.py +++ b/tools/h5/compress.py @@ -3,7 +3,7 @@ import sys import h5py -from utils.h5serial import h5save, h5load +from utils.h5serial import h5save, h5load, h5File from cnfg.ihyp import * @@ -21,10 +21,8 @@ def handle(srcf, rsf, h5args=h5zipargs): if srcf == rsf: h5save(h5load(srcf, restore_list=False), rsf, h5args=h5args) else: - sfg, rfg = h5py.File(srcf, "r"), h5py.File(rsf, 'w') - handle_group(sfg, rfg, h5args=h5args) - sfg.close() - rfg.close() + with h5File(srcf, "r") as sfg, h5File(rsf, 'w') as rfg: + handle_group(sfg, rfg, h5args=h5args) if __name__ == "__main__": handle(sys.argv[1], sys.argv[-1]) diff --git a/tools/lsort/merge.py b/tools/lsort/merge.py index b60d6a1..b84fd96 100644 --- a/tools/lsort/merge.py +++ b/tools/lsort/merge.py @@ -2,10 +2,11 @@ import sys -from utils.fmt.base import clean_liststr_lentok, all_true, all_le, maxfreq_filter, shuffle_pair, iter_dict_sort, dict_insert_list, dict_insert_set, FileList +from utils.fmt.base import clean_liststr_lentok, all_le, maxfreq_filter, shuffle_pair, iter_dict_sort, dict_insert_list, dict_insert_set, FileList from random import seed as rpyseed -from os import walk, path +from os import walk +from os.path import join as pjoin # remove_same: reduce same data in the corpus # shuf: shuffle the data of same source/target length @@ -13,12 +14,12 @@ def handle(cached, tgtfl, remove_same=False, shuf=True, max_remove=True): - def paral_reader(fl): + def paral_reader(srcfl): with FileList(srcfl, "rb") as fl: for lines in zip(*fl): lines = [line.strip() for line in lines] - if all_true(lines): + if all(lines): lines, lens = zip(*[clean_liststr_lentok(line.decode("utf-8").split()) for line in lines]) lgth = sum(lens) yield tuple(line.encode("utf-8") for line in lines), lgth, *reversed(lens[1:]) @@ -32,7 +33,7 @@ def open_files(cache_dir, num_files): for file in files: curfid = file.split(".")[1] if curfid not in opened: - pg = paral_reader([path.join(cache_dir, "%d.%s.txt" % (i, curfid,)) for i in range(num_files)]) + pg = paral_reader([pjoin(cache_dir, "%d.%s.txt" % (i, curfid,)) for i in range(num_files)]) opened.add(curfid) try: prd = next(pg) diff --git a/tools/lsort/partsort.py b/tools/lsort/partsort.py index e592fa6..010a3a3 100644 --- a/tools/lsort/partsort.py +++ b/tools/lsort/partsort.py @@ -1,9 +1,9 @@ #encoding: utf-8 import sys -from os import path +from os.path import join as pjoin -from utils.fmt.base import clean_liststr_lentok, all_true, all_le, iter_dict_sort, dict_insert_list, dict_insert_set, FileList +from utils.fmt.base import clean_liststr_lentok, all_le, iter_dict_sort, dict_insert_list, dict_insert_set, FileList def handle(srcfl, tgtd, max_len=256, remove_same=False, cache_token=500000000): @@ -27,19 +27,19 @@ def save_cache(cache, tgtfl): with FileList(srcfl, "rb") as fl: for lines in zip(*fl): lines = [line.strip() for line in lines] - if all_true(lines): + if all(lines): lines, lens = zip(*[clean_liststr_lentok(line.decode("utf-8").split()) for line in lines]) if all_le(lens, max_len): lgth = sum(lens) data = _insert_func(data, tuple(line.encode("utf-8") for line in lines), lgth, *reversed(lens[1:])) mem_token += lgth if mem_token >= cache_token: - save_cache(data, [path.join(tgtd, "%d.%d.txt" % (i, curf,)) for i in range(num_files)]) + save_cache(data, [pjoin(tgtd, "%d.%d.txt" % (i, curf,)) for i in range(num_files)]) data = {} mem_token = 0 curf += 1 if data: - save_cache(data, [path.join(tgtd, "%d.%d.txt" % (i, curf,)) for i in range(num_files)]) + save_cache(data, [pjoin(tgtd, "%d.%d.txt" % (i, curf,)) for i in range(num_files)]) if __name__ == "__main__": handle(sys.argv[1:-2], sys.argv[-2], int(sys.argv[-1])) diff --git a/tools/mkiodata.py b/tools/mkiodata.py index b453db7..4cca15d 100644 --- a/tools/mkiodata.py +++ b/tools/mkiodata.py @@ -3,7 +3,7 @@ import sys import numpy -import h5py +from utils.h5serial import h5File from utils.fmt.base import ldvocab from utils.fmt.dual import batch_padder @@ -19,22 +19,21 @@ def handle(finput, ftarget, fvocab_i, fvocab_t, frs, minbsize=1, expand_for_mulg else: _bsize = bsize _maxtoken = maxtoken - rsf = h5py.File(frs, 'w') - src_grp = rsf.create_group("src") - tgt_grp = rsf.create_group("tgt") - curd = 0 - for i_d, td in batch_padder(finput, ftarget, vcbi, vcbt, _bsize, maxpad, maxpart, _maxtoken, minbsize): - rid = numpy.array(i_d, dtype=numpy.int32) - rtd = numpy.array(td, dtype=numpy.int32) - #rld = numpy.array(ld, dtype=numpy.int32) - wid = str(curd) - src_grp.create_dataset(wid, data=rid, **h5datawargs) - tgt_grp.create_dataset(wid, data=rtd, **h5datawargs) - #rsf["l" + wid] = rld - curd += 1 - rsf["ndata"] = numpy.array([curd], dtype=numpy.int32) - rsf["nword"] = numpy.array([nwordi, nwordt], dtype=numpy.int32) - rsf.close() + with h5File(frs,'w') as rsf: + src_grp = rsf.create_group("src") + tgt_grp = rsf.create_group("tgt") + curd = 0 + for i_d, td in batch_padder(finput, ftarget, vcbi, vcbt, _bsize, maxpad, maxpart, _maxtoken, minbsize): + rid = numpy.array(i_d, dtype=numpy.int32) + rtd = numpy.array(td, dtype=numpy.int32) + #rld = numpy.array(ld, dtype=numpy.int32) + wid = str(curd) + src_grp.create_dataset(wid, data=rid, **h5datawargs) + tgt_grp.create_dataset(wid, data=rtd, **h5datawargs) + #rsf["l" + wid] = rld + curd += 1 + rsf["ndata"] = numpy.array([curd], dtype=numpy.int32) + rsf["nword"] = numpy.array([nwordi, nwordt], dtype=numpy.int32) print("Number of batches: %d\nSource Vocabulary Size: %d\nTarget Vocabulary Size: %d" % (curd, nwordi, nwordt,)) if __name__ == "__main__": diff --git a/tools/mktest.py b/tools/mktest.py index 15c67f9..81377f6 100644 --- a/tools/mktest.py +++ b/tools/mktest.py @@ -3,9 +3,9 @@ import sys import numpy -import h5py from utils.fmt.base import ldvocab +from utils.h5serial import h5File from utils.fmt.single import batch_padder from cnfg.ihyp import * @@ -20,19 +20,18 @@ def handle(finput, fvocab_i, frs, minbsize=1, expand_for_mulgpu=True, bsize=max_ else: _bsize = bsize _maxtoken = maxtoken - rsf = h5py.File(frs,'w') - src_grp = rsf.create_group("src") - curd = 0 - for i_d in batch_padder(finput, vcbi, _bsize, maxpad, maxpart, _maxtoken, minbsize): - rid = numpy.array(i_d, dtype=numpy.int32) - #rld = numpy.array(ld, dtype=numpy.int32) - wid = str(curd) - src_grp.create_dataset(wid, data=rid, **h5datawargs) - #rsf["l" + wid] = rld - curd += 1 - rsf["ndata"] = numpy.array([curd], dtype=numpy.int32) - rsf["nword"] = numpy.array([nwordi], dtype=numpy.int32) - rsf.close() + with h5File(frs,'w') as rsf: + src_grp = rsf.create_group("src") + curd = 0 + for i_d in batch_padder(finput, vcbi, _bsize, maxpad, maxpart, _maxtoken, minbsize): + rid = numpy.array(i_d, dtype=numpy.int32) + #rld = numpy.array(ld, dtype=numpy.int32) + wid = str(curd) + src_grp.create_dataset(wid, data=rid, **h5datawargs) + #rsf["l" + wid] = rld + curd += 1 + rsf["ndata"] = numpy.array([curd], dtype=numpy.int32) + rsf["nword"] = numpy.array([nwordi], dtype=numpy.int32) print("Number of batches: %d\nSource Vocabulary Size: %d" % (curd, nwordi,)) if __name__ == "__main__": diff --git a/tools/mulang/cnfg b/tools/mulang/cnfg new file mode 120000 index 0000000..bcd9a88 --- /dev/null +++ b/tools/mulang/cnfg @@ -0,0 +1 @@ +../../cnfg/ \ No newline at end of file diff --git a/tools/mulang/eff/cnfg b/tools/mulang/eff/cnfg new file mode 120000 index 0000000..a958f86 --- /dev/null +++ b/tools/mulang/eff/cnfg @@ -0,0 +1 @@ +../../../cnfg/ \ No newline at end of file diff --git a/tools/mulang/eff/mkiodata.py b/tools/mulang/eff/mkiodata.py new file mode 100644 index 0000000..ab8c111 --- /dev/null +++ b/tools/mulang/eff/mkiodata.py @@ -0,0 +1,51 @@ +#encoding: utf-8 + +import sys + +import numpy + +from utils.h5serial import h5File +from utils.fmt.base import ldvocab +from utils.fmt.mulang.eff.dual import batch_padder + +from cnfg.ihyp import * + +def handle(finput, ftarget, fvocab_i, fvocab_t, fvocab_task, frs, minbsize=1, expand_for_mulgpu=True, bsize=max_sentences_gpu, maxpad=max_pad_tokens_sentence, maxpart=normal_tokens_vs_pad_tokens, maxtoken=max_tokens_gpu, minfreq=False, vsize=False): + vcbi, nwordi = ldvocab(fvocab_i, minf=minfreq, omit_vsize=vsize, vanilla=False) + vcbt, nwordt = ldvocab(fvocab_t, minf=minfreq, omit_vsize=vsize, vanilla=False) + vcbtask, nwordtask = ldvocab(fvocab_task, minf=False, omit_vsize=False, vanilla=True) + if expand_for_mulgpu: + _bsize = bsize * minbsize + _maxtoken = maxtoken * minbsize + else: + _bsize = bsize + _maxtoken = maxtoken + with h5File(frs, 'w') as rsf: + curd = {} + torder = [] + for i_d, td, taskd in batch_padder(finput, ftarget, vcbi, vcbt, vcbtask, _bsize, maxpad, maxpart, _maxtoken, minbsize): + _str_taskd = str(taskd) + if _str_taskd in rsf: + task_grp = rsf[_str_taskd] + src_grp = task_grp["src"] + tgt_grp = task_grp["tgt"] + else: + task_grp = rsf.create_group(_str_taskd) + src_grp = task_grp.create_group("src") + tgt_grp = task_grp.create_group("tgt") + torder.append(taskd) + rid = numpy.array(i_d, dtype=numpy.int32) + rtd = numpy.array(td, dtype=numpy.int32) + _id = curd.get(taskd, 0) + wid = str(_id) + src_grp.create_dataset(wid, data=rid, **h5datawargs) + tgt_grp.create_dataset(wid, data=rtd, **h5datawargs) + curd[taskd] = _id + 1 + rsf["taskorder"] = numpy.array(torder, dtype=numpy.int32) + curd = [curd[tmp] for tmp in torder] + rsf["ndata"] = numpy.array(curd, dtype=numpy.int32) + rsf["nword"] = numpy.array([nwordi, nwordtask, nwordt], dtype=numpy.int32) + print("Number of Batches: %d\nSource Vocabulary Size: %d\nTarget Vocabulary Size: %d\nNumber of Tasks: %d" % (sum(curd), nwordi, nwordt, nwordtask,)) + +if __name__ == "__main__": + handle(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4], sys.argv[5], sys.argv[6], int(sys.argv[7])) diff --git a/tools/mulang/eff/mktest.py b/tools/mulang/eff/mktest.py new file mode 100644 index 0000000..2a51ff8 --- /dev/null +++ b/tools/mulang/eff/mktest.py @@ -0,0 +1,46 @@ +#encoding: utf-8 + +import sys + +import numpy + +from utils.h5serial import h5File +from utils.fmt.base import ldvocab +from utils.fmt.mulang.eff.single import batch_padder + +from cnfg.ihyp import * + +# maxtoken should be the maxtoken in mkiodata.py / 2 / beam size roughly, similar for bsize + +def handle(finput, fvocab_i, fvocab_task, frs, minbsize=1, expand_for_mulgpu=True, bsize=max_sentences_gpu, maxpad=max_pad_tokens_sentence, maxpart=normal_tokens_vs_pad_tokens, maxtoken=max_tokens_gpu, minfreq=False, vsize=False): + vcbi, nwordi = ldvocab(fvocab_i, minf=minfreq, omit_vsize=vsize, vanilla=False) + vcbtask, nwordtask = ldvocab(fvocab_task, minf=False, omit_vsize=False, vanilla=True) + if expand_for_mulgpu: + _bsize = bsize * minbsize + _maxtoken = maxtoken * minbsize + else: + _bsize = bsize + _maxtoken = maxtoken + with h5File(frs, 'w') as rsf: + curd = {} + torder = [] + for i_d, taskd in batch_padder(finput, vcbi, vcbtask, _bsize, maxpad, maxpart, _maxtoken, minbsize): + _str_taskd = str(taskd) + if _str_taskd in rsf: + src_grp = rsf[_str_taskd]["src"] + else: + src_grp = rsf.create_group(_str_taskd).create_group("src") + torder.append(taskd) + rid = numpy.array(i_d, dtype=numpy.int32) + _id = curd.get(taskd, 0) + wid = str(_id) + src_grp.create_dataset(wid, data=rid, **h5datawargs) + curd[taskd] = _id + 1 + rsf["taskorder"] = numpy.array(torder, dtype=numpy.int32) + curd = [curd[tmp] for tmp in torder] + rsf["ndata"] = numpy.array(curd, dtype=numpy.int32) + rsf["nword"] = numpy.array([nwordi, nwordtask], dtype=numpy.int32) + print("Number of batches: %d\nSource Vocabulary Size: %d\nNumber of Tasks: %d" % (sum(curd), nwordi, nwordtask,)) + +if __name__ == "__main__": + handle(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4], int(sys.argv[5])) diff --git a/tools/mulang/eff/sort.py b/tools/mulang/eff/sort.py new file mode 100644 index 0000000..f4617a2 --- /dev/null +++ b/tools/mulang/eff/sort.py @@ -0,0 +1,51 @@ +#encoding: utf-8 + +import sys +from random import seed as rpyseed + +from utils.fmt.base import clean_liststr_lentok, all_le, maxfreq_filter, shuffle_pair, iter_dict_sort, dict_insert_list, dict_insert_set, FileList + +# remove_same: reduce same data in the corpus +# shuf: shuffle the data of same source/target length +# max_remove: if one source has several targets, only keep those with highest frequency + +def handle(srcfl, tgtfl, max_len=256, remove_same=False, shuf=True, max_remove=False): + + _max_len = max(1, max_len - 2) + + _insert_func = dict_insert_set if remove_same and (not max_remove) else dict_insert_list + data = {} + + with FileList(srcfl, "rb") as fl: + for lines in zip(*fl): + lines = [line.strip() for line in lines] + if all(lines): + lines, lens = zip(*[clean_liststr_lentok(line.decode("utf-8").split()) for line in lines]) + if all_le(lens, max_len): + lgth = sum(lens) + ls = lines[0] + data = _insert_func(data, tuple(line.encode("utf-8") for line in lines), ls[:ls.find(" ")], lgth, *reversed(lens[1:])) + + ens = "\n".encode("utf-8") + + with FileList(tgtfl, "wb") as fl: + for tmp in iter_dict_sort(data): + lines = zip(*tmp) + if len(tmp) > 1: + if max_remove: + lines = maxfreq_filter(*lines) + if shuf: + lines = shuffle_pair(*lines) + for du, f in zip(lines, fl): + f.write(ens.join(du)) + f.write(ens) + +if __name__ == "__main__": + rpyseed(666666) + _nargs = len(sys.argv) + if _nargs % 2 == 0: + _sep_ind = _nargs // 2 + handle(sys.argv[1:_sep_ind], sys.argv[_sep_ind:-1], max_len=int(sys.argv[-1])) + else: + _sep_ind = (_nargs + 1) // 2 + handle(sys.argv[1:_sep_ind], sys.argv[_sep_ind:]) diff --git a/tools/mulang/eff/utils b/tools/mulang/eff/utils new file mode 120000 index 0000000..256f914 --- /dev/null +++ b/tools/mulang/eff/utils @@ -0,0 +1 @@ +../../../utils/ \ No newline at end of file diff --git a/tools/mulang/share_vocab.py b/tools/mulang/share_vocab.py new file mode 100644 index 0000000..8499754 --- /dev/null +++ b/tools/mulang/share_vocab.py @@ -0,0 +1,39 @@ +#encoding: utf-8 + +import sys + +from utils.fmt.base import clean_list, clean_list_iter, save_vocab + +def handle(srcfl, rsf, rslangf, vsize=65532): + + vocab = {} + lang_vocab = {} + + curid = 0 + for srcf in srcfl: + if srcf == "--target": + break + with open(srcf, "rb") as f: + for line in f: + tmp = line.strip() + if tmp: + tokens = clean_list(tmp.decode("utf-8").split()) + for token in tokens[1:]: + vocab[token] = vocab.get(token, 0) + 1 + token = tokens[0] + lang_vocab[token] = lang_vocab.get(token, 0) + 1 + curid += 1 + + for srcf in srcfl[curid+1:]: + with open(srcf, "rb") as f: + for line in f: + tmp = line.strip() + if tmp: + for token in clean_list_iter(tmp.decode("utf-8").split()): + vocab[token] = vocab.get(token, 0) + 1 + + save_vocab(vocab, rsf, omit_vsize=vsize) + save_vocab(lang_vocab, rslangf, omit_vsize=False) + +if __name__ == "__main__": + handle(sys.argv[1:-3], sys.argv[-3], sys.argv[-2], int(sys.argv[-1])) diff --git a/tools/mulang/utils b/tools/mulang/utils new file mode 120000 index 0000000..7d6b64a --- /dev/null +++ b/tools/mulang/utils @@ -0,0 +1 @@ +../../utils/ \ No newline at end of file diff --git a/tools/mulang/vocab.py b/tools/mulang/vocab.py new file mode 100644 index 0000000..2d1425c --- /dev/null +++ b/tools/mulang/vocab.py @@ -0,0 +1,26 @@ +#encoding: utf-8 + +import sys + +from utils.fmt.base import clean_list, save_vocab + +def handle(srcf, rsf, rslangf, vsize=65532): + + vocab = {} + lang_vocab = {} + + with open(srcf, "rb") as f: + for line in f: + tmp = line.strip() + if tmp: + tokens = clean_list(tmp.decode("utf-8").split()) + for token in tokens[1:]: + vocab[token] = vocab.get(token, 0) + 1 + token = tokens[0] + lang_vocab[token] = lang_vocab.get(token, 0) + 1 + + save_vocab(vocab, rsf, omit_vsize=vsize) + save_vocab(lang_vocab, rslangf, omit_vsize=False) + +if __name__ == "__main__": + handle(sys.argv[1], sys.argv[2], sys.argv[3]) if len(sys.argv) == 4 else handle(sys.argv[1], sys.argv[2], sys.argv[3], int(sys.argv[-1])) diff --git a/tools/prune_model_vocab.py b/tools/prune_model_vocab.py index 117105e..4a5e910 100644 --- a/tools/prune_model_vocab.py +++ b/tools/prune_model_vocab.py @@ -34,7 +34,7 @@ def handle(common, src, tgt, srcm, rsm, minfreq=False, vsize=False): mymodel = NMT(cnfg.isize, nwordf, nwordf, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) mymodel = load_model_cpu(srcm, mymodel) mymodel.update_vocab(src_indices=src_indices, tgt_indices=tgt_indices) - save_model(mymodel, rsm, sub_module=False, logger=None, h5args=h5zipargs) + save_model(mymodel, rsm, sub_module=False, h5args=h5zipargs) if __name__ == "__main__": handle(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4], sys.argv[5]) diff --git a/tools/restore.py b/tools/restore.py index a3f57e7..3152fc8 100644 --- a/tools/restore.py +++ b/tools/restore.py @@ -2,8 +2,8 @@ import sys -# WARNING: all_true might be too strict in some cases which may use any_true -from utils.fmt.base import clean_str, all_true, FileList +# WARNING: all() might be too strict in some cases which may use any() +from utils.fmt.base import clean_str, FileList # srtfl: (k - 1) source + 1 target def handle(srcfl, srtfl, tgtf): @@ -13,7 +13,7 @@ def handle(srcfl, srtfl, tgtf): with FileList(srtfl, "rb") as fs: for lines in zip(*fs): lines = tuple(line.strip() for line in lines) - if all_true(lines): + if all(lines): lines = tuple(clean_str(line.decode("utf-8")) for line in lines) data[lines[:-1]] = lines[-1].encode("utf-8") @@ -21,7 +21,7 @@ def handle(srcfl, srtfl, tgtf): with FileList(srcfl, "rb") as fs, open(tgtf, "wb") as ft: for lines in zip(*fs): lines = tuple(line.strip() for line in lines) - if all_true(lines): + if all(lines): lines = tuple(clean_str(line.decode("utf-8")) for line in lines) if lines in data: ft.write(data[lines]) diff --git a/tools/shuffle.py b/tools/shuffle.py index 2e6f88d..c6b78ad 100644 --- a/tools/shuffle.py +++ b/tools/shuffle.py @@ -2,8 +2,7 @@ import sys -from random import seed as rpyseed -from random import shuffle +from random import seed as rpyseed, shuffle from utils.fmt.base import clean_str, FileList diff --git a/tools/sort.py b/tools/sort.py index b9d0cf5..73718b9 100644 --- a/tools/sort.py +++ b/tools/sort.py @@ -3,7 +3,7 @@ import sys from random import seed as rpyseed -from utils.fmt.base import clean_liststr_lentok, all_true, all_le, maxfreq_filter, shuffle_pair, iter_dict_sort, dict_insert_list, dict_insert_set, FileList +from utils.fmt.base import clean_liststr_lentok, all_le, maxfreq_filter, shuffle_pair, iter_dict_sort, dict_insert_list, dict_insert_set, FileList # remove_same: reduce same data in the corpus # shuf: shuffle the data of same source/target length @@ -19,7 +19,7 @@ def handle(srcfl, tgtfl, max_len=256, remove_same=False, shuf=True, max_remove=F with FileList(srcfl, "rb") as fl: for lines in zip(*fl): lines = [line.strip() for line in lines] - if all_true(lines): + if all(lines): lines, lens = zip(*[clean_liststr_lentok(line.decode("utf-8").split()) for line in lines]) if all_le(lens, max_len): lgth = sum(lens) diff --git a/tools/spm/decode.py b/tools/spm/decode.py new file mode 100644 index 0000000..73ecda8 --- /dev/null +++ b/tools/spm/decode.py @@ -0,0 +1,50 @@ +#encoding: utf-8 + +# portal from fairseq: https://github.com/pytorch/fairseq/blob/master/scripts/spm_encode.py + +import sys +from argparse import ArgumentParser +from sentencepiece import SentencePieceProcessor + +def main(): + parser = ArgumentParser() + parser.add_argument("--model", required=True, help="sentencepiece model to use for decoding") + parser.add_argument("--input", default="-", help="input file to decode") + parser.add_argument("--input_format", choices=["piece", "id"], default="piece") + args = parser.parse_args() + + sp = SentencePieceProcessor() + sp.Load(args.model) + + if args.input_format == "piece": + + def decode(l): + return "".join(sp.DecodePieces(l)) + + elif args.input_format == "id": + + def decode(l): + return "".join(sp.DecodeIds(l)) + + def tok2int(tok): + # remap reference-side to 0 + return int(tok) if tok != "" else 0 + + if args.input == "-": + if args.input_format == "id": + for line in sys.stdin: + print(decode(list(map(tok2int, line.rstrip().split())))) + elif args.input_format == "piece": + for line in sys.stdin: + print(decode(line.rstrip().split())) + else: + with open(args.input, "r", encoding="utf-8") as h: + if args.input_format == "id": + for line in h: + print(decode(list(map(tok2int, line.rstrip().split())))) + elif args.input_format == "piece": + for line in h: + print(decode(line.rstrip().split())) + +if __name__ == "__main__": + main() diff --git a/tools/spm/encode.py b/tools/spm/encode.py new file mode 100644 index 0000000..d2c1a2c --- /dev/null +++ b/tools/spm/encode.py @@ -0,0 +1,65 @@ +#encoding: utf-8 + +# portal from fairseq: https://github.com/pytorch/fairseq/blob/master/scripts/spm_encode.py + +import sys +from contextlib import ExitStack +from argparse import ArgumentParser +from sentencepiece import SentencePieceProcessor + +def main(): + + parser = ArgumentParser() + parser.add_argument("--model", required=True, help="sentencepiece model to use for encoding") + parser.add_argument("--inputs", nargs="+", default=["-"], help="input files to filter/encode") + parser.add_argument("--outputs", nargs="+", default=["-"], help="path to save encoded outputs") + parser.add_argument("--output_format", choices=["piece", "id"], default="piece") + parser.add_argument("--min-len", type=int, metavar="N", help="filter sentence pairs with fewer than N tokens") + parser.add_argument("--max-len", type=int, metavar="N", help="filter sentence pairs with more than N tokens") + args = parser.parse_args() + + sp = SentencePieceProcessor() + sp.Load(args.model) + + if args.output_format == "piece": + + def encode(l): + return sp.EncodeAsPieces(l) + + elif args.output_format == "id": + + def encode(l): + return list(map(str, sp.EncodeAsIds(l))) + + if args.min_len is not None or args.max_len is not None: + + def valid(line): + return (args.min_len is None or len(line) >= args.min_len) and (args.max_len is None or len(line) <= args.max_len) + + else: + + def valid(lines): + return True + + with ExitStack() as stack: + inputs = [stack.enter_context(open(input, "r", encoding="utf-8")) if input != "-" else sys.stdin for input in args.inputs] + outputs = [stack.enter_context(open(output, "w", encoding="utf-8")) if output != "-" else sys.stdout for output in args.outputs] + + def encode_line(line): + line = line.strip() + if len(line) > 0: + line = encode(line) + if valid(line): + return line + return None + + for i, lines in enumerate(zip(*inputs), start=1): + enc_lines = list(map(encode_line, lines)) + if not any(enc_line is None for enc_line in enc_lines): + for enc_line, output_h in zip(enc_lines, outputs): + print(" ".join(enc_line), file=output_h) + if i % 10000 == 0: + print("processed {} lines".format(i), file=sys.stderr) + +if __name__ == "__main__": + main() diff --git a/tools/spm/train.py b/tools/spm/train.py new file mode 100644 index 0000000..0b6fb43 --- /dev/null +++ b/tools/spm/train.py @@ -0,0 +1,9 @@ +#encoding: utf-8 + +# portal from fairseq: https://github.com/pytorch/fairseq/blob/master/scripts/spm_train.py + +import sys +from sentencepiece SentencePieceTrainer + +if __name__ == "__main__": + SentencePieceTrainer.Train(" ".join(sys.argv[1:])) diff --git a/train.py b/train.py index e05cfd0..7ba0321 100644 --- a/train.py +++ b/train.py @@ -12,6 +12,7 @@ from utils.base import * from utils.init import init_model_params +from utils.contpara import get_model_parameters from utils.h5serial import h5save, h5load from utils.fmt.base import tostr, save_states, load_states, pad_id from utils.fmt.base4torch import parse_cuda, load_emb @@ -23,9 +24,6 @@ from tqdm import tqdm -from os import makedirs -from os.path import exists as p_check - import h5py import cnfg.base as cnfg @@ -75,7 +73,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok _done_tokens += wd_add if _done_tokens >= tokens_optm: - optm_step(optm, model=model, scaler=scaler, multi_gpu=multi_gpu, multi_gpu_optimizer=multi_gpu_optimizer) + optm_step(optm, model=model, scaler=scaler, multi_gpu=multi_gpu, multi_gpu_optimizer=multi_gpu_optimizer, zero_grad_none=optm_step_zero_grad_set_none) _done_tokens = 0 if _cur_rstep is not None: if save_checkp_epoch and (save_every is not None) and (_cur_rstep % save_every == 0) and (chkpf is not None) and (_cur_rstep > 0): @@ -88,7 +86,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok else: _chkpf = chkpf _chkpof = chkpof - save_model(model, _chkpf, multi_gpu, logger) + save_model(model, _chkpf, multi_gpu, print_func=logger.info) if chkpof is not None: h5save(optm.state_dict(), _chkpof) if statesf is not None: @@ -122,8 +120,8 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok else: _chkpf = chkpf _chkpof = chkpof - #save_model(model, _chkpf, isinstance(model, nn.DataParallel), logger) - save_model(model, _chkpf, multi_gpu, logger) + #save_model(model, _chkpf, isinstance(model, nn.DataParallel), print_func=logger.info) + save_model(model, _chkpf, multi_gpu, print_func=logger.info) if chkpof is not None: h5save(optm.state_dict(), _chkpof) if statesf is not None: @@ -181,32 +179,23 @@ def load_fixing(module): module.fix_load() rid = cnfg.run_id - earlystop = cnfg.earlystop - maxrun = cnfg.maxrun - tokens_optm = cnfg.tokens_optm - done_tokens = 0 - batch_report = cnfg.batch_report report_eva = cnfg.report_eva - use_ams = cnfg.use_ams - save_optm_state = cnfg.save_optm_state - +save_auto_clean = cnfg.save_auto_clean +overwrite_eva = cnfg.overwrite_eva save_every = cnfg.save_every start_chkp_save = cnfg.epoch_start_checkpoint_save - epoch_save = cnfg.epoch_save - remain_steps = cnfg.training_steps wkdir = "".join((cnfg.exp_dir, cnfg.data_id, "/", cnfg.group_id, "/", rid, "/")) -if not p_check(wkdir): - makedirs(wkdir) +mkdir(wkdir) chkpf = None chkpof = None @@ -247,9 +236,7 @@ def load_fixing(module): mymodel = load_model_cpu(fine_tune_m, mymodel) mymodel.apply(load_fixing) -#lw = torch.ones(nwordt).float() -#lw[0] = 0.0 -#lossf = nn.NLLLoss(lw, ignore_index=0, reduction='sum') +#lossf = NLLLoss(ignore_index=pad_id, reduction='sum') lossf = LabelSmoothingLoss(nwordt, cnfg.label_smoothing, ignore_index=pad_id, reduction='sum', forbidden_index=cnfg.forbidden_indexes) if cnfg.src_emb is not None: @@ -271,21 +258,20 @@ def load_fixing(module): mymodel = DataParallelMT(mymodel, device_ids=cuda_devices, output_device=cuda_device.index, host_replicate=True, gather_output=False) lossf = DataParallelCriterion(lossf, device_ids=cuda_devices, output_device=cuda_device.index, replicate_once=True) -if multi_gpu_optimizer: - optimizer = mymodel.build_optimizer(Optimizer, lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams) - mymodel.zero_grad(set_to_none=True) +if multi_gpu: + optimizer = mymodel.build_optimizer(Optimizer, lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams, multi_gpu_optimizer=multi_gpu_optimizer, contiguous_parameters=contiguous_parameters) else: # lr will be over written by LRScheduler before used - optimizer = Optimizer((mymodel.module if multi_gpu else mymodel).parameters(), lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams) - optimizer.zero_grad(set_to_none=True) + optimizer = Optimizer(get_model_parameters(mymodel, contiguous_parameters=contiguous_parameters), lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams) +optimizer.zero_grad(set_to_none=optm_step_zero_grad_set_none) fine_tune_state = cnfg.fine_tune_state if fine_tune_state is not None: logger.info("Load optimizer state from: " + fine_tune_state) optimizer.load_state_dict(h5load(fine_tune_state)) +# lrsch.step() will be automatically called with the constructor lrsch = LRScheduler(optimizer, cnfg.isize, cnfg.warm_step, scale=cnfg.lr_scale) -#lrsch.step() num_checkpoint = cnfg.num_checkpoint cur_checkid = 0 @@ -296,16 +282,16 @@ def load_fixing(module): logger.info("Init lr: %s, Dev Loss/Error: %.3f %.2f" % (" ".join(tostr(getlr(optimizer))), minloss, minerr,)) if fine_tune_m is None: - save_model(mymodel, wkdir + "init.h5", multi_gpu, logger) + save_model(mymodel, wkdir + "init.h5", multi_gpu, print_func=logger.info) logger.info("Initial model saved") else: cnt_states = cnfg.train_statesf - if (cnt_states is not None) and p_check(cnt_states): + if cnt_states is not None: logger.info("Continue last epoch") tminerr, done_tokens, cur_checkid, remain_steps, _ = train(td, load_states(cnt_states), vd, nvalid, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, multi_gpu_optimizer, tokens_optm, batch_report, save_every, chkpf, chkpof, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, False, False, scaler) vloss, vprec = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp) logger.info("Epoch: 0, train loss: %.3f, valid loss/error: %.3f %.2f" % (tminerr, vloss, vprec,)) - save_model(mymodel, wkdir + "train_0_%.3f_%.3f_%.2f.h5" % (tminerr, vloss, vprec,), multi_gpu, logger) + save_model(mymodel, wkdir + "train_0_%.3f_%.3f_%.2f.h5" % (tminerr, vloss, vprec,), multi_gpu, print_func=logger.info, mtyp=("eva" if overwrite_eva else "train") if save_auto_clean else None) if save_optm_state: h5save(optimizer.state_dict(), wkdir + "train_0_%.3f_%.3f_%.2f.optm.h5" % (tminerr, vloss, vprec,)) logger.info("New best model saved") @@ -334,7 +320,7 @@ def load_fixing(module): logger.info("Epoch: %d, train loss: %.3f, valid loss/error: %.3f %.2f" % (i, terr, vloss, vprec,)) if (vprec <= minerr) or (vloss <= minloss): - save_model(mymodel, wkdir + "eva_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec,), multi_gpu, logger) + save_model(mymodel, wkdir + "eva_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec,), multi_gpu, print_func=logger.info, mtyp="eva" if save_auto_clean else None) if save_optm_state: h5save(optimizer.state_dict(), wkdir + "eva_%d_%.3f_%.3f_%.2f.optm.h5" % (i, terr, vloss, vprec,)) logger.info("New best model saved") @@ -349,11 +335,11 @@ def load_fixing(module): else: if terr < tminerr: tminerr = terr - save_model(mymodel, wkdir + "train_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec,), multi_gpu, logger) + save_model(mymodel, wkdir + "train_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec,), multi_gpu, print_func=logger.info, mtyp=("eva" if overwrite_eva else "train") if save_auto_clean else None) if save_optm_state: h5save(optimizer.state_dict(), wkdir + "train_%d_%.3f_%.3f_%.2f.optm.h5" % (i, terr, vloss, vprec,)) elif epoch_save: - save_model(mymodel, wkdir + "epoch_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec,), multi_gpu, logger) + save_model(mymodel, wkdir + "epoch_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec,), multi_gpu, print_func=logger.info) namin += 1 if namin >= earlystop: @@ -389,7 +375,7 @@ def load_fixing(module): #lrsch.step() #done_tokens = 0 -save_model(mymodel, wkdir + "last.h5", multi_gpu, logger) +save_model(mymodel, wkdir + "last.h5", multi_gpu, print_func=logger.info) if save_optm_state: h5save(optimizer.state_dict(), wkdir + "last.optm.h5") logger.info("model saved") diff --git a/transformer/AGG/HierDecoder.py b/transformer/AGG/HierDecoder.py index 5dfcf2d..e5a13a1 100644 --- a/transformer/AGG/HierDecoder.py +++ b/transformer/AGG/HierDecoder.py @@ -4,8 +4,7 @@ from torch import nn from modules.base import * -from transformer.Decoder import DecoderLayer as DecoderLayerBase -from transformer.Decoder import Decoder as DecoderBase +from transformer.Decoder import DecoderLayer as DecoderLayerBase, Decoder as DecoderBase from utils.base import align_modules_by_type diff --git a/transformer/AGG/HierEncoder.py b/transformer/AGG/HierEncoder.py index fb199de..1d3b665 100644 --- a/transformer/AGG/HierEncoder.py +++ b/transformer/AGG/HierEncoder.py @@ -3,8 +3,7 @@ from torch import nn from modules.base import * -from transformer.Encoder import EncoderLayer as EncoderLayerBase -from transformer.Encoder import Encoder as EncoderBase +from transformer.Encoder import EncoderLayer as EncoderLayerBase, Encoder as EncoderBase from utils.base import align_modules_by_type diff --git a/transformer/APE/Decoder.py b/transformer/APE/Decoder.py index 0daf633..c3b6e64 100644 --- a/transformer/APE/Decoder.py +++ b/transformer/APE/Decoder.py @@ -2,10 +2,9 @@ import torch from torch import nn -from modules.base import CrossAttn +from modules.base import ResCrossAttn -from transformer.Decoder import DecoderLayer as DecoderLayerBase -from transformer.Decoder import Decoder as DecoderBase +from transformer.Decoder import DecoderLayer as DecoderLayerBase, Decoder as DecoderBase from utils.base import all_done, index_tensors, expand_bsize_for_beam, mask_tensor_type from math import sqrt @@ -22,46 +21,18 @@ def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, a super(DecoderLayer, self).__init__(isize, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, **kwargs) - self.cross_attn_mt = CrossAttn(isize, _ahsize, isize, num_head=num_head, dropout=attn_drop) - self.layer_normer3 = nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) + self.cross_attn_mt = ResCrossAttn(isize, _ahsize, isize, num_head=num_head, dropout=attn_drop, norm_residual=self.cross_attn.norm_residual) def forward(self, inpute, inputm, inputo, src_pad_mask=None, mt_pad_mask=None, tgt_pad_mask=None, query_unit=None): if query_unit is None: - _inputo = self.layer_normer1(inputo) - - context = self.self_attn(_inputo, mask=tgt_pad_mask) - - if self.drop is not None: - context = self.drop(context) - - context = context + (_inputo if self.norm_residual else inputo) - + context = self.self_attn(inputo, mask=tgt_pad_mask) else: - _query_unit = self.layer_normer1(query_unit) - - context, states_return = self.self_attn(_query_unit, states=inputo) - - if self.drop is not None: - context = self.drop(context) + context, states_return = self.self_attn(query_unit, states=inputo) - context = context + (_query_unit if self.norm_residual else query_unit) - - _context = self.layer_normer2(context) - _context_new = self.cross_attn(_context, inpute, mask=src_pad_mask) - - if self.drop is not None: - _context_new = self.drop(_context_new) - - context = _context_new + (_context if self.norm_residual else context) - - _context = self.layer_normer3(context) - _context_new = self.cross_attn_mt(_context, inputm, mask=mt_pad_mask) - - if self.drop is not None: - _context_new = self.drop(_context_new) + context = self.cross_attn(context, inpute, mask=src_pad_mask) - context = _context_new + (_context if self.norm_residual else context) + context = self.cross_attn_mt(context, inputm, mask=mt_pad_mask) context = self.ff(context) diff --git a/transformer/APE/Encoder.py b/transformer/APE/Encoder.py index e816949..04f23b8 100644 --- a/transformer/APE/Encoder.py +++ b/transformer/APE/Encoder.py @@ -16,22 +16,9 @@ class MSEncoderLayer(MSEncoderLayerBase): def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None): - _inputo = self.layer_normer1(inputo) + context = self.self_attn(inputo, mask=tgt_pad_mask) - context = self.self_attn(_inputo, mask=tgt_pad_mask) - - if self.drop is not None: - context = self.drop(context) - - context = context + (_inputo if self.norm_residual else inputo) - - _context = self.layer_normer2(context) - _context_new = self.cross_attn(_context, inpute, mask=src_pad_mask) - - if self.drop is not None: - _context_new = self.drop(_context_new) - - context = _context_new + (_context if self.norm_residual else context) + context = self.cross_attn(context, inpute, mask=src_pad_mask) context = self.ff(context) diff --git a/transformer/APE/NMT.py b/transformer/APE/NMT.py index e0e7e37..3605f40 100644 --- a/transformer/APE/NMT.py +++ b/transformer/APE/NMT.py @@ -11,7 +11,7 @@ class NMT(NMTBase): - def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=False, forbidden_index=None): + def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=True, forbidden_index=None): enc_layer, dec_layer = parse_double_value_tuple(num_layer) diff --git a/transformer/AvgDecoder.py b/transformer/AvgDecoder.py index a9c58a9..3691152 100644 --- a/transformer/AvgDecoder.py +++ b/transformer/AvgDecoder.py @@ -2,7 +2,7 @@ import torch from torch import nn -from modules.base import * +from modules.aan import AverageAttn from utils.sampler import SampleMax from utils.base import all_done, index_tensors, expand_bsize_for_beam from utils.aan import share_aan_cache @@ -10,13 +10,12 @@ from utils.fmt.base import pad_id -from transformer.Decoder import DecoderLayer as DecoderLayerBase -from transformer.Decoder import Decoder as DecoderBase - -# Average Decoder is proposed in Accelerating Neural Transformer via an Average Attention Network (https://www.aclweb.org/anthology/P18-1166/) +from transformer.Decoder import DecoderLayer as DecoderLayerBase, Decoder as DecoderBase from cnfg.ihyp import * +# Average Decoder is proposed in Accelerating Neural Transformer via an Average Attention Network (https://www.aclweb.org/anthology/P18-1166/) + class DecoderLayer(DecoderLayerBase): # isize: input size @@ -32,6 +31,7 @@ def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, a super(DecoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, **kwargs) + self.drop, self.layer_normer1, self.norm_residual = self.self_attn.drop, self.self_attn.normer, self.self_attn.norm_residual self.self_attn = AverageAttn(isize, _fhsize, dropout) # inpute: encoded representation from encoder (bsize, seql, isize) @@ -69,13 +69,7 @@ def forward(self, inpute, inputo, src_pad_mask=None, query_unit=None, step=1): context = context + (_query_unit if self.norm_residual else query_unit) - _context = self.layer_normer2(context) - _context_new = self.cross_attn(_context, inpute, mask=src_pad_mask) - - if self.drop is not None: - _context_new = self.drop(_context_new) - - context = _context_new + (_context if self.norm_residual else context) + context = self.cross_attn(context, inpute, mask=src_pad_mask) context = self.ff(context) diff --git a/transformer/Decoder.py b/transformer/Decoder.py index 879975d..1966d65 100644 --- a/transformer/Decoder.py +++ b/transformer/Decoder.py @@ -28,18 +28,11 @@ def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, a _ahsize = isize if ahsize is None else ahsize _fhsize = _ahsize * 4 if fhsize is None else fhsize - self.self_attn = SelfAttn(isize, _ahsize, isize, num_head=num_head, dropout=attn_drop, k_rel_pos=k_rel_pos, uni_direction_reduction=True) - self.cross_attn = CrossAttn(isize, _ahsize, isize, num_head=num_head, dropout=attn_drop) + self.self_attn = ResSelfAttn(isize, _ahsize, num_head=num_head, dropout=attn_drop, norm_residual=norm_residual, k_rel_pos=k_rel_pos, uni_direction_reduction=True) + self.cross_attn = ResCrossAttn(isize, _ahsize, num_head=num_head, dropout=attn_drop, norm_residual=norm_residual) self.ff = PositionwiseFF(isize, hsize=_fhsize, dropout=dropout, norm_residual=norm_residual) - self.layer_normer1 = nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) - self.layer_normer2 = nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) - - self.drop = Dropout(dropout, inplace=True) if dropout > 0.0 else None - - self.norm_residual = norm_residual - # inpute: encoded representation from encoder (bsize, seql, isize) # inputo: embedding of decoded translation (bsize, nquery, isize) # src_pad_mask: mask for given encoding source sentence (bsize, nquery, seql), see Encoder, expanded after generated with: @@ -47,6 +40,44 @@ def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, a # tgt_pad_mask: mask to hide the future input # query_unit: single query to decode, used to support decoding for given step + def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None, query_unit=None): + + if query_unit is None: + context = self.self_attn(inputo, mask=tgt_pad_mask) + else: + context, states_return = self.self_attn(query_unit, states=inputo) + + context = self.cross_attn(context, inpute, mask=src_pad_mask) + + context = self.ff(context) + + if query_unit is None: + return context + else: + return context, states_return + +# Not used, keep this class to remind the DecoderLayer implementation before v0.3.5. +class NAWDecoderLayer(DecoderLayer): + + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, ahsize=None, norm_residual=norm_residual_default, k_rel_pos=use_k_relative_position_decoder): + + _ahsize = isize if ahsize is None else ahsize + _fhsize = _ahsize * 4 if fhsize is None else fhsize + + super(NAWDecoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, norm_residual=norm_residual, k_rel_pos=k_rel_pos) + + self.layer_normer1, self.drop, self.norm_residual = self.self_attn.normer, self.self_attn.drop, self.self_attn.norm_residual + self.self_attn = self.self_attn.net + self.layer_normer2 = self.cross_attn.normer + self.cross_attn = self.cross_attn.net + #self.self_attn = SelfAttn(isize, _ahsize, isize, num_head=num_head, dropout=attn_drop, k_rel_pos=k_rel_pos, uni_direction_reduction=True) + #self.cross_attn = CrossAttn(isize, _ahsize, isize, num_head=num_head, dropout=attn_drop) + #self.ff = PositionwiseFF(isize, hsize=_fhsize, dropout=dropout, norm_residual=norm_residual) + #self.layer_normer1 = nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) + #self.layer_normer2 = nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) + #self.drop = Dropout(dropout, inplace=True) if dropout > 0.0 else None + #self.norm_residual = norm_residual + def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None, query_unit=None): if query_unit is None: @@ -483,7 +514,7 @@ def unbind_classifier_weight(self): def update_vocab(self, indices): _nwd = len(indices) - _wemb = nn.Embedding(_nwd, self.wemb.weight.size(-1), padding_idx=pad_id) + _wemb = nn.Embedding(_nwd, self.wemb.weight.size(-1), padding_idx=self.wemb.padding_idx) _classifier = Linear(self.classifier.weight.size(-1), _nwd) with torch.no_grad(): _wemb.weight.copy_(self.wemb.weight.index_select(0, indices)) diff --git a/transformer/Doc/Para/Base/Decoder.py b/transformer/Doc/Para/Base/Decoder.py index 01ddaeb..8ca53ff 100644 --- a/transformer/Doc/Para/Base/Decoder.py +++ b/transformer/Doc/Para/Base/Decoder.py @@ -10,8 +10,7 @@ from utils.fmt.base import pad_id -from transformer.Decoder import DecoderLayer as DecoderLayerBase -from transformer.Decoder import Decoder as DecoderBase +from transformer.Decoder import DecoderLayer as DecoderLayerBase, Decoder as DecoderBase from cnfg.ihyp import * @@ -26,36 +25,17 @@ def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, a self.cattns = nn.ModuleList([CrossAttn(isize, _ahsize, isize, num_head, dropout=attn_drop) for i in range(ncross)]) self.cattn_ln = nn.ModuleList([nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) for i in range(ncross)]) self.grs = nn.ModuleList([GateResidual(isize) for i in range(ncross)]) + self.drop = Dropout(dropout, inplace=True) if dropout > 0.0 else None + self.norm_residual = self.cross_attn.norm_residual def forward(self, inpute, inputo, inputc, src_pad_mask=None, tgt_pad_mask=None, context_mask=None, query_unit=None): if query_unit is None: - _inputo = self.layer_normer1(inputo) - - context = self.self_attn(_inputo, mask=tgt_pad_mask) - - if self.drop is not None: - context = self.drop(context) - - context = context + (_inputo if self.norm_residual else inputo) - + context = self.self_attn(inputo, mask=tgt_pad_mask) else: - _query_unit = self.layer_normer1(query_unit) - - context, states_return = self.self_attn(_query_unit, states=inputo) - - if self.drop is not None: - context = self.drop(context) - - context = context + (_query_unit if self.norm_residual else query_unit) - - _context = self.layer_normer2(context) - _context_new = self.cross_attn(_context, inpute, mask=src_pad_mask) - - if self.drop is not None: - _context_new = self.drop(_context_new) + context, states_return = self.self_attn(query_unit, states=inputo) - context = _context_new + (_context if self.norm_residual else context) + context = self.cross_attn(context, inpute, mask=src_pad_mask) for _ln, _cattn, _gr, _inputc, _maskc in zip(self.cattn_ln, self.cattns, self.grs, inputc, [None for i in range(len(inputc))] if context_mask is None else context_mask): _inputs = _ln(context) diff --git a/transformer/Doc/Para/Base/Encoder.py b/transformer/Doc/Para/Base/Encoder.py index 9f6604f..45c98e8 100644 --- a/transformer/Doc/Para/Base/Encoder.py +++ b/transformer/Doc/Para/Base/Encoder.py @@ -8,8 +8,7 @@ from utils.base import mask_tensor_type -from transformer.Encoder import EncoderLayer as EncoderLayerBase -from transformer.Encoder import Encoder as EncoderBase +from transformer.Encoder import EncoderLayer as EncoderLayerBase, Encoder as EncoderBase from cnfg.ihyp import * diff --git a/transformer/Doc/Para/Base/NMT.py b/transformer/Doc/Para/Base/NMT.py index 70127b8..67e3be5 100644 --- a/transformer/Doc/Para/Base/NMT.py +++ b/transformer/Doc/Para/Base/NMT.py @@ -12,7 +12,7 @@ class NMT(nn.Module): - def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=False, forbidden_index=None, nprev_context=2, num_layer_context=1): + def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=True, forbidden_index=None, nprev_context=2, num_layer_context=1): super(NMT, self).__init__() diff --git a/transformer/Encoder.py b/transformer/Encoder.py index c1e415f..6271942 100644 --- a/transformer/Encoder.py +++ b/transformer/Encoder.py @@ -34,17 +34,37 @@ def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, a _ahsize = isize if ahsize is None else ahsize _fhsize = _ahsize * 4 if fhsize is None else fhsize - self.attn = SelfAttn(isize, _ahsize, isize, num_head=num_head, dropout=attn_drop, k_rel_pos=k_rel_pos) + self.attn = ResSelfAttn(isize, _ahsize, num_head=num_head, dropout=attn_drop, norm_residual=norm_residual, k_rel_pos=k_rel_pos) self.ff = PositionwiseFF(isize, hsize=_fhsize, dropout=dropout, norm_residual=norm_residual) - self.layer_normer = nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) + # inputs: input of this layer (bsize, seql, isize) - self.drop = Dropout(dropout, inplace=True) if dropout > 0.0 else None + def forward(self, inputs, mask=None): - self.norm_residual = norm_residual + context = self.attn(inputs, mask=mask) - # inputs: input of this layer (bsize, seql, isize) + context = self.ff(context) + + return context + +# Not used, keep this class to remind the EncoderLayer implementation before v0.3.5. +class NAWEncoderLayer(EncoderLayer): + + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, ahsize=None, norm_residual=norm_residual_default, k_rel_pos=use_k_relative_position_encoder): + + _ahsize = isize if ahsize is None else ahsize + _fhsize = _ahsize * 4 if fhsize is None else fhsize + + super(NAWEncoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, norm_residual=norm_residual, k_rel_pos=k_rel_pos) + + #self.attn = SelfAttn(isize, _ahsize, isize, num_head=num_head, dropout=attn_drop, k_rel_pos=k_rel_pos) + #self.ff = PositionwiseFF(isize, hsize=_fhsize, dropout=dropout, norm_residual=norm_residual) + self.layer_normer, self.drop, self.norm_residual = self.attn.normer, self.attn.drop, self.attn.norm_residual + self.attn = self.attn.net + #self.layer_normer = nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) + #self.drop = Dropout(dropout, inplace=True) if dropout > 0.0 else None + #self.norm_residual = norm_residual def forward(self, inputs, mask=None): @@ -128,7 +148,7 @@ def load_base(self, base_encoder): def update_vocab(self, indices): - _wemb = nn.Embedding(len(indices), self.wemb.weight.size(-1), padding_idx=pad_id) + _wemb = nn.Embedding(len(indices), self.wemb.weight.size(-1), padding_idx=self.wemb.padding_idx) with torch.no_grad(): _wemb.weight.copy_(self.wemb.weight.index_select(0, indices)) self.wemb = _wemb diff --git a/transformer/HPLSTM/Decoder.py b/transformer/HPLSTM/Decoder.py index 9e55caa..7156dd8 100644 --- a/transformer/HPLSTM/Decoder.py +++ b/transformer/HPLSTM/Decoder.py @@ -2,7 +2,7 @@ import torch from torch import nn -from modules.base import CrossAttn, Dropout +from modules.base import ResCrossAttn, Dropout from modules.hplstm.hfn import HPLSTM from utils.sampler import SampleMax from utils.base import all_done, index_tensors, expand_bsize_for_beam @@ -26,14 +26,12 @@ def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, a _fhsize = _ahsize * 4 if fhsize is None else fhsize self.net = HPLSTM(isize, num_head=num_head, osize=isize, fhsize=_fhsize, dropout=dropout) - self.cross_attn = CrossAttn(isize, _ahsize, isize, num_head=num_head, dropout=attn_drop) + self.cross_attn = ResCrossAttn(isize, _ahsize, num_head=num_head, dropout=attn_drop, norm_residual=norm_residual) self.layer_normer = nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) self.drop = Dropout(dropout, inplace=True) if dropout > 0.0 else None - self.norm_residual = norm_residual - def forward(self, inpute, inputo, src_pad_mask=None, query_unit=None): if query_unit is None: @@ -53,13 +51,7 @@ def forward(self, inpute, inputo, src_pad_mask=None, query_unit=None): context = context + query_unit - _context = self.layer_normer(context) - _context_new = self.cross_attn(_context, inpute, mask=src_pad_mask) - - if self.drop is not None: - _context_new = self.drop(_context_new) - - context = _context_new + (_context if self.norm_residual else context) + context = self.cross_attn(context, inpute, mask=src_pad_mask) if query_unit is None: return context diff --git a/transformer/HPLSTM/FNDecoder.py b/transformer/HPLSTM/FNDecoder.py index fb4f764..d22a512 100644 --- a/transformer/HPLSTM/FNDecoder.py +++ b/transformer/HPLSTM/FNDecoder.py @@ -2,7 +2,7 @@ import torch from torch import nn -from modules.base import CrossAttn, Dropout, PositionwiseFF +from modules.base import ResCrossAttn, Dropout, PositionwiseFF from modules.hplstm.hfn import HPLSTM from utils.sampler import SampleMax from utils.base import all_done, index_tensors, expand_bsize_for_beam @@ -23,14 +23,11 @@ def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, a _fhsize = _ahsize * 4 if fhsize is None else fhsize self.net = HPLSTM(isize, num_head=num_head, osize=isize, fhsize=_fhsize, dropout=dropout) - self.cross_attn = CrossAttn(isize, _ahsize, isize, num_head=num_head, dropout=attn_drop) + self.cross_attn = ResCrossAttn(isize, _ahsize, num_head=num_head, dropout=attn_drop, norm_residual=norm_residual) self.ff = PositionwiseFF(isize, hsize=_fhsize, dropout=dropout, norm_residual=norm_residual) - self.layer_normer = nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) self.drop = Dropout(dropout, inplace=True) if dropout > 0.0 else None - self.norm_residual = norm_residual - def forward(self, inpute, inputo, src_pad_mask=None, query_unit=None): if query_unit is None: @@ -50,13 +47,7 @@ def forward(self, inpute, inputo, src_pad_mask=None, query_unit=None): context = context + query_unit - _context = self.layer_normer(context) - _context_new = self.cross_attn(_context, inpute, mask=src_pad_mask) - - if self.drop is not None: - _context_new = self.drop(_context_new) - - context = _context_new + (_context if self.norm_residual else context) + context = self.cross_attn(context, inpute, mask=src_pad_mask) context = self.ff(context) diff --git a/transformer/LD/Decoder.py b/transformer/LD/Decoder.py index dcc49b5..c33418e 100644 --- a/transformer/LD/Decoder.py +++ b/transformer/LD/Decoder.py @@ -12,8 +12,7 @@ from utils.fmt.base import pad_id -from transformer.Decoder import DecoderLayer as DecoderLayerBase -from transformer.Decoder import Decoder as DecoderBase +from transformer.Decoder import DecoderLayer as DecoderLayerBase, Decoder as DecoderBase from cnfg.ihyp import * @@ -35,33 +34,15 @@ def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, a def forward(self, inpute, inputh, inputo, src_pad_mask=None, chk_pad_mask=None, tgt_pad_mask=None, query_unit=None): if query_unit is None: - context = self.self_attn(inputo, mask=tgt_pad_mask) - - if self.drop is not None: - context = self.drop(context) - - context = context + inputo - else: - context, states_return = self.self_attn(query_unit, states=inputo) - if self.drop is not None: - context = self.drop(context) - - context = context + query_unit - _context = self.layer_normer1(context) _context = self.scff(_context, self.cattn(_context, inputh, mask=chk_pad_mask)) - _context_new = self.cross_attn(_context, inpute, mask=src_pad_mask) - - if self.drop is not None: - _context_new = self.drop(_context_new) - - context = self.layer_normer2(_context_new + _context) + context = self.cross_attn(_context, inpute, mask=src_pad_mask) context = self.ff(context) diff --git a/transformer/LD/Encoder.py b/transformer/LD/Encoder.py index 7c61a7f..3edf728 100644 --- a/transformer/LD/Encoder.py +++ b/transformer/LD/Encoder.py @@ -6,8 +6,7 @@ from math import sqrt, ceil -from transformer.TA.Encoder import EncoderLayer as EncoderLayerBase -from transformer.TA.Encoder import Encoder as EncoderBase +from transformer.TA.Encoder import EncoderLayer as EncoderLayerBase, Encoder as EncoderBase from cnfg.ihyp import * @@ -32,11 +31,6 @@ def forward(self, inputs, sumr, mask=None, rmask=None): context = self.attn(inputs, mask=mask) - if self.drop is not None: - context = self.drop(context) - - context = self.layer_normer(context + inputs) - context = self.ff(context) return context diff --git a/transformer/LD/NMT.py b/transformer/LD/NMT.py index a42417b..bcb03a9 100644 --- a/transformer/LD/NMT.py +++ b/transformer/LD/NMT.py @@ -11,7 +11,7 @@ class NMT(NMTBase): - def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=False, forbidden_index=None): + def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=True, forbidden_index=None): enc_layer, dec_layer = parse_double_value_tuple(num_layer) diff --git a/transformer/MuLang/Eff/Base/Decoder.py b/transformer/MuLang/Eff/Base/Decoder.py new file mode 100644 index 0000000..c388143 --- /dev/null +++ b/transformer/MuLang/Eff/Base/Decoder.py @@ -0,0 +1,276 @@ +#encoding: utf-8 + +import torch +from torch import nn +from modules.mulang.eff.base import LayerNorm, MBLinear, ResSelfAttn, ResCrossAttn, PositionwiseFF +from utils.sampler import SampleMax +from utils.base import all_done, index_tensors, expand_bsize_for_beam +from math import sqrt + +from utils.fmt.base import pad_id + +from transformer.Decoder import DecoderLayer as DecoderLayerBase, Decoder as DecoderBase + +from cnfg.ihyp import * + +class DecoderLayer(DecoderLayerBase): + + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, ahsize=None, ntask=None, k_rel_pos=use_k_relative_position_decoder, **kwargs): + + _ahsize = isize if ahsize is None else ahsize + _fhsize = _ahsize * 4 if fhsize is None else fhsize + + super(DecoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, k_rel_pos=k_rel_pos, **kwargs) + + self.self_attn = ResSelfAttn(isize, _ahsize, num_head=num_head, dropout=attn_drop, norm_residual=self.self_attn.norm_residual, k_rel_pos=k_rel_pos, uni_direction_reduction=True, ntask=ntask) + self.cross_attn = ResCrossAttn(isize, _ahsize, num_head=num_head, dropout=attn_drop, norm_residual=self.cross_attn.norm_residual, ntask=ntask) + self.ff = PositionwiseFF(isize, hsize=_fhsize, dropout=dropout, norm_residual=self.ff.norm_residual, ntask=ntask) + + def forward(self, inpute, inputo, taskid=None, src_pad_mask=None, tgt_pad_mask=None, query_unit=None): + + if query_unit is None: + context = self.self_attn(inputo, taskid=taskid, mask=tgt_pad_mask) + else: + context, states_return = self.self_attn(query_unit, taskid=taskid, states=inputo) + + context = self.cross_attn(context, inpute, taskid=taskid, mask=src_pad_mask) + + context = self.ff(context, taskid=taskid) + + if query_unit is None: + return context + else: + return context, states_return + +class Decoder(DecoderBase): + + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindemb=True, forbidden_index=None, ntask=None, task_emb_w=None, share_layer=False, **kwargs): + + _ahsize = isize if ahsize is None else ahsize + _fhsize = _ahsize * 4 if fhsize is None else fhsize + + super(Decoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindemb=bindemb, forbidden_index=None, share_layer=share_layer, **kwargs) + + self.task_emb = nn.Embedding(ntask, isize, padding_idx=None) + if task_emb_w is not None: + self.task_emb.weight = task_emb_w + self.classifier = MBLinear(isize, nwd, ntask) + if bindemb: + self.classifier.weight = self.wemb.weight + if share_layer: + _shared_layer = DecoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize, ntask=ntask) + self.nets = nn.ModuleList([_shared_layer for i in range(num_layer)]) + else: + self.nets = nn.ModuleList([DecoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize, ntask=ntask) for i in range(num_layer)]) + + self.out_normer = LayerNorm(isize, ntask=ntask, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) if norm_output else None + + if forbidden_index is not None: + self.fbl = [tuple(set(fblu)) for fblu in forbidden_index] + + def forward(self, inpute, inputo, taskid=None, src_pad_mask=None): + + nquery = inputo.size(-1) + + out = self.wemb(inputo) + self.task_emb.weight[taskid] + + out = out * sqrt(out.size(-1)) + if self.pemb is not None: + out = out + self.pemb(inputo, expand=False) + + if self.drop is not None: + out = self.drop(out) + + _mask = self._get_subsequent_mask(nquery) + + for net in self.nets: + out = net(inpute, out, taskid=taskid, src_pad_mask=src_pad_mask, tgt_pad_mask=_mask) + + if self.out_normer is not None: + out = self.out_normer(out, taskid=taskid) + + out = self.lsm(self.classifier(out, taskid)) + + return out + + def decode(self, inpute, taskid=None, src_pad_mask=None, beam_size=1, max_len=512, length_penalty=0.0, fill_pad=False): + + return self.beam_decode(inpute, taskid, src_pad_mask, beam_size, max_len, length_penalty, fill_pad=fill_pad) if beam_size > 1 else self.greedy_decode(inpute, taskid, src_pad_mask, max_len, fill_pad=fill_pad) + + def greedy_decode(self, inpute, taskid=None, src_pad_mask=None, max_len=512, fill_pad=False, sample=False): + + bsize = inpute.size(0) + + sos_emb = self.get_sos_emb(inpute) + sqrt_isize = sqrt(sos_emb.size(-1)) + _task_emb = self.task_emb.weight[taskid] + + out = (sos_emb + _task_emb) * sqrt_isize + if self.pemb is not None: + out = out + self.pemb.get_pos(0) + + if self.drop is not None: + out = self.drop(out) + + states = {} + + for _tmp, net in enumerate(self.nets): + out, _state = net(inpute, (None, None,), taskid=taskid, src_pad_mask=src_pad_mask, tgt_pad_mask=None, query_unit=out) + states[_tmp] = _state + + if self.out_normer is not None: + out = self.out_normer(out, taskid=taskid) + + out = self.classifier(out, taskid) + wds = SampleMax(out.softmax(-1), dim=-1, keepdim=False) if sample else out.argmax(dim=-1) + + trans = [wds] + done_trans = wds.eq(2) + + for i in range(1, max_len): + + out = (self.wemb(wds) + _task_emb) * sqrt_isize + if self.pemb is not None: + out = out + self.pemb.get_pos(i) + + if self.drop is not None: + out = self.drop(out) + + for _tmp, net in enumerate(self.nets): + out, _state = net(inpute, states[_tmp], taskid=taskid, src_pad_mask=src_pad_mask, tgt_pad_mask=None, query_unit=out) + states[_tmp] = _state + + if self.out_normer is not None: + out = self.out_normer(out, taskid=taskid) + + out = self.classifier(out, taskid) + wds = SampleMax(out.softmax(-1), dim=-1, keepdim=False) if sample else out.argmax(dim=-1) + + trans.append(wds.masked_fill(done_trans, pad_id) if fill_pad else wds) + + done_trans = done_trans | wds.eq(2) + if all_done(done_trans, bsize): + break + + return torch.cat(trans, 1) + + def beam_decode(self, inpute, taskid=None, src_pad_mask=None, beam_size=8, max_len=512, length_penalty=0.0, return_all=False, clip_beam=clip_beam_with_lp, fill_pad=False): + + bsize, seql = inpute.size()[:2] + + beam_size2 = beam_size * beam_size + bsizeb2 = bsize * beam_size2 + real_bsize = bsize * beam_size + + sos_emb = self.get_sos_emb(inpute) + isize = sos_emb.size(-1) + sqrt_isize = sqrt(isize) + _task_emb = self.task_emb.weight[taskid] + + if length_penalty > 0.0: + lpv = sos_emb.new_ones(real_bsize, 1) + lpv_base = 6.0 ** length_penalty + + out = (sos_emb + _task_emb) * sqrt_isize + if self.pemb is not None: + out = out + self.pemb.get_pos(0) + + if self.drop is not None: + out = self.drop(out) + + states = {} + + for _tmp, net in enumerate(self.nets): + out, _state = net(inpute, (None, None,), taskid=taskid, src_pad_mask=src_pad_mask, tgt_pad_mask=None, query_unit=out) + states[_tmp] = _state + + if self.out_normer is not None: + out = self.out_normer(out, taskid=taskid) + + out = self.lsm(self.classifier(out, taskid)) + + scores, wds = out.topk(beam_size, dim=-1) + scores = scores.squeeze(1) + sum_scores = scores + wds = wds.view(real_bsize, 1) + trans = wds + + done_trans = wds.view(bsize, beam_size).eq(2) + + self.repeat_cross_attn_buffer(beam_size) + + _src_pad_mask = None if src_pad_mask is None else src_pad_mask.repeat(1, beam_size, 1).view(real_bsize, 1, seql) + + states = expand_bsize_for_beam(states, beam_size=beam_size) + + for step in range(1, max_len): + + out = (self.wemb(wds) + _task_emb) * sqrt_isize + if self.pemb is not None: + out = out + self.pemb.get_pos(step) + + if self.drop is not None: + out = self.drop(out) + + for _tmp, net in enumerate(self.nets): + out, _state = net(inpute, states[_tmp], taskid=taskid, src_pad_mask=_src_pad_mask, tgt_pad_mask=None, query_unit=out) + states[_tmp] = _state + + if self.out_normer is not None: + out = self.out_normer(out, taskid=taskid) + + out = self.lsm(self.classifier(out, taskid)).view(bsize, beam_size, -1) + + _scores, _wds = out.topk(beam_size, dim=-1) + _scores = (_scores.masked_fill(done_trans.unsqueeze(2).expand(bsize, beam_size, beam_size), 0.0) + sum_scores.unsqueeze(2).expand(bsize, beam_size, beam_size)) + + if length_penalty > 0.0: + lpv.masked_fill_(~done_trans.view(real_bsize, 1), ((step + 6.0) ** length_penalty) / lpv_base) + + if clip_beam and (length_penalty > 0.0): + scores, _inds = (_scores.view(real_bsize, beam_size) / lpv.expand(real_bsize, beam_size)).view(bsize, beam_size2).topk(beam_size, dim=-1) + _tinds = (_inds + torch.arange(0, bsizeb2, beam_size2, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize) + sum_scores = _scores.view(bsizeb2).index_select(0, _tinds).view(bsize, beam_size) + else: + scores, _inds = _scores.view(bsize, beam_size2).topk(beam_size, dim=-1) + _tinds = (_inds + torch.arange(0, bsizeb2, beam_size2, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize) + sum_scores = scores + + wds = _wds.view(bsizeb2).index_select(0, _tinds).view(real_bsize, 1) + + _inds = (_inds // beam_size + torch.arange(0, real_bsize, beam_size, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize) + + trans = torch.cat((trans.index_select(0, _inds), wds.masked_fill(done_trans.view(real_bsize, 1), pad_id) if fill_pad else wds), 1) + + done_trans = (done_trans.view(real_bsize).index_select(0, _inds) | wds.eq(2).squeeze(1)).view(bsize, beam_size) + + _done = False + if length_penalty > 0.0: + lpv = lpv.index_select(0, _inds) + elif (not return_all) and all_done(done_trans.select(1, 0), bsize): + _done = True + + if _done or all_done(done_trans, real_bsize): + break + + states = index_tensors(states, indices=_inds, dim=0) + + if (not clip_beam) and (length_penalty > 0.0): + scores = scores / lpv.view(bsize, beam_size) + scores, _inds = scores.topk(beam_size, dim=-1) + _inds = (_inds + torch.arange(0, real_bsize, beam_size, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize) + trans = trans.view(real_bsize, -1).index_select(0, _inds).view(bsize, beam_size, -1) + + if return_all: + + return trans, scores + else: + + return trans.view(bsize, beam_size, -1).select(1, 0) + + def fix_load(self): + + if self.fbl is not None: + with torch.no_grad(): + for ind, fblu in enumerate(self.fbl): + self.classifier.bias[ind].index_fill_(0, torch.tensor(fblu, dtype=torch.long, device=self.classifier.bias.device), -inf_default) diff --git a/transformer/MuLang/Eff/Base/Encoder.py b/transformer/MuLang/Eff/Base/Encoder.py new file mode 100644 index 0000000..3f1baa1 --- /dev/null +++ b/transformer/MuLang/Eff/Base/Encoder.py @@ -0,0 +1,67 @@ +#encoding: utf-8 + +from torch import nn +from modules.mulang.eff.base import LayerNorm, MWLinear, ResSelfAttn, PositionwiseFF +from math import sqrt + +from transformer.Encoder import EncoderLayer as EncoderLayerBase, Encoder as EncoderBase + +from cnfg.ihyp import * + +class EncoderLayer(EncoderLayerBase): + + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, ahsize=None, ntask=None, k_rel_pos=use_k_relative_position_encoder, **kwargs): + + _ahsize = isize if ahsize is None else ahsize + _fhsize = _ahsize * 4 if fhsize is None else fhsize + + super(EncoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, **kwargs) + + self.attn = ResSelfAttn(isize, _ahsize, num_head=num_head, dropout=attn_drop, norm_residual=self.attn.norm_residual, k_rel_pos=k_rel_pos, ntask=ntask) + self.ff = PositionwiseFF(isize, hsize=_fhsize, dropout=dropout, norm_residual=self.ff.norm_residual, ntask=ntask) + + def forward(self, inputs, taskid=None, mask=None): + + context = self.attn(inputs, taskid=taskid, mask=mask) + + context = self.ff(context, taskid=taskid) + + return context + +class Encoder(EncoderBase): + + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, share_layer=False, ntask=None, **kwargs): + + _ahsize = isize if ahsize is None else ahsize + _fhsize = _ahsize * 4 if fhsize is None else fhsize + + super(Encoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, share_layer=share_layer, **kwargs) + + self.task_emb = nn.Embedding(ntask, isize, padding_idx=None) + self.transo = MWLinear(isize, isize, ntask, bias=enable_proj_bias_default) + + if share_layer: + _shared_layer = EncoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize, ntask=ntask) + self.nets = nn.ModuleList([_shared_layer for i in range(num_layer)]) + else: + self.nets = nn.ModuleList([EncoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize, ntask=ntask) for i in range(num_layer)]) + + self.out_normer = LayerNorm(isize, ntask=ntask, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) if norm_output else None + + def forward(self, inputs, taskid=None, mask=None): + + out = self.wemb(inputs) + out = out * sqrt(out.size(-1)) + self.task_emb.weight[taskid] + if self.pemb is not None: + out = out + self.pemb(inputs, expand=False) + + if self.drop is not None: + out = self.drop(out) + + for net in self.nets: + out = net(out, taskid=taskid, mask=mask) + + if self.out_normer is not None: + out = self.out_normer(out, taskid=taskid) + + return self.transo(out, taskid) diff --git a/transformer/MuLang/Eff/Base/NMT.py b/transformer/MuLang/Eff/Base/NMT.py new file mode 100644 index 0000000..ff6c1e9 --- /dev/null +++ b/transformer/MuLang/Eff/Base/NMT.py @@ -0,0 +1,45 @@ +#encoding: utf-8 + +from utils.relpos import share_rel_pos_cache +from utils.fmt.base import parse_double_value_tuple + +from transformer.NMT import NMT as NMTBase + +from transformer.MuLang.Eff.Base.Encoder import Encoder +from transformer.MuLang.Eff.Base.Decoder import Decoder + +from cnfg.ihyp import * + +class NMT(NMTBase): + + def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=True, forbidden_index=None, ntask=None, **kwargs): + + enc_layer, dec_layer = parse_double_value_tuple(num_layer) + + super(NMT, self).__init__(isize, snwd, tnwd, (enc_layer, dec_layer,), fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, global_emb=global_emb, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, bindDecoderEmb=bindDecoderEmb, forbidden_index=None) + + self.enc = Encoder(isize, snwd, enc_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, ntask=ntask) + + if global_emb: + emb_w = self.enc.wemb.weight + task_emb_w = self.enc.task_emb.weight + else: + emb_w = task_emb_w = None + + self.dec = Decoder(isize, tnwd, dec_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, bindemb=bindDecoderEmb, forbidden_index=forbidden_index, ntask=ntask, task_emb_w=task_emb_w) + + if rel_pos_enabled: + share_rel_pos_cache(self) + + def forward(self, inpute, inputo, taskid=None, mask=None): + + _mask = inpute.eq(0).unsqueeze(1) if mask is None else mask + + return self.dec(self.enc(inpute, taskid=taskid, mask=_mask), inputo, taskid=taskid, src_pad_mask=_mask) + + def decode(self, inpute, taskid=None, beam_size=1, max_len=None, length_penalty=0.0): + + mask = inpute.eq(0).unsqueeze(1) + _max_len = inpute.size(1) + max(64, inpute.size(1) // 4) if max_len is None else max_len + + return self.dec.decode(self.enc(inpute, taskid=taskid, mask=mask), taskid=taskid, src_pad_mask=mask, beam_size=beam_size, max_len=_max_len, length_penalty=length_penalty) diff --git a/transformer/MuLang/Eff/Base/__init__.py b/transformer/MuLang/Eff/Base/__init__.py new file mode 100644 index 0000000..8fb0d7c --- /dev/null +++ b/transformer/MuLang/Eff/Base/__init__.py @@ -0,0 +1 @@ +#encoding: utf-8 diff --git a/transformer/MuLang/Eff/__init__.py b/transformer/MuLang/Eff/__init__.py new file mode 100644 index 0000000..8fb0d7c --- /dev/null +++ b/transformer/MuLang/Eff/__init__.py @@ -0,0 +1 @@ +#encoding: utf-8 diff --git a/transformer/MuLang/__init__.py b/transformer/MuLang/__init__.py new file mode 100644 index 0000000..8fb0d7c --- /dev/null +++ b/transformer/MuLang/__init__.py @@ -0,0 +1 @@ +#encoding: utf-8 diff --git a/transformer/NMT.py b/transformer/NMT.py index f8e3af0..46e6add 100644 --- a/transformer/NMT.py +++ b/transformer/NMT.py @@ -28,7 +28,7 @@ class NMT(nn.Module): # xseql: maxmimum length of sequence # ahsize: number of hidden units for MultiHeadAttention - def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=False, forbidden_index=None): + def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=True, forbidden_index=None): super(NMT, self).__init__() diff --git a/transformer/Probe/Decoder.py b/transformer/Probe/Decoder.py index 1a56995..bbc73ea 100644 --- a/transformer/Probe/Decoder.py +++ b/transformer/Probe/Decoder.py @@ -4,12 +4,11 @@ from torch import nn from modules.base import Linear, Dropout -from modules.attn.rap import CrossAttn +from modules.attn.rap import ResCrossAttn from math import sqrt -from transformer.Decoder import DecoderLayer as DecoderLayerBase -from transformer.Decoder import Decoder as DecoderBase +from transformer.Decoder import DecoderLayer as DecoderLayerBase, Decoder as DecoderBase from cnfg.ihyp import * @@ -21,42 +20,19 @@ def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, a super(DecoderLayer, self).__init__(isize, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, **kwargs) - self.cross_attn = CrossAttn(isize, _ahsize, isize, num_head, dropout=attn_drop) + self.cross_attn = ResCrossAttn(isize, _ahsize, num_head, dropout=attn_drop, norm_residual=self.cross_attn.norm_residual) def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None, query_unit=None, compute_ffn=True): if query_unit is None: - _inputo = self.layer_normer1(inputo) - - context = self.self_attn(_inputo, mask=tgt_pad_mask) - - if self.drop is not None: - context = self.drop(context) - - context = context + (_inputo if self.norm_residual else inputo) - + context = self.self_attn(inputo, mask=tgt_pad_mask) else: - _query_unit = self.layer_normer1(query_unit) - - context, states_return = self.self_attn(_query_unit, states=inputo) + context, states_return = self.self_attn(query_unit, states=inputo) - if self.drop is not None: - context = self.drop(context) - - context = context + (_query_unit if self.norm_residual else query_unit) - - _context = self.layer_normer2(context) - _context_new, _attn = self.cross_attn(_context, inpute, mask=src_pad_mask) + context, _attn = self.cross_attn(context, inpute, mask=src_pad_mask) if compute_ffn: - if self.drop is not None: - _context_new = self.drop(_context_new) - - context = _context_new + (_context if self.norm_residual else context) - context = self.ff(context) - else: - context = _context_new if query_unit is None: return context, _attn diff --git a/transformer/Probe/NMT.py b/transformer/Probe/NMT.py index d2720c6..26cd048 100644 --- a/transformer/Probe/NMT.py +++ b/transformer/Probe/NMT.py @@ -12,7 +12,7 @@ class NMT(NMTBase): - def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=False, forbidden_index=None, num_layer_ana=0): + def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=True, forbidden_index=None, num_layer_ana=0): super(NMT, self).__init__(isize, snwd, tnwd, num_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, global_emb=global_emb, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, bindDecoderEmb=bindDecoderEmb, forbidden_index=forbidden_index) diff --git a/transformer/Probe/ReDecoder.py b/transformer/Probe/ReDecoder.py index 2cdb4e3..b6da96d 100644 --- a/transformer/Probe/ReDecoder.py +++ b/transformer/Probe/ReDecoder.py @@ -7,8 +7,7 @@ from modules.base import Linear -from transformer.Decoder import DecoderLayer as DecoderLayerBase -from transformer.Decoder import Decoder as DecoderBase +from transformer.Decoder import DecoderLayer as DecoderLayerBase, Decoder as DecoderBase from cnfg.ihyp import * @@ -25,31 +24,18 @@ def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None, query_un if query_unit is None: if self.perform_self_attn: - _inputo = self.layer_normer1(inputo) - - context = self.self_attn(_inputo, mask=tgt_pad_mask) - if self.drop is not None: - context = self.drop(context) - context = context + (_inputo if self.norm_residual else inputo) + context = self.self_attn(inputo, mask=tgt_pad_mask) else: context, states_return = inputo, None else: if self.perform_self_attn: - _query_unit = self.layer_normer1(query_unit) - context, states_return = self.self_attn(_query_unit, states=inputo) - if self.drop is not None: - context = self.drop(context) - context = context + (_query_unit if self.norm_residual else query_unit) + context, states_return = self.self_attn(query_unit, states=inputo) else: context, states_return = query_unit, query_unit if inputo is None else torch.cat((inputo, query_unit,), 1) if self.perform_cross_attn: - _context = self.layer_normer2(context) - _context_new = self.cross_attn(_context, inpute, mask=src_pad_mask) - if self.drop is not None: - _context_new = self.drop(_context_new) - context = _context_new + (_context if self.norm_residual else context) + context = self.cross_attn(context, inpute, mask=src_pad_mask) context = self.ff(context) diff --git a/transformer/Probe/ReNMT.py b/transformer/Probe/ReNMT.py index 6aea4ef..f22dee5 100644 --- a/transformer/Probe/ReNMT.py +++ b/transformer/Probe/ReNMT.py @@ -11,7 +11,7 @@ class NMT(NMTBase): - def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=False, forbidden_index=None, num_layer_ana=None): + def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=True, forbidden_index=None, num_layer_ana=None): super(NMT, self).__init__(isize, snwd, tnwd, num_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, global_emb=global_emb, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, bindDecoderEmb=bindDecoderEmb, forbidden_index=forbidden_index) diff --git a/transformer/RealFormer/Decoder.py b/transformer/RealFormer/Decoder.py index 5338420..6047403 100644 --- a/transformer/RealFormer/Decoder.py +++ b/transformer/RealFormer/Decoder.py @@ -3,10 +3,9 @@ import torch from torch import nn -from modules.attn.res import SelfAttn, CrossAttn +from modules.attn.res import ResSelfAttn, ResCrossAttn -from transformer.Decoder import DecoderLayer as DecoderLayerBase -from transformer.Decoder import Decoder as DecoderBase +from transformer.Decoder import DecoderLayer as DecoderLayerBase, Decoder as DecoderBase from utils.sampler import SampleMax from utils.base import all_done, index_tensors, expand_bsize_for_beam @@ -25,8 +24,8 @@ def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, a super(DecoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, norm_residual=norm_residual, k_rel_pos=k_rel_pos, **kwargs) - self.self_attn = SelfAttn(isize, _ahsize, isize, num_head=num_head, dropout=attn_drop, k_rel_pos=k_rel_pos, uni_direction_reduction=True) - self.cross_attn = CrossAttn(isize, _ahsize, isize, num_head=num_head, dropout=attn_drop) + self.self_attn = ResSelfAttn(isize, _ahsize, num_head=num_head, dropout=attn_drop, norm_residual=norm_residual, k_rel_pos=k_rel_pos, uni_direction_reduction=True) + self.cross_attn = ResCrossAttn(isize, _ahsize, num_head=num_head, dropout=attn_drop, norm_residual=norm_residual) def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None, query_unit=None, resin=None): @@ -36,32 +35,11 @@ def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None, query_un sresin, cresin = resin if query_unit is None: - _inputo = self.layer_normer1(inputo) - - context, sresout = self.self_attn(_inputo, mask=tgt_pad_mask, resin=sresin) - - if self.drop is not None: - context = self.drop(context) - - context = context + (_inputo if self.norm_residual else inputo) - + context, sresout = self.self_attn(inputo, mask=tgt_pad_mask, resin=sresin) else: - _query_unit = self.layer_normer1(query_unit) - - context, states_return, sresout = self.self_attn(_query_unit, states=inputo, resin=sresin) - - if self.drop is not None: - context = self.drop(context) - - context = context + (_query_unit if self.norm_residual else query_unit) - - _context = self.layer_normer2(context) - _context_new, cresout = self.cross_attn(_context, inpute, mask=src_pad_mask, resin=cresin) - - if self.drop is not None: - _context_new = self.drop(_context_new) + context, states_return, sresout = self.self_attn(query_unit, states=inputo, resin=sresin) - context = _context_new + (_context if self.norm_residual else context) + context, cresout = self.cross_attn(context, inpute, mask=src_pad_mask, resin=cresin) context = self.ff(context) diff --git a/transformer/RealFormer/Encoder.py b/transformer/RealFormer/Encoder.py index 7fafff9..a134657 100644 --- a/transformer/RealFormer/Encoder.py +++ b/transformer/RealFormer/Encoder.py @@ -3,10 +3,9 @@ from torch import nn from math import sqrt -from modules.attn.res import SelfAttn +from modules.attn.res import ResSelfAttn -from transformer.Encoder import EncoderLayer as EncoderLayerBase -from transformer.Encoder import Encoder as EncoderBase +from transformer.Encoder import EncoderLayer as EncoderLayerBase, Encoder as EncoderBase from cnfg.ihyp import * @@ -19,17 +18,11 @@ def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, a super(EncoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, norm_residual=norm_residual, k_rel_pos=k_rel_pos, **kwargs) - self.attn = SelfAttn(isize, _ahsize, isize, num_head=num_head, dropout=attn_drop, k_rel_pos=k_rel_pos) + self.attn = ResSelfAttn(isize, _ahsize, num_head=num_head, dropout=attn_drop, k_rel_pos=k_rel_pos) def forward(self, inputs, mask=None, resin=None): - _inputs = self.layer_normer(inputs) - context, resout = self.attn(_inputs, mask=mask, resin=resin) - - if self.drop is not None: - context = self.drop(context) - - context = context + (_inputs if self.norm_residual else inputs) + context, resout = self.attn(inputs, mask=mask, resin=resin) context = self.ff(context) diff --git a/transformer/SC/Decoder.py b/transformer/SC/Decoder.py index 1e22b40..a50c322 100644 --- a/transformer/SC/Decoder.py +++ b/transformer/SC/Decoder.py @@ -5,15 +5,14 @@ from modules.base import ResidueCombiner from utils.sampler import SampleMax -from modules.TA import PositionwiseFF +from modules.TA import ResCrossAttn, PositionwiseFF from utils.base import all_done, index_tensors, expand_bsize_for_beam, repeat_bsize_for_beam_tensor from math import sqrt from utils.fmt.base import pad_id -from transformer.Decoder import DecoderLayer as DecoderLayerBase -from transformer.Decoder import Decoder as DecoderBase +from transformer.Decoder import DecoderLayer as DecoderLayerBase, Decoder as DecoderBase from cnfg.ihyp import * @@ -26,8 +25,11 @@ def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, a super(DecoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, **kwargs) + self.cross_attn = ResCrossAttn(isize, _ahsize, num_head=num_head, dropout=attn_drop, norm_residual=self.cross_attn.norm_residual) self.ff = PositionwiseFF(isize, _fhsize, dropout) self.scff = ResidueCombiner(isize, 2, _fhsize) + self.drop, self.layer_normer1 = self.self_attn.drop, self.self_attn.drop.normer + self.self_attn = self.self_attn.net def forward(self, inpute, inputh, inputo, src_pad_mask=None, tgt_pad_mask=None, query_unit=None): @@ -55,12 +57,7 @@ def forward(self, inpute, inputh, inputo, src_pad_mask=None, tgt_pad_mask=None, _context = self.layer_normer1(context) - _context_new = self.cross_attn(_context, inpute, mask=src_pad_mask) - - if self.drop is not None: - _context_new = self.drop(_context_new) - - context = self.layer_normer2(_context_new + _context) + context = self.cross_attn(_context, inpute, mask=src_pad_mask) context = self.ff(context) diff --git a/transformer/SC/NMT.py b/transformer/SC/NMT.py index af9cbf2..9d67134 100644 --- a/transformer/SC/NMT.py +++ b/transformer/SC/NMT.py @@ -13,7 +13,7 @@ class NMT(NMTBase): - def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=False, forbidden_index=None): + def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=True, forbidden_index=None): enc_layer, dec_layer = parse_double_value_tuple(num_layer) diff --git a/transformer/TA/Encoder.py b/transformer/TA/Encoder.py index 5cc90f7..fab4bb6 100644 --- a/transformer/TA/Encoder.py +++ b/transformer/TA/Encoder.py @@ -3,11 +3,10 @@ import torch from torch import nn from modules.base import Dropout -from modules.TA import PositionwiseFF +from modules.TA import ResSelfAttn, PositionwiseFF from math import sqrt -from transformer.Encoder import EncoderLayer as EncoderLayerBase -from transformer.Encoder import Encoder as EncoderBase +from transformer.Encoder import EncoderLayer as EncoderLayerBase, Encoder as EncoderBase from cnfg.ihyp import * @@ -26,6 +25,7 @@ def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, a super(EncoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, **kwargs) + self.attn = ResSelfAttn(isize, _ahsize, num_head=num_head, dropout=attn_drop) self.ff = PositionwiseFF(isize, _fhsize, dropout) # inputs: input of this layer (bsize, seql, isize) @@ -34,11 +34,6 @@ def forward(self, inputs, mask=None): context = self.attn(inputs, mask=mask) - if self.drop is not None: - context = self.drop(context) - - context = self.layer_normer(context + inputs) - context = self.ff(context) return context diff --git a/translator.py b/translator.py index 66aa651..ae54a8b 100644 --- a/translator.py +++ b/translator.py @@ -89,9 +89,7 @@ def __init__(self, modelfs, fvocab_i, fvocab_t, cnfg, minbsize=1, expand_for_mul if self.multi_gpu: model = DataParallelMT(model, device_ids=cuda_devices, output_device=self.cuda_device.index, host_replicate=True, gather_output=False) self.use_amp = cnfg.use_amp and self.use_cuda - self.beam_size = cnfg.beam_size - self.length_penalty = cnfg.length_penalty self.net = model diff --git a/utils/aan.py b/utils/aan.py index 1fdc968..391e69c 100644 --- a/utils/aan.py +++ b/utils/aan.py @@ -1,7 +1,7 @@ #encoding: utf-8 from torch.nn import ModuleList -from modules.base import AverageAttn +from modules.aan import AverageAttn def share_aan_cache(netin): diff --git a/utils/base.py b/utils/base.py index 43f1d38..c315580 100644 --- a/utils/base.py +++ b/utils/base.py @@ -3,21 +3,18 @@ import torch from torch import Tensor from torch.nn import ModuleDict - +from os import makedirs, remove +from os.path import exists as fs_check from threading import Thread - from functools import wraps - -from random import sample -from random import seed as rpyseed - +from random import sample, seed as rpyseed from math import ceil import logging from utils.h5serial import h5save, h5load -from cnfg.ihyp import h5modelwargs +from cnfg.ihyp import h5modelwargs, optm_step_zero_grad_set_none secure_type_map = {torch.float16: torch.float64, torch.float32: torch.float64, torch.uint8: torch.int64, torch.int8: torch.int64, torch.int16: torch.int64, torch.int32: torch.int64} @@ -193,39 +190,49 @@ def load_model_cpu_old(modf, base_model): return base_model -def save_model(model, fname, sub_module=False, logger=None, h5args=h5modelwargs): +_save_model_cleaner_holder = {} +def save_model_cleaner(fname, typename, holder=_save_model_cleaner_holder): + + if typename in holder: + holder[typename].update(fname) + else: + holder[typename] = bestfkeeper(fname) + +def save_model(model, fname, sub_module=False, print_func=print, mtyp=None, h5args=h5modelwargs): _msave = model.module if sub_module else model try: h5save([t.data for t in _msave.parameters()], fname, h5args=h5args) + if mtyp is not None: + save_model_cleaner(fname, mtyp) except Exception as e: - if logger is None: - print(e) - else: - logger.info(str(e)) + if print_func is not None: + print_func(str(e)) -def async_save_model(model, fname, sub_module=False, logger=None, h5args=h5modelwargs, para_lock=None, log_success=None): +def async_save_model(model, fname, sub_module=False, print_func=print, mtyp=None, h5args=h5modelwargs, para_lock=None, log_success=None): - def _worker(model, fname, sub_module=False, logger=None, para_lock=None, log_success=None): + def _worker(model, fname, sub_module=False, print_func=print, mtyp=None, para_lock=None, log_success=None): success = True _msave = model.module if sub_module else model try: if para_lock is None: h5save([t.data for t in _msave.parameters()], fname, h5args=h5args) + if mtyp is not None: + save_model_cleaner(fname, mtyp) else: with para_lock: h5save([t.data for t in _msave.parameters()], fname, h5args=h5args) + if mtyp is not None: + save_model_cleaner(fname, mtyp) except Exception as e: - if logger is None: - print(e) - else: - logger.info(str(e)) + if print_func is not None: + print_func(str(e)) success = False - if success and (logger is not None) and (log_success is not None): - logger.info(log_success) + if success and (print_func is not None) and (log_success is not None): + print_func(str(log_success)) - Thread(target=_worker, args=(model, fname, sub_module, logger, para_lock, log_success)).start() + Thread(target=_worker, args=(model, fname, sub_module, print_func, mtyp, para_lock, log_success)).start() def get_logger(fname): @@ -416,7 +423,7 @@ def iternext(iterin): return rs -def optm_step(optm, model=None, scaler=None, closure=None, multi_gpu=False, multi_gpu_optimizer=False): +def optm_step(optm, model=None, scaler=None, closure=None, multi_gpu=False, multi_gpu_optimizer=False, zero_grad_none=optm_step_zero_grad_set_none): if multi_gpu: model.collect_gradients() @@ -426,7 +433,7 @@ def optm_step(optm, model=None, scaler=None, closure=None, multi_gpu=False, mult scaler.step(optm, closure=closure) scaler.update() if not multi_gpu_optimizer: - optm.zero_grad(set_to_none=True) + optm.zero_grad(set_to_none=zero_grad_none) if multi_gpu: model.update_replicas() @@ -512,3 +519,38 @@ def iter_func(*args, **kwargs): break return iter_func + +def mkdir(pth): + + if not fs_check(pth): + makedirs(pth) + +class holder(dict): + + def __enter__(self): + + return self + + def get_hold(self, k, sv=None): + + if k in self: + return self[k] + else: + self[k] = sv + return sv + + def __exit__(self, *inputs, **kwargs): + + pass + +class bestfkeeper: + + def __init__(self, fname=None): + + self.prev_fname = fname + + def update(self, fname=None): + + if self.prev_fname is not None and fs_check(self.prev_fname): + remove(self.prev_fname) + self.prev_fname = fname diff --git a/utils/contpara.py b/utils/contpara.py new file mode 100644 index 0000000..6aa7880 --- /dev/null +++ b/utils/contpara.py @@ -0,0 +1,130 @@ +#encoding: utf-8 + +# WARNING: this file may create _contiguous_parameters to the model + +import torch +from torch import nn + +from utils.base import filter_para_grad + +class ContiguousParams(nn.Module): + + def __init__(self, parameters=None, init_tensors=None): + + super(ContiguousParams, self).__init__() + + self.weights = self.pll = None + + parameters = tuple(parameters) + if not isinstance(parameters[0], (tuple, list,)): + parameters = (parameters,) + if init_tensors is not None and (not isinstance(init_tensors, (tuple, list,))): + init_tensors = (init_tensors,) + + self.allocate(parameters=parameters, init_tensors=init_tensors) + self.bind(update=init_tensors is None) + + def allocate(self, parameters=None, init_tensors=None): + + self.pll = self.pll if parameters is None else [filter_para_grad(pl) for pl in parameters] + cpl = [] + if init_tensors is None: + for pl in self.pll: + if len(pl) > 1: + _numel = sum(para.numel() for para in pl) + _weight = nn.Parameter(pl[0].new_empty(_numel)) + _weight.grad = pl[0].new_zeros(_numel) + cpl.append(_weight) + else: + _weight = pl[0] + if _weight.grad is None: + _weight.grad = _weight.new_zeros(_weight.size()) + cpl.append(_weight) + else: + for pl, init_tensor in zip(self.pll, init_tensors): + if len(pl) > 1: + _numel = sum(para.numel() for para in pl) if init_tensor is None else init_tensor.numel() + _weight = nn.Parameter(pl[0].new_empty(_numel) if init_tensor is None else init_tensor) + _weight.grad = pl[0].new_zeros(init_tensor.numel()) if (init_tensor is None) or (init_tensor.grad is None) else init_tensor.grad + cpl.append(_weight) + else: + _weight = pl[0] + if _weight.grad is None: + _weight.grad = _weight.new_zeros(_weight.size()) if (init_tensor is None) or (init_tensor.grad is None) else init_tensor.grad.view(_weight.size()) + cpl.append(_weight) + self.weights = nn.ParameterList(cpl) + + def bind(self, update=True): + + with torch.no_grad(): + for pl, weight in zip(self.pll, self.weights): + if len(pl) > 1: + lind = 0 + for para in pl: + rind = lind + para.numel() + _sizes = para.size() + if update: + weight.data[lind:rind].copy_(para.data.view(-1)) + para.data = weight.data[lind:rind].view(_sizes) + if update and (para.grad is not None): + weight.grad[lind:rind].copy_(para.grad.view(-1)) + para.grad = weight.grad[lind:rind].view(_sizes) + lind = rind + + def bind_data(self, update=True): + + with torch.no_grad(): + for pl, weight in zip(self.pll, self.weights): + if len(pl) > 1: + lind = 0 + for para in pl: + rind = lind + para.numel() + if update: + weight.data[lind:rind].copy_(para.data.view(-1)) + para.data = weight.data[lind:rind].view(para.size()) + lind = rind + + def bind_grad(self, update=True): + + for pl, weight in zip(self.pll, self.weights): + if len(pl) > 1: + lind = 0 + for para in pl: + rind = lind + para.numel() + if update and (para.grad is not None): + weight.grad[lind:rind].copy_(para.grad.view(-1)) + para.grad = weight.grad[lind:rind].view(para.size()) + lind = rind + +def is_model_contiguous_parameters(model): + + return hasattr(model, "_contiguous_parameters") + +def get_contiguous_parameters_m(model, index=0): + + if is_model_contiguous_parameters(model): + return [model._contiguous_parameters[index]] + else: + _contiguous_parameters = ContiguousParams(parameters=model.parameters()).parameters() + model._contiguous_parameters = list(_contiguous_parameters) + return _contiguous_parameters + +def get_contiguous_parameters_p(parameters, model=None): + + _contiguous_parameters = ContiguousParams(parameters=parameters).parameters() + if model is not None: + if is_model_contiguous_parameters(model): + model._contiguous_parameters.extend(list(_contiguous_parameters)) + else: + model._contiguous_parameters = list(_contiguous_parameters) + + return _contiguous_parameters + +def get_all_contiguous_parameters_m(model): + + for para in model._contiguous_parameters: + yield para + +def get_model_parameters(model, contiguous_parameters=False): + + return get_contiguous_parameters_m(model) if contiguous_parameters else model.parameters() diff --git a/utils/cpp/base.h b/utils/cpp/base.h new file mode 100644 index 0000000..42e56e5 --- /dev/null +++ b/utils/cpp/base.h @@ -0,0 +1,93 @@ +#ifndef _NEUTRON_UTILS_CPP_BASE +#define _NEUTRON_UTILS_CPP_BASE + +#include +#include +#include +#include + +template inline T map_get(std::map mp, std::string key, T dv=NULL) { + auto iter = mp.find(key); + if (iter == mp.end()) { + return dv; + } + else { + return iter->second; + } +} + +inline torch::Tensor map_get(std::map mp, std::string key, torch::Tensor dv=torch::Tensor()) { + auto iter = mp.find(key); + if (iter == mp.end()) { + return dv; + } + else { + return iter->second; + } +} + +inline int64_t map_get(std::map mp, std::string key, int64_t dv=-1) { + auto iter = mp.find(key); + if (iter == mp.end()) { + return dv; + } + else { + return iter->second; + } +} + +inline double map_get(std::map mp, std::string key, double dv=0.0) { + auto iter = mp.find(key); + if (iter == mp.end()) { + return dv; + } + else { + return iter->second; + } +} + +inline bool map_get(std::map mp, std::string key, bool dv=false) { + auto iter = mp.find(key); + if (iter == mp.end()) { + return dv; + } + else { + return iter->second; + } +} + +inline bool is_not_none(torch::Tensor input=torch::Tensor()) { + if (input.defined() and input.size(-1) > 0) { + return true; + } + else { + return false; + } +} + +inline bool is_none(torch::Tensor input=torch::Tensor()) { + if (input.defined() and input.size(-1) > 0) { + return false; + } + else { + return true; + } +} + +inline bool pyt_is_not_none(torch::Tensor input=torch::Tensor()) { + return input.size(-1) > 0; +} + +inline bool pyt_is_none(torch::Tensor input=torch::Tensor()) { + return input.size(-1) == 0; +} + +inline bool ct_is_not_none(torch::Tensor input=torch::Tensor()) { + return input.defined(); +} + +inline bool ct_is_none(torch::Tensor input=torch::Tensor()) { + return not input.defined(); +} + +#endif diff --git a/utils/dynbatch.py b/utils/dynbatch.py index a71247b..a853704 100644 --- a/utils/dynbatch.py +++ b/utils/dynbatch.py @@ -5,6 +5,7 @@ from math import log2, exp, pi, acos from random import random from utils.angle import prep_cos, cos_acc_pg +from utils.random import multinomial # comment the following line and uncomment the 4 lines following it to load para_group_select_alpha from cnfg.dynb para_group_select_alpha = 3.0 @@ -162,27 +163,11 @@ def select_gumble_max(lin): def sample_norm_softmax(lin): - _p = random() - rs_ind = len(lin) - 1 - for i, lu in enumerate(softmax(lin)): - _p -= lu - if _p < 0.0: - rs_ind = i - break - - return rs_ind + return multinomial(softmax(lin), s=1.0) def sample_norm(lin, alpha=1.0): - _p = random() - rs_ind = len(lin) - 1 - for i, lu in enumerate(pos_norm(lin, alpha)): - _p -= lu - if _p < 0.0: - rs_ind = i - break - - return rs_ind + return multinomial(pos_norm(lin, alpha), s=1.0) def sample_gumble_norm(lin, alpha=para_group_select_alpha): diff --git a/utils/fmt/base.py b/utils/fmt/base.py index b98b935..d73e297 100644 --- a/utils/fmt/base.py +++ b/utils/fmt/base.py @@ -346,46 +346,17 @@ def legal_vocab(sent, ilgset, ratio): return False if rt > ratio else True -def iter_check_func(lin, func=None, ri=False): +def all_in(lin, setin): - if func is None: - for lu in lin: - if lu: - return ri - else: - for lu in lin: - if func(lu): - return ri - - return not ri - -def all_true(lin): - - return iter_check_func(lin, func=lambda x: not x, ri=False) - -def any_true(lin): - - return iter_check_func(lin, func=None, ri=True) - -def iter_cmp_func(lin, v, func, ri=False): - - for lu in lin: - if func(lu, v): - return ri - - return not ri - -def all_in(lin, sin): - - return iter_cmp_func(lin, sin, lambda x, y: not x in y, ri=False) + return all(lu in setin for lu in lin) def all_le(lin, value): - return iter_cmp_func(lin, value, lambda x, y: x > y, ri=False) + return all(lu <= value for lu in lin) def all_gt(lin, value): - return iter_cmp_func(lin, value, lambda x, y: x <= y, ri=False) + return all(lu > value for lu in lin) def get_char_ratio(strin): diff --git a/utils/fmt/mulang/__init__.py b/utils/fmt/mulang/__init__.py new file mode 100644 index 0000000..8fb0d7c --- /dev/null +++ b/utils/fmt/mulang/__init__.py @@ -0,0 +1 @@ +#encoding: utf-8 diff --git a/utils/fmt/mulang/eff/__init__.py b/utils/fmt/mulang/eff/__init__.py new file mode 100644 index 0000000..8fb0d7c --- /dev/null +++ b/utils/fmt/mulang/eff/__init__.py @@ -0,0 +1 @@ +#encoding: utf-8 diff --git a/utils/fmt/mulang/eff/dual.py b/utils/fmt/mulang/eff/dual.py new file mode 100644 index 0000000..b6f7e0a --- /dev/null +++ b/utils/fmt/mulang/eff/dual.py @@ -0,0 +1,54 @@ +#encoding: utf-8 + +from utils.fmt.base import list_reader, get_bsize, map_batch, pad_batch + +def batch_loader(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize): + + rsi = [] + rst = [] + rstask = None + nd = maxlen = mlen_i = mlen_t = 0 + for i_d, td in zip(list_reader(finput), list_reader(ftarget)): + lid = len(i_d) - 1 + ltd = len(td) + lgth = lid + ltd + _task = i_d[0] + # uncomment the following 2 lines to filter out empty data (e.g. in OPUS-100). + #if (lid <= 0) or (ltd <= 0): + #continue + if maxlen == 0: + maxlen = lgth + min(maxpad, lgth // maxpart + 1) + _bsize = get_bsize(maxlen, maxtoken, bsize) + rstask = _task + if (rstask == _task) and ((nd < minbsize) or (lgth <= maxlen and nd < _bsize)): + rsi.append(i_d[1:]) + rst.append(td) + if lid > mlen_i: + mlen_i = lid + if ltd > mlen_t: + mlen_t = ltd + nd += 1 + else: + yield rsi, rst, rstask, mlen_i, mlen_t + rsi = [i_d[1:]] + rstask = _task + rst = [td] + mlen_i = lid + mlen_t = ltd + maxlen = lgth + min(maxpad, lgth // maxpart + 1) + _bsize = get_bsize(maxlen, maxtoken, bsize) + nd = 1 + if rsi: + yield rsi, rst, rstask, mlen_i, mlen_t + +def batch_mapper(finput, ftarget, vocabi, vocabt, vocabtask, bsize, maxpad, maxpart, maxtoken, minbsize): + + for i_d, td, taskd, mlen_i, mlen_t in batch_loader(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize): + rsi, extok_i = map_batch(i_d, vocabi) + rst, extok_t = map_batch(td, vocabt) + yield rsi, rst, vocabtask[taskd], mlen_i + extok_i, mlen_t + extok_t + +def batch_padder(finput, ftarget, vocabi, vocabt, vocabtask, bsize, maxpad, maxpart, maxtoken, minbsize): + + for i_d, td, taskd, mlen_i, mlen_t in batch_mapper(finput, ftarget, vocabi, vocabt, vocabtask, bsize, maxpad, maxpart, maxtoken, minbsize): + yield pad_batch(i_d, mlen_i), pad_batch(td, mlen_t), taskd diff --git a/utils/fmt/mulang/eff/single.py b/utils/fmt/mulang/eff/single.py new file mode 100644 index 0000000..04bd18d --- /dev/null +++ b/utils/fmt/mulang/eff/single.py @@ -0,0 +1,49 @@ +#encoding: utf-8 + +from utils.fmt.base import list_reader, get_bsize, map_batch, pad_batch + +def batch_loader(finput, bsize, maxpad, maxpart, maxtoken, minbsize): + + rsi = [] + rstask = None + nd = maxlen = minlen = mlen_i = 0 + _bsize = bsize + for i_d in list_reader(finput): + lgth = len(i_d) - 1 + _task = i_d[0] + #if lgth <= 0: + #continue + if maxlen == 0: + _maxpad = max(1, min(maxpad, lgth // maxpart + 1) // 2) + maxlen = lgth + _maxpad + minlen = lgth - _maxpad + _bsize = get_bsize(maxlen, maxtoken, bsize) + rstask = _task + if (rstask == _task) and ((nd < minbsize) or (lgth <= maxlen and lgth >= minlen and nd < _bsize)): + rsi.append(i_d[1:]) + if lgth > mlen_i: + mlen_i = lgth + nd += 1 + else: + yield rsi, rstask, mlen_i + rsi = [i_d[1:]] + rstask = _task + mlen_i = lgth + _maxpad = max(1, min(maxpad, lgth // maxpart + 1) // 2) + maxlen = lgth + _maxpad + minlen = lgth - _maxpad + _bsize = get_bsize(maxlen, maxtoken, bsize) + nd = 1 + if rsi: + yield rsi, rstask, mlen_i + +def batch_mapper(finput, vocabi, vocabtask, bsize, maxpad, maxpart, maxtoken, minbsize): + + for i_d, taskd, mlen_i in batch_loader(finput, bsize, maxpad, maxpart, maxtoken, minbsize): + rsi, extok_i = map_batch(i_d, vocabi) + yield rsi, vocabtask[taskd], mlen_i + extok_i + +def batch_padder(finput, vocabi, vocabtask, bsize, maxpad, maxpart, maxtoken, minbsize): + + for i_d, taskd, mlen_i in batch_mapper(finput, vocabi, vocabtask, bsize, maxpad, maxpart, maxtoken, minbsize): + yield pad_batch(i_d, mlen_i), taskd diff --git a/utils/h5serial.py b/utils/h5serial.py index 9630e94..0906786 100644 --- a/utils/h5serial.py +++ b/utils/h5serial.py @@ -1,13 +1,23 @@ #encoding: utf-8 -import torch, h5py - +import torch +from h5py import File as h5FileBase, Dataset from collections.abc import Iterator from utils.fmt.base import list2dict, dict_is_list from cnfg.ihyp import * +class h5File(h5FileBase): + + def __enter__(self): + + return self + + def __exit__(self, *inputs, **kwargs): + + self.close() + def h5write_dict(gwrt, dtw, h5args=h5modelwargs): for k, v in dtw.items(): @@ -30,15 +40,14 @@ def h5write_list(gwrt, ltw, h5args=h5modelwargs): def h5save(obj_save, fname, h5args=h5modelwargs): - h5f = h5py.File(fname, 'w') - _obj_save = tuple(obj_save) if isinstance(obj_save, Iterator) else obj_save - if isinstance(_obj_save, dict): - h5write_dict(h5f, _obj_save, h5args=h5args) - elif isinstance(_obj_save, (list, tuple,)): - h5write_list(h5f, _obj_save, h5args=h5args) - else: - h5write_list(h5f, [_obj_save], h5args=h5args) - h5f.close() + with h5File(fname, 'w') as h5f: + _obj_save = tuple(obj_save) if isinstance(obj_save, Iterator) else obj_save + if isinstance(_obj_save, dict): + h5write_dict(h5f, _obj_save, h5args=h5args) + elif isinstance(_obj_save, (list, tuple,)): + h5write_list(h5f, _obj_save, h5args=h5args) + else: + h5write_list(h5f, [_obj_save], h5args=h5args) def restore_list_in_dict(din): @@ -55,17 +64,18 @@ def h5load_group(grd): rsd = {} for k, v in grd.items(): - if isinstance(v, h5py.Dataset): + if isinstance(v, Dataset): rsd[k] = torch.from_numpy(v[:]) else: rsd[k] = h5load_group(v) + return rsd def h5load(fname, restore_list=True): - f = h5py.File(fname, "r") - rsd = h5load_group(f) - f.close() + with h5File(fname, "r") as f: + rsd = h5load_group(f) if restore_list: rsd = restore_list_in_dict(rsd) + return rsd diff --git a/utils/mulang.py b/utils/mulang.py new file mode 100644 index 0000000..f98219d --- /dev/null +++ b/utils/mulang.py @@ -0,0 +1,42 @@ +#encoding: utf-8 + +from utils.random import multinomial +from random import shuffle + +def T_normalize(wl, T): + + _t = 1.0 / T + _tmp = [_wu ** _t for _wu in wl] + _s = sum(_tmp) + + return [_tu / _s for _tu in _tmp] + +def data_generator(dlin, shuf=True): + + tmp = list(dlin) + while True: + if shuf: + shuffle(tmp) + for tmpu in tmp: + yield tmpu + +def sample_iter(wl, T, ntrain, taskl): + + samples = {} + for i, (nd, task,) in enumerate(zip(ntrain, taskl)): + samples[i] = (task, data_generator(str(i) for i in range(nd)),) + pl = T_normalize(wl, T) + while True: + task, dg = samples[multinomial(pl, s=1.0)] + yield next(dg), task + +class data_sampler: + + def __init__(self, task_weight, task_weight_T, ntrain, train_taskl, nsample=None): + + self.generator = sample_iter(task_weight, task_weight_T, ntrain, train_taskl) + self.nsample = nsample + + def generate(self, nsample=None): + + return [next(self.generator) for i in range(self.nsample if nsample is None else nsample)] diff --git a/utils/pyctorch.py b/utils/pyctorch.py new file mode 100644 index 0000000..e114f49 --- /dev/null +++ b/utils/pyctorch.py @@ -0,0 +1,24 @@ +#encoding: utf-8 + +import torch + +non_tensor = torch.Tensor() + +def transfer_CNone_tuple(lin): + + return tuple(non_tensor if lu is None else lu for lu in lin) + +def transfer_CNone_list(lin): + + return [non_tensor if lu is None else lu for lu in lin] + +def transfer_CNone(din): + + if isinstance(din, list): + return [transfer_CNone(du) for du in din] + elif isinstance(din, tuple): + return tuple(transfer_CNone(du) for du in din) + elif isinstance(din, dict): + return {k: transfer_CNone(du) for k, du in din.items()} + else: + return non_tensor if din is None else din diff --git a/utils/random.py b/utils/random.py new file mode 100644 index 0000000..90edd1d --- /dev/null +++ b/utils/random.py @@ -0,0 +1,37 @@ +#encoding: utf-8 + +from random import random + +def multinomial(lin, s=None): + + _s = sum(lin) if s is None else s + _p = random() + if _s != 1.0: + _p *= _s + rs_ind = len(lin) - 1 + for i, lu in enumerate(lin): + _p -= lu + if _p <= 0.0: + rs_ind = i + break + + return rs_ind + +def multinomial_k(lin, k, s=None): + + _s = sum(lin) if s is None else s + rs = [] + init_rs_ind = len(lin) - 1 + for i in range(k): + _p = random() + if _s != 1.0: + _p *= _s + rs_ind = init_rs_ind + for i, lu in enumerate(lin): + _p -= lu + if _p <= 0.0: + rs_ind = i + break + rs.append(rs_ind) + + return rs diff --git a/utils/torch.py b/utils/torch.py index acd4892..e5737aa 100644 --- a/utils/torch.py +++ b/utils/torch.py @@ -1,5 +1,92 @@ #encoding: utf-8 +import torch +from numbers import Number + +from cnfg.ihyp import ieps_upper_bound_default + +upper_one = 1.0 - ieps_upper_bound_default + def bmv(inputm, inputv): return inputm.bmm(inputv.unsqueeze(-1)).squeeze(-1) + +def randint_t_core(high): + + return high.new_empty(high.size()).uniform_(0.0, upper_one).mul_(high).floor_().to(torch.long) + +def randint_t(low, high): + + if isinstance(low, Number) and (low == 0.0): + return randint_t_core(high) + else: + rs = randint_t_core(high - low) + return rs.add_(low.to(rs.dtype)) + +def multinomial(x, num_samples, replacement=False, generator=None, dim=-1, **kwargs): + + _ndim = x.dim() + + if (dim == -1) or (dim == (_ndim - 1)): + _t_output, out = False, x + else: + _t_output, out = True, x.transpose(dim, -1) + + if _ndim > 2: + _osize = list(out.size()) + out = out.view(-1, _osize[-1]) + _osize[-1] = num_samples + + out = out.multinomial(num_samples, replacement=replacement, generator=generator, **kwargs) + + if _ndim > 2: + out = out.view(_osize) + + if _t_output: + out = out.transpose(dim, -1) + + return out + +def ensure_num_threads(n): + + if torch.get_num_threads() < n: + torch.set_num_threads(n) + + return torch.get_num_threads() + +def ensure_num_interop_threads(n): + + if torch.get_num_interop_threads() < n: + torch.set_num_interop_threads(n) + + return torch.get_num_interop_threads() + +class num_threads: + + def __init__(self, n): + + self.num_threads_exe = n + + def __enter__(self): + + self.num_threads_env = torch.get_num_threads() + torch.set_num_threads(self.num_threads_exe) + + def __exit__(self, *inputs, **kwargs): + + torch.set_num_threads(self.num_threads_env) + +class num_interop_threads: + + def __init__(self, n): + + self.num_threads_exe = n + + def __enter__(self): + + self.num_threads_env = torch.get_num_interop_threads() + torch.set_num_interop_threads(self.num_threads_exe) + + def __exit__(self, *inputs, **kwargs): + + torch.set_num_interop_threads(self.num_threads_env)