From 56ca23b735c1db33b0c4e9c9fa35ffda6ead5ffb Mon Sep 17 00:00:00 2001 From: Qiuhui Liu Date: Sun, 26 Apr 2020 11:05:10 +0800 Subject: [PATCH] April updates --- adv/train/train_dynb.py | 4 +- cnfg/ihyp.py | 11 +- modules/LD.py | 66 ++++++ modules/act.py | 5 +- modules/base.py | 30 +-- modules/rnncells.py | 4 +- tools/check/dynb/report_dynb.py | 4 +- transformer/Decoder.py | 6 +- transformer/Doc/Para/Base/Decoder.py | 2 +- transformer/Doc/Para/Base/Encoder.py | 2 +- transformer/Doc/Para/Base/NMT.py | 2 +- transformer/Encoder.py | 4 +- transformer/LD/AttnEncoder.py | 103 ++++++++++ transformer/LD/Decoder.py | 288 +++++++++++++++++++++++++++ transformer/LD/Encoder.py | 129 ++++++++++++ transformer/LD/NMT.py | 49 +++++ transformer/LD/__init__.py | 1 + transformer/NMT.py | 2 +- transformer/RNMTDecoder.py | 19 +- transformer/SC/Encoder.py | 2 +- transformer/SC/NMT.py | 2 +- transformer/TA/Encoder.py | 2 +- transformer/UniEncoder.py | 2 +- utils/dynbatch.py | 7 +- utils/init.py | 6 +- 25 files changed, 700 insertions(+), 52 deletions(-) create mode 100644 modules/LD.py create mode 100644 transformer/LD/AttnEncoder.py create mode 100644 transformer/LD/Decoder.py create mode 100644 transformer/LD/Encoder.py create mode 100644 transformer/LD/NMT.py create mode 100644 transformer/LD/__init__.py diff --git a/adv/train/train_dynb.py b/adv/train/train_dynb.py index 8870c42..d5a7240 100644 --- a/adv/train/train_dynb.py +++ b/adv/train/train_dynb.py @@ -43,7 +43,7 @@ def select_function(modin, select_index): return _sel_m.parameters() -grad_mon = GradientMonitor(num_layer * 2, select_function, angle_alpha=cnfg.dyn_tol_alpha, num_tol_amin=cnfg.dyn_tol_amin, num_his_recoder=cnfg.num_dynb_his, num_his_gm=1) +grad_mon = GradientMonitor(num_layer * 2, select_function, module=None, angle_alpha=cnfg.dyn_tol_alpha, num_tol_amin=cnfg.dyn_tol_amin, num_his_recoder=cnfg.num_dynb_his, num_his_gm=1) def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tokens, multi_gpu, tokens_optm=32768, nreport=None, save_every=None, chkpf=None, chkpof=None, statesf=None, num_checkpoint=1, cur_checkid=0, report_eva=True, remain_steps=None, save_loss=False, save_checkp_epoch=False, use_amp=False): @@ -95,6 +95,8 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok _perform_dyn_optm_step, _cos_sim = grad_mon.update(model.module if multi_gpu else model) if _perform_dyn_optm_step or (_done_tokens >= tokens_optm): + if not _perform_dyn_optm_step: + grad_mon.reset() _do_optm_step = True if _cos_sim is None else (_cos_sim <= update_angle) if _do_optm_step: if multi_gpu: diff --git a/cnfg/ihyp.py b/cnfg/ihyp.py index c43844f..a957b03 100644 --- a/cnfg/ihyp.py +++ b/cnfg/ihyp.py @@ -8,10 +8,9 @@ from utils.fmt.base import parse_none, parse_double_value_tuple -if ease_optimization: - enable_residual_bias_default = False -else: - enable_residual_bias_default = True +enable_residual_bias_default = not ease_optimization + +enable_ln_parameters = True use_adv_act_default = False override_GeLU_Swish = False @@ -35,14 +34,14 @@ ieps_default = 1e-9 ieps_ln_default = 1e-6 ieps_adam_default = 1e-9 -ieps_noise_default = ieps_ln_default - ieps_ln_default = parse_none(ieps_ln_default, ieps_default) ieps_adam_default = parse_none(ieps_adam_default, ieps_default) +ieps_noise_default = ieps_ln_default adam_betas_default = (0.9, 0.98,) use_k_relative_position_encoder, use_k_relative_position_decoder = parse_double_value_tuple(use_k_relative_position) +rel_pos_enabled = (max(use_k_relative_position_encoder, use_k_relative_position_decoder) > 0) disable_std_pemb_encoder, disable_std_pemb_decoder = parse_double_value_tuple(disable_std_pemb) h5datawargs = {} if hdf5_data_compression is None else {"compression": hdf5_data_compression, "compression_opts": hdf5_data_compression_level, "shuffle":True} diff --git a/modules/LD.py b/modules/LD.py new file mode 100644 index 0000000..b218e07 --- /dev/null +++ b/modules/LD.py @@ -0,0 +1,66 @@ +#encoding: utf-8 + +import torch +from torch import nn + +from modules.base import Scorer, Linear, Dropout +from modules.act import GeLU + +from cnfg.ihyp import * + +class ATTNCombiner(nn.Module): + + def __init__(self, isize, hsize=None, dropout=0.0, use_GeLU=use_adv_act_default): + + super(ATTNCombiner, self).__init__() + + _hsize = isize * 4 if hsize is None else hsize + + self.net = nn.Sequential(Linear(isize * 2, _hsize), Dropout(dropout, inplace=True), GeLU() if use_GeLU else nn.Sigmoid(), Scorer(_hsize), nn.Sigmoid()) if dropout > 0.0 else nn.Sequential(Linear(isize * 2, _hsize), GeLU() if use_GeLU else nn.Sigmoid(), Scorer(_hsize), nn.Sigmoid()) + + def forward(self, input1, input2, mask=None): + + scores = self.net(torch.cat((input1.expand_as(input2), input2,), dim=-1)) + + _seql = input2.size(-2) + if mask is not None: + _tm = mask.sum(-2, keepdim=True) + _nele = (_seql - _tm).masked_fill(_tm.eq(_seql), 1).to(scores) + scores = scores / _nele + else: + scores = scores / _seql + scores = scores.masked_fill(mask, 0.0) + + out = scores.transpose(1, 2).bmm(input2) + (1.0 - scores.sum(dim=-2, keepdim=True)) * input1 + + return out + +class DATTNCombiner(nn.Module): + + def __init__(self, isize, hsize=None, dropout=0.0, use_GeLU=use_adv_act_default): + + super(DATTNCombiner, self).__init__() + + _hsize = isize * 4 if hsize is None else hsize + + self.net = nn.Sequential(Linear(isize * 2, _hsize), Dropout(dropout, inplace=True), GeLU() if use_GeLU else nn.Sigmoid(), Scorer(_hsize, bias=False)) if dropout > 0.0 else nn.Sequential(Linear(isize * 2, _hsize), GeLU() if use_GeLU else nn.Sigmoid(), Scorer(_hsize, bias=False)) + + # input1: (bsize, 1, isize) + # input2: (bsize, seql, isize) + # mask: (bsize, seql, 1) + def forward(self, input1, input2, mask=None): + + # scores: (bsize, seql, 1) + scores = self.net(torch.cat((input1.expand_as(input2), input2,), dim=-1)) + + _seql = input2.size(-2) + if mask is not None: + # using math.inf as inf_default will lead to nan after softmax in case a sequence is full of tokens, advice: using the other values as inf_default, like 1e9. + scores = scores.masked_fill(mask, -inf_default) + + scores = scores.softmax(dim=-2) + + # out: (bsize, 1, isize) + out = scores.transpose(1, 2).bmm(input2) + + return out diff --git a/modules/act.py b/modules/act.py index c025a42..9a718a6 100644 --- a/modules/act.py +++ b/modules/act.py @@ -54,8 +54,9 @@ def forward(self, x): def fix_init(self): - if self.reset_beta is not None: - self.beta.fill_(self.reset_beta) + with torch.no_grad(): + if self.reset_beta is not None: + self.beta.fill_(self.reset_beta) if override_GeLU_Swish: GeLU = Swish diff --git a/modules/base.py b/modules/base.py index 2d8f37a..b36689a 100644 --- a/modules/base.py +++ b/modules/base.py @@ -27,7 +27,7 @@ def __init__(self, isize, hsize=None, dropout=0.0, norm_residual=norm_residual_d self.net = nn.Sequential(Linear(isize, _hsize), GeLU() if use_GeLU else nn.ReLU(inplace=True), Dropout(dropout, inplace=inplace_after_GeLU), Linear(_hsize, isize, bias=enable_bias), Dropout(dropout, inplace=True)) if dropout > 0.0 else nn.Sequential(Linear(isize, _hsize), GeLU() if use_GeLU else nn.ReLU(inplace=True), Linear(_hsize, isize, bias=enable_bias)) - self.normer = nn.LayerNorm(isize, eps=ieps_ln_default) + self.normer = nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) self.norm_residual = norm_residual @@ -138,7 +138,7 @@ def __init__(self, isize, hsize, osize, num_head=8, dropout=0.0, k_isize=None, v if k_rel_pos > 0: self.k_rel_pos = k_rel_pos self.rel_pemb = nn.Embedding(k_rel_pos * 2 + 1, self.attn_dim) - _rpm = torch.arange(-xseql + 1, 1).unsqueeze(0) + _rpm = torch.arange(-xseql + 1, 1, dtype=torch.long).unsqueeze(0) self.register_buffer("rel_pos", (_rpm - _rpm.t()).clamp(min=-k_rel_pos, max=k_rel_pos) + k_rel_pos) self.xseql = xseql # the buffer can be shared inside the encoder or the decoder across layers for saving memory, by setting self.ref_rel_posm of self attns in deep layers to SelfAttn in layer 0, and sharing corresponding self.rel_pos @@ -275,7 +275,7 @@ def __init__(self, isize, hsize, osize, num_head=8, dropout=0.0, enable_bias=ena if k_rel_pos > 0: self.k_rel_pos = k_rel_pos self.rel_pemb = nn.Embedding(k_rel_pos * 2 + 1, self.attn_dim) - _rpm = torch.arange(-xseql + 1, 1).unsqueeze(0) + _rpm = torch.arange(-xseql + 1, 1, dtype=torch.long).unsqueeze(0) self.register_buffer("rel_pos", (_rpm - _rpm.t()).clamp(min=-k_rel_pos, max=k_rel_pos) + k_rel_pos) self.xseql = xseql # the buffer can be shared inside the encoder or the decoder across layers for saving memory, by setting self.ref_rel_posm of self attns in deep layers to SelfAttn in layer 0, and sharing corresponding self.rel_pos @@ -399,7 +399,7 @@ def __init__(self, isize, ncomb=2, hsize=None, dropout=0.0, use_GeLU=use_adv_act # should dropout be in front of sigmoid or not? self.net = nn.Sequential(Linear(isize * ncomb, _hsize), GeLU() if use_GeLU else nn.Sigmoid(), Dropout(dropout, inplace=inplace_after_GeLU), Linear(_hsize, isize, bias=enable_bias), Dropout(dropout, inplace=True)) if dropout > 0.0 else nn.Sequential(Linear(isize * ncomb, _hsize), GeLU() if use_GeLU else nn.Sigmoid(), Linear(_hsize, isize, bias=enable_bias)) - self.out_normer = nn.LayerNorm(isize, eps=ieps_ln_default) + self.out_normer = nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) def forward(self, *xl): @@ -503,7 +503,7 @@ def _threshold_and_support(input, dim=0): def _make_ix_like(input, dim=0): d = input.size(dim) - rho = torch.arange(1, d + 1, device=input.device, dtype=input.dtype) + rho = torch.arange(1, d + 1, dtype=input.dtype, device=input.device) view = [1] * input.dim() view[0] = -1 @@ -574,14 +574,14 @@ class SparseNormer(nn.Module): # dim: dimension to normalize - def __init__(self, dim=-1, ieps=1e-32): + def __init__(self, dim=-1, eps=ieps_default): super(SparseNormer, self).__init__() self.dim = dim self.bias = nn.Parameter(torch.zeros(1)) self.act = nn.ReLU(inplace=True) - self.ieps = ieps + self.eps = eps def forward(self, x): @@ -589,21 +589,21 @@ def forward(self, x): _tmp = _tmp * _tmp # fix zero-devision in case all elements in _tmp are 0. - return _tmp / (_tmp.sum(self.dim, keepdim=True) + self.ieps) + return _tmp / (_tmp.sum(self.dim, keepdim=True) + self.eps) class MHSparseNormer(nn.Module): # nheads: number of heads # dim: dimension to normalize - def __init__(self, nheads, dim=-1, ieps=1e-32): + def __init__(self, nheads, dim=-1, eps=ieps_default): super(MHSparseNormer, self).__init__() self.dim = dim self.bias = nn.Parameter(torch.zeros(1, nheads, 1, 1)) self.act = nn.ReLU(inplace=True) - self.ieps = ieps + self.eps = eps # input should be: (bsize, nheads, nquery, seql) def forward(self, x): @@ -612,11 +612,12 @@ def forward(self, x): _tmp = _tmp * _tmp # fix zero-devision in case all elements in _tmp are 0. - return _tmp / (_tmp.sum(self.dim, keepdim=True) + self.ieps) + return _tmp / (_tmp.sum(self.dim, keepdim=True) + self.eps) def fix_init(self): - self.bias.data.zero_() + with torch.no_grad(): + self.bias.data.zero_() class MHAttnSummer(nn.Module): @@ -753,8 +754,9 @@ def forward(self, x): def fix_init(self): - self.k.data.fill_(1.0) - self.bias.data.zero_() + with torch.no_grad(): + self.k.data.fill_(1.0) + self.bias.data.zero_() def reduce_model(modin): diff --git a/modules/rnncells.py b/modules/rnncells.py index 4b76e32..36bc6d4 100644 --- a/modules/rnncells.py +++ b/modules/rnncells.py @@ -24,7 +24,7 @@ def __init__(self, isize, osize, use_GeLU=use_adv_act_default): # layer normalization is also applied for the computation of hidden for efficiency self.trans = Linear(isize + osize, osize * 4) - self.normer = nn.LayerNorm((4, osize), eps=1e-06) + self.normer = nn.LayerNorm((4, osize), eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) self.act = GeLU() if use_GeLU else nn.Tanh() @@ -57,7 +57,7 @@ def __init__(self, isize, osize, use_GeLU=use_adv_act_default): self.transi = Linear(isize, osize) self.transh = Linear(osize, osize) - self.normer = nn.LayerNorm((2, osize), eps=1e-06) + self.normer = nn.LayerNorm((2, osize), eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) self.act = GeLU() if use_GeLU else nn.Tanh() diff --git a/tools/check/dynb/report_dynb.py b/tools/check/dynb/report_dynb.py index d3cfe47..55c0fb3 100644 --- a/tools/check/dynb/report_dynb.py +++ b/tools/check/dynb/report_dynb.py @@ -45,7 +45,7 @@ def select_function(modin, select_index): return _sel_m.parameters() -grad_mon = GradientMonitor(num_layer * 2, select_function, angle_alpha=cnfg.dyn_tol_alpha, num_tol_amin=cnfg.dyn_tol_amin, num_his_recoder=cnfg.num_dynb_his, num_his_gm=max_his) +grad_mon = GradientMonitor(num_layer * 2, select_function, module=None, angle_alpha=cnfg.dyn_tol_alpha, num_tol_amin=cnfg.dyn_tol_amin, num_his_recoder=cnfg.num_dynb_his, num_his_gm=max_his) def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tokens, multi_gpu, tokens_optm=32768, nreport=None, save_every=None, chkpf=None, chkpof=None, statesf=None, num_checkpoint=1, cur_checkid=0, report_eva=True, remain_steps=None, save_loss=False, save_checkp_epoch=False, use_amp=False): @@ -107,6 +107,8 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok _log_f_dynbatch.write(("%d %s\n" % (wd_add, " ".join(["%.2f" % (_cu,) for _cu in _cos_sim_l]))).encode("utf-8")) if _perform_dyn_optm_step or (_done_tokens >= tokens_optm): + if not _perform_dyn_optm_step: + grad_mon.reset() _do_optm_step = True if _cos_sim is None else (_cos_sim <= update_angle) if _do_optm_step: if log_dynb: diff --git a/transformer/Decoder.py b/transformer/Decoder.py index d380c76..68f4cac 100644 --- a/transformer/Decoder.py +++ b/transformer/Decoder.py @@ -32,8 +32,8 @@ def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, a self.ff = PositionwiseFF(isize, _fhsize, dropout, norm_residual) - self.layer_normer1 = nn.LayerNorm(isize, eps=ieps_ln_default) - self.layer_normer2 = nn.LayerNorm(isize, eps=ieps_ln_default) + self.layer_normer1 = nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) + self.layer_normer2 = nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) self.drop = Dropout(dropout, inplace=True) if dropout > 0.0 else None @@ -134,7 +134,7 @@ def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0. self.lsm = nn.LogSoftmax(-1) - self.out_normer = nn.LayerNorm(isize, eps=ieps_ln_default) if norm_output else None + self.out_normer = nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) if norm_output else None self.fbl = None if forbidden_index is None else tuple(set(forbidden_index)) diff --git a/transformer/Doc/Para/Base/Decoder.py b/transformer/Doc/Para/Base/Decoder.py index a5cd865..f1e4f14 100644 --- a/transformer/Doc/Para/Base/Decoder.py +++ b/transformer/Doc/Para/Base/Decoder.py @@ -21,7 +21,7 @@ def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, a super(DecoderLayer, self).__init__(isize, fhsize, dropout, attn_drop, num_head, _ahsize) self.cattns = nn.ModuleList([CrossAttn(isize, _ahsize, isize, num_head, dropout=attn_drop) for i in range(ncross)]) - self.cattn_ln = nn.ModuleList([nn.LayerNorm(isize, eps=ieps_ln_default) for i in range(ncross)]) + self.cattn_ln = nn.ModuleList([nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) for i in range(ncross)]) self.grs = nn.ModuleList([GateResidual(isize) for i in range(ncross)]) def forward(self, inpute, inputo, inputc, src_pad_mask=None, tgt_pad_mask=None, context_mask=None, query_unit=None): diff --git a/transformer/Doc/Para/Base/Encoder.py b/transformer/Doc/Para/Base/Encoder.py index 938519a..27e6ff8 100644 --- a/transformer/Doc/Para/Base/Encoder.py +++ b/transformer/Doc/Para/Base/Encoder.py @@ -22,7 +22,7 @@ def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, a super(CrossEncoderLayer, self).__init__(isize, fhsize, dropout, attn_drop, num_head, _ahsize) self.cattns = nn.ModuleList([CrossAttn(isize, _ahsize, isize, num_head, dropout=attn_drop) for i in range(ncross)]) - self.cattn_ln = nn.ModuleList([nn.LayerNorm(isize, eps=ieps_ln_default) for i in range(ncross)]) + self.cattn_ln = nn.ModuleList([nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) for i in range(ncross)]) self.grs = nn.ModuleList([GateResidual(isize) for i in range(ncross)]) def forward(self, inputs, inputc, mask=None, context_mask=None): diff --git a/transformer/Doc/Para/Base/NMT.py b/transformer/Doc/Para/Base/NMT.py index 8003ff3..70127b8 100644 --- a/transformer/Doc/Para/Base/NMT.py +++ b/transformer/Doc/Para/Base/NMT.py @@ -24,7 +24,7 @@ def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_ self.dec = Decoder(isize, tnwd, dec_layer, fhsize, dropout, attn_drop, emb_w, num_head, xseql, ahsize, norm_output, bindDecoderEmb, forbidden_index, nprev_context) - if use_k_relative_position > 0: + if rel_pos_enabled: share_rel_pos_cache(self) def forward(self, inpute, inputo, inputc, mask=None, context_mask=None): diff --git a/transformer/Encoder.py b/transformer/Encoder.py index 9c32956..a4bc278 100644 --- a/transformer/Encoder.py +++ b/transformer/Encoder.py @@ -38,7 +38,7 @@ def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, a self.ff = PositionwiseFF(isize, _fhsize, dropout, norm_residual) - self.layer_normer = nn.LayerNorm(isize, eps=ieps_ln_default) + self.layer_normer = nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) self.drop = Dropout(dropout, inplace=True) if dropout > 0.0 else None @@ -91,7 +91,7 @@ def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0. else: self.nets = nn.ModuleList([EncoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) for i in range(num_layer)]) - self.out_normer = nn.LayerNorm(isize, eps=ieps_ln_default) if norm_output else None + self.out_normer = nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) if norm_output else None # inputs: (bsize, seql) # mask: (bsize, 1, seql), generated with: diff --git a/transformer/LD/AttnEncoder.py b/transformer/LD/AttnEncoder.py new file mode 100644 index 0000000..d32002a --- /dev/null +++ b/transformer/LD/AttnEncoder.py @@ -0,0 +1,103 @@ +#encoding: utf-8 + +import torch +from torch import nn +from modules.LD import ATTNCombiner + +from math import sqrt, ceil + +from transformer.LD.Encoder import Encoder as EncoderBase + +from cnfg.ihyp import * + +class Encoder(EncoderBase): + + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, num_layer_dec=6, max_chunk_tokens=8, min_chunks=6): + + _ahsize = isize if ahsize is None else ahsize + + _fhsize = _ahsize * 4 if fhsize is None else fhsize + + super(Encoder, self).__init__(isize, nwd, num_layer, _fhsize, dropout, attn_drop, num_head, xseql, _ahsize, norm_output, num_layer_dec, max_chunk_tokens, min_chunks) + + self.attn_emb = ATTNCombiner(isize, isize, dropout) + self.attns = nn.ModuleList([ATTNCombiner(isize, isize, dropout) for i in range(num_layer)]) + + def forward(self, inputs, mask=None): + + def transform(lin, w, drop): + + _tmp = torch.stack(lin, -1) + _osize = _tmp.size() + _tmp = _tmp.view(-1, _osize[-1]).mm(w.softmax(dim=0) if drop is None else drop(w).softmax(dim=0)) + _osize = list(_osize) + _osize[-1] = -1 + + return _tmp.view(_osize) + + def build_chunk_max(atm, rept, bsize, nchk, ntok, npad, mask=None, rmask=None, chkmask=None): + + pad_out = rept.masked_fill(mask.squeeze(1).unsqueeze(-1), -inf_default) if npad == 0 else torch.cat((rept.masked_fill(mask.squeeze(1).unsqueeze(-1), -inf_default), rept.new_full((bsize, npad, rept.size(-1)), -inf_default),), dim=1) + + # query: bsize, nchk, isize + # kv: bsize, nchk*ntok, isize + query = pad_out.view(bsize, nchk, ntok, -1).max(2)[0].masked_fill(rmask.view(bsize, -1, 1), 0.0) + kv = rept.masked_fill(mask.squeeze(1).unsqueeze(-1), 0.0) if npad == 0 else torch.cat((rept.masked_fill(mask.squeeze(1).unsqueeze(-1), 0.0), rept.new_zeros((bsize, npad, rept.size(-1))),), dim=1) + out = atm(query.view(bsize * nchk, 1, -1), kv.view(bsize * nchk, ntok, -1), chkmask.view(bsize * nchk, ntok, 1)).view(bsize, nchk, -1) + + # mask is not necessary in theory .masked_fill(rmask.squeeze(1).unsqueeze(-1), 0.0) + return out + + def build_chunk_mean(atm, rept, bsize, nchk, ntok, npad, mask=None, rmask=None, chkmask=None, nele=None): + + pad_out = rept.masked_fill(mask.squeeze(1).unsqueeze(-1), 0.0) if npad == 0 else torch.cat((rept.masked_fill(mask.squeeze(1).unsqueeze(-1), 0.0), rept.new_zeros((bsize, npad, rept.size(-1))),), dim=1) + + query = pad_out.view(bsize, nchk, ntok, -1).sum(2) / nele + out = atm(query.view(bsize * nchk, 1, -1), pad_out.view(bsize * nchk, ntok, -1), chkmask.view(bsize * nchk, ntok, 1)).view(bsize, nchk, -1) + + return out + + bsize, seql = inputs.size() + + _ntok = max(min(self.mxct, ceil(seql / self.mnck)), 3) + _npad = (_ntok - (seql % _ntok)) % _ntok + _nchk = int((seql + _npad) / _ntok) + if mask is None: + _chk_mask = None + _rmask = None + else: + _chk_mask = mask if _npad == 0 else torch.cat((mask, mask.new_ones(bsize, 1, _npad),), dim=-1) + _nmask = _chk_mask.view(bsize, 1, _nchk, _ntok).sum(-1) + _rmask = _nmask.ge(_ntok) + + out = self.wemb(inputs) + out = out * sqrt(out.size(-1)) + if self.pemb is not None: + out = out + self.pemb(inputs, expand=False) + + #if _rmask is not None: + #_nele = (_ntok - _nmask).masked_fill(_nmask.eq(_ntok), 1).view(bsize, _nchk, 1).to(out) + + if self.drop is not None: + out = self.drop(out) + + out = self.out_normer(out) + outs = [out] + + _ho = build_chunk_max(self.attn_emb, out, bsize, _nchk, _ntok, _npad, mask, _rmask, _chk_mask) + #_ho = build_chunk_mean(self.attn_emb, out, bsize, _nchk, _ntok, _npad, mask, _rmask, _chk_mask, _nele) + hl = [_ho] + + for net, attnm in zip(self.nets, self.attns): + out = net(out, _ho, mask, _rmask) + outs.append(out) + _ho = build_chunk_max(attnm, out, bsize, _nchk, _ntok, _npad, mask, _rmask, _chk_mask) + #_ho = build_chunk_mean(attnm, out, bsize, _nchk, _ntok, _npad, mask, _rmask, _chk_mask, _nele) + hl.append(_ho) + + out = transform(outs, self.tattn_w, self.tattn_drop) + + # hl: (bsize, _nchk, isize, num_layer + 1) + hl = transform(hl, self.sc_tattn_w, self.sc_tattn_drop) + + return out, hl, _rmask diff --git a/transformer/LD/Decoder.py b/transformer/LD/Decoder.py new file mode 100644 index 0000000..a9a72af --- /dev/null +++ b/transformer/LD/Decoder.py @@ -0,0 +1,288 @@ +#encoding: utf-8 + +import torch +from torch import nn + +from modules.base import CrossAttn, ResidueCombiner +from modules.TA import PositionwiseFF + +from utils.base import repeat_bsize_for_beam_tensor +from math import sqrt + +from transformer.Decoder import DecoderLayer as DecoderLayerBase +from transformer.Decoder import Decoder as DecoderBase + +from cnfg.ihyp import * + +class DecoderLayer(DecoderLayerBase): + + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, ahsize=None): + + _ahsize = isize if ahsize is None else ahsize + + _fhsize = _ahsize * 4 if fhsize is None else fhsize + + super(DecoderLayer, self).__init__(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) + + self.cattn = CrossAttn(isize, _ahsize, isize, num_head, dropout=attn_drop) + + self.ff = PositionwiseFF(isize, _fhsize, dropout) + self.scff = ResidueCombiner(isize, 2, _fhsize, dropout) + + def forward(self, inpute, inputh, inputo, src_pad_mask=None, chk_pad_mask=None, tgt_pad_mask=None, query_unit=None, concat_query=False): + + if query_unit is None: + + states_return = None + + context = self.self_attn(inputo, mask=tgt_pad_mask) + + if self.drop is not None: + context = self.drop(context) + + context = context + inputo + + else: + + if concat_query: + + inputo = query_unit if inputo is None else torch.cat((inputo, query_unit,), 1) + + states_return = inputo + + context = self.self_attn(query_unit, iK=inputo) + + if self.drop is not None: + context = self.drop(context) + + context = context + query_unit + + _context = self.layer_normer1(context) + + _context = self.scff(_context, self.cattn(_context, inputh, mask=chk_pad_mask)) + + _context_new = self.cross_attn(_context, inpute, mask=src_pad_mask) + + if self.drop is not None: + _context_new = self.drop(_context_new) + + context = self.layer_normer2(_context_new + _context) + + context = self.ff(context) + + if states_return is None: + return context + else: + return context, states_return + +class Decoder(DecoderBase): + + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindemb=False, forbidden_index=None): + + _ahsize = isize if ahsize is None else ahsize + + _fhsize = _ahsize * 4 if fhsize is None else fhsize + + super(Decoder, self).__init__(isize, nwd, num_layer, _fhsize, dropout, attn_drop, emb_w, num_head, xseql, _ahsize, norm_output, bindemb, forbidden_index) + + self.nets = nn.ModuleList([DecoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) for i in range(num_layer)]) + + def forward(self, inpute, inputh, inputo, src_pad_mask=None, chk_pad_mask=None): + + bsize, nquery = inputo.size() + + out = self.wemb(inputo) + + out = out * sqrt(out.size(-1)) + if self.pemb is not None: + out = out + self.pemb(inputo, expand=False) + + if self.drop is not None: + out = self.drop(out) + + out = self.out_normer(out) + + _mask = self._get_subsequent_mask(nquery) + + for net, inputu, inputhu in zip(self.nets, inpute.unbind(dim=-1), inputh.unbind(dim=-1)): + out = net(inputu, inputhu, out, src_pad_mask, chk_pad_mask, _mask) + + out = self.lsm(self.classifier(out)) + + return out + + def greedy_decode(self, inpute, inputh, src_pad_mask=None, chk_pad_mask=None, max_len=512, fill_pad=False): + + bsize, seql= inpute.size()[:2] + + sos_emb = self.get_sos_emb(inpute) + + sqrt_isize = sqrt(sos_emb.size(-1)) + + out = sos_emb * sqrt_isize + if self.pemb is not None: + out = out + self.pemb.get_pos(0) + + if self.drop is not None: + out = self.drop(out) + + out = self.out_normer(out) + + states = {} + + for _tmp, (net, inputu, inputhu) in enumerate(zip(self.nets, inpute.unbind(dim=-1), inputh.unbind(dim=-1))): + out, _state = net(inputu, inputhu, None, src_pad_mask, chk_pad_mask, None, out, True) + states[_tmp] = _state + + out = self.lsm(self.classifier(out)) + + wds = out.argmax(dim=-1) + + trans = [wds] + + done_trans = wds.eq(2) + + for i in range(1, max_len): + + out = self.wemb(wds) * sqrt_isize + if self.pemb is not None: + out = out + self.pemb.get_pos(i) + + if self.drop is not None: + out = self.drop(out) + + out = self.out_normer(out) + + for _tmp, (net, inputu, inputhu) in enumerate(zip(self.nets, inpute.unbind(dim=-1), inputh.unbind(dim=-1))): + out, _state = net(inputu, inputhu, states[_tmp], src_pad_mask, chk_pad_mask, None, out, True) + states[_tmp] = _state + + out = self.lsm(self.classifier(out)) + wds = out.argmax(dim=-1) + + trans.append(wds.masked_fill(done_trans, 0) if fill_pad else wds) + + done_trans = done_trans | wds.eq(2) + if done_trans.int().sum().item() == bsize: + break + + return torch.cat(trans, 1) + + def beam_decode(self, inpute, inputh, src_pad_mask=None, chk_pad_mask=None, beam_size=8, max_len=512, length_penalty=0.0, return_all=False, clip_beam=False, fill_pad=False): + + bsize, seql = inpute.size()[:2] + + beam_size2 = beam_size * beam_size + bsizeb2 = bsize * beam_size2 + real_bsize = bsize * beam_size + + sos_emb = self.get_sos_emb(inpute) + isize = sos_emb.size(-1) + sqrt_isize = sqrt(isize) + + if length_penalty > 0.0: + lpv = sos_emb.new_ones(real_bsize, 1) + lpv_base = 6.0 ** length_penalty + + out = sos_emb * sqrt_isize + if self.pemb is not None: + out = out + self.pemb.get_pos(0) + + if self.drop is not None: + out = self.drop(out) + + out = self.out_normer(out) + + states = {} + + for _tmp, (net, inputu, inputhu) in enumerate(zip(self.nets, inpute.unbind(dim=-1), inputh.unbind(dim=-1))): + out, _state = net(inputu, inputhu, None, src_pad_mask, chk_pad_mask, None, out, True) + states[_tmp] = _state + + out = self.lsm(self.classifier(out)) + + scores, wds = out.topk(beam_size, dim=-1) + scores = scores.squeeze(1) + sum_scores = scores + wds = wds.view(real_bsize, 1) + trans = wds + + done_trans = wds.view(bsize, beam_size).eq(2) + + inpute = inpute.repeat(1, beam_size, 1, 1).view(real_bsize, seql, isize, -1) + inputh = repeat_bsize_for_beam_tensor(inputh, beam_size) + + _src_pad_mask = None if src_pad_mask is None else src_pad_mask.repeat(1, beam_size, 1).view(real_bsize, 1, seql) + _chk_pad_mask = None if chk_pad_mask is None else repeat_bsize_for_beam_tensor(chk_pad_mask, beam_size) + + for key, value in states.items(): + states[key] = repeat_bsize_for_beam_tensor(value, beam_size) + + for step in range(1, max_len): + + out = self.wemb(wds) * sqrt_isize + if self.pemb is not None: + out = out + self.pemb.get_pos(step) + + if self.drop is not None: + out = self.drop(out) + + out = self.out_normer(out) + + for _tmp, (net, inputu, inputhu) in enumerate(zip(self.nets, inpute.unbind(dim=-1), inputh.unbind(dim=-1))): + out, _state = net(inputu, inputhu, states[_tmp], _src_pad_mask, _chk_pad_mask, None, out, True) + states[_tmp] = _state + + out = self.lsm(self.classifier(out)).view(bsize, beam_size, -1) + + _scores, _wds = out.topk(beam_size, dim=-1) + _scores = (_scores.masked_fill(done_trans.unsqueeze(2).expand(bsize, beam_size, beam_size), 0.0) + sum_scores.unsqueeze(2).expand(bsize, beam_size, beam_size)) + + if length_penalty > 0.0: + lpv = lpv.masked_fill(~done_trans.view(real_bsize, 1), ((step + 6.0) ** length_penalty) / lpv_base) + + if clip_beam and (length_penalty > 0.0): + scores, _inds = (_scores.view(real_bsize, beam_size) / lpv.expand(real_bsize, beam_size)).view(bsize, beam_size2).topk(beam_size, dim=-1) + _tinds = (_inds + torch.arange(0, bsizeb2, beam_size2, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize) + sum_scores = _scores.view(bsizeb2).index_select(0, _tinds).view(bsize, beam_size) + else: + scores, _inds = _scores.view(bsize, beam_size2).topk(beam_size, dim=-1) + _tinds = (_inds + torch.arange(0, bsizeb2, beam_size2, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize) + sum_scores = scores + + wds = _wds.view(bsizeb2).index_select(0, _tinds).view(real_bsize, 1) + + _inds = (_inds / beam_size + torch.arange(0, real_bsize, beam_size, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize) + + trans = torch.cat((trans.index_select(0, _inds), wds.masked_fill(done_trans.view(real_bsize, 1), 0) if fill_pad else wds), 1) + + done_trans = (done_trans.view(real_bsize).index_select(0, _inds) | wds.eq(2).squeeze(1)).view(bsize, beam_size) + + _done = False + if length_penalty > 0.0: + lpv = lpv.index_select(0, _inds) + elif (not return_all) and done_trans.select(1, 0).int().sum().item() == bsize: + _done = True + + if _done or (done_trans.int().sum().item() == real_bsize): + break + + for key, value in states.items(): + states[key] = value.index_select(0, _inds) + + if (not clip_beam) and (length_penalty > 0.0): + scores = scores / lpv.view(bsize, beam_size) + scores, _inds = scores.topk(beam_size, dim=-1) + _inds = (_inds + torch.arange(0, real_bsize, beam_size, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize) + trans = trans.view(real_bsize, -1).index_select(0, _inds).view(bsize, beam_size, -1) + + if return_all: + + return trans, scores + else: + + return trans.view(bsize, beam_size, -1).select(1, 0) + + def decode(self, inpute, inputh, src_pad_mask, chk_pad_mask, beam_size=1, max_len=512, length_penalty=0.0, fill_pad=False): + + return self.beam_decode(inpute, inputh, src_pad_mask, chk_pad_mask, beam_size, max_len, length_penalty, fill_pad=fill_pad) if beam_size > 1 else self.greedy_decode(inpute, inputh, src_pad_mask, chk_pad_mask, max_len, fill_pad=fill_pad) diff --git a/transformer/LD/Encoder.py b/transformer/LD/Encoder.py new file mode 100644 index 0000000..a34ead8 --- /dev/null +++ b/transformer/LD/Encoder.py @@ -0,0 +1,129 @@ +#encoding: utf-8 + +import torch +from torch import nn +from modules.base import CrossAttn, Dropout, ResidueCombiner + +from math import sqrt, ceil + +from transformer.TA.Encoder import EncoderLayer as EncoderLayerBase +from transformer.TA.Encoder import Encoder as EncoderBase + +from cnfg.ihyp import * + +class EncoderLayer(EncoderLayerBase): + + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, ahsize=None): + + _ahsize = isize if ahsize is None else ahsize + _fhsize = _ahsize * 4 if fhsize is None else fhsize + + super(EncoderLayer, self).__init__(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) + + self.cattn = CrossAttn(isize, _ahsize, isize, num_head, dropout=attn_drop) + self.scff = ResidueCombiner(isize, 2, _fhsize, dropout) + + def forward(self, inputs, sumr, mask=None, rmask=None): + + #_bsize, _seql, _isize = inputs.size() + #_rep1, _rep2 = self.cattn(inputs.repeat(2, 1, 1), sumr, rmask).view(2, _bsize, _seql, _isize).unbind(0) + #inputs = self.scff(inputs, _rep1, _rep2) + inputs = self.scff(inputs, self.cattn(inputs, sumr, rmask)) + + context = self.attn(inputs, mask=mask) + + if self.drop is not None: + context = self.drop(context) + + context = self.layer_normer(context + inputs) + + context = self.ff(context) + + return context + +class Encoder(EncoderBase): + + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, num_layer_dec=6, max_chunk_tokens=8, min_chunks=4): + + _ahsize = isize if ahsize is None else ahsize + + _fhsize = _ahsize * 4 if fhsize is None else fhsize + + super(Encoder, self).__init__(isize, nwd, num_layer, _fhsize, dropout, attn_drop, num_head, xseql, _ahsize, norm_output, num_layer_dec) + + self.nets = nn.ModuleList([EncoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) for i in range(num_layer)]) + + self.sc_tattn_w = nn.Parameter(torch.Tensor(num_layer + 1, num_layer_dec).uniform_(- sqrt(1.0 / (num_layer + 1)), sqrt(1.0 / (num_layer + 1)))) + self.sc_tattn_drop = Dropout(dropout) if dropout > 0.0 else None + + self.mxct = max_chunk_tokens + self.mnck = float(min_chunks) + + def forward(self, inputs, mask=None): + + def transform(lin, w, drop): + + _tmp = torch.stack(lin, -1) + _osize = _tmp.size() + _tmp = _tmp.view(-1, _osize[-1]).mm(w.softmax(dim=0) if drop is None else drop(w).softmax(dim=0)) + _osize = list(_osize) + _osize[-1] = -1 + + return _tmp.view(_osize) + + def build_chunk_max(rept, bsize, nchk, ntok, npad, mask=None, rmask=None): + + out = rept.masked_fill(mask.squeeze(1).unsqueeze(-1), -inf_default) if npad == 0 else torch.cat((rept.masked_fill(mask.squeeze(1).unsqueeze(-1), -inf_default), rept.new_full((bsize, npad, rept.size(-1)), -inf_default),), dim=1) + + return out.view(bsize, nchk, ntok, -1).max(2)[0].masked_fill(rmask.squeeze(1).unsqueeze(-1), 0.0) + + def build_chunk_mean(rept, bsize, nchk, ntok, npad, mask=None, rmask=None, nele=None): + + out = rept.masked_fill(mask.squeeze(1).unsqueeze(-1), 0.0) if npad == 0 else torch.cat((rept.masked_fill(mask.squeeze(1).unsqueeze(-1), 0.0), rept.new_zeros((bsize, npad, rept.size(-1))),), dim=1) + + return (out.view(bsize, nchk, ntok, -1).sum(2) / nele).masked_fill(rmask.squeeze(1).unsqueeze(-1), 0.0) + + bsize, seql = inputs.size() + + _ntok = max(min(self.mxct, ceil(seql / self.mnck)), 2) + _npad = (_ntok - (seql % _ntok)) % _ntok + _nchk = int((seql + _npad) / _ntok) + if mask is None: + _chk_mask = None + _rmask = None + else: + _chk_mask = mask if _npad == 0 else torch.cat((mask, mask.new_ones(bsize, 1, _npad),), dim=-1) + _nmask = _chk_mask.view(bsize, 1, _nchk, _ntok).sum(-1) + _rmask = _nmask.ge(_ntok) + + out = self.wemb(inputs) + out = out * sqrt(out.size(-1)) + if self.pemb is not None: + out = out + self.pemb(inputs, expand=False) + + #if _rmask is not None: + #_nele = (_ntok - _nmask).view(bsize, _nchk, 1).to(out) + + if self.drop is not None: + out = self.drop(out) + + out = self.out_normer(out) + outs = [out] + + _ho = build_chunk_max(out, bsize, _nchk, _ntok, _npad, mask, _rmask) + #_ho = torch.cat((build_chunk_mean(out, bsize, _nchk, _ntok, _npad, mask, _srmask, _nele), build_chunk_max(out, bsize, _nchk, _ntok, _npad, mask, _srmask),), 0) + hl = [_ho] + + for net in self.nets: + out = net(out, _ho, mask, _rmask) + outs.append(out) + _ho = build_chunk_max(out, bsize, _nchk, _ntok, _npad, mask, _rmask) + #_ho = torch.cat((build_chunk_mean(out, bsize, _nchk, _ntok, _npad, mask, _srmask, _nele), build_chunk_max(out, bsize, _nchk, _ntok, _npad, mask, _srmask),), 0) + hl.append(_ho) + + out = transform(outs, self.tattn_w, self.tattn_drop) + + # hl: (bsize, _nchk, isize, num_layer + 1) + hl = transform(hl, self.sc_tattn_w, self.sc_tattn_drop) + + return out, hl, _rmask diff --git a/transformer/LD/NMT.py b/transformer/LD/NMT.py new file mode 100644 index 0000000..38f3817 --- /dev/null +++ b/transformer/LD/NMT.py @@ -0,0 +1,49 @@ +#encoding: utf-8 + +from torch import nn + +from utils.relpos import share_rel_pos_cache +from utils.fmt.base import parse_double_value_tuple + +from transformer.LD.Encoder import Encoder +from transformer.LD.Decoder import Decoder + +from cnfg.ihyp import * + +class NMT(nn.Module): + + def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=False, forbidden_index=None): + + super(NMT, self).__init__() + + enc_layer, dec_layer = parse_double_value_tuple(num_layer) + + self.enc = Encoder(isize, snwd, enc_layer, fhsize, dropout, attn_drop, num_head, xseql, ahsize, norm_output, dec_layer) + + emb_w = self.enc.wemb.weight if global_emb else None + + self.dec = Decoder(isize, tnwd, dec_layer, fhsize, dropout, attn_drop, emb_w, num_head, xseql, ahsize, norm_output, bindDecoderEmb, forbidden_index) + + if rel_pos_enabled: + share_rel_pos_cache(self) + + def forward(self, inpute, inputo, mask=None): + + _mask = inpute.eq(0).unsqueeze(1) if mask is None else mask + ence, ench, hmask = self.enc(inpute, _mask) + + return self.dec(ence, ench, inputo, _mask, hmask) + + # inpute: source sentences from encoder (bsize, seql) + # beam_size: the beam size for beam search + # max_len: maximum length to generate + + def decode(self, inpute, beam_size=1, max_len=None, length_penalty=0.0): + + mask = inpute.eq(0).unsqueeze(1) + + _max_len = inpute.size(1) + max(64, inpute.size(1) // 4) if max_len is None else max_len + + ence, ench, hmask = self.enc(inpute, mask) + + return self.dec.decode(ence, ench, mask, hmask, beam_size, _max_len, length_penalty) diff --git a/transformer/LD/__init__.py b/transformer/LD/__init__.py new file mode 100644 index 0000000..8fb0d7c --- /dev/null +++ b/transformer/LD/__init__.py @@ -0,0 +1 @@ +#encoding: utf-8 diff --git a/transformer/NMT.py b/transformer/NMT.py index 3645cc3..91c01e6 100644 --- a/transformer/NMT.py +++ b/transformer/NMT.py @@ -41,7 +41,7 @@ def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_ self.dec = Decoder(isize, tnwd, dec_layer, fhsize, dropout, attn_drop, emb_w, num_head, xseql, ahsize, norm_output, bindDecoderEmb, forbidden_index) #self.dec = Decoder(isize, tnwd, dec_layer, dropout, attn_drop, emb_w, num_head, xseql, ahsize, norm_output, bindDecoderEmb, forbidden_index)# for RNMT - if use_k_relative_position > 0: + if rel_pos_enabled: share_rel_pos_cache(self) # inpute: source sentences from encoder (bsize, seql) diff --git a/transformer/RNMTDecoder.py b/transformer/RNMTDecoder.py index 83b8514..ad02a35 100644 --- a/transformer/RNMTDecoder.py +++ b/transformer/RNMTDecoder.py @@ -8,6 +8,8 @@ from modules.base import * from modules.rnncells import * +from utils.fmt.base import pad_id + from cnfg.ihyp import * class FirstLayer(nn.Module): @@ -131,7 +133,7 @@ def __init__(self, isize, nwd, num_layer, dropout=0.0, attn_drop=0.0, emb_w=None self.lsm = nn.LogSoftmax(-1) - self.out_normer = nn.LayerNorm(isize, eps=ieps_ln_default) if norm_output else None + self.out_normer = nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) if norm_output else None self.fbl = None if forbidden_index is None else tuple(set(forbidden_index)) @@ -440,15 +442,16 @@ def get_sos_emb(self, inpute): return self.wemb.weight[1].reshape(1, 1, -1).expand(bsize, 1, -1) - # will it be better if zero corresponding weights? but called by fix_load prevent doing so - def fix_init(self): - _tmp = list(self.classifier.modules())[-1] - if self.fbl is not None: - for ind in self.fbl: - _tmp.bias.data[ind] = -1e32 + self.fix_load() + with torch.no_grad(): + self.wemb.weight[pad_id].zero_() + self.classifier.weight[pad_id].zero_() def fix_load(self): - self.fix_init() + if self.fbl is not None: + with torch.no_grad(): + #list(self.classifier.modules())[-1].bias.index_fill_(0, torch.tensor(self.fbl, dtype=torch.long, device=self.classifier.bias.device), -inf_default) + self.classifier.bias.index_fill_(0, torch.tensor(self.fbl, dtype=torch.long, device=self.classifier.bias.device), -inf_default) diff --git a/transformer/SC/Encoder.py b/transformer/SC/Encoder.py index 323c831..cad1828 100644 --- a/transformer/SC/Encoder.py +++ b/transformer/SC/Encoder.py @@ -21,7 +21,7 @@ def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0. self.attns = nn.ModuleList([CrossAttn(isize, _ahsize, isize, num_head, dropout=attn_drop) for i in range(num_layer)]) - self.sc_tattn_w = nn.Parameter(torch.Tensor(num_layer + 1, num_layer_dec).uniform_(- sqrt(2.0 / (num_layer + num_layer_dec + 1)), sqrt(2.0 / (num_layer + num_layer_dec + 1)))) + self.sc_tattn_w = nn.Parameter(torch.Tensor(num_layer + 1, num_layer_dec).uniform_(- sqrt(1.0 / (num_layer + 1)), sqrt(1.0 / (num_layer + 1)))) self.sc_tattn_drop = Dropout(dropout) if dropout > 0.0 else None # inputs: (bsize, seql) diff --git a/transformer/SC/NMT.py b/transformer/SC/NMT.py index 5f382b8..949cd78 100644 --- a/transformer/SC/NMT.py +++ b/transformer/SC/NMT.py @@ -26,7 +26,7 @@ def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_ self.dec = Decoder(isize, tnwd, dec_layer, fhsize, dropout, attn_drop, emb_w, num_head, xseql, ahsize, norm_output, bindDecoderEmb, forbidden_index) - if use_k_relative_position > 0: + if rel_pos_enabled: share_rel_pos_cache(self) def forward(self, inpute, inputo, mask=None): diff --git a/transformer/TA/Encoder.py b/transformer/TA/Encoder.py index 323aee4..be00c72 100644 --- a/transformer/TA/Encoder.py +++ b/transformer/TA/Encoder.py @@ -64,7 +64,7 @@ def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0. self.nets = nn.ModuleList([EncoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) for i in range(num_layer)]) - self.tattn_w = nn.Parameter(torch.Tensor(num_layer + 1, num_layer_dec).uniform_(- sqrt(2.0 / (num_layer + num_layer_dec + 1)), sqrt(2.0 / (num_layer + num_layer_dec + 1)))) + self.tattn_w = nn.Parameter(torch.Tensor(num_layer + 1, num_layer_dec).uniform_(- sqrt(1.0 / (num_layer + 1)), sqrt(1.0 / (num_layer + 1)))) self.tattn_drop = Dropout(dropout) if dropout > 0.0 else None # inputs: (bsize, seql) diff --git a/transformer/UniEncoder.py b/transformer/UniEncoder.py index 5cd5866..77eec64 100644 --- a/transformer/UniEncoder.py +++ b/transformer/UniEncoder.py @@ -42,7 +42,7 @@ def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0. self.net = EncoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) self.halter = nn.Sequential(Scorer(isize), nn.Sigmoid()) - self.out_normer = nn.LayerNorm(isize, eps=ieps_ln_default) if norm_output else None + self.out_normer = nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) if norm_output else None self.act_loss = ACT_Loss() diff --git a/utils/dynbatch.py b/utils/dynbatch.py index 4b701b7..e033bd4 100644 --- a/utils/dynbatch.py +++ b/utils/dynbatch.py @@ -85,18 +85,19 @@ class GradientMonitor: # num_his_gm: cache num_his_gm gradients into a history, and return this number of angle changes. # returns: (update_r, angle_r), update_r: to performing an optimization step, angle_r: the angle change in current step. - def __init__(self, num_group, select_func, angle_alpha=1.1, num_tol_amin=3, num_his_recoder=50, num_his_gm=1): + def __init__(self, num_group, select_func, module=None, angle_alpha=1.1, num_tol_amin=3, num_his_recoder=50, num_his_gm=1): self.scale = 180.0 / pi self.num_group = num_group self.recorder = EffRecoder(num_group, num_his=num_his_recoder, init_value=1.0)#init_value=180.0 if use sample_gumble_norm in self.reset self.select_func = select_func + self.module = module self.alpha, self.num_tol_amin, self.num_his = angle_alpha, num_tol_amin, num_his_gm self.reset() - def update(self, mod): + def update(self, mod=None): - _cur_gg = backup_para_grad(self.select_func(mod, self.sel_ind)) + _cur_gg = backup_para_grad(self.select_func(self.module if mod is None else mod, self.sel_ind)) angle_r = None if self.num_his > 1: if self.prev_grad is None: diff --git a/utils/init.py b/utils/init.py index 39fc081..cfc49fc 100644 --- a/utils/init.py +++ b/utils/init.py @@ -78,8 +78,10 @@ def init_model_params_lipschitz(modin, gain_glorot=sqrt(1.0/3.0), gain_kaiming=s if _m.bias is not None: _m.bias.zero_() elif isinstance(_m, LayerNorm): - _m.weight.fill_(1.0) - _m.bias.zero_() + if _m.weight is not None: + _m.weight.fill_(1.0) + if _m.bias is not None: + _m.bias.zero_() return _tmpm