Skip to content

Commit

Permalink
Implement repetition-penalty for generation
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Dec 20, 2023
1 parent dcc9309 commit 1df0df1
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 3 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ Other configurable options include the top-p (nucleus sampling) probability, and
To test generation latency (e.g. batch size = 1) with different sampling strategies:

```
python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.5
python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.5
python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
```

To test generation throughput with random prompts (e.g. large batch size):
Expand Down
3 changes: 3 additions & 0 deletions benchmarks/benchmark_generation_mamba_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--topk", type=int, default=1)
parser.add_argument("--topp", type=float, default=1.0)
parser.add_argument("--repetition-penalty", type=float, default=1.0)
parser.add_argument("--batch", type=int, default=1)
args = parser.parse_args()

Expand Down Expand Up @@ -61,6 +62,7 @@
temperature=args.temperature,
top_k=args.topk,
top_p=args.topp,
repetition_penalty=args.repetition_penalty,
)
else:
fn = lambda: model.generate(
Expand All @@ -73,6 +75,7 @@
temperature=args.temperature,
top_k=args.topk,
top_p=args.topp,
repetition_penalty=args.repetition_penalty,
)
out = fn()
if args.prompt is not None:
Expand Down
25 changes: 24 additions & 1 deletion mamba_ssm/utils/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,20 @@ def modify_logits_for_top_p_filtering(logits, top_p):
logits.masked_fill_(indices_to_remove, float("-inf"))


def modify_logit_for_repetition_penalty(logits, prev_output_tokens, repetition_penalty=1.0):
"""Apply repetition penalty. See https://arxiv.org/abs/1909.05858
logits: (batch_size, vocab_size)
prev_output_tokens: (batch_size, seq_len)
"""
if repetition_penalty == 1.0:
return logits
score = torch.gather(logits, 1, prev_output_tokens)
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
logits.scatter_(1, prev_output_tokens, score)
return logits


def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
"""Sample from top-k logits.
Arguments:
Expand Down Expand Up @@ -97,6 +111,7 @@ def decode(
top_k=1,
top_p=0.0,
temperature=1.0,
repetition_penalty=1.0,
eos_token_id=None,
teacher_outputs=None,
vocab_size=None,
Expand Down Expand Up @@ -186,10 +201,18 @@ def should_stop(current_token, inference_params):
if enable_timing:
start.record()
scores, sequences = [], [input_ids]
sequences_cat = input_ids
while not should_stop(sequences[-1], inference_params):
scores.append(get_logits(sequences[-1], inference_params))
inference_params.seqlen_offset += sequences[-1].shape[1]
sampled_tokens = sample_tokens(scores[-1], inference_params)
if repetition_penalty == 1.0:
sampled_tokens = sample_tokens(scores[-1], inference_params)
else:
logits = modify_logit_for_repetition_penalty(
scores[-1].clone(), sequences_cat, repetition_penalty
)
sampled_tokens = sample_tokens(logits, inference_params)
sequences_cat = torch.cat([sequences_cat, sampled_tokens], dim=1)
sequences.append(sampled_tokens)
if streamer is not None:
streamer.put(sampled_tokens.cpu())
Expand Down

0 comments on commit 1df0df1

Please sign in to comment.