|
|
|
|
|
|
|
import argparse |
|
import collections |
|
from scipy.stats import spearmanr |
|
|
|
import jsonlines |
|
import numpy as np |
|
from datasets import load_dataset |
|
from tqdm.auto import tqdm |
|
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, TrainingArguments, Trainer |
|
|
|
def add_answers_column(example): |
|
starts, texts = [], [] |
|
for hard_label in example["hard_labels"]: |
|
starts.append(hard_label[0]) |
|
texts.append(example["context"][hard_label[0]:hard_label[1]]) |
|
example["answers"] = {"answer_start": starts, "text": texts} |
|
return example |
|
|
|
def to_dataset(file_path): |
|
mushroom = load_dataset("json", data_files=file_path)["train"] |
|
mushroom = mushroom.rename_column("model_output_text", "context") |
|
mushroom = mushroom.rename_column("model_input", "question") |
|
if "hard_labels" in mushroom.column_names: |
|
mushroom = mushroom.map(add_answers_column) |
|
else: |
|
print("No hard labels found in the evaluation data: only generating predictions.") |
|
|
|
return mushroom |
|
|
|
def preprocess_examples(examples, tokenizer): |
|
questions = [q.strip() for q in examples["question"]] |
|
inputs = tokenizer( |
|
questions, |
|
examples["context"], |
|
max_length=384, |
|
truncation="only_second", |
|
stride=128, |
|
return_overflowing_tokens=True, |
|
return_offsets_mapping=True, |
|
padding="max_length", |
|
) |
|
|
|
sample_map = inputs.pop("overflow_to_sample_mapping") |
|
example_ids = [] |
|
|
|
for i in range(len(inputs["input_ids"])): |
|
sample_idx = sample_map[i] |
|
example_ids.append(examples["id"][sample_idx]) |
|
|
|
sequence_ids = inputs.sequence_ids(i) |
|
offset = inputs["offset_mapping"][i] |
|
inputs["offset_mapping"][i] = [ |
|
o if sequence_ids[k] == 1 else None for k, o in enumerate(offset) |
|
] |
|
|
|
inputs["example_id"] = example_ids |
|
return inputs |
|
|
|
|
|
def score_iou(ref_dict, pred_dict): |
|
""" |
|
Computes intersection-over-union between reference and predicted hard |
|
labels, for a single datapoint. |
|
|
|
Arguments: |
|
ref_dict (dict): a gold reference datapoint, |
|
pred_dict (dict): a model's prediction |
|
|
|
Returns: |
|
int: The IoU, or 1.0 if neither the reference nor the prediction contain hallucinations |
|
""" |
|
|
|
assert ref_dict['id'] == pred_dict['id'] |
|
|
|
ref_indices = {idx for span in ref_dict['hard_labels'] for idx in range(*span)} |
|
pred_indices = {idx for span in pred_dict['hard_labels'] for idx in range(*span)} |
|
|
|
if not pred_indices and not ref_indices: return 1. |
|
|
|
return len(ref_indices & pred_indices) / len(ref_indices | pred_indices) |
|
|
|
def score_cor(ref_dict, pred_dict): |
|
"""computes Spearman correlation between predicted and reference soft labels, for a single datapoint. |
|
inputs: |
|
- ref_dict: a gold reference datapoint, |
|
- pred_dict: a model's prediction |
|
returns: |
|
the Spearman correlation, or a binarized exact match (0.0 or 1.0) if the reference or prediction contains no variation |
|
""" |
|
|
|
assert ref_dict['id'] == pred_dict['id'] |
|
|
|
ref_vec = [0.] * ref_dict['text_len'] |
|
pred_vec = [0.] * ref_dict['text_len'] |
|
for span in ref_dict['soft_labels']: |
|
for idx in range(span['start'], span['end']): |
|
ref_vec[idx] = span['prob'] |
|
for span in pred_dict['soft_labels']: |
|
for idx in range(span['start'], span['end']): |
|
pred_vec[idx] = span['prob'] |
|
|
|
if len({round(flt, 8) for flt in pred_vec}) == 1 or len({round(flt, 8) for flt in ref_vec}) == 1 : |
|
return float(len({round(flt, 8) for flt in ref_vec}) == len({round(flt, 8) for flt in pred_vec})) |
|
|
|
return spearmanr(ref_vec, pred_vec).correlation |
|
|
|
def infer_soft_labels(hard_labels): |
|
"""reformat hard labels into soft labels with prob 1""" |
|
return [ |
|
{ |
|
'start': start, |
|
'end': end, |
|
'prob': 1.0, |
|
} |
|
for start, end in hard_labels |
|
] |
|
|
|
def find_possible_spans(answers, example): |
|
""" |
|
Creates and filters possible hallucination spans. |
|
|
|
Arguments: |
|
answers (list): List containing dictionaries with spans as text and |
|
logit scores. |
|
example: The instance which is being predicted. The context is used to map the predicted text to the start |
|
and end indexes of the target context. |
|
Returns: |
|
list: List with lists of hard labels. |
|
""" |
|
best_answer = max(answers, key=lambda x: x["logit_score"]) |
|
threshold = best_answer["logit_score"] * 0.8 |
|
hard_labels = [] |
|
for answer in answers: |
|
if answer["logit_score"] > threshold: |
|
start_index = example["context"].index(answer["text"]) |
|
end_index = start_index + len(answer["text"]) |
|
hard_labels.append([start_index, end_index]) |
|
soft_labels = infer_soft_labels(hard_labels) |
|
return hard_labels, soft_labels |
|
|
|
def compute_metrics(start_logits, end_logits, features, examples, predictions_file): |
|
""" |
|
Function to process predictions, create spans and if possible, |
|
calculates IoU |
|
|
|
Arguments: |
|
args (ArgumentParser): Arguments supplied by user. |
|
start_logits (list): Logits of all start positions. |
|
end_logits (list): Logits of all end positions. |
|
features (Dataset): Dataset containing features of questions and context. |
|
examples (Dataset): Dataset containing examples with hard labels. |
|
|
|
Returns: |
|
None |
|
""" |
|
example_to_features = collections.defaultdict(list) |
|
for idx, feature in enumerate(features): |
|
example_to_features[feature["example_id"]].append(idx) |
|
|
|
predicted_answers = [] |
|
for example in tqdm(examples): |
|
example_id = example["id"] |
|
context = example["context"] |
|
answers = [] |
|
|
|
|
|
for feature_index in example_to_features[example_id]: |
|
start_logit = start_logits[feature_index] |
|
end_logit = end_logits[feature_index] |
|
offsets = features[feature_index]["offset_mapping"] |
|
|
|
start_indexes = np.argsort(start_logit)[-1: -20 - 1: -1].tolist() |
|
end_indexes = np.argsort(end_logit)[-1: -20 - 1: -1].tolist() |
|
for start_index in start_indexes: |
|
for end_index in end_indexes: |
|
|
|
if offsets[start_index] is None or offsets[end_index] is None: |
|
continue |
|
|
|
if ( |
|
end_index < start_index |
|
or end_index - start_index + 1 > 30 |
|
): |
|
continue |
|
|
|
answer = { |
|
"text": context[offsets[start_index][0]: offsets[end_index][1]], |
|
"logit_score": start_logit[start_index] + end_logit[end_index], |
|
} |
|
answers.append(answer) |
|
|
|
|
|
if len(answers) > 0: |
|
hard_labels, soft_labels = find_possible_spans(answers, example) |
|
predicted_answers.append( |
|
{"id": example_id, "hard_labels": hard_labels, "soft_labels": soft_labels} |
|
) |
|
else: |
|
predicted_answers.append({"id": example_id, "hard_labels": [], "soft_labels": []}) |
|
|
|
with jsonlines.open(predictions_file, mode="w") as writer: |
|
writer.write_all(predicted_answers) |
|
|
|
if "answers" in examples.column_names: |
|
true_answers = [{"id": ex["id"], "hard_labels": ex["hard_labels"], "soft_labels": ex["soft_labels"], |
|
"text_len": len(ex["context"])} for ex in examples] |
|
ious = np.array([score_iou(r, d) for r, d in zip(true_answers, predicted_answers)]) |
|
cors = np.array([score_cor(r, d) for r, d in zip(true_answers, predicted_answers)]) |
|
|
|
print(f"IOU: {ious.mean():.8f}, COR: {cors.mean():.8f}") |
|
else: |
|
print("Evaluation data contained no answers. No scores to show.") |
|
|
|
def main(model_path, evaluation_file_path, output_file): |
|
model = AutoModelForQuestionAnswering.from_pretrained( |
|
model_path |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_path |
|
) |
|
|
|
args = TrainingArguments( |
|
output_dir="output_dir", |
|
per_device_eval_batch_size=16, |
|
report_to="none" |
|
) |
|
|
|
model = Trainer( |
|
model=model, |
|
args=args, |
|
tokenizer=tokenizer, |
|
) |
|
|
|
mushroom_dataset = to_dataset(evaluation_file_path) |
|
features = mushroom_dataset.map( |
|
preprocess_examples, |
|
batched=True, |
|
remove_columns=mushroom_dataset.column_names, |
|
fn_kwargs={"tokenizer": tokenizer} |
|
) |
|
|
|
predictions, _, _ = model.predict(features) |
|
start_logits, end_logits = predictions |
|
compute_metrics(start_logits, end_logits, features, mushroom_dataset, output_file) |
|
|
|
|
|
if __name__ == '__main__': |
|
p = argparse.ArgumentParser() |
|
p.add_argument('model_name', type=str) |
|
p.add_argument('evaluation_file_path', type=str) |
|
p.add_argument('output_file', type=str) |
|
a = p.parse_args() |
|
main(a.model_name, a.evaluation_file_path, a.output_file) |
|
|