Spaces:
Running
Running
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from .hashing import get_seed_rng | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
class WmGenerator(): | |
def __init__(self, | |
model: AutoModelForCausalLM, | |
tokenizer: AutoTokenizer, | |
ngram: int = 1, | |
seed: int = 0, | |
**kwargs | |
): | |
# model config | |
self.tokenizer = tokenizer | |
self.vocab_size = self.tokenizer.vocab_size | |
self.model = model | |
self.max_seq_len = model.config.max_sequence_length if 'max_sequence_length' in model.config.to_dict() else 2048 | |
self.pad_id = model.config.pad_token_id if model.config.pad_token_id is not None else -1 | |
self.eos_id = model.config.eos_token_id | |
# watermark config | |
self.ngram = ngram | |
self.seed = seed | |
self.rng = torch.Generator() | |
self.rng.manual_seed(self.seed) | |
def generate( | |
self, | |
prompt: str, | |
max_gen_len: int, | |
temperature: float = 0.8, | |
top_p: float = 0.95, | |
return_aux: bool = False, | |
) -> str: | |
prompt_tokens = self.tokenizer.encode(prompt) | |
prompt_size = len(prompt_tokens) | |
total_len = min(self.max_seq_len, max_gen_len + prompt_size) | |
tokens = torch.full((1, total_len), self.pad_id).to(device).long() | |
if total_len < prompt_size: | |
print("prompt is bigger than max sequence length") | |
prompt_tokens = prompt_tokens[:total_len] | |
tokens[0, :len(prompt_tokens)] = torch.tensor(prompt_tokens).long() | |
input_text_mask = tokens != self.pad_id | |
start_pos = prompt_size | |
prev_pos = 0 | |
for cur_pos in range(start_pos, total_len): | |
past_key_values = outputs.past_key_values if prev_pos > 0 else None | |
outputs = self.model.forward( | |
tokens[:, prev_pos:cur_pos], | |
use_cache=True, | |
past_key_values=past_key_values | |
) | |
ngram_tokens = tokens[0, cur_pos-self.ngram:cur_pos].tolist() | |
aux = { | |
'ngram_tokens': ngram_tokens, | |
'cur_pos': cur_pos, | |
} | |
next_tok = self.sample_next(outputs.logits[:, -1, :], aux, temperature, top_p) | |
tokens[0, cur_pos] = torch.where(input_text_mask[0, cur_pos], tokens[0, cur_pos], next_tok) | |
prev_pos = cur_pos | |
if next_tok == self.eos_id: | |
break | |
# cut to max gen len | |
t = tokens[0, :prompt_size + max_gen_len].tolist() | |
# cut to eos tok if any | |
finish_reason = 'length' | |
try: | |
find_eos = t[prompt_size:].index(self.eos_id) | |
if find_eos: | |
t = t[: prompt_size+find_eos] | |
finish_reason = 'eos' | |
except ValueError: | |
pass | |
aux_info = { | |
't': t, | |
'finish_reason': finish_reason, | |
'n_toks_gen': len(t) - prompt_size, | |
'n_toks_tot': len(t), | |
} | |
decoded = self.tokenizer.decode(t) | |
if return_aux: | |
return decoded, aux_info | |
return decoded | |
def sample_next( | |
self, | |
logits: torch.FloatTensor, # (1, vocab_size): logits for last token | |
aux: dict, # ngram_tokens (1, ngram): tokens to consider when seeding | |
temperature: float = 0.8, # temperature for sampling | |
top_p: float = 0.95, # top p for sampling | |
): | |
"""Vanilla sampling with temperature and top p.""" | |
if temperature > 0: | |
probs = torch.softmax(logits / temperature, dim=-1) | |
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) | |
probs_sum = torch.cumsum(probs_sort, dim=-1) | |
mask = probs_sum - probs_sort > top_p | |
probs_sort[mask] = 0.0 | |
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) | |
next_token = torch.multinomial(probs_sort, num_samples=1) # one hot of next token, ordered by original probs | |
next_token = torch.gather(probs_idx, -1, next_token) # one hot of next token, ordered by vocab | |
else: | |
next_token = torch.argmax(logits, dim=-1) | |
next_token = next_token.reshape(-1)[0] # Get the single token value | |
return next_token | |
class OpenaiGenerator(WmGenerator): | |
""" | |
Generate text using LLaMA and Aaronson's watermarking method. | |
From ngram tokens, select the next token based on the following: | |
- hash the ngram tokens and get a seed | |
- use the seed to generate V random number r between [0,1] | |
- select argmax ( r^(1/p) ) | |
""" | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
def sample_next( | |
self, | |
logits: torch.FloatTensor, # (1, vocab_size): logits for last token | |
aux: dict, # (1, ngram): tokens to consider when seeding | |
temperature: float = 0.8, # temperature for sampling | |
top_p: float = 0.95, # top p for sampling | |
): | |
ngram_tokens = aux['ngram_tokens'] | |
if temperature > 0: | |
probs = torch.softmax(logits / temperature, dim=-1) | |
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) | |
probs_sum = torch.cumsum(probs_sort, dim=-1) | |
mask = probs_sum - probs_sort > top_p | |
probs_sort[mask] = 0.0 | |
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) | |
# seed with hash of ngram tokens | |
seed = get_seed_rng(self.seed, ngram_tokens) | |
self.rng.manual_seed(seed) | |
# generate rs randomly between [0,1] | |
rs = torch.rand(self.vocab_size, generator=self.rng) # n | |
rs = torch.Tensor(rs).to(probs_sort.device) | |
rs = rs[probs_idx[0]] | |
# compute r^(1/p) | |
probs_sort[0] = torch.pow(rs, 1/probs_sort[0]) | |
# select argmax ( r^(1/p) ) | |
next_token = torch.argmax(probs_sort, dim=-1, keepdim=True) | |
next_token = torch.gather(probs_idx, -1, next_token) | |
else: | |
next_token = torch.argmax(logits, dim=-1) | |
next_token = next_token.reshape(-1)[0] # Get the single token value | |
return next_token | |
class MarylandGenerator(WmGenerator): | |
""" | |
Generate text using LLaMA and Maryland's watemrarking method. | |
From ngram tokens, select the next token based on the following: | |
- hash the ngram tokens and get a seed | |
- use the seed to partition the vocabulary into greenlist (gamma*V words) and blacklist | |
- add delta to greenlist words' logits | |
""" | |
def __init__(self, | |
*args, | |
gamma: float = 0.5, | |
delta: float = 1.0, | |
**kwargs | |
): | |
super().__init__(*args, **kwargs) | |
self.gamma = gamma | |
self.delta = delta | |
def sample_next( | |
self, | |
logits: torch.FloatTensor, # (1, vocab_size): logits for last token | |
aux: dict, # ngram_tokens (1, ngram): tokens to consider when seeding | |
temperature: float = 0.8, # temperature for sampling | |
top_p: float = 0.95, # top p for sampling | |
): | |
ngram_tokens = aux['ngram_tokens'] | |
logits = self.logits_processor(logits, ngram_tokens) | |
if temperature > 0: | |
probs = torch.softmax(logits / temperature, dim=-1) | |
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) | |
probs_sum = torch.cumsum(probs_sort, dim=-1) | |
mask = probs_sum - probs_sort > top_p | |
probs_sort[mask] = 0.0 | |
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) | |
next_token = torch.multinomial(probs_sort, num_samples=1) # one hot of next token, ordered by original probs | |
next_token = torch.gather(probs_idx, -1, next_token) # one hot of next token, ordered by vocab | |
else: | |
next_token = torch.argmax(logits, dim=-1) | |
next_token = next_token.reshape(-1)[0] # Get the single token value | |
return next_token | |
def logits_processor(self, logits, ngram_tokens): | |
"""Process logits to mask out words in greenlist.""" | |
logits = logits.clone() | |
seed = get_seed_rng(self.seed, ngram_tokens) | |
self.rng.manual_seed(seed) | |
vocab_permutation = torch.randperm(self.vocab_size, generator=self.rng) | |
greenlist = vocab_permutation[:int(self.gamma * self.vocab_size)] # gamma * n | |
bias = torch.zeros(self.vocab_size).to(logits.device) | |
bias[greenlist] = self.delta | |
logits[0] += bias # add bias to greenlist words | |
return logits | |