|
import os |
|
import json |
|
import torch |
|
import torch.nn as nn |
|
from typing import Dict, Any |
|
import torch.nn.functional as F |
|
import warnings |
|
|
|
class SwiGLUBlock(nn.Module): |
|
""" |
|
SwiGLU activation using two separate linear layers. |
|
Input -> Linear (w1) -> Swish \ |
|
* -> Output |
|
Input -> Linear (w3) -> Gate / |
|
""" |
|
def __init__(self, input_dim, hidden_dim, bias=True): |
|
super().__init__() |
|
self.input_dim = input_dim |
|
self.hidden_dim = hidden_dim |
|
self.bias = bias |
|
|
|
|
|
self.in_proj_swish = nn.Linear(self.input_dim, self.hidden_dim, bias=self.bias) |
|
|
|
self.in_proj_gate = nn.Linear(self.input_dim, self.hidden_dim, bias=self.bias) |
|
|
|
def forward(self, x): |
|
|
|
hidden_states = self.in_proj_swish(x) |
|
gate = self.in_proj_gate(x) |
|
|
|
|
|
activated_hidden = F.silu(hidden_states) * gate |
|
return activated_hidden |
|
|
|
|
|
class AdvancedWeightedPooling(nn.Module): |
|
""" |
|
Performs Attention Pooling using the [CLS] token as the query. |
|
|
|
Args: |
|
embed_dim (int): The hidden dimension of the embeddings. |
|
num_heads (int): The number of attention heads. |
|
dropout (float, optional): Dropout probability for MHA. Defaults to 0.0. |
|
bias (bool, optional): Whether to use bias in linear layers (MHA internal, MLP). Defaults to True. |
|
use_layernorm (bool, optional): Apply Layer Normalization after pooling (and potential MLP/residual). Defaults to False. |
|
use_MLP (bool, optional): Apply an MLP layer after attention pooling. Defaults to False. |
|
MLP_h_size (int, optional): Hidden size for the MLP. Defaults to embed_dim if use_MLP is True. |
|
use_residual_mean (bool, optional): Add a masked mean-pooled representation to the attention output. Defaults to False. |
|
use_residual_MLP (bool, optional): Add the input of the MLP back to its output (residual connection). Defaults to True. |
|
ignore_cls_as_kv (bool, optional): Exclude the [CLS] token from the key/value pairs in MHA. Defaults to True. |
|
expand_emb_dim_to (int, optional): Expand the embedding dimension before MHA/MLP. Defaults to 0 (no expansion). |
|
compress_output_dim_to (int, optional): Compress the final output dimension after all other steps. Defaults to 0 (no compression). |
|
""" |
|
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True, use_layernorm: bool = False, use_MLP: bool = False, MLP_h_size: int = -1, use_residual_MLP: str = 'add', ignore_cls_as_kv: bool = True, expand_emb_dim_to: int = 0, compress_output_dim_to: int = 0): |
|
super(AdvancedWeightedPooling, self).__init__() |
|
|
|
self.embed_dim = embed_dim |
|
|
|
self.num_heads = num_heads |
|
self.dropout = dropout |
|
self.bias = bias |
|
self.use_layernorm = use_layernorm |
|
self.use_MLP = use_MLP |
|
self.MLP_h_size = MLP_h_size |
|
self.use_residual_MLP = use_residual_MLP |
|
self.ignore_cls_as_kv = ignore_cls_as_kv |
|
self.expand_emb_dim_to = expand_emb_dim_to |
|
self.compress_output_dim_to = compress_output_dim_to |
|
|
|
self.current_embed_dim = self.embed_dim if self.expand_emb_dim_to == 0 else self.expand_emb_dim_to |
|
|
|
|
|
if self.MLP_h_size == -1: |
|
self.MLP_h_size = self.current_embed_dim |
|
|
|
if self.compress_output_dim_to > 0 and (self.expand_emb_dim_to == 0 and self.compress_output_dim_to == self.embed_dim and not self.use_residual_MLP != 'concat'): |
|
warnings.warn(f"input dim ({self.embed_dim}) == compress_output_dim_to ({self.compress_output_dim_to}) without any valid expand_emb_dim_to. Disabling compression.") |
|
self.compress_output_dim_to = 0 |
|
|
|
if self.expand_emb_dim_to > 0 and self.expand_emb_dim_to != self.embed_dim: |
|
print(f"INFO: Expanding embedding dimension from {self.embed_dim} to {self.expand_emb_dim_to}") |
|
self.tokens_up_proj = nn.Linear(self.embed_dim, self.expand_emb_dim_to, bias=self.bias) |
|
self.cls_up_proj = nn.Linear(self.embed_dim, self.expand_emb_dim_to, bias=self.bias) |
|
self.current_embed_dim = self.expand_emb_dim_to |
|
elif self.expand_emb_dim_to > 0 and self.expand_emb_dim_to == self.embed_dim: |
|
warnings.warn(f"`expand_emb_dim_to` ({self.expand_emb_dim_to}) is the same as `embed_dim` ({self.embed_dim}). No expansion layer created.") |
|
self.expand_emb_dim_to = 0 |
|
|
|
|
|
|
|
self.mha = nn.MultiheadAttention( |
|
embed_dim=self.current_embed_dim, |
|
num_heads=self.num_heads, |
|
dropout=self.dropout, |
|
bias=self.bias, |
|
add_bias_kv = False, |
|
batch_first=True |
|
) |
|
|
|
if self.use_MLP: |
|
self.MLP = nn.Sequential( |
|
SwiGLUBlock(self.current_embed_dim, self.MLP_h_size, bias=self.bias), |
|
nn.Dropout(self.dropout), |
|
nn.Linear(self.MLP_h_size, self.current_embed_dim, bias=self.bias) |
|
) |
|
|
|
|
|
if self.compress_output_dim_to > 0: |
|
self.compression_layer_input_dims = self.current_embed_dim if self.use_residual_MLP != 'concat' else self.current_embed_dim*2 |
|
self.output_down_proj = nn.Linear(self.current_embed_dim, self.compress_output_dim_to, bias=self.bias) |
|
|
|
|
|
if self.use_layernorm: |
|
if self.compress_output_dim_to != 0: |
|
self.LayerNorm_input_dims = self.compress_output_dim_to |
|
elif self.use_residual_MLP != 'concat': |
|
self.LayerNorm_input_dims = self.current_embed_dim |
|
else: |
|
self.LayerNorm_input_dims = self.current_embed_dim*2 |
|
|
|
self.layernorm = nn.LayerNorm(self.LayerNorm_input_dims, eps=1e-05, elementwise_affine=True) |
|
|
|
|
|
|
|
|
|
self.config_keys = ['embed_dim', 'num_heads', 'dropout', 'bias', 'use_layernorm', 'use_MLP', 'MLP_h_size', 'use_residual_MLP', 'ignore_cls_as_kv', 'expand_emb_dim_to', 'compress_output_dim_to'] |
|
|
|
def _masked_mean_pooling(self, token_embeddings, attention_mask): |
|
"""Helper function for masked mean pooling.""" |
|
|
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) |
|
|
|
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
|
return sum_embeddings / sum_mask |
|
|
|
def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
|
token_embeddings_all = features['token_embeddings'] |
|
attention_mask = features.get('attention_mask') |
|
if attention_mask is None: |
|
attention_mask = torch.ones(token_embeddings_all.shape[:2], device=token_embeddings_all.device, dtype=torch.long) |
|
else: |
|
attention_mask = attention_mask.long() |
|
|
|
|
|
cls_embedding = token_embeddings_all[:, 0:1, :] |
|
|
|
|
|
if self.ignore_cls_as_kv: |
|
token_embeddings_kv = token_embeddings_all[:, 1:, :] |
|
|
|
sequence_attention_mask = attention_mask[:, 1:] |
|
else: |
|
token_embeddings_kv = token_embeddings_all |
|
sequence_attention_mask = attention_mask |
|
|
|
|
|
if self.expand_emb_dim_to > 0: |
|
|
|
cls_embedding = self.cls_up_proj(cls_embedding) |
|
token_embeddings_kv = self.tokens_up_proj(token_embeddings_kv) |
|
|
|
|
|
if self.ignore_cls_as_kv and token_embeddings_kv.shape[1] == 0: |
|
warnings.warn("Input sequence only contains [CLS] token after slicing when ignore_cls_as_kv=True. " |
|
"Attention pooling cannot be performed. Returning CLS embedding (potentially processed).") |
|
|
|
pooled_embedding = cls_embedding.squeeze(1) |
|
|
|
|
|
if self.use_MLP: |
|
mlp_input = pooled_embedding |
|
post_MLP_embedding = self.MLP(mlp_input) |
|
if self.mlp_combination_mode == 'concat': |
|
pooled_embedding = torch.cat([mlp_input, post_MLP_embedding], dim=-1) |
|
elif self.mlp_combination_mode == 'add': |
|
pooled_embedding = mlp_input + post_MLP_embedding |
|
else: |
|
pooled_embedding = post_MLP_embedding |
|
|
|
if self.use_layernorm: |
|
pooled_embedding = self.layernorm(pooled_embedding) |
|
|
|
if self.compress_output_dim_to > 0: |
|
pooled_embedding = self.output_down_proj(pooled_embedding) |
|
|
|
return {'sentence_embedding': pooled_embedding} |
|
|
|
|
|
|
|
query = cls_embedding |
|
key = token_embeddings_kv |
|
value = token_embeddings_kv |
|
|
|
|
|
|
|
key_padding_mask = (sequence_attention_mask == 0) |
|
|
|
attn_output, _ = self.mha( |
|
query=query, |
|
key=key, |
|
value=value, |
|
key_padding_mask=key_padding_mask, |
|
need_weights=False |
|
) |
|
|
|
pooled_embedding = attn_output.squeeze(1) |
|
|
|
|
|
|
|
if self.use_MLP: |
|
mlp_input = pooled_embedding |
|
post_MLP_embedding = self.MLP(mlp_input) |
|
if self.use_residual_MLP: |
|
pooled_embedding = mlp_input + post_MLP_embedding |
|
else: |
|
pooled_embedding = post_MLP_embedding |
|
|
|
|
|
if self.compress_output_dim_to > 0: |
|
pooled_embedding = self.output_down_proj(pooled_embedding) |
|
|
|
|
|
if self.use_layernorm: |
|
pooled_embedding = self.layernorm(pooled_embedding) |
|
|
|
|
|
return {'sentence_embedding': pooled_embedding} |
|
|
|
def get_sentence_embedding_dimension(self) -> int: |
|
"""Returns the final output dimension of the pooling layer.""" |
|
|
|
final_dim = self.current_embed_dim |
|
|
|
|
|
if self.use_MLP and self.use_residual_MLP == 'concat': |
|
final_dim *= 2 |
|
|
|
|
|
if self.compress_output_dim_to > 0: |
|
final_dim = self.compress_output_dim_to |
|
|
|
return final_dim |
|
|
|
def get_config_dict(self) -> Dict[str, Any]: |
|
|
|
return {key: getattr(self, key) for key in self.config_keys} |
|
|
|
def save(self, output_path: str, safe_serialization: bool = True) -> None: |
|
os.makedirs(output_path, exist_ok=True) |
|
|
|
with open(os.path.join(output_path, 'config.json'), 'w') as fOut: |
|
json.dump(self.get_config_dict(), fOut, indent=2) |
|
|
|
model_path_st = os.path.join(output_path, 'model.safetensors') |
|
model_path_bin = os.path.join(output_path, 'pytorch_model.bin') |
|
|
|
state_dict = self.state_dict() |
|
if safe_serialization: |
|
try: |
|
from safetensors.torch import save_file |
|
|
|
save_file(state_dict, model_path_st) |
|
print(f"Saved state dict to {model_path_st}") |
|
|
|
if os.path.exists(model_path_bin): |
|
os.remove(model_path_bin) |
|
except ImportError: |
|
warnings.warn("safetensors not available. Falling back to regular PyTorch serialization (pytorch_model.bin).", UserWarning) |
|
torch.save(state_dict, model_path_bin) |
|
print(f"Saved state dict to {model_path_bin}") |
|
except Exception as e: |
|
warnings.warn(f"Error saving safetensors file: {e}. Falling back to pytorch_model.bin", UserWarning) |
|
torch.save(state_dict, model_path_bin) |
|
print(f"Saved state dict to {model_path_bin}") |
|
else: |
|
torch.save(state_dict, model_path_bin) |
|
print(f"Saved state dict to {model_path_bin}") |
|
|
|
if os.path.exists(model_path_st): |
|
os.remove(model_path_st) |
|
|
|
|
|
@staticmethod |
|
def load(input_path: str) -> 'AdvancedWeightedPooling': |
|
|
|
config_path = os.path.join(input_path, 'config.json') |
|
if not os.path.exists(config_path): |
|
raise OSError(f"config.json not found in {input_path}") |
|
with open(config_path) as fIn: |
|
config = json.load(fIn) |
|
|
|
|
|
|
|
|
|
model = AdvancedWeightedPooling(**config) |
|
|
|
|
|
safetensors_path = os.path.join(input_path, 'model.safetensors') |
|
pytorch_path = os.path.join(input_path, 'pytorch_model.bin') |
|
|
|
loaded_state_dict = None |
|
load_success = False |
|
|
|
if os.path.exists(safetensors_path): |
|
try: |
|
from safetensors.torch import load_file |
|
loaded_state_dict = load_file(safetensors_path, device='cpu') |
|
print(f"Loaded state dict from {safetensors_path}") |
|
load_success = True |
|
except ImportError: |
|
warnings.warn("safetensors not available or error loading. Falling back to pytorch_model.bin if exists.", UserWarning) |
|
except Exception as e: |
|
warnings.warn(f"Error loading safetensors file: {e}. Falling back to pytorch_model.bin if exists.", UserWarning) |
|
|
|
|
|
if not load_success and os.path.exists(pytorch_path): |
|
try: |
|
loaded_state_dict = torch.load(pytorch_path, map_location=torch.device('cpu')) |
|
print(f"Loaded state dict from {pytorch_path}") |
|
load_success = True |
|
except Exception as e: |
|
warnings.warn(f"Error loading pytorch_model.bin: {e}", UserWarning) |
|
|
|
|
|
if loaded_state_dict: |
|
|
|
|
|
load_result = model.load_state_dict(loaded_state_dict, strict=True) |
|
print(f"Model state loaded. Result: {load_result}") |
|
elif not load_success: |
|
warnings.warn(f"Warning: No model weights file found or loaded successfully at {safetensors_path} or {pytorch_path}. Model initialized randomly.", UserWarning) |
|
|
|
return model |