Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplifying the use of the model to perform different tasks #25

Merged
merged 14 commits into from
Jul 18, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import torch

class MultipleChoiceLossCompute:
"A Loss compute and train function for multiple choice tasks."

def __init__(self, lm_criterion, clf_criterion, lm_coef, opt=None):
self.lm_criterion = lm_criterion
self.clf_criterion = clf_criterion
self.lm_coef = lm_coef
self.opt = opt

def __call__(self, X, Y, M, clf_logits, lm_logits=None, only_return_losses=False):
# Language modeling loss
if lm_logits is not None:
x_shifted = X[:, :, 1:, 0].contiguous().view(-1) # Shape: 252
M = M.view(-1, M.size(2))
lm_losses = self.lm_criterion(lm_logits, x_shifted)
lm_losses = lm_losses.view(X.size(0) * X.size(1), X.size(2) - 1)
lm_losses = lm_losses * M[:, 1:]
lm_losses = lm_losses.sum(1) / torch.sum(M[:, 1:], 1)
# Classification loss
clf_losses = self.clf_criterion(clf_logits, Y)
if only_return_losses:
return (clf_losses, lm_losses) if lm_logits is not None else clf_losses

if self.lm_coef > 0 and lm_logits is not None:
train_loss = clf_losses.sum() + self.lm_coef * lm_losses.sum()
else:
train_loss = clf_losses.sum()
train_loss.backward()
if self.opt is not None:
self.opt.step()
self.opt.zero_grad()
return train_loss.item()

class ClassificationLossCompute:
"A Loss compute and train function for classification tasks."

def __init__(self, lm_criterion, clf_criterion, lm_coef, opt=None):
self.lm_criterion = lm_criterion
self.clf_criterion = clf_criterion
self.lm_coef = lm_coef
self.opt = opt

def __call__(self, X, Y, M, clf_logits, lm_logits=None, only_return_losses=False):
# Language modeling loss
if lm_logits is not None:
x_shifted = X[:, 1:, 0].contiguous().view(-1)
M = M.view(-1, M.size(-1))
lm_losses = self.lm_criterion(lm_logits, x_shifted)
lm_losses = lm_losses.view(X.size(0), X.size(-2) - 1)
lm_losses = lm_losses * M[:, 1:]
lm_losses = lm_losses.sum(1) / torch.sum(M[:, 1:], 1)
# Classification loss
clf_losses = self.clf_criterion(clf_logits, Y)
if only_return_losses:
return (clf_losses, lm_losses) if lm_logits is not None else clf_losses

if self.lm_coef > 0 and lm_logits is not None:
train_loss = clf_losses.sum() + self.lm_coef * lm_losses.sum()
else:
train_loss = clf_losses.sum()
train_loss.backward()
if self.opt is not None:
self.opt.step()
self.opt.zero_grad()
return train_loss.item()

# TODO Implement a LossCompute class for similiraty tasks.
89 changes: 79 additions & 10 deletions model_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import math
import re
import collections

import numpy as np
import torch
Expand Down Expand Up @@ -158,7 +159,6 @@ def __init__(self, cfg, vocab=40990, n_ctx=512):
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(cfg.n_layer)])
self.decoder = nn.Linear(cfg.n_embd, vocab, bias=False)
self.decoder.weight = self.embed.weight # Tied weights
self.clf_dropout = nn.Dropout2d(cfg.clf_pdrop) # To reproduce the noise_shape parameter of TF implementation

nn.init.normal_(self.embed.weight, std=0.02)

Expand Down Expand Up @@ -188,22 +188,23 @@ def forward(self, h):
return lm_logits


class ClfHead(nn.Module):
class MultipleChoiceHead(nn.Module):
""" Classifier Head for the transformer """

def __init__(self, clf_token, cfg):
super(ClfHead, self).__init__()
super(MultipleChoiceHead, self).__init__()
self.n_embd = cfg.n_embd
self.clf_token = clf_token
self.dropout = nn.Dropout2d(cfg.clf_pdrop) # To reproduce the noise_shape parameter of TF implementation
self.linear = nn.Linear(cfg.n_embd, 1)
nn.init.normal_(self.linear.weight, std=0.02)

nn.init.normal_(self.linear.weight, std = 0.02)
nn.init.normal_(self.linear.bias, 0)

def forward(self, h, x):
# Classification logits
clf_h = h.view(-1, self.n_embd)
flat = x[:, :, :, 0].contiguous().view(-1)
flat = x[..., 0].contiguous().view(-1)
clf_h = clf_h[flat == self.clf_token, :]
clf_h = clf_h.view(-1, x.size(1), self.n_embd, 1)
# This double transposition is there to replicate the behavior
Expand All @@ -213,22 +214,90 @@ def forward(self, h, x):
clf_h = self.dropout(clf_h.transpose(1, 2)).transpose(1, 2)
clf_h = clf_h.contiguous().view(-1, self.n_embd)
clf_logits = self.linear(clf_h)

return clf_logits.view(-1, x.size(1))


class ClfHead(nn.Module):
"""Classification Head for the transformer

TODO: test this class."""
def __init__(self, clf_token, cfg, n_class):
super(ClfHead, self).__init__()
self.n_embd = cfg.n_embd
self.clf_token = clf_token
self.dropout = nn.Dropout(cfg.clf_pdrop)
self.linear = nn.Linear(cfg.n_embd, n_class)

nn.init.normal_(self.linear.weight, std = 0.02)
nn.init.normal_(self.linear.bias, 0)

def forward(self, h, x):
clf_h = h.view(-1, self.n_embd)
flat = x[..., 0].contiguous().view(-1)
clf_h = clf_h[flat == self.clf_token, :]
clf_h = self.dropout(clf_h)
clf_logits = self.linear(clf_h)

return clf_logits

class SimilarityHead(nn.Module):
""" Similarity Head for the transformer

TODO: test this class."""
def __init__(self, clf_token, cfg):
super(SimilarityHead, self).__init__()
self.n_embd = cfg.n_embd
self.clf_token = clf_token
self.dropout = nn.Dropout(cfg.clf_pdrop)
self.linear = nn.Linear(cfg.n_embd, 1)

nn.init.normal_(self.linear.weight, std = 0.02)
nn.init.normal_(self.linear.bias, 0)

def forward(self, h, x):
sim_h = h.view(-1, self.n_embd)
flat = x[..., 0].contiguous().view(-1)
sim_h = sim_h[flat == self.clf_token, :]
sim_h = self.dropout(sim_h)
sim_h = sim_h.sum(dim = 1)
sim_logits = self.linear(sim_h)

return sim_logits

class DoubleHeadModel(nn.Module):
""" Transformer with language model and classification heads """
def __init__(self, cfg, clf_token, vocab=40990, n_ctx=512):
""" Transformer with language model and task specific heads """
def __init__(self, cfg, clf_token, task_head_type, vocab=40990, n_ctx=512):
super(DoubleHeadModel, self).__init__()
self.transformer = TransformerModel(cfg, vocab=vocab, n_ctx=n_ctx)
self.lm_head = LMHead(self.transformer, cfg)
self.clf_head = ClfHead(clf_token, cfg)
if isinstance(task_head_type, str):
if task_head_type == 'multiple_choice':
self.task_head = MultipleChoiceHead(clf_token, cfg)
elif task_head_type == 'similarity':
self.task_head = SimilarityHead(clf_token, cfg)
elif task_head_type == 'inference':
# the three classes correspond to entailment, contradiction and neutral.
self.task_head = ClfHead(clf_token, cfg, 3)
else:
raise ValueError("task_head_type is expected to be 'multiple_choice' "
"'similarity', 'inference' or ('classification', n_class) "
f"got {task_head_type}.")
elif isinstance(task_head_type, collections.abc.Sequence) and len(task_head_type) == 2 and \
task_head_type[0] == 'classification':
n_class = task_head_type[1]
self.task_head = ClfHead(clf_token, cfg, n_class)
else:
raise ValueError("task_head_type is expected to be 'multiple_choice' "
"'similarity', 'inference' or ('classification', n_class) "
f"got {task_head_type}.")

def forward(self, x):
h = self.transformer(x)
lm_logits = self.lm_head(h)
clf_logits = self.clf_head(h, x)
return lm_logits, clf_logits
task_logits = self.task_head(h, x)

return lm_logits, task_logits


def load_openai_pretrained_model(model, n_ctx=-1, n_special=-1, n_transfer=12, n_embd=768, path='./model/',
Expand Down
46 changes: 6 additions & 40 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,41 +15,7 @@
from text_utils import TextEncoder
from utils import (encode_dataset, iter_data,
ResultLogger, make_path)


class LossCompute:
"A Loss compute and train function."

def __init__(self, lm_criterion, clf_criterion, lm_coef, opt=None):
self.lm_criterion = lm_criterion
self.clf_criterion = clf_criterion
self.lm_coef = lm_coef
self.opt = opt

def __call__(self, X, Y, M, clf_logits, lm_logits=None, only_return_losses=False):
# Language modeling loss
if lm_logits is not None:
x_shifted = X[:, :, 1:, 0].contiguous().view(-1) # Shape: 252
M = M.view(-1, M.size(2))
lm_losses = self.lm_criterion(lm_logits, x_shifted)
lm_losses = lm_losses.view(X.size(0) * X.size(1), X.size(2) - 1)
lm_losses = lm_losses * M[:, 1:]
lm_losses = lm_losses.sum(1) / torch.sum(M[:, 1:], 1)
# Classification loss
clf_losses = self.clf_criterion(clf_logits, Y)
if only_return_losses:
return (clf_losses, lm_losses) if lm_logits is not None else clf_losses

if self.lm_coef > 0 and lm_logits is not None:
train_loss = clf_losses.sum() + self.lm_coef * lm_losses.sum()
else:
train_loss = clf_losses.sum()
train_loss.backward()
if self.opt is not None:
self.opt.step()
self.opt.zero_grad()
return train_loss.item()

from loss import MultipleChoiceLossCompute

def transform_roc(X1, X2, X3):
n_batch = len(X1)
Expand Down Expand Up @@ -263,7 +229,7 @@ def run_epoch():
n_batch_train = args.n_batch * max(n_gpu, 1)
n_updates_total = (n_train // n_batch_train) * args.n_iter

dh_model = DoubleHeadModel(args, clf_token, vocab, n_ctx)
dh_model = DoubleHeadModel(args, clf_token, 'multiple_choice', vocab, n_ctx)

criterion = nn.CrossEntropyLoss(reduce=False)
model_opt = OpenAIAdam(dh_model.parameters(),
Expand All @@ -277,10 +243,10 @@ def run_epoch():
l2=args.l2,
vector_l2=args.vector_l2,
max_grad_norm=args.max_grad_norm)
compute_loss_fct = LossCompute(criterion,
criterion,
args.lm_coef,
model_opt)
compute_loss_fct = MultipleChoiceLossCompute(criterion,
criterion,
args.lm_coef,
model_opt)
load_openai_pretrained_model(dh_model.transformer, n_ctx=n_ctx, n_special=n_special)

dh_model.to(device)
Expand Down