Skip to content

Commit

Permalink
black formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
aakbik committed Nov 27, 2019
1 parent 026cc79 commit 8e3b09b
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 33 deletions.
17 changes: 4 additions & 13 deletions flair/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,10 +556,7 @@ class HashEmbeddings(TokenEmbeddings):
"""Standard embeddings with Hashing Trick."""

def __init__(
self,
num_embeddings: int = 1000,
embedding_length: int = 300,
hash_method='md5'
self, num_embeddings: int = 1000, embedding_length: int = 300, hash_method="md5"
):

super().__init__()
Expand All @@ -579,7 +576,6 @@ def __init__(

self.to(flair.device)


@property
def num_embeddings(self) -> int:
return self.__num_embeddings
Expand All @@ -589,23 +585,18 @@ def embedding_length(self) -> int:
return self.__embedding_length

def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:

def get_idx_for_item(text):
hash_function = hashlib.new(self.__hash_method)
hash_function.update(bytes(str(text), 'utf-8'))
hash_function.update(bytes(str(text), "utf-8"))
return int(hash_function.hexdigest(), 16) % self.__num_embeddings

hash_sentences = []
for i, sentence in enumerate(sentences):
context_idxs = [
get_idx_for_item(t.text) for t in sentence.tokens
]
context_idxs = [get_idx_for_item(t.text) for t in sentence.tokens]

hash_sentences.extend(context_idxs)

hash_sentences = torch.tensor(hash_sentences, dtype=torch.long).to(
flair.device
)
hash_sentences = torch.tensor(hash_sentences, dtype=torch.long).to(flair.device)

embedded = self.embedding_layer.forward(hash_sentences)

Expand Down
14 changes: 8 additions & 6 deletions flair/models/sequence_tagger_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,17 +743,19 @@ def _viterbi_decode(
for index, (tag_id, tag_scores) in enumerate(zip(best_path, all_scores_np)):
if type(tag_id) != int and tag_id.item() != tag_scores.argmax():
swap_index_score = tag_scores.argmax()
all_scores_np[index][tag_id.item()], all_scores_np[index][
swap_index_score
] = (
(
all_scores_np[index][tag_id.item()],
all_scores_np[index][swap_index_score],
) = (
all_scores_np[index][swap_index_score],
all_scores_np[index][tag_id.item()],
)
elif type(tag_id) == int and tag_id != tag_scores.argmax():
swap_index_score = tag_scores.argmax()
all_scores_np[index][tag_id], all_scores_np[index][
swap_index_score
] = (
(
all_scores_np[index][tag_id],
all_scores_np[index][swap_index_score],
) = (
all_scores_np[index][swap_index_score],
all_scores_np[index][tag_id],
)
Expand Down
2 changes: 1 addition & 1 deletion flair/models/similarity_learning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def __init__(
target_mapping: torch.nn.Module = None,
recall_at_points: List[int] = [1, 5, 10, 20],
recall_at_points_weights: List[float] = [0.4, 0.3, 0.2, 0.1],
interleave_embedding_updates: bool = False
interleave_embedding_updates: bool = False,
):
super(SimilarityLearner, self).__init__()
self.source_embeddings: Embeddings = source_embeddings
Expand Down
39 changes: 27 additions & 12 deletions flair/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ def train(
sampler=None,
use_amp: bool = False,
amp_opt_level: str = "O1",
eval_on_train_fraction = 0.,
eval_on_train_shuffle = False,
eval_on_train_fraction=0.0,
eval_on_train_shuffle=False,
**kwargs,
) -> dict:
"""
Expand Down Expand Up @@ -181,15 +181,24 @@ def train(
else False
)
log_dev = True if not train_with_dev else False
log_train_part = True if (eval_on_train_fraction == 'dev' or eval_on_train_fraction > 0.) else False
log_train_part = (
True
if (eval_on_train_fraction == "dev" or eval_on_train_fraction > 0.0)
else False
)

if log_train_part:
train_part_size = len(self.corpus.dev) if eval_on_train_fraction == 'dev' \
else int(len(self.corpus.train) * eval_on_train_fraction)
assert(train_part_size > 0)
train_part_size = (
len(self.corpus.dev)
if eval_on_train_fraction == "dev"
else int(len(self.corpus.train) * eval_on_train_fraction)
)
assert train_part_size > 0
if not eval_on_train_shuffle:
train_part_indices = list(range(train_part_size))
train_part = torch.utils.data.dataset.Subset(self.corpus.train, train_part_indices)
train_part = torch.utils.data.dataset.Subset(
self.corpus.train, train_part_indices
)

# prepare loss logging file and set up header
loss_txt = init_output_file(base_path, "loss.tsv")
Expand Down Expand Up @@ -248,7 +257,9 @@ def train(
train_part_indices = list(range(self.corpus.train))
random.shuffle(train_part_indices)
train_part_indices = train_part_indices[:train_part_size]
train_part = torch.utils.data.dataset.Subset(self.corpus.train, train_part_indices)
train_part = torch.utils.data.dataset.Subset(
self.corpus.train, train_part_indices
)

# get new learning rate
for group in optimizer.param_groups:
Expand Down Expand Up @@ -384,11 +395,13 @@ def train(
DataLoader(
train_part,
batch_size=mini_batch_chunk_size,
num_workers=num_workers
num_workers=num_workers,
),
embedding_storage_mode=embeddings_storage_mode
embedding_storage_mode=embeddings_storage_mode,
)
result_line += (
f"\t{train_part_loss}\t{train_part_eval_result.log_line}"
)
result_line += f"\t{train_part_loss}\t{train_part_eval_result.log_line}"
log.info(
f"TRAIN_SPLIT : loss {train_part_loss} - score {train_part_eval_result.main_score}"
)
Expand Down Expand Up @@ -483,7 +496,9 @@ def train(
if log_train_part:
f.write(
"\tTRAIN_PART_LOSS\tTRAIN_PART_"
+ "\tTRAIN_PART_".join(train_part_eval_result.log_header.split("\t"))
+ "\tTRAIN_PART_".join(
train_part_eval_result.log_header.split("\t")
)
)
if log_dev:
f.write(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_transformer_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@


def calculate_mean_embedding(
subword_embeddings: List[torch.FloatTensor]
subword_embeddings: List[torch.FloatTensor],
) -> torch.FloatTensor:
all_embeddings: List[torch.FloatTensor] = [
embedding.unsqueeze(0) for embedding in subword_embeddings
Expand Down

0 comments on commit 8e3b09b

Please sign in to comment.