bobox's picture
Update CustomPooler.py
da55ac3 verified
import os
import json
import torch
import torch.nn as nn
from typing import Dict, Any
import torch.nn.functional as F
import warnings
class SwiGLUBlock(nn.Module):
"""
SwiGLU activation using two separate linear layers.
Input -> Linear (w1) -> Swish \
* -> Output
Input -> Linear (w3) -> Gate /
"""
def __init__(self, input_dim, hidden_dim, bias=True):
super().__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.bias = bias
# Layer 1: Input -> Hidden (for the main Swish path)
self.in_proj_swish = nn.Linear(self.input_dim, self.hidden_dim, bias=self.bias)
# Layer 3: Input -> Hidden (for the gate path)
self.in_proj_gate = nn.Linear(self.input_dim, self.hidden_dim, bias=self.bias)
def forward(self, x):
# x shape: [..., input_dim]
hidden_states = self.in_proj_swish(x) # Output shape: [..., hidden_dim]
gate = self.in_proj_gate(x) # Output shape: [..., hidden_dim]
# Apply SwiGLU activation: Swish(hidden_states) * gate
activated_hidden = F.silu(hidden_states) * gate # Output shape: [..., hidden_dim]
return activated_hidden
class AdvancedWeightedPooling(nn.Module):
"""
Performs Attention Pooling using the [CLS] token as the query.
Args:
embed_dim (int): The hidden dimension of the embeddings.
num_heads (int): The number of attention heads.
dropout (float, optional): Dropout probability for MHA. Defaults to 0.0.
bias (bool, optional): Whether to use bias in linear layers (MHA internal, MLP). Defaults to True.
use_layernorm (bool, optional): Apply Layer Normalization after pooling (and potential MLP/residual). Defaults to False.
use_MLP (bool, optional): Apply an MLP layer after attention pooling. Defaults to False.
MLP_h_size (int, optional): Hidden size for the MLP. Defaults to embed_dim if use_MLP is True.
use_residual_mean (bool, optional): Add a masked mean-pooled representation to the attention output. Defaults to False.
use_residual_MLP (bool, optional): Add the input of the MLP back to its output (residual connection). Defaults to True.
ignore_cls_as_kv (bool, optional): Exclude the [CLS] token from the key/value pairs in MHA. Defaults to True.
expand_emb_dim_to (int, optional): Expand the embedding dimension before MHA/MLP. Defaults to 0 (no expansion).
compress_output_dim_to (int, optional): Compress the final output dimension after all other steps. Defaults to 0 (no compression).
"""
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True, use_layernorm: bool = False, use_MLP: bool = False, MLP_h_size: int = -1, use_residual_MLP: str = 'add', ignore_cls_as_kv: bool = True, expand_emb_dim_to: int = 0, compress_output_dim_to: int = 0):
super(AdvancedWeightedPooling, self).__init__()
# --- Store initial embed_dim consistently ---
self.embed_dim = embed_dim # <-- Use self.embed_dim consistently
# --- Store other config parameters ---
self.num_heads = num_heads
self.dropout = dropout
self.bias = bias
self.use_layernorm = use_layernorm
self.use_MLP = use_MLP
self.MLP_h_size = MLP_h_size
self.use_residual_MLP = use_residual_MLP
self.ignore_cls_as_kv = ignore_cls_as_kv
self.expand_emb_dim_to = expand_emb_dim_to
self.compress_output_dim_to = compress_output_dim_to
self.current_embed_dim = self.embed_dim if self.expand_emb_dim_to == 0 else self.expand_emb_dim_to
if self.MLP_h_size == -1:
self.MLP_h_size = self.current_embed_dim
if self.compress_output_dim_to > 0 and (self.expand_emb_dim_to == 0 and self.compress_output_dim_to == self.embed_dim and not self.use_residual_MLP != 'concat'):
warnings.warn(f"input dim ({self.embed_dim}) == compress_output_dim_to ({self.compress_output_dim_to}) without any valid expand_emb_dim_to. Disabling compression.")
self.compress_output_dim_to = 0
if self.expand_emb_dim_to > 0 and self.expand_emb_dim_to != self.embed_dim:
print(f"INFO: Expanding embedding dimension from {self.embed_dim} to {self.expand_emb_dim_to}")
self.tokens_up_proj = nn.Linear(self.embed_dim, self.expand_emb_dim_to, bias=self.bias)
self.cls_up_proj = nn.Linear(self.embed_dim, self.expand_emb_dim_to, bias=self.bias)
self.current_embed_dim = self.expand_emb_dim_to # Update the dimension for subsequent layers
elif self.expand_emb_dim_to > 0 and self.expand_emb_dim_to == self.embed_dim:
warnings.warn(f"`expand_emb_dim_to` ({self.expand_emb_dim_to}) is the same as `embed_dim` ({self.embed_dim}). No expansion layer created.")
self.expand_emb_dim_to = 0 # Treat as no expansion needed
# --- Sub-modules ---
# MHA operates on the potentially expanded dimension
self.mha = nn.MultiheadAttention(
embed_dim=self.current_embed_dim,
num_heads=self.num_heads,
dropout=self.dropout,
bias=self.bias,
add_bias_kv = False, # Keep False if CLS is query only
batch_first=True
)
if self.use_MLP:
self.MLP = nn.Sequential(
SwiGLUBlock(self.current_embed_dim, self.MLP_h_size, bias=self.bias),
nn.Dropout(self.dropout),
nn.Linear(self.MLP_h_size, self.current_embed_dim, bias=self.bias)
)
if self.compress_output_dim_to > 0:
self.compression_layer_input_dims = self.current_embed_dim if self.use_residual_MLP != 'concat' else self.current_embed_dim*2
self.output_down_proj = nn.Linear(self.current_embed_dim, self.compress_output_dim_to, bias=self.bias)
if self.use_layernorm:
if self.compress_output_dim_to != 0:
self.LayerNorm_input_dims = self.compress_output_dim_to
elif self.use_residual_MLP != 'concat':
self.LayerNorm_input_dims = self.current_embed_dim
else:
self.LayerNorm_input_dims = self.current_embed_dim*2
self.layernorm = nn.LayerNorm(self.LayerNorm_input_dims, eps=1e-05, elementwise_affine=True)
# --- Configuration for Saving/Loading ---
# Keep 'embed_dim' here as it refers to the initial config parameter
self.config_keys = ['embed_dim', 'num_heads', 'dropout', 'bias', 'use_layernorm', 'use_MLP', 'MLP_h_size', 'use_residual_MLP', 'ignore_cls_as_kv', 'expand_emb_dim_to', 'compress_output_dim_to']
def _masked_mean_pooling(self, token_embeddings, attention_mask):
"""Helper function for masked mean pooling."""
# Ensure mask is expanded correctly for broadcasting
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
# Clamp sum_mask after summing to avoid division by zero
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return sum_embeddings / sum_mask
def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
token_embeddings_all = features['token_embeddings'] # Shape: (batch, seq_len, initial_dim)
attention_mask = features.get('attention_mask')
if attention_mask is None:
attention_mask = torch.ones(token_embeddings_all.shape[:2], device=token_embeddings_all.device, dtype=torch.long)
else:
attention_mask = attention_mask.long()
# --- Prepare MHA Inputs ---
cls_embedding = token_embeddings_all[:, 0:1, :] # Shape: (batch, 1, initial_dim)
# Decide which embeddings to use as K/V based on ignore_cls_as_kv
if self.ignore_cls_as_kv:
token_embeddings_kv = token_embeddings_all[:, 1:, :] # Exclude CLS
# Adjust attention mask for K/V if CLS is ignored
sequence_attention_mask = attention_mask[:, 1:]
else:
token_embeddings_kv = token_embeddings_all # Include CLS
sequence_attention_mask = attention_mask
# --- Optional Expansion ---
if self.expand_emb_dim_to > 0:
# Apply expansion to both CLS (query) and the K/V tokens
cls_embedding = self.cls_up_proj(cls_embedding) # Shape: (batch, 1, current_embed_dim)
token_embeddings_kv = self.tokens_up_proj(token_embeddings_kv) # Shape: (batch, kv_seq_len, current_embed_dim)
# Check for empty sequence after slicing (if ignore_cls_as_kv is True)
if self.ignore_cls_as_kv and token_embeddings_kv.shape[1] == 0:
warnings.warn("Input sequence only contains [CLS] token after slicing when ignore_cls_as_kv=True. "
"Attention pooling cannot be performed. Returning CLS embedding (potentially processed).")
# Process the CLS embedding as if it were the pooled output
pooled_embedding = cls_embedding.squeeze(1) # Shape: (batch, current_embed_dim)
# Apply subsequent layers if configured
if self.use_MLP:
mlp_input = pooled_embedding
post_MLP_embedding = self.MLP(mlp_input)
if self.mlp_combination_mode == 'concat':
pooled_embedding = torch.cat([mlp_input, post_MLP_embedding], dim=-1)
elif self.mlp_combination_mode == 'add':
pooled_embedding = mlp_input + post_MLP_embedding
else:
pooled_embedding = post_MLP_embedding
if self.use_layernorm:
pooled_embedding = self.layernorm(pooled_embedding) # Apply LN before potential compression
if self.compress_output_dim_to > 0:
pooled_embedding = self.output_down_proj(pooled_embedding) # Apply final compression
return {'sentence_embedding': pooled_embedding}
# --- Multi-Head Attention ---
query = cls_embedding # Shape: (batch, 1, current_embed_dim)
key = token_embeddings_kv # Shape: (batch, kv_seq_len, current_embed_dim)
value = token_embeddings_kv # Shape: (batch, kv_seq_len, current_embed_dim)
# Create boolean mask: True for padding (0), False for real tokens (1)
# Mask shape should match (batch, kv_seq_len)
key_padding_mask = (sequence_attention_mask == 0)
attn_output, _ = self.mha(
query=query,
key=key,
value=value,
key_padding_mask=key_padding_mask,
need_weights=False
)
# attn_output shape: (batch, query_len=1, current_embed_dim)
pooled_embedding = attn_output.squeeze(1) # Shape: (batch, current_embed_dim)
# --- Optional MLP ---
if self.use_MLP:
mlp_input = pooled_embedding # Input to MLP
post_MLP_embedding = self.MLP(mlp_input)
if self.use_residual_MLP:
pooled_embedding = mlp_input + post_MLP_embedding # residual
else:
pooled_embedding = post_MLP_embedding
# --- Optional Output Compression ---
if self.compress_output_dim_to > 0:
pooled_embedding = self.output_down_proj(pooled_embedding)
# --- Optional LayerNorm ---
if self.use_layernorm:
pooled_embedding = self.layernorm(pooled_embedding)
return {'sentence_embedding': pooled_embedding}
def get_sentence_embedding_dimension(self) -> int:
"""Returns the final output dimension of the pooling layer."""
# Start with the dimension after potential expansion
final_dim = self.current_embed_dim
# Account for MLP concatenation if used
if self.use_MLP and self.use_residual_MLP == 'concat':
final_dim *= 2
# If compression is applied, that's the final dimension
if self.compress_output_dim_to > 0:
final_dim = self.compress_output_dim_to
return final_dim
def get_config_dict(self) -> Dict[str, Any]:
# Now self.embed_dim exists and matches the key in config_keys
return {key: getattr(self, key) for key in self.config_keys}
def save(self, output_path: str, safe_serialization: bool = True) -> None: # Default to safe serialization
os.makedirs(output_path, exist_ok=True)
# Save config using the initial parameters
with open(os.path.join(output_path, 'config.json'), 'w') as fOut:
json.dump(self.get_config_dict(), fOut, indent=2)
model_path_st = os.path.join(output_path, 'model.safetensors')
model_path_bin = os.path.join(output_path, 'pytorch_model.bin')
state_dict = self.state_dict()
if safe_serialization:
try:
from safetensors.torch import save_file
# Need to ensure state_dict keys match what load_state_dict expects
save_file(state_dict, model_path_st)
print(f"Saved state dict to {model_path_st}")
# Remove old bin file if it exists and we successfully saved safetensors
if os.path.exists(model_path_bin):
os.remove(model_path_bin)
except ImportError:
warnings.warn("safetensors not available. Falling back to regular PyTorch serialization (pytorch_model.bin).", UserWarning)
torch.save(state_dict, model_path_bin)
print(f"Saved state dict to {model_path_bin}")
except Exception as e: # Catch potential errors during saving
warnings.warn(f"Error saving safetensors file: {e}. Falling back to pytorch_model.bin", UserWarning)
torch.save(state_dict, model_path_bin)
print(f"Saved state dict to {model_path_bin}")
else:
torch.save(state_dict, model_path_bin)
print(f"Saved state dict to {model_path_bin}")
# Remove old safetensors file if it exists
if os.path.exists(model_path_st):
os.remove(model_path_st)
@staticmethod
def load(input_path: str) -> 'AdvancedWeightedPooling':
# Load config first to initialize the model structure
config_path = os.path.join(input_path, 'config.json')
if not os.path.exists(config_path):
raise OSError(f"config.json not found in {input_path}")
with open(config_path) as fIn:
config = json.load(fIn)
# Instantiate the model using the loaded config
# This ensures all layers (like up/down projections, MLP, LN) are created
# based on the *saved* configuration before loading weights.
model = AdvancedWeightedPooling(**config)
# Determine paths for weights files
safetensors_path = os.path.join(input_path, 'model.safetensors')
pytorch_path = os.path.join(input_path, 'pytorch_model.bin')
loaded_state_dict = None
load_success = False
# Prioritize safetensors
if os.path.exists(safetensors_path):
try:
from safetensors.torch import load_file
loaded_state_dict = load_file(safetensors_path, device='cpu')
print(f"Loaded state dict from {safetensors_path}")
load_success = True
except ImportError:
warnings.warn("safetensors not available or error loading. Falling back to pytorch_model.bin if exists.", UserWarning)
except Exception as e:
warnings.warn(f"Error loading safetensors file: {e}. Falling back to pytorch_model.bin if exists.", UserWarning)
# Fallback to pytorch_model.bin if safetensors failed or doesn't exist
if not load_success and os.path.exists(pytorch_path):
try:
loaded_state_dict = torch.load(pytorch_path, map_location=torch.device('cpu'))
print(f"Loaded state dict from {pytorch_path}")
load_success = True
except Exception as e:
warnings.warn(f"Error loading pytorch_model.bin: {e}", UserWarning)
if loaded_state_dict:
# Use strict=True for debugging missing/unexpected keys during development
# Can be set to strict=False for more flexibility if needed, but True is safer
load_result = model.load_state_dict(loaded_state_dict, strict=True)
print(f"Model state loaded. Result: {load_result}")
elif not load_success: # Only warn if neither file could be loaded
warnings.warn(f"Warning: No model weights file found or loaded successfully at {safetensors_path} or {pytorch_path}. Model initialized randomly.", UserWarning)
return model