pierrefdz's picture
inintal commit
8e6cbe9
raw
history blame
8.65 kB
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from .hashing import get_seed_rng
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class WmGenerator():
def __init__(self,
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
ngram: int = 1,
seed: int = 0,
**kwargs
):
# model config
self.tokenizer = tokenizer
self.vocab_size = self.tokenizer.vocab_size
self.model = model
self.max_seq_len = model.config.max_sequence_length if 'max_sequence_length' in model.config.to_dict() else 2048
self.pad_id = model.config.pad_token_id if model.config.pad_token_id is not None else -1
self.eos_id = model.config.eos_token_id
# watermark config
self.ngram = ngram
self.seed = seed
self.rng = torch.Generator()
self.rng.manual_seed(self.seed)
@torch.no_grad()
def generate(
self,
prompt: str,
max_gen_len: int,
temperature: float = 0.8,
top_p: float = 0.95,
return_aux: bool = False,
) -> str:
prompt_tokens = self.tokenizer.encode(prompt)
prompt_size = len(prompt_tokens)
total_len = min(self.max_seq_len, max_gen_len + prompt_size)
tokens = torch.full((1, total_len), self.pad_id).to(device).long()
if total_len < prompt_size:
print("prompt is bigger than max sequence length")
prompt_tokens = prompt_tokens[:total_len]
tokens[0, :len(prompt_tokens)] = torch.tensor(prompt_tokens).long()
input_text_mask = tokens != self.pad_id
start_pos = prompt_size
prev_pos = 0
for cur_pos in range(start_pos, total_len):
past_key_values = outputs.past_key_values if prev_pos > 0 else None
outputs = self.model.forward(
tokens[:, prev_pos:cur_pos],
use_cache=True,
past_key_values=past_key_values
)
ngram_tokens = tokens[0, cur_pos-self.ngram:cur_pos].tolist()
aux = {
'ngram_tokens': ngram_tokens,
'cur_pos': cur_pos,
}
next_tok = self.sample_next(outputs.logits[:, -1, :], aux, temperature, top_p)
tokens[0, cur_pos] = torch.where(input_text_mask[0, cur_pos], tokens[0, cur_pos], next_tok)
prev_pos = cur_pos
if next_tok == self.eos_id:
break
# cut to max gen len
t = tokens[0, :prompt_size + max_gen_len].tolist()
# cut to eos tok if any
finish_reason = 'length'
try:
find_eos = t[prompt_size:].index(self.eos_id)
if find_eos:
t = t[: prompt_size+find_eos]
finish_reason = 'eos'
except ValueError:
pass
aux_info = {
't': t,
'finish_reason': finish_reason,
'n_toks_gen': len(t) - prompt_size,
'n_toks_tot': len(t),
}
decoded = self.tokenizer.decode(t)
if return_aux:
return decoded, aux_info
return decoded
def sample_next(
self,
logits: torch.FloatTensor, # (1, vocab_size): logits for last token
aux: dict, # ngram_tokens (1, ngram): tokens to consider when seeding
temperature: float = 0.8, # temperature for sampling
top_p: float = 0.95, # top p for sampling
):
"""Vanilla sampling with temperature and top p."""
if temperature > 0:
probs = torch.softmax(logits / temperature, dim=-1)
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > top_p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1) # one hot of next token, ordered by original probs
next_token = torch.gather(probs_idx, -1, next_token) # one hot of next token, ordered by vocab
else:
next_token = torch.argmax(logits, dim=-1)
next_token = next_token.reshape(-1)[0] # Get the single token value
return next_token
class OpenaiGenerator(WmGenerator):
"""
Generate text using LLaMA and Aaronson's watermarking method.
From ngram tokens, select the next token based on the following:
- hash the ngram tokens and get a seed
- use the seed to generate V random number r between [0,1]
- select argmax ( r^(1/p) )
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def sample_next(
self,
logits: torch.FloatTensor, # (1, vocab_size): logits for last token
aux: dict, # (1, ngram): tokens to consider when seeding
temperature: float = 0.8, # temperature for sampling
top_p: float = 0.95, # top p for sampling
):
ngram_tokens = aux['ngram_tokens']
if temperature > 0:
probs = torch.softmax(logits / temperature, dim=-1)
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > top_p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
# seed with hash of ngram tokens
seed = get_seed_rng(self.seed, ngram_tokens)
self.rng.manual_seed(seed)
# generate rs randomly between [0,1]
rs = torch.rand(self.vocab_size, generator=self.rng) # n
rs = torch.Tensor(rs).to(probs_sort.device)
rs = rs[probs_idx[0]]
# compute r^(1/p)
probs_sort[0] = torch.pow(rs, 1/probs_sort[0])
# select argmax ( r^(1/p) )
next_token = torch.argmax(probs_sort, dim=-1, keepdim=True)
next_token = torch.gather(probs_idx, -1, next_token)
else:
next_token = torch.argmax(logits, dim=-1)
next_token = next_token.reshape(-1)[0] # Get the single token value
return next_token
class MarylandGenerator(WmGenerator):
"""
Generate text using LLaMA and Maryland's watemrarking method.
From ngram tokens, select the next token based on the following:
- hash the ngram tokens and get a seed
- use the seed to partition the vocabulary into greenlist (gamma*V words) and blacklist
- add delta to greenlist words' logits
"""
def __init__(self,
*args,
gamma: float = 0.5,
delta: float = 1.0,
**kwargs
):
super().__init__(*args, **kwargs)
self.gamma = gamma
self.delta = delta
def sample_next(
self,
logits: torch.FloatTensor, # (1, vocab_size): logits for last token
aux: dict, # ngram_tokens (1, ngram): tokens to consider when seeding
temperature: float = 0.8, # temperature for sampling
top_p: float = 0.95, # top p for sampling
):
ngram_tokens = aux['ngram_tokens']
logits = self.logits_processor(logits, ngram_tokens)
if temperature > 0:
probs = torch.softmax(logits / temperature, dim=-1)
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > top_p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1) # one hot of next token, ordered by original probs
next_token = torch.gather(probs_idx, -1, next_token) # one hot of next token, ordered by vocab
else:
next_token = torch.argmax(logits, dim=-1)
next_token = next_token.reshape(-1)[0] # Get the single token value
return next_token
def logits_processor(self, logits, ngram_tokens):
"""Process logits to mask out words in greenlist."""
logits = logits.clone()
seed = get_seed_rng(self.seed, ngram_tokens)
self.rng.manual_seed(seed)
vocab_permutation = torch.randperm(self.vocab_size, generator=self.rng)
greenlist = vocab_permutation[:int(self.gamma * self.vocab_size)] # gamma * n
bias = torch.zeros(self.vocab_size).to(logits.device)
bias[greenlist] = self.delta
logits[0] += bias # add bias to greenlist words
return logits