Hjgugugjhuhjggg commited on
Commit
9cd71e4
·
verified ·
1 Parent(s): 4ec33a6

Update text_generation.py

Browse files
Files changed (1) hide show
  1. text_generation.py +26 -5
text_generation.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import torch
3
  import torch.nn.functional as F
4
  from tqdm import trange
@@ -47,7 +46,8 @@ try:
47
  device
48
  except NameError:
49
  device = "cpu"
50
- if device.startswith("cuda"):
 
51
  torch.backends.cudnn.benchmark = True
52
 
53
  MAX_GENERATION_LENGTH = 512
@@ -106,17 +106,38 @@ def _generate_sequence(model_call, context_tensor, generated, decode_fn, end_tok
106
  def sample_sequence(prompt, model, enc, max_length=MAX_GENERATION_LENGTH, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"):
107
  context_tokens = enc.encode(prompt)
108
  context_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device)
109
- return _generate_sequence(lambda ct, past: model(ct, past_key_values=past), context_tensor, list(context_tokens), lambda token: enc.decode([token]), lambda token: token == enc.encoder[END_OF_TEXT_TOKEN], temperature, top_k, top_p, repetition_penalty, max_length)
 
 
 
 
 
 
 
110
 
111
  def sample_sequence_codegen(prompt, model, tokenizer, max_length=MAX_GENERATION_LENGTH, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"):
112
  context_tokens = tokenizer.encode(prompt)
113
  context_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device)
114
- return _generate_sequence(lambda ct, past: model(input_ids=ct, past_key_values=past, labels=None), context_tensor, list(context_tokens), lambda token: tokenizer.decode([token]), lambda token: token == 50256, temperature, top_k, top_p, repetition_penalty, max_length)
 
 
 
 
 
 
 
115
 
116
  def summarize_text(text):
117
  if summarization_model and summarization_tokenizer:
118
  input_ids = summarization_tokenizer.encode(text, return_tensors="pt", truncation=True, max_length=1024).to(device)
119
- summary_ids = summarization_model.generate(input_ids, max_length=150, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True)
 
 
 
 
 
 
 
120
  return summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
121
  return text[:300] + "..." if len(text) > 300 else text
122
 
 
 
1
  import torch
2
  import torch.nn.functional as F
3
  from tqdm import trange
 
46
  device
47
  except NameError:
48
  device = "cpu"
49
+
50
+ if torch.device(device).type == "cuda":
51
  torch.backends.cudnn.benchmark = True
52
 
53
  MAX_GENERATION_LENGTH = 512
 
106
  def sample_sequence(prompt, model, enc, max_length=MAX_GENERATION_LENGTH, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"):
107
  context_tokens = enc.encode(prompt)
108
  context_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device)
109
+ return _generate_sequence(
110
+ lambda ct, past: model(ct, past_key_values=past),
111
+ context_tensor,
112
+ list(context_tokens),
113
+ lambda token: enc.decode([token]),
114
+ lambda token: token == enc.encoder[END_OF_TEXT_TOKEN],
115
+ temperature, top_k, top_p, repetition_penalty, max_length
116
+ )
117
 
118
  def sample_sequence_codegen(prompt, model, tokenizer, max_length=MAX_GENERATION_LENGTH, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"):
119
  context_tokens = tokenizer.encode(prompt)
120
  context_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device)
121
+ return _generate_sequence(
122
+ lambda ct, past: model(input_ids=ct, past_key_values=past, labels=None),
123
+ context_tensor,
124
+ list(context_tokens),
125
+ lambda token: tokenizer.decode([token]),
126
+ lambda token: token == 50256,
127
+ temperature, top_k, top_p, repetition_penalty, max_length
128
+ )
129
 
130
  def summarize_text(text):
131
  if summarization_model and summarization_tokenizer:
132
  input_ids = summarization_tokenizer.encode(text, return_tensors="pt", truncation=True, max_length=1024).to(device)
133
+ summary_ids = summarization_model.generate(
134
+ input_ids,
135
+ max_length=150,
136
+ min_length=40,
137
+ length_penalty=2.0,
138
+ num_beams=4,
139
+ early_stopping=True
140
+ )
141
  return summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
142
  return text[:300] + "..." if len(text) > 300 else text
143