-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* change the signature of node/edge filter * upd filter * Support multi-dimension node feature in SPMV * push transformer * remove some experimental settings * stable version * hotfix * upd tutorial * upd README * merge * remove redundency * remove tqdm * several changes * Refactor * Refactor * tutorial train * fixed a bug * fixed perf issue * upd * change dir * move un-related to contrib * tutuorial code * remove redundency * upd * upd * upd * upd * improve viz * universal done * halt norm * fixed a bug * add draw graph * fixed several bugs * remove dependency on core * upd format of README * trigger * trigger * upd viz * trigger * add transformer tutorial * fix tutorial * fix readme * small fix on tutorials * url fix in readme * fixed func link * upd
- Loading branch information
Showing
25 changed files
with
2,849 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
*~ | ||
data/ | ||
scripts/ | ||
checkpoints/ | ||
log/ | ||
*__pycache__* | ||
*.tar.gz | ||
*.zip | ||
*.pyc | ||
*.lprof | ||
*.swp |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# Transformer in DGL | ||
In this example we implement the [Transformer](https://arxiv.org/pdf/1706.03762.pdf) and [Universal Transformer](https://arxiv.org/abs/1807.03819) with ACT in DGL. | ||
|
||
The folder contains training module and inferencing module (beam decoder) for Transformer and training module for Universal Transformer | ||
|
||
## Requirements | ||
|
||
- PyTorch 0.4.1+ | ||
- networkx | ||
- tqdm | ||
|
||
## Usage | ||
|
||
- For training: | ||
|
||
``` | ||
python translation_train.py [--gpus id1,id2,...] [--N #layers] [--dataset DATASET] [--batch BATCHSIZE] [--universal] | ||
``` | ||
- For evaluating BLEU score on test set(by enabling `--print` to see translated text): | ||
``` | ||
python translation_test.py [--gpu id] [--N #layers] [--dataset DATASET] [--batch BATCHSIZE] [--checkpoint CHECKPOINT] [--print] [--universal] | ||
``` | ||
Available datasets: `copy`, `sort`, `wmt14`, `multi30k`(default). | ||
## Test Results | ||
### Transfomer | ||
- Multi30k: we achieve BLEU score 35.41 with default setting on Multi30k dataset, without using pre-trained embeddings. (if we set the number of layers to 2, the BLEU score could reach 36.45). | ||
- WMT14: work in progress | ||
### Universal Transformer | ||
- work in progress | ||
## Notes | ||
- Currently we do not support Multi-GPU training(this will be fixed soon), you should only specifiy only one gpu\_id when running the training script. | ||
## Reference | ||
- [The Annotated Transformer](http://nlp.seas.harvard.edu/2018/04/03/attention.html) | ||
- [Tensor2Tensor](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
from .graph import * | ||
from .fields import * | ||
from .utils import prepare_dataset | ||
import os | ||
import numpy as np | ||
|
||
class ClassificationDataset: | ||
"Dataset class for classification task." | ||
def __init__(self): | ||
raise NotImplementedError | ||
|
||
class TranslationDataset: | ||
''' | ||
Dataset class for translation task. | ||
By default, the source language shares the same vocabulary with the target language. | ||
''' | ||
INIT_TOKEN = '<sos>' | ||
EOS_TOKEN = '<eos>' | ||
PAD_TOKEN = '<pad>' | ||
MAX_LENGTH = 50 | ||
def __init__(self, path, exts, train='train', valid='valid', test='test', vocab='vocab.txt', replace_oov=None): | ||
vocab_path = os.path.join(path, vocab) | ||
self.src = {} | ||
self.tgt = {} | ||
with open(os.path.join(path, train + '.' + exts[0]), 'r') as f: | ||
self.src['train'] = f.readlines() | ||
with open(os.path.join(path, train + '.' + exts[1]), 'r') as f: | ||
self.tgt['train'] = f.readlines() | ||
with open(os.path.join(path, valid + '.' + exts[0]), 'r') as f: | ||
self.src['valid'] = f.readlines() | ||
with open(os.path.join(path, valid + '.' + exts[1]), 'r') as f: | ||
self.tgt['valid'] = f.readlines() | ||
with open(os.path.join(path, test + '.' + exts[0]), 'r') as f: | ||
self.src['test'] = f.readlines() | ||
with open(os.path.join(path, test + '.' + exts[1]), 'r') as f: | ||
self.tgt['test'] = f.readlines() | ||
|
||
if not os.path.exists(vocab_path): | ||
self._make_vocab(vocab_path) | ||
|
||
vocab = Vocab(init_token=self.INIT_TOKEN, | ||
eos_token=self.EOS_TOKEN, | ||
pad_token=self.PAD_TOKEN, | ||
unk_token=replace_oov) | ||
vocab.load(vocab_path) | ||
self.vocab = vocab | ||
strip_func = lambda x: x[:self.MAX_LENGTH] | ||
self.src_field = Field(vocab, | ||
preprocessing=None, | ||
postprocessing=strip_func) | ||
self.tgt_field = Field(vocab, | ||
preprocessing=lambda seq: [self.INIT_TOKEN] + seq + [self.EOS_TOKEN], | ||
postprocessing=strip_func) | ||
|
||
def get_seq_by_id(self, idx, mode='train', field='src'): | ||
"get raw sequence in dataset by specifying index, mode(train/valid/test), field(src/tgt)" | ||
if field == 'src': | ||
return self.src[mode][idx].strip().split() | ||
else: | ||
return [self.INIT_TOKEN] + self.tgt[mode][idx].strip().split() + [self.EOS_TOKEN] | ||
|
||
def _make_vocab(self, path, thres=2): | ||
word_dict = {} | ||
for mode in ['train', 'valid', 'test']: | ||
for line in self.src[mode] + self.tgt[mode]: | ||
for token in line.strip().split(): | ||
if token not in word_dict: | ||
word_dict[token] = 0 | ||
else: | ||
word_dict[token] += 1 | ||
|
||
with open(path, 'w') as f: | ||
for k, v in word_dict.items(): | ||
if v > 2: | ||
print(k, file=f) | ||
|
||
@property | ||
def vocab_size(self): | ||
return len(self.vocab) | ||
|
||
@property | ||
def pad_id(self): | ||
return self.vocab[self.PAD_TOKEN] | ||
|
||
@property | ||
def sos_id(self): | ||
return self.vocab[self.INIT_TOKEN] | ||
|
||
@property | ||
def eos_id(self): | ||
return self.vocab[self.EOS_TOKEN] | ||
|
||
def __call__(self, graph_pool, mode='train', batch_size=32, k=1, devices=['cpu']): | ||
''' | ||
Create a batched graph correspond to the mini-batch of the dataset. | ||
args: | ||
graph_pool: a GraphPool object for accelerating. | ||
mode: train/valid/test | ||
batch_size: batch size | ||
devices: ['cpu'] or a list of gpu ids. | ||
k: beam size(only required for test) | ||
''' | ||
dev_id, gs = 0, [] | ||
src_data, tgt_data = self.src[mode], self.tgt[mode] | ||
n = len(src_data) | ||
order = np.random.permutation(n) if mode == 'train' else range(n) | ||
src_buf, tgt_buf = [], [] | ||
|
||
for idx in order: | ||
src_sample = self.src_field( | ||
src_data[idx].strip().split()) | ||
tgt_sample = self.tgt_field( | ||
tgt_data[idx].strip().split()) | ||
src_buf.append(src_sample) | ||
tgt_buf.append(tgt_sample) | ||
if len(src_buf) == batch_size: | ||
if mode == 'test': | ||
assert len(devices) == 1 # we only allow single gpu for inference | ||
yield graph_pool.beam(src_buf, self.sos_id, self.MAX_LENGTH, k, device=devices[0]) | ||
else: | ||
gs.append(graph_pool(src_buf, tgt_buf, device=devices[dev_id])) | ||
dev_id += 1 | ||
if dev_id == len(devices): | ||
yield gs if len(devices) > 1 else gs[0] | ||
dev_id, gs = 0, [] | ||
src_buf, tgt_buf = [], [] | ||
|
||
if len(src_buf) != 0: | ||
if mode == 'test': | ||
yield graph_pool.beam(src_buf, self.sos_id, self.MAX_LENGTH, k, device=devices[0]) | ||
else: | ||
gs.append(graph_pool(src_buf, tgt_buf, device=devices[dev_id])) | ||
yield gs if len(devices) > 1 else gs[0] | ||
|
||
def get_sequence(self, batch): | ||
"return a list of sequence from a list of index arrays" | ||
ret = [] | ||
filter_list = set([self.pad_id, self.sos_id, self.eos_id]) | ||
for seq in batch: | ||
try: | ||
l = seq.index(self.eos_id) | ||
except: | ||
l = len(seq) | ||
ret.append(' '.join(self.vocab[token] for token in seq[:l] if not token in filter_list)) | ||
return ret | ||
|
||
def get_dataset(dataset): | ||
"we wrapped a set of datasets as example" | ||
prepare_dataset(dataset) | ||
if dataset == 'babi': | ||
raise NotImplementedError | ||
elif dataset == 'copy' or dataset == 'sort': | ||
return TranslationDataset( | ||
'data/{}'.format(dataset), | ||
('in', 'out'), | ||
train='train', | ||
valid='valid', | ||
test='test', | ||
) | ||
elif dataset == 'multi30k': | ||
return TranslationDataset( | ||
'data/multi30k', | ||
('en.atok', 'de.atok'), | ||
train='train', | ||
valid='val', | ||
test='test2016', | ||
replace_oov='<unk>' | ||
) | ||
elif dataset == 'wmt14': | ||
return TranslationDataset( | ||
'data/wmt14', | ||
('en', 'de'), | ||
train='train.tok.clean.bpe.32000', | ||
valid='newstest2013.tok.bpe.32000', | ||
test='newstest2014.tok.bpe.32000', | ||
vocab='vocab.bpe.32000') | ||
else: | ||
raise KeyError() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
class Vocab: | ||
def __init__(self, init_token=None, eos_token=None, pad_token=None, unk_token=None): | ||
self.init_token = init_token | ||
self.eos_token = eos_token | ||
self.pad_token = pad_token | ||
self.unk_token = unk_token | ||
self.vocab_lst = [] | ||
self.vocab_dict = None | ||
|
||
def load(self, path): | ||
if self.init_token is not None: | ||
self.vocab_lst.append(self.init_token) | ||
if self.eos_token is not None: | ||
self.vocab_lst.append(self.eos_token) | ||
if self.pad_token is not None: | ||
self.vocab_lst.append(self.pad_token) | ||
if self.unk_token is not None: | ||
self.vocab_lst.append(self.unk_token) | ||
with open(path, 'r') as f: | ||
for token in f.readlines(): | ||
token = token.strip() | ||
self.vocab_lst.append(token) | ||
self.vocab_dict = { | ||
v: k for k, v in enumerate(self.vocab_lst) | ||
} | ||
|
||
def __len__(self): | ||
return len(self.vocab_lst) | ||
|
||
def __getitem__(self, key): | ||
if isinstance(key, str): | ||
if key in self.vocab_dict: | ||
return self.vocab_dict[key] | ||
else: | ||
return self.vocab_dict[self.unk_token] | ||
else: | ||
return self.vocab_lst[key] | ||
|
||
class Field: | ||
def __init__(self, vocab, preprocessing=None, postprocessing=None): | ||
self.vocab = vocab | ||
self.preprocessing = preprocessing | ||
self.postprocessing = postprocessing | ||
|
||
def preprocess(self, x): | ||
if self.preprocessing is not None: | ||
return self.preprocessing(x) | ||
return x | ||
|
||
def postprocess(self, x): | ||
if self.postprocessing is not None: | ||
return self.postprocessing(x) | ||
return x | ||
|
||
def numericalize(self, x): | ||
return [self.vocab[token] for token in x] | ||
|
||
def __call__(self, x): | ||
return self.postprocess( | ||
self.numericalize( | ||
self.preprocess(x) | ||
) | ||
) |
Oops, something went wrong.