Add Guidance for Repetition Penalty
Browse fileshttps://github.com/huggingface/transformers/pull/37625 added support for excluding the input tokens from RepetitionPenaltyLogitsProcessor - this updates the code snippet to do this with a repetition penalty of 3.
README.md
CHANGED
@@ -51,7 +51,7 @@ Then run the code:
|
|
51 |
```python
|
52 |
import torch
|
53 |
import torchaudio
|
54 |
-
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
|
55 |
from huggingface_hub import hf_hub_download
|
56 |
|
57 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
@@ -64,7 +64,6 @@ speech_granite = AutoModelForSpeechSeq2Seq.from_pretrained(
|
|
64 |
model_name).to(device)
|
65 |
|
66 |
# prepare speech and text prompt, using the appropriate prompt template
|
67 |
-
|
68 |
audio_path = hf_hub_download(repo_id=model_name, filename='10226_10111_000000.wav')
|
69 |
wav, sr = torchaudio.load(audio_path, normalize=True)
|
70 |
assert wav.shape[0] == 1 and sr == 16000 # mono, 16khz
|
@@ -92,7 +91,14 @@ model_inputs = speech_granite_processor(
|
|
92 |
device=device, # Computation device; returned tensors are put on CPU
|
93 |
return_tensors="pt",
|
94 |
).to(device)
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
model_outputs = speech_granite.generate(
|
97 |
**model_inputs,
|
98 |
max_new_tokens=200,
|
@@ -100,9 +106,9 @@ model_outputs = speech_granite.generate(
|
|
100 |
do_sample=False,
|
101 |
min_length=1,
|
102 |
top_p=1.0,
|
103 |
-
repetition_penalty=1.0,
|
104 |
length_penalty=1.0,
|
105 |
temperature=1.0,
|
|
|
106 |
bos_token_id=tokenizer.bos_token_id,
|
107 |
eos_token_id=tokenizer.eos_token_id,
|
108 |
pad_token_id=tokenizer.pad_token_id,
|
|
|
51 |
```python
|
52 |
import torch
|
53 |
import torchaudio
|
54 |
+
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, RepetitionPenaltyLogitsProcessor
|
55 |
from huggingface_hub import hf_hub_download
|
56 |
|
57 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
64 |
model_name).to(device)
|
65 |
|
66 |
# prepare speech and text prompt, using the appropriate prompt template
|
|
|
67 |
audio_path = hf_hub_download(repo_id=model_name, filename='10226_10111_000000.wav')
|
68 |
wav, sr = torchaudio.load(audio_path, normalize=True)
|
69 |
assert wav.shape[0] == 1 and sr == 16000 # mono, 16khz
|
|
|
91 |
device=device, # Computation device; returned tensors are put on CPU
|
92 |
return_tensors="pt",
|
93 |
).to(device)
|
94 |
+
|
95 |
+
# The recommended repetition penalty is 3 as long as input IDs are excluded.
|
96 |
+
# Otherwise, you should use a reptition penalty of 1 to keep results stable.
|
97 |
+
reptition_penalty_processor = RepetitionPenaltyLogitsProcessor(
|
98 |
+
penalty=3.0,
|
99 |
+
prompt_ignore_length=model_inputs["input_ids"].shape[-1],
|
100 |
+
)
|
101 |
+
|
102 |
model_outputs = speech_granite.generate(
|
103 |
**model_inputs,
|
104 |
max_new_tokens=200,
|
|
|
106 |
do_sample=False,
|
107 |
min_length=1,
|
108 |
top_p=1.0,
|
|
|
109 |
length_penalty=1.0,
|
110 |
temperature=1.0,
|
111 |
+
logits_processor=[reptition_penalty_processor],
|
112 |
bos_token_id=tokenizer.bos_token_id,
|
113 |
eos_token_id=tokenizer.eos_token_id,
|
114 |
pad_token_id=tokenizer.pad_token_id,
|