abrooks9944 commited on
Commit
981696b
·
verified ·
1 Parent(s): df28cad

Add Guidance for Repetition Penalty

Browse files

https://github.com/huggingface/transformers/pull/37625 added support for excluding the input tokens from RepetitionPenaltyLogitsProcessor - this updates the code snippet to do this with a repetition penalty of 3.

Files changed (1) hide show
  1. README.md +10 -4
README.md CHANGED
@@ -51,7 +51,7 @@ Then run the code:
51
  ```python
52
  import torch
53
  import torchaudio
54
- from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
55
  from huggingface_hub import hf_hub_download
56
 
57
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -64,7 +64,6 @@ speech_granite = AutoModelForSpeechSeq2Seq.from_pretrained(
64
  model_name).to(device)
65
 
66
  # prepare speech and text prompt, using the appropriate prompt template
67
-
68
  audio_path = hf_hub_download(repo_id=model_name, filename='10226_10111_000000.wav')
69
  wav, sr = torchaudio.load(audio_path, normalize=True)
70
  assert wav.shape[0] == 1 and sr == 16000 # mono, 16khz
@@ -92,7 +91,14 @@ model_inputs = speech_granite_processor(
92
  device=device, # Computation device; returned tensors are put on CPU
93
  return_tensors="pt",
94
  ).to(device)
95
-
 
 
 
 
 
 
 
96
  model_outputs = speech_granite.generate(
97
  **model_inputs,
98
  max_new_tokens=200,
@@ -100,9 +106,9 @@ model_outputs = speech_granite.generate(
100
  do_sample=False,
101
  min_length=1,
102
  top_p=1.0,
103
- repetition_penalty=1.0,
104
  length_penalty=1.0,
105
  temperature=1.0,
 
106
  bos_token_id=tokenizer.bos_token_id,
107
  eos_token_id=tokenizer.eos_token_id,
108
  pad_token_id=tokenizer.pad_token_id,
 
51
  ```python
52
  import torch
53
  import torchaudio
54
+ from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, RepetitionPenaltyLogitsProcessor
55
  from huggingface_hub import hf_hub_download
56
 
57
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
64
  model_name).to(device)
65
 
66
  # prepare speech and text prompt, using the appropriate prompt template
 
67
  audio_path = hf_hub_download(repo_id=model_name, filename='10226_10111_000000.wav')
68
  wav, sr = torchaudio.load(audio_path, normalize=True)
69
  assert wav.shape[0] == 1 and sr == 16000 # mono, 16khz
 
91
  device=device, # Computation device; returned tensors are put on CPU
92
  return_tensors="pt",
93
  ).to(device)
94
+
95
+ # The recommended repetition penalty is 3 as long as input IDs are excluded.
96
+ # Otherwise, you should use a reptition penalty of 1 to keep results stable.
97
+ reptition_penalty_processor = RepetitionPenaltyLogitsProcessor(
98
+ penalty=3.0,
99
+ prompt_ignore_length=model_inputs["input_ids"].shape[-1],
100
+ )
101
+
102
  model_outputs = speech_granite.generate(
103
  **model_inputs,
104
  max_new_tokens=200,
 
106
  do_sample=False,
107
  min_length=1,
108
  top_p=1.0,
 
109
  length_penalty=1.0,
110
  temperature=1.0,
111
+ logits_processor=[reptition_penalty_processor],
112
  bos_token_id=tokenizer.bos_token_id,
113
  eos_token_id=tokenizer.eos_token_id,
114
  pad_token_id=tokenizer.pad_token_id,