tess-2-demo / sdlm /data /sni /sni_collator.py
hamishivi's picture
commit
17ff0d8 verified
raw
history blame contribute delete
10.5 kB
import logging
import random
import string
from dataclasses import dataclass
from typing import Any, Optional, Union
from transformers import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
logger = logging.getLogger(__name__)
@dataclass
class DataCollatorForNI:
tokenizer: PreTrainedTokenizerBase
model: Optional[Any] = None
padding: Union[bool, str, PaddingStrategy] = True
max_source_length: Optional[int] = None
max_target_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
label_pad_token_id: int = -100
return_tensors: str = "pt"
add_task_name: bool = False
add_task_definition: bool = True
num_pos_examples: int = 0
num_neg_examples: int = 0
add_explanation: bool = False
tk_instruct: bool = False
text_only: bool = False
random_gen: random.Random = random.Random(42)
def __call__(self, batch, return_tensors=None):
if return_tensors is None:
return_tensors = self.return_tensors
sources = []
batch = [batch]
for instance in batch:
if self.tk_instruct:
all_valid_encodings = [
# instruction only
{
"add_task_name": False,
"add_task_definition": True,
"num_pos_examples": 0,
"num_neg_examples": 0,
"add_explanation": False,
},
# example only
{
"add_task_name": False,
"add_task_definition": False,
"num_pos_examples": 2,
"num_neg_examples": 0,
"add_explanation": False,
},
# instruction + pos examples
{
"add_task_name": False,
"add_task_definition": True,
"num_pos_examples": 2,
"num_neg_examples": 0,
"add_explanation": False,
},
# instruction + pos examples + neg examples
{
"add_task_name": False,
"add_task_definition": True,
"num_pos_examples": 2,
"num_neg_examples": 2,
"add_explanation": False,
},
# instruction + pos (w. explanation)
{
"add_task_name": False,
"add_task_definition": True,
"num_pos_examples": 2,
"num_neg_examples": 0,
"add_explanation": True,
},
]
encoding_schema = self.random_gen.choice(all_valid_encodings)
add_task_name = encoding_schema["add_task_name"]
add_task_definition = encoding_schema["add_task_definition"]
num_pos_examples = encoding_schema["num_pos_examples"]
num_neg_examples = encoding_schema["num_neg_examples"]
add_explanation = encoding_schema["add_explanation"]
else:
add_task_name = self.add_task_name
add_task_definition = self.add_task_definition
num_pos_examples = self.num_pos_examples
num_neg_examples = self.num_neg_examples
add_explanation = self.add_explanation
task_input = ""
# add the input first.
task_input += "Now complete the following example -\n"
task_input += f"Input: {instance['Instance']['input'].strip()}"
if not task_input[-1] in string.punctuation:
task_input += "."
task_input += "\n"
task_input += "Output: "
task_name = ""
if add_task_name:
task_name += instance["Task"] + ". "
definition = ""
if add_task_definition:
if isinstance(instance["Definition"], list):
definition = (
"Definition: " + instance["Definition"][0].strip()
) # TODO: should we use <Definition>?
else:
definition = "Definition: " + instance["Definition"].strip()
if not definition[-1] in string.punctuation:
definition += "."
definition += "\n\n"
# try to add positive examples.
pos_examples = []
for idx, pos_example in enumerate(
instance["Positive Examples"][:num_pos_examples]
):
pos_example_str = f" Positive Example {idx+1} -\n"
pos_example_str += f"Input: {pos_example['input'].strip()}"
if not pos_example_str[-1] in string.punctuation:
pos_example_str += "."
pos_example_str += "\n"
pos_example_str += f" Output: {pos_example['output'].strip()}"
if not pos_example_str[-1] in string.punctuation:
pos_example_str += "."
pos_example_str += "\n"
if add_explanation and "explanation" in pos_example:
pos_example_str += (
f" Explanation: {pos_example['explanation'].strip()}"
)
if not pos_example_str[-1] in string.punctuation:
pos_example_str += "."
pos_example_str += "\n"
pos_example_str += "\n"
if (
len(
self.tokenizer(
definition
+ " ".join(pos_examples)
+ pos_example_str
+ task_input
)["input_ids"]
)
<= self.max_source_length
):
pos_examples.append(pos_example_str)
else:
break
# try to add negative examples.
neg_examples = []
for idx, neg_example in enumerate(
instance["Negative Examples"][:num_neg_examples]
):
neg_example_str = f" Negative Example {idx+1} -\n"
neg_example_str += f"Input: {neg_example['input'].strip()}"
if not neg_example_str[-1] in string.punctuation:
neg_example_str += "."
neg_example_str += "\n"
neg_example_str += f" Output: {neg_example['output'].strip()}"
if not neg_example_str[-1] in string.punctuation:
neg_example_str += "."
neg_example_str += "\n"
if add_explanation and "explanation" in neg_example:
neg_example_str += (
f" Explanation: {neg_example['explanation'].strip()}"
)
if not neg_example_str[-1] in string.punctuation:
neg_example_str += "."
neg_example_str += "\n"
neg_example_str += "\n"
if (
len(
self.tokenizer(
definition
+ " ".join(pos_examples)
+ " ".join(neg_examples)
+ neg_example_str
+ task_input
)["input_ids"]
)
<= self.max_source_length
):
neg_examples.append(neg_example_str)
else:
break
source = (
task_name
+ definition
+ "".join(pos_examples)
+ "".join(neg_examples)
+ task_input
)
tokenized_source = self.tokenizer(source)["input_ids"]
if len(tokenized_source) <= self.max_source_length:
sources.append(source)
else:
sources.append(
self.tokenizer.decode(
tokenized_source[: self.max_source_length],
skip_special_tokens=True,
)
)
if self.text_only:
model_inputs = {"inputs": sources}
else:
model_inputs = self.tokenizer(
sources,
max_length=self.max_source_length,
padding=self.padding,
return_tensors=self.return_tensors,
truncation=True,
pad_to_multiple_of=self.pad_to_multiple_of,
)
if "output" in batch[0]["Instance"] and batch[0]["Instance"]["output"]:
# Randomly select one reference if multiple are provided.
labels = [self.random_gen.choice(ex["Instance"]["output"]) for ex in batch]
if self.text_only:
model_inputs["label"] = labels
else:
with self.tokenizer.as_target_tokenizer():
labels = self.tokenizer(
labels,
max_length=self.max_target_length,
padding=self.padding,
return_tensors=self.return_tensors,
truncation=True,
pad_to_multiple_of=self.pad_to_multiple_of,
)
label_mask = labels["attention_mask"].bool()
model_inputs["label"] = labels["input_ids"].masked_fill(
~label_mask, self.label_pad_token_id
)
else:
model_inputs["label"] = None
# prepare decoder_input_ids
if (
self.model is not None
and hasattr(self.model, "prepare_decoder_input_ids_from_labels")
and not self.text_only
):
decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(
labels=model_inputs["label"]
)
model_inputs["decoder_input_ids"] = decoder_input_ids
# flatten the inputs to avoid listing
return {k: v[0] for k, v in model_inputs.items()}