import types from collections import defaultdict from typing import List, Optional, Tuple, Union import torch from transformers.pipelines.base import ArgumentHandler, ChunkPipeline, Dataset from transformers.utils import is_tf_available, is_torch_available if is_tf_available(): import tensorflow as tf from transformers.models.auto.modeling_tf_auto import ( TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, ) if is_torch_available(): from transformers.models.auto.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES def list_of_dicts2dict_of_lists(list_of_dicts: list[dict]) -> dict[str, list]: return {k: [d[k] for d in list_of_dicts] for k in list_of_dicts[0].keys()} class FeatureExtractionArgumentHandler(ArgumentHandler): """Handles arguments for feature extraction.""" def __call__(self, inputs: Union[str, List[str]], **kwargs): if inputs is not None and isinstance(inputs, (list, tuple)) and len(inputs) > 0: inputs = list(inputs) batch_size = len(inputs) elif isinstance(inputs, str): inputs = [inputs] batch_size = 1 elif ( Dataset is not None and isinstance(inputs, Dataset) or isinstance(inputs, types.GeneratorType) ): return inputs, None else: raise ValueError("At least one input is required.") offset_mapping = kwargs.get("offset_mapping") if offset_mapping: if isinstance(offset_mapping, list) and isinstance(offset_mapping[0], tuple): offset_mapping = [offset_mapping] if len(offset_mapping) != batch_size: raise ValueError("offset_mapping should have the same batch size as the input") return inputs, offset_mapping class FeatureExtractionPipelineWithStriding(ChunkPipeline): """Same as transformers.FeatureExtractionPipeline, but with long input handling. Inspired by transformers.TokenClassificationPipeline. The functionality is triggered when providing the "stride" parameter (can be 0). When passing "create_unique_embeddings_per_token=True", only one embedding (and other data, see flags below) per token will be returned (this makes use of min_distance_to_border, see "return_min_distance_to_border" below for details). Note that this removes data for special token positions! Per default, it will return just the embeddings. If any of the return_ADDITIONAL_RESULT is enabled (see below), it will return dictionaries with "last_hidden_state" and all ADDITIONAL_RESULT depending on the flags. Flags to return additional results: return_offset_mapping: If enabled, return the offset mapping. return_special_tokens_mask: If enabled, return the special tokens mask. return_sequence_indices: If enabled, return the sequence indices. return_position_ids: If enabled, return the position ids from, values are in [0, model_max_length). return_min_distance_to_border: If enabled, return the minimum distance to the "border" of the input that gets passed into the model. This is useful when striding is used which may produce multiple embeddings for a token (compare values in offset_mapping). In this case, min_distance_to_border can be used to select the embedding that is more in the center of the input by choosing the entry with the *higher* min_distance_to_border. """ default_input_names = "sequences" def __init__(self, args_parser=FeatureExtractionArgumentHandler(), *args, **kwargs): super().__init__(*args, **kwargs) self.check_model_type( TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES if self.framework == "tf" else MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES ) self._args_parser = args_parser def _sanitize_parameters( self, offset_mapping: Optional[List[Tuple[int, int]]] = None, stride: Optional[int] = None, create_unique_embeddings_per_token: Optional[bool] = False, return_offset_mapping: Optional[bool] = None, return_special_tokens_mask: Optional[bool] = None, return_sequence_indices: Optional[bool] = None, return_position_ids: Optional[bool] = None, return_min_distance_to_border: Optional[bool] = None, return_tensors=None, ): preprocess_params = {} if offset_mapping is not None: preprocess_params["offset_mapping"] = offset_mapping if stride is not None: if stride >= self.tokenizer.model_max_length: raise ValueError( "`stride` must be less than `tokenizer.model_max_length` (or even lower if the tokenizer adds special tokens)" ) if self.tokenizer.is_fast: tokenizer_params = { "return_overflowing_tokens": True, "padding": True, "stride": stride, } preprocess_params["tokenizer_params"] = tokenizer_params # type: ignore else: raise ValueError( "`stride` was provided to process all the text but you're using a slow tokenizer." " Please use a fast tokenizer." ) postprocess_params = {} if create_unique_embeddings_per_token is not None: postprocess_params["create_unique_embeddings_per_token"] = ( create_unique_embeddings_per_token ) if return_offset_mapping is not None: postprocess_params["return_offset_mapping"] = return_offset_mapping if return_special_tokens_mask is not None: postprocess_params["return_special_tokens_mask"] = return_special_tokens_mask if return_sequence_indices is not None: postprocess_params["return_sequence_indices"] = return_sequence_indices if return_position_ids is not None: postprocess_params["return_position_ids"] = return_position_ids if return_min_distance_to_border is not None: postprocess_params["return_min_distance_to_border"] = return_min_distance_to_border if return_tensors is not None: postprocess_params["return_tensors"] = return_tensors return preprocess_params, {}, postprocess_params def __call__(self, inputs: Union[str, List[str]], **kwargs): _inputs, offset_mapping = self._args_parser(inputs, **kwargs) if offset_mapping: kwargs["offset_mapping"] = offset_mapping return super().__call__(inputs, **kwargs) def preprocess(self, sentence, offset_mapping=None, **preprocess_params): tokenizer_params = preprocess_params.pop("tokenizer_params", {}) truncation = ( True if self.tokenizer.model_max_length and self.tokenizer.model_max_length > 0 else False ) inputs = self.tokenizer( sentence, return_tensors=self.framework, truncation=truncation, return_special_tokens_mask=True, return_offsets_mapping=self.tokenizer.is_fast, **tokenizer_params, ) inputs.pop("overflow_to_sample_mapping", None) num_chunks = len(inputs["input_ids"]) for i in range(num_chunks): if self.framework == "tf": model_inputs = {k: tf.expand_dims(v[i], 0) for k, v in inputs.items()} else: model_inputs = {k: v[i].unsqueeze(0) for k, v in inputs.items()} if offset_mapping is not None: model_inputs["offset_mapping"] = offset_mapping model_inputs["sentence"] = sentence if i == 0 else None model_inputs["is_last"] = i == num_chunks - 1 yield model_inputs def _forward(self, model_inputs, **kwargs): # Forward special_tokens_mask = model_inputs.pop("special_tokens_mask") offset_mapping = model_inputs.pop("offset_mapping", None) sentence = model_inputs.pop("sentence") is_last = model_inputs.pop("is_last") if self.framework == "tf": last_hidden_state = self.model(**model_inputs)[0] else: output = self.model(**model_inputs) last_hidden_state = ( output["last_hidden_state"] if isinstance(output, dict) else output[0] ) return { "last_hidden_state": last_hidden_state, "special_tokens_mask": special_tokens_mask, "offset_mapping": offset_mapping, "sentence": sentence, "is_last": is_last, **model_inputs, } def postprocess_tensor(self, data, return_tensors=False): if return_tensors: return data if self.framework == "pt": return data.tolist() elif self.framework == "tf": return data.numpy().tolist() else: raise ValueError(f"unknown framework: {self.framework}") def make_embeddings_unique_per_token( self, data, offset_mapping, special_tokens_mask, min_distance_to_border ): char_offsets2token_pos = defaultdict(list) bs, seq_len = offset_mapping.shape[:2] if bs != 1: raise ValueError(f"expected result batch size 1, but it is: {bs}") for token_idx, ((char_start, shar_end), is_special_token, min_dist) in enumerate( zip( offset_mapping[0].tolist(), special_tokens_mask[0].tolist(), min_distance_to_border[0].tolist(), ) ): if not is_special_token: char_offsets2token_pos[(char_start, shar_end)].append((token_idx, min_dist)) # tokens_with_multiple_embeddings = {k: v for k, v in char_offsets2token_pos.items() if len(v) > 1} char_offsets2best_token_pos = { k: max(v, key=lambda pos_dist: pos_dist[1])[0] for k, v in char_offsets2token_pos.items() } # sort by char offsets (start and end) sorted_char_offsets_token_positions = sorted( char_offsets2best_token_pos.items(), key=lambda char_offsets_tok_pos: ( char_offsets_tok_pos[0][0], char_offsets_tok_pos[0][1], ), ) best_indices = [tok_pos for char_offsets, tok_pos in sorted_char_offsets_token_positions] result = {k: v[0][best_indices].unsqueeze(0) for k, v in data.items()} return result def postprocess( self, all_outputs, create_unique_embeddings_per_token: bool = False, return_offset_mapping: bool = False, return_special_tokens_mask: bool = False, return_sequence_indices: bool = False, return_position_ids: bool = False, return_min_distance_to_border: bool = False, return_tensors: bool = False, ): all_outputs_dict = list_of_dicts2dict_of_lists(all_outputs) if self.framework == "pt": result = { "last_hidden_state": torch.concat(all_outputs_dict["last_hidden_state"], axis=1) } if return_offset_mapping or create_unique_embeddings_per_token: result["offset_mapping"] = torch.concat(all_outputs_dict["offset_mapping"], axis=1) if return_special_tokens_mask or create_unique_embeddings_per_token: result["special_tokens_mask"] = torch.concat( all_outputs_dict["special_tokens_mask"], axis=1 ) if return_sequence_indices: sequence_indices = [] for seq_idx, model_outputs in enumerate(all_outputs): sequence_indices.append(torch.ones_like(model_outputs["input_ids"]) * seq_idx) result["sequence_indices"] = torch.concat(sequence_indices, axis=1) if return_position_ids: position_ids = [] for seq_idx, model_outputs in enumerate(all_outputs): seq_len = model_outputs["input_ids"].size(1) position_ids.append(torch.arange(seq_len).unsqueeze(0)) result["indices"] = torch.concat(position_ids, axis=1) if return_min_distance_to_border or create_unique_embeddings_per_token: min_distance_to_border = [] for seq_idx, model_outputs in enumerate(all_outputs): seq_len = model_outputs["input_ids"].size(1) current_indices = torch.arange(seq_len).unsqueeze(0) min_distance_to_border.append( torch.minimum(current_indices, seq_len - current_indices) ) result["min_distance_to_border"] = torch.concat(min_distance_to_border, axis=1) elif self.framework == "tf": raise NotImplementedError() else: raise ValueError(f"unknown framework: {self.framework}") if create_unique_embeddings_per_token: offset_mapping = result["offset_mapping"] if not return_offset_mapping: del result["offset_mapping"] special_tokens_mask = result["special_tokens_mask"] if not return_special_tokens_mask: del result["special_tokens_mask"] min_distance_to_border = result["min_distance_to_border"] if not return_min_distance_to_border: del result["min_distance_to_border"] result = self.make_embeddings_unique_per_token( data=result, offset_mapping=offset_mapping, special_tokens_mask=special_tokens_mask, min_distance_to_border=min_distance_to_border, ) result = { k: self.postprocess_tensor(v, return_tensors=return_tensors) for k, v in result.items() } if set(result) == {"last_hidden_state"}: return result["last_hidden_state"] else: return result