Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/transformer #58

Open
wants to merge 55 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
ffdbe67
fixes to cnn embedding
emanjavacas Apr 30, 2019
a35e89b
fixes
emanjavacas Apr 30, 2019
b7d51ad
Merge branch 'dev' of https://github.com/mikekestemont/pie into dev
emanjavacas Apr 30, 2019
6e2baf2
cnn embedding fix
emanjavacas Apr 30, 2019
ed7eca1
minor fixes
emanjavacas May 6, 2019
896f6ed
Merge branch 'master' into dev
emanjavacas May 6, 2019
104c40b
removed processing files
emanjavacas May 8, 2019
1f0cae8
custom lstm fix
emanjavacas May 8, 2019
663ef0d
allow env nested keys and minor refactor
emanjavacas May 8, 2019
63ca969
moved lr scheduler to use pytorch builtin
emanjavacas May 8, 2019
7dfb65c
added optimize script for random hyperparam search
emanjavacas May 8, 2019
4d3d1dd
minor refactor
emanjavacas May 8, 2019
d2a0ff8
added inspect model script
emanjavacas May 8, 2019
71d115a
serializing without tmp files
emanjavacas May 8, 2019
0c134c2
serializing without tmp files
emanjavacas May 8, 2019
c7fe111
refactor scheduler init
emanjavacas May 8, 2019
eb6e6f6
read-only tasks
emanjavacas May 8, 2019
94a45d2
fixes and refactors
emanjavacas May 9, 2019
98502fd
added morph condition embeddings
emanjavacas May 9, 2019
7b77dd4
added example opt.json file
emanjavacas May 9, 2019
56962f0
fixes
emanjavacas May 30, 2019
6b4bf70
minor
emanjavacas May 30, 2019
7f19b0a
abstract stats function over readers
emanjavacas May 30, 2019
c47c2a0
Merge branch 'feature/conditioning' into dev
emanjavacas May 30, 2019
e1f2318
Merge branch 'master' into dev
emanjavacas Apr 21, 2020
8035d07
Added beam options to evaluate
emanjavacas Apr 21, 2020
fc87d00
deprecation fix
emanjavacas Apr 24, 2020
73156e9
highway was broken :-s
emanjavacas Apr 24, 2020
e0edf12
add option to cache batches (speed up for lengthy preprocessing workf…
emanjavacas Apr 24, 2020
bf29618
added optimize (random search)
emanjavacas Apr 24, 2020
c5ed597
Merge branch 'dev' into feature/transformer
emanjavacas Apr 24, 2020
c84e968
minimum formatting changes
emanjavacas Apr 24, 2020
8c997e8
rm read_only
emanjavacas Apr 24, 2020
81bf39a
Merge branch 'master' into feature/transformer
emanjavacas Apr 24, 2020
882bb70
fixes
emanjavacas Apr 24, 2020
dc131e1
train run changed signature
emanjavacas Apr 24, 2020
d8cca50
Merge branch 'feature-transformer' into feature/transformer
emanjavacas Apr 24, 2020
0869987
cosmetic
emanjavacas Apr 24, 2020
e140024
Added docstring to run_optimize
emanjavacas Apr 25, 2020
ddad681
abstracted out build_embeddings
emanjavacas Apr 25, 2020
814726d
fixing LSTM :-s
emanjavacas Apr 26, 2020
b797646
deprecation fix
emanjavacas Apr 26, 2020
e7f9be9
forgotten option
emanjavacas Apr 26, 2020
f821750
fixes to abstracted out embeddings
emanjavacas Apr 26, 2020
d3a8cdb
more deprecation fixes
emanjavacas Apr 26, 2020
63d6091
formatting
emanjavacas Apr 28, 2020
7575ed6
added transformer.py
emanjavacas Apr 28, 2020
ccb93c4
added transformer.py
emanjavacas Apr 28, 2020
974c869
Merge branch 'master' into feature/transformer
emanjavacas Apr 28, 2020
23b23e0
make transformers work for all tokenization types
emanjavacas Apr 29, 2020
cf558bb
fixed bug
emanjavacas May 12, 2020
a691072
fixed serialization issue
emanjavacas May 12, 2020
6626479
made transformer model serializable, rearrange scripts
emanjavacas May 12, 2020
17d6a03
added evaluate option to script
emanjavacas May 12, 2020
95153e2
Merge branch 'master' into feature/transformer
emanjavacas May 12, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pie/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

from .base_model import BaseModel
from .model import SimpleModel
from .transformer import TransformerDataset, TransformerModel
from .encoder import RNNEncoder
from .embedding import CNNEmbedding, RNNEmbedding, EmbeddingConcat, EmbeddingMixer
from .embedding import build_embeddings
Expand Down
2 changes: 1 addition & 1 deletion pie/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def save(self, fpath, infix=None, settings=None):

# create dir if necessary
dirname = os.path.dirname(fpath)
if not os.path.isdir(dirname):
if dirname and not os.path.isdir(dirname):
os.makedirs(dirname)

with tarfile.open(fpath, 'w') as tar:
Expand Down
307 changes: 307 additions & 0 deletions pie/models/transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,307 @@

import logging
from datetime import datetime

import torch
import torch.nn as nn
import torch.nn.functional as F

from pie import torch_utils
from pie.data import Dataset
from pie.data.dataset import pack_batch
from .base_model import BaseModel
from .decoder import (LinearDecoder, CRFDecoder, AttentionalDecoder)
from .embedding import build_embeddings


def get_instance_spans(tokenizer, text):
index = []
tokens = []
for (i, token) in enumerate(text.split()):
index.append(len(tokens))
for sub_token in tokenizer.tokenize(token, add_prefix_space=True):
tokens.append(sub_token)
index.append(len(tokens))
spans = list(zip(index[:-1], index[1:]))
return spans


def get_spans(tokenizer, text, batch):
spans = [get_instance_spans(tokenizer, inp) for inp in text]
max_span_len = max(end - start for sent in spans for start, end in sent)
max_spans = max(map(len, spans))
batch_size, _, emb_dim = batch.shape
output = torch.zeros(
batch_size, max_spans, max_span_len, emb_dim, device=batch.device)
mask = torch.zeros(batch_size, max_spans, max_span_len)

for i in range(batch_size):
for span, (start, end) in enumerate(spans[i]):
output[i, span, 0:end-start].copy_(batch[i, start:end])
mask[i, span, 0:end-start] = 1

return output, mask.bool()


def check_alignment(tokenizer, text):
spans = get_instance_spans(tokenizer, text)
orig_tokens = text.split()
assert len(spans) == len(orig_tokens)
tokens = tokenizer.tokenize(text)
output = []
for idx, (start, end) in enumerate(spans):
output.append((tokens[start:end], orig_tokens[idx]))
return output


class TransformerDataset(Dataset):
def __init__(self, settings, reader, label_encoder, tokenizer, model):
super().__init__(settings, reader, label_encoder)

self.tokenizer = tokenizer
self.model = model

def get_transformer_output(self, text, device):
encoded = self.tokenizer.batch_encode_plus(
text, return_tensors='pt', pad_to_max_length=True)
encoded = {k: val.to(self.model.device) for k, val in encoded.items()}
with torch.no_grad():
batch = self.model(**encoded)[0] # some models return 2 items, others 1
# remove <s>, </s> tokens
batch = batch[:, 1:-1]
# get spans
context, mask = get_spans(self.tokenizer, text, batch)
context, mask = context.to(device), mask.to(device)

return context, mask

def pack_batch(self, batch, device=None):
device = device or self.device
(word, char), tasks = pack_batch(self.label_encoder, batch, device)
context, mask = self.get_transformer_output(
[' '.join(inp) for inp, _ in batch], device)

return (word, char, (context, mask)), tasks


class SpanSelfAttention(nn.Module):
def __init__(self, context_dim, hidden_size, dropout=0.0):
self.context_dim = context_dim
self.hidden_size = hidden_size
self.dropout = dropout
super().__init__()

self.W = nn.Linear(context_dim, hidden_size)
self.v_a = nn.Parameter(torch.Tensor(hidden_size, 1))
self.init()

def init(self):
self.v_a.data.uniform_(-0.05, 0.05)
nn.init.xavier_uniform_(self.W.weight)

def forward(self, context, mask):
# (batch, num_spans, max_span_len, 1)
weights = self.W(context) @ self.v_a.unsqueeze(0).unsqueeze(0)
weights = weights.squeeze(3)
# apply mask
weights.masked_fill_(~mask, -float('inf'))
# softmax
weights = F.softmax(weights, dim=-1)
# remove nans that arise in padding
weights = torch.where(torch.isnan(weights), torch.zeros_like(weights), weights)
# weighted sum (batch, num_spans, max_span_len, dim) -> (batch, num_spans, dim)
context = (context * weights.unsqueeze(-1)).sum(2)
context = F.dropout(context, p=self.dropout, training=self.training)
# transpose to batch-second
context = context.transpose(0, 1)
return context


class TransformerModel(BaseModel):
def __init__(self, label_encoder, tasks, context_dim,
# input embeddings
wemb_dim=0, cemb_dim=0, cemb_type='RNN', custom_cemb_cell=False,
cemb_layers=1, cell='GRU', init_rnn='default', merge_type='concat',
# decoder
linear_layers=1, dropout=0.0, scorer='general'):
self.context_dim = context_dim
self.linear_layers = linear_layers
self.wemb_dim = wemb_dim
self.cemb_dim = cemb_dim
self.cemb_type = cemb_type
self.custom_cemb_cell = custom_cemb_cell
self.cemb_layers = cemb_layers
self.cell = cell
self.merge_type = merge_type
self.scorer = scorer
self.dropout = dropout
super().__init__(label_encoder, tasks)

hidden_size = context_dim

# embeddings
(self.wemb, self.cemb, self.merger), in_dim = build_embeddings(
label_encoder, wemb_dim,
cemb_dim, cemb_type, custom_cemb_cell, cemb_layers, cell, init_rnn,
merge_type, dropout)

# self attention
self.self_att = SpanSelfAttention(
context_dim, hidden_size, dropout=dropout)

# decoders
decoders = {}
for tname, task in self.tasks.items():

if task['level'].lower() == 'char':
if task['decoder'].lower() == 'attentional':
decoder = AttentionalDecoder(
label_encoder.tasks[tname], cemb_dim, self.cemb.embedding_dim,
context_dim=hidden_size + in_dim, scorer=scorer,
num_layers=cemb_layers, cell=cell, dropout=dropout,
init_rnn=init_rnn)

elif task['level'].lower() == 'token':
# linear
if task['decoder'].lower() == 'linear':
decoder = LinearDecoder(
label_encoder.tasks[tname], hidden_size + in_dim,
highway_layers=linear_layers - 1)
# crf
elif task['decoder'].lower() == 'crf':
decoder = CRFDecoder(
label_encoder.tasks[tname], hidden_size + in_dim,
highway_layers=linear_layers - 1)

else:
raise ValueError(
"Unknown decoder type {} for token-level task: {}".format(
task['decoder'], tname))

self.add_module('{}_decoder'.format(tname), decoder)
decoders[tname] = decoder

self.decoders = decoders

def get_args_and_kwargs(self):
return {'args': (self.context_dim, ),
'kwargs': {'linear_layers': self.linear_layers,
"wemb_dim": self.wemb_dim, "cemb_dim": self.cemb_dim,
"cemb_type": self.cemb_type,
"custom_cemb_cell": self.custom_cemb_cell,
"cemb_layers": self.cemb_layers, "cell": self.cell,
"merge_type": self.merge_type, "scorer": self.scorer}}

def embedding(self, word, wlen, char, clen):
wemb, cemb, cemb_outs = None, None, None
if self.wemb is not None:
# set words to unknown with prob `p` depending on word frequency
word = torch_utils.word_dropout(
word, self.word_dropout, self.training, self.label_encoder.word)
wemb = self.wemb(word)
if self.cemb is not None:
# cemb_outs: (seq_len x batch x emb_dim)
cemb, cemb_outs = self.cemb(char, clen, wlen)

if wemb is None:
emb = cemb
elif cemb is None:
emb = wemb
elif self.merger is not None:
emb = self.merger(wemb, cemb)
else:
emb = None

return emb, (wemb, cemb, cemb_outs)

def loss(self, batch_data, *target_tasks):
((word, wlen), (char, clen), (context, mask)), tasks = batch_data
output = {}

emb, (_, _, cemb_outs) = self.embedding(word, wlen, char, clen)

outs = self.self_att(context, mask)
if emb is not None:
outs = torch.cat([outs, emb], dim=-1)

for task in target_tasks:
(target, length), decoder = tasks[task], self.decoders[task]

if self.tasks[task]['level'].lower() == 'char':
cemb_outs = F.dropout(
cemb_outs, p=self.dropout, training=self.training)
logits = decoder(target, length, cemb_outs, clen,
context=torch_utils.flatten_padded_batch(outs, wlen))
output[task] = decoder.loss(logits, target)
else:
if isinstance(decoder, LinearDecoder):
logits = decoder(outs)
output[task] = decoder.loss(logits, target)
elif isinstance(decoder, CRFDecoder):
logits = decoder(outs)
output[task] = decoder.loss(logits, target, length)

return output

def predict(self, inp, *tasks, use_beam=False, beam_width=10, **kwargs):
tasks = set(self.label_encoder.tasks if not len(tasks) else tasks)
(word, wlen), (char, clen), (context, mask) = inp

emb, (_, _, cemb_outs) = self.embedding(word, wlen, char, clen)

outs = self.self_att(context, mask)
if emb is not None:
outs = torch.cat([outs, emb], dim=-1)

preds = {}
for task in tasks:
decoder = self.decoders[task]

if self.label_encoder.tasks[task].level.lower() == 'char':
if not use_beam:
hyps, _ = decoder.predict_max(
cemb_outs, clen,
context=torch_utils.flatten_padded_batch(outs, wlen))
else:
hyps, _ = decoder.predict_beam(
cemb_outs, clen,
context=torch_utils.flatten_padded_batch(outs, wlen),
width=beam_width)
if self.label_encoder.tasks[task].preprocessor_fn is None:
hyps = [''.join(hyp) for hyp in hyps]
else:
if isinstance(decoder, LinearDecoder):
hyps, _ = decoder.predict(outs, wlen)
elif isinstance(decoder, CRFDecoder):
hyps, _ = decoder.predict(outs, wlen)
else:
raise ValueError()

preds[task] = hyps

return preds


# transformer_path = '../latin-data/latin-model/v4/checkpoint-110000/'
# from transformers import AutoModel, AutoTokenizer
# tokenizer = AutoTokenizer.from_pretrained(transformer_path)
# model = AutoModel.from_pretrained(transformer_path)
# from pie.settings import settings_from_file
# settings = settings_from_file('transformer-lemma.json')
# from pie.data import Reader, MultiLabelEncoder
# reader = Reader(settings, settings.input_path)
# label_encoder = MultiLabelEncoder.from_settings(settings).fit_reader(reader)
# r = reader.readsents()
# sents = []
# for _ in range(10):
# _, (inp, tasks) = next(r)
# sents.append(inp)
# text = [' '.join(s) for s in sents]
# encoded = tokenizer.batch_encode_plus(
# text, return_tensors='pt', pad_to_max_length=True)
# encoded = {k: val.to(model.device) for k, val in encoded.items()}
# with torch.no_grad():
# batch = model(**encoded)[0]
# # some models return 2 items, others 1
# get_instance_spans(tokenizer, text[0])
# get_spans(tokenizer, text, batch)
14 changes: 1 addition & 13 deletions pie/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import logging

import pie
from pie.settings import settings_from_file
from pie.settings import settings_from_file, get_targets, get_fname_infix
from pie.trainer import Trainer
from pie import initialization
from pie.data import Dataset, Reader, MultiLabelEncoder
Expand All @@ -19,18 +19,6 @@
import torch


def get_targets(settings):
return [task['name'] for task in settings.tasks if task.get('target')]


def get_fname_infix(settings):
# fname
fname = os.path.join(settings.modelpath, settings.modelname)
timestamp = datetime.now().strftime("%Y_%m_%d-%H_%M_%S")
infix = '+'.join(get_targets(settings)) + '-' + timestamp
return fname, infix


def run(settings):
now = datetime.now()

Expand Down
Loading