|
import types |
|
from collections import defaultdict |
|
from typing import List, Optional, Tuple, Union |
|
|
|
import torch |
|
from transformers.pipelines.base import ArgumentHandler, ChunkPipeline, Dataset |
|
from transformers.utils import is_tf_available, is_torch_available |
|
|
|
if is_tf_available(): |
|
import tensorflow as tf |
|
from transformers.models.auto.modeling_tf_auto import ( |
|
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, |
|
) |
|
if is_torch_available(): |
|
from transformers.models.auto.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES |
|
|
|
|
|
def list_of_dicts2dict_of_lists(list_of_dicts: list[dict]) -> dict[str, list]: |
|
return {k: [d[k] for d in list_of_dicts] for k in list_of_dicts[0].keys()} |
|
|
|
|
|
class FeatureExtractionArgumentHandler(ArgumentHandler): |
|
"""Handles arguments for feature extraction.""" |
|
|
|
def __call__(self, inputs: Union[str, List[str]], **kwargs): |
|
if inputs is not None and isinstance(inputs, (list, tuple)) and len(inputs) > 0: |
|
inputs = list(inputs) |
|
batch_size = len(inputs) |
|
elif isinstance(inputs, str): |
|
inputs = [inputs] |
|
batch_size = 1 |
|
elif ( |
|
Dataset is not None |
|
and isinstance(inputs, Dataset) |
|
or isinstance(inputs, types.GeneratorType) |
|
): |
|
return inputs, None |
|
else: |
|
raise ValueError("At least one input is required.") |
|
|
|
offset_mapping = kwargs.get("offset_mapping") |
|
if offset_mapping: |
|
if isinstance(offset_mapping, list) and isinstance(offset_mapping[0], tuple): |
|
offset_mapping = [offset_mapping] |
|
if len(offset_mapping) != batch_size: |
|
raise ValueError("offset_mapping should have the same batch size as the input") |
|
return inputs, offset_mapping |
|
|
|
|
|
class FeatureExtractionPipelineWithStriding(ChunkPipeline): |
|
"""Same as transformers.FeatureExtractionPipeline, but with long input handling. Inspired by |
|
transformers.TokenClassificationPipeline. The functionality is triggered when providing the |
|
"stride" parameter (can be 0). When passing "create_unique_embeddings_per_token=True", only one |
|
embedding (and other data, see flags below) per token will be returned (this makes use of |
|
min_distance_to_border, see "return_min_distance_to_border" below for details). Note that this |
|
removes data for special token positions! |
|
|
|
Per default, it will return just the embeddings. If any of the return_ADDITIONAL_RESULT is |
|
enabled (see below), it will return dictionaries with "last_hidden_state" and all |
|
ADDITIONAL_RESULT depending on the flags. |
|
|
|
Flags to return additional results: |
|
return_offset_mapping: If enabled, return the offset mapping. |
|
return_special_tokens_mask: If enabled, return the special tokens mask. |
|
return_sequence_indices: If enabled, return the sequence indices. |
|
return_position_ids: If enabled, return the position ids from, values are in [0, model_max_length). |
|
return_min_distance_to_border: If enabled, return the minimum distance to the "border" of |
|
the input that gets passed into the model. This is useful when striding is used which may |
|
produce multiple embeddings for a token (compare values in offset_mapping). In this case, |
|
min_distance_to_border can be used to select the embedding that is more in the center |
|
of the input by choosing the entry with the *higher* min_distance_to_border. |
|
""" |
|
|
|
default_input_names = "sequences" |
|
|
|
def __init__(self, args_parser=FeatureExtractionArgumentHandler(), *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.check_model_type( |
|
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES |
|
if self.framework == "tf" |
|
else MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES |
|
) |
|
|
|
self._args_parser = args_parser |
|
|
|
def _sanitize_parameters( |
|
self, |
|
offset_mapping: Optional[List[Tuple[int, int]]] = None, |
|
stride: Optional[int] = None, |
|
create_unique_embeddings_per_token: Optional[bool] = False, |
|
return_offset_mapping: Optional[bool] = None, |
|
return_special_tokens_mask: Optional[bool] = None, |
|
return_sequence_indices: Optional[bool] = None, |
|
return_position_ids: Optional[bool] = None, |
|
return_min_distance_to_border: Optional[bool] = None, |
|
return_tensors=None, |
|
): |
|
preprocess_params = {} |
|
if offset_mapping is not None: |
|
preprocess_params["offset_mapping"] = offset_mapping |
|
|
|
if stride is not None: |
|
if stride >= self.tokenizer.model_max_length: |
|
raise ValueError( |
|
"`stride` must be less than `tokenizer.model_max_length` (or even lower if the tokenizer adds special tokens)" |
|
) |
|
|
|
if self.tokenizer.is_fast: |
|
tokenizer_params = { |
|
"return_overflowing_tokens": True, |
|
"padding": True, |
|
"stride": stride, |
|
} |
|
preprocess_params["tokenizer_params"] = tokenizer_params |
|
else: |
|
raise ValueError( |
|
"`stride` was provided to process all the text but you're using a slow tokenizer." |
|
" Please use a fast tokenizer." |
|
) |
|
postprocess_params = {} |
|
if create_unique_embeddings_per_token is not None: |
|
postprocess_params["create_unique_embeddings_per_token"] = ( |
|
create_unique_embeddings_per_token |
|
) |
|
if return_offset_mapping is not None: |
|
postprocess_params["return_offset_mapping"] = return_offset_mapping |
|
if return_special_tokens_mask is not None: |
|
postprocess_params["return_special_tokens_mask"] = return_special_tokens_mask |
|
if return_sequence_indices is not None: |
|
postprocess_params["return_sequence_indices"] = return_sequence_indices |
|
if return_position_ids is not None: |
|
postprocess_params["return_position_ids"] = return_position_ids |
|
if return_min_distance_to_border is not None: |
|
postprocess_params["return_min_distance_to_border"] = return_min_distance_to_border |
|
if return_tensors is not None: |
|
postprocess_params["return_tensors"] = return_tensors |
|
return preprocess_params, {}, postprocess_params |
|
|
|
def __call__(self, inputs: Union[str, List[str]], **kwargs): |
|
|
|
_inputs, offset_mapping = self._args_parser(inputs, **kwargs) |
|
if offset_mapping: |
|
kwargs["offset_mapping"] = offset_mapping |
|
|
|
return super().__call__(inputs, **kwargs) |
|
|
|
def preprocess(self, sentence, offset_mapping=None, **preprocess_params): |
|
tokenizer_params = preprocess_params.pop("tokenizer_params", {}) |
|
truncation = ( |
|
True |
|
if self.tokenizer.model_max_length and self.tokenizer.model_max_length > 0 |
|
else False |
|
) |
|
inputs = self.tokenizer( |
|
sentence, |
|
return_tensors=self.framework, |
|
truncation=truncation, |
|
return_special_tokens_mask=True, |
|
return_offsets_mapping=self.tokenizer.is_fast, |
|
**tokenizer_params, |
|
) |
|
inputs.pop("overflow_to_sample_mapping", None) |
|
num_chunks = len(inputs["input_ids"]) |
|
|
|
for i in range(num_chunks): |
|
if self.framework == "tf": |
|
model_inputs = {k: tf.expand_dims(v[i], 0) for k, v in inputs.items()} |
|
else: |
|
model_inputs = {k: v[i].unsqueeze(0) for k, v in inputs.items()} |
|
if offset_mapping is not None: |
|
model_inputs["offset_mapping"] = offset_mapping |
|
model_inputs["sentence"] = sentence if i == 0 else None |
|
model_inputs["is_last"] = i == num_chunks - 1 |
|
|
|
yield model_inputs |
|
|
|
def _forward(self, model_inputs, **kwargs): |
|
|
|
special_tokens_mask = model_inputs.pop("special_tokens_mask") |
|
offset_mapping = model_inputs.pop("offset_mapping", None) |
|
sentence = model_inputs.pop("sentence") |
|
is_last = model_inputs.pop("is_last") |
|
if self.framework == "tf": |
|
last_hidden_state = self.model(**model_inputs)[0] |
|
else: |
|
output = self.model(**model_inputs) |
|
last_hidden_state = ( |
|
output["last_hidden_state"] if isinstance(output, dict) else output[0] |
|
) |
|
|
|
return { |
|
"last_hidden_state": last_hidden_state, |
|
"special_tokens_mask": special_tokens_mask, |
|
"offset_mapping": offset_mapping, |
|
"sentence": sentence, |
|
"is_last": is_last, |
|
**model_inputs, |
|
} |
|
|
|
def postprocess_tensor(self, data, return_tensors=False): |
|
if return_tensors: |
|
return data |
|
if self.framework == "pt": |
|
return data.tolist() |
|
elif self.framework == "tf": |
|
return data.numpy().tolist() |
|
else: |
|
raise ValueError(f"unknown framework: {self.framework}") |
|
|
|
def make_embeddings_unique_per_token( |
|
self, data, offset_mapping, special_tokens_mask, min_distance_to_border |
|
): |
|
char_offsets2token_pos = defaultdict(list) |
|
bs, seq_len = offset_mapping.shape[:2] |
|
if bs != 1: |
|
raise ValueError(f"expected result batch size 1, but it is: {bs}") |
|
for token_idx, ((char_start, shar_end), is_special_token, min_dist) in enumerate( |
|
zip( |
|
offset_mapping[0].tolist(), |
|
special_tokens_mask[0].tolist(), |
|
min_distance_to_border[0].tolist(), |
|
) |
|
): |
|
|
|
if not is_special_token: |
|
char_offsets2token_pos[(char_start, shar_end)].append((token_idx, min_dist)) |
|
|
|
|
|
char_offsets2best_token_pos = { |
|
k: max(v, key=lambda pos_dist: pos_dist[1])[0] |
|
for k, v in char_offsets2token_pos.items() |
|
} |
|
|
|
sorted_char_offsets_token_positions = sorted( |
|
char_offsets2best_token_pos.items(), |
|
key=lambda char_offsets_tok_pos: ( |
|
char_offsets_tok_pos[0][0], |
|
char_offsets_tok_pos[0][1], |
|
), |
|
) |
|
best_indices = [tok_pos for char_offsets, tok_pos in sorted_char_offsets_token_positions] |
|
|
|
result = {k: v[0][best_indices].unsqueeze(0) for k, v in data.items()} |
|
return result |
|
|
|
def postprocess( |
|
self, |
|
all_outputs, |
|
create_unique_embeddings_per_token: bool = False, |
|
return_offset_mapping: bool = False, |
|
return_special_tokens_mask: bool = False, |
|
return_sequence_indices: bool = False, |
|
return_position_ids: bool = False, |
|
return_min_distance_to_border: bool = False, |
|
return_tensors: bool = False, |
|
): |
|
|
|
all_outputs_dict = list_of_dicts2dict_of_lists(all_outputs) |
|
if self.framework == "pt": |
|
result = { |
|
"last_hidden_state": torch.concat(all_outputs_dict["last_hidden_state"], axis=1) |
|
} |
|
if return_offset_mapping or create_unique_embeddings_per_token: |
|
result["offset_mapping"] = torch.concat(all_outputs_dict["offset_mapping"], axis=1) |
|
if return_special_tokens_mask or create_unique_embeddings_per_token: |
|
result["special_tokens_mask"] = torch.concat( |
|
all_outputs_dict["special_tokens_mask"], axis=1 |
|
) |
|
if return_sequence_indices: |
|
sequence_indices = [] |
|
for seq_idx, model_outputs in enumerate(all_outputs): |
|
sequence_indices.append(torch.ones_like(model_outputs["input_ids"]) * seq_idx) |
|
result["sequence_indices"] = torch.concat(sequence_indices, axis=1) |
|
if return_position_ids: |
|
position_ids = [] |
|
for seq_idx, model_outputs in enumerate(all_outputs): |
|
seq_len = model_outputs["input_ids"].size(1) |
|
position_ids.append(torch.arange(seq_len).unsqueeze(0)) |
|
result["indices"] = torch.concat(position_ids, axis=1) |
|
if return_min_distance_to_border or create_unique_embeddings_per_token: |
|
min_distance_to_border = [] |
|
for seq_idx, model_outputs in enumerate(all_outputs): |
|
seq_len = model_outputs["input_ids"].size(1) |
|
current_indices = torch.arange(seq_len).unsqueeze(0) |
|
min_distance_to_border.append( |
|
torch.minimum(current_indices, seq_len - current_indices) |
|
) |
|
result["min_distance_to_border"] = torch.concat(min_distance_to_border, axis=1) |
|
elif self.framework == "tf": |
|
raise NotImplementedError() |
|
else: |
|
raise ValueError(f"unknown framework: {self.framework}") |
|
|
|
if create_unique_embeddings_per_token: |
|
offset_mapping = result["offset_mapping"] |
|
if not return_offset_mapping: |
|
del result["offset_mapping"] |
|
special_tokens_mask = result["special_tokens_mask"] |
|
if not return_special_tokens_mask: |
|
del result["special_tokens_mask"] |
|
min_distance_to_border = result["min_distance_to_border"] |
|
if not return_min_distance_to_border: |
|
del result["min_distance_to_border"] |
|
result = self.make_embeddings_unique_per_token( |
|
data=result, |
|
offset_mapping=offset_mapping, |
|
special_tokens_mask=special_tokens_mask, |
|
min_distance_to_border=min_distance_to_border, |
|
) |
|
|
|
result = { |
|
k: self.postprocess_tensor(v, return_tensors=return_tensors) for k, v in result.items() |
|
} |
|
if set(result) == {"last_hidden_state"}: |
|
return result["last_hidden_state"] |
|
else: |
|
return result |
|
|