bobox commited on
Commit
da55ac3
·
verified ·
1 Parent(s): 02a2718

Update CustomPooler.py

Browse files
Files changed (1) hide show
  1. CustomPooler.py +334 -334
CustomPooler.py CHANGED
@@ -1,335 +1,335 @@
1
- import os
2
- import json
3
- import torch
4
- import torch.nn as nn
5
- from typing import Dict, Any
6
- import torch.nn.functional as F
7
- import warnings
8
-
9
- class SwiGLUBlock(nn.Module):
10
- """
11
- SwiGLU activation using two separate linear layers.
12
- Input -> Linear (w1) -> Swish \
13
- * -> Output
14
- Input -> Linear (w3) -> Gate /
15
- """
16
- def __init__(self, input_dim, hidden_dim, bias=True):
17
- super().__init__()
18
- self.input_dim = input_dim
19
- self.hidden_dim = hidden_dim
20
- self.bias = bias
21
-
22
- # Layer 1: Input -> Hidden (for the main Swish path)
23
- self.in_proj_swish = nn.Linear(self.input_dim, self.hidden_dim, bias=self.bias)
24
- # Layer 3: Input -> Hidden (for the gate path)
25
- self.in_proj_gate = nn.Linear(self.input_dim, self.hidden_dim, bias=self.bias)
26
-
27
- def forward(self, x):
28
- # x shape: [..., input_dim]
29
- hidden_states = self.in_proj_swish(x) # Output shape: [..., hidden_dim]
30
- gate = self.in_proj_gate(x) # Output shape: [..., hidden_dim]
31
-
32
- # Apply SwiGLU activation: Swish(hidden_states) * gate
33
- activated_hidden = F.silu(hidden_states) * gate # Output shape: [..., hidden_dim]
34
- return activated_hidden
35
-
36
-
37
- class AdvancedWeightedPooling(nn.Module):
38
- """
39
- Performs Attention Pooling using the [CLS] token as the query.
40
-
41
- Args:
42
- embed_dim (int): The hidden dimension of the embeddings.
43
- num_heads (int): The number of attention heads.
44
- dropout (float, optional): Dropout probability for MHA. Defaults to 0.0.
45
- bias (bool, optional): Whether to use bias in linear layers (MHA internal, MLP). Defaults to True.
46
- use_layernorm (bool, optional): Apply Layer Normalization after pooling (and potential MLP/residual). Defaults to False.
47
- use_MLP (bool, optional): Apply an MLP layer after attention pooling. Defaults to False.
48
- MLP_h_size (int, optional): Hidden size for the MLP. Defaults to embed_dim if use_MLP is True.
49
- use_residual_mean (bool, optional): Add a masked mean-pooled representation to the attention output. Defaults to False.
50
- use_residual_MLP (bool, optional): Add the input of the MLP back to its output (residual connection). Defaults to True.
51
- ignore_cls_as_kv (bool, optional): Exclude the [CLS] token from the key/value pairs in MHA. Defaults to True.
52
- expand_emb_dim_to (int, optional): Expand the embedding dimension before MHA/MLP. Defaults to 0 (no expansion).
53
- compress_output_dim_to (int, optional): Compress the final output dimension after all other steps. Defaults to 0 (no compression).
54
- """
55
- def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True, use_layernorm: bool = False, use_MLP: bool = False, MLP_h_size: int = -1, use_residual_MLP: str = 'add', ignore_cls_as_kv: bool = True, expand_emb_dim_to: int = 0, compress_output_dim_to: int = 0):
56
- super(AdvancedWeightedPooling, self).__init__()
57
- # --- Store initial embed_dim consistently ---
58
- self.embed_dim = embed_dim # <-- Use self.embed_dim consistently
59
- # --- Store other config parameters ---
60
- self.num_heads = num_heads
61
- self.dropout = dropout
62
- self.bias = bias
63
- self.use_layernorm = use_layernorm
64
- self.use_MLP = use_MLP
65
- self.MLP_h_size = MLP_h_size
66
- self.use_residual_MLP = use_residual_MLP
67
- self.ignore_cls_as_kv = ignore_cls_as_kv
68
- self.expand_emb_dim_to = expand_emb_dim_to
69
- self.compress_output_dim_to = compress_output_dim_to
70
-
71
- self.current_embed_dim = self.embed_dim if self.expand_emb_dim_to == 0 else self.expand_emb_dim_to
72
-
73
-
74
- if self.MLP_h_size == -1:
75
- self.MLP_h_size = self.current_embed_dim
76
-
77
- if self.compress_output_dim_to > 0 and (self.expand_emb_dim_to == 0 and self.compress_output_dim_to == self.embed_dim and not self.use_residual_MLP != 'concat'):
78
- warnings.warn(f"input dim ({self.embed_dim}) == compress_output_dim_to ({self.compress_output_dim_to}) without any valid expand_emb_dim_to. Disabling compression.")
79
- self.compress_output_dim_to = 0
80
-
81
- if self.expand_emb_dim_to > 0 and self.expand_emb_dim_to != self.embed_dim:
82
- print(f"INFO: Expanding embedding dimension from {self.embed_dim} to {self.expand_emb_dim_to}")
83
- self.tokens_up_proj = nn.Linear(self.embed_dim, self.expand_emb_dim_to, bias=self.bias)
84
- self.cls_up_proj = nn.Linear(self.embed_dim, self.expand_emb_dim_to, bias=self.bias)
85
- self.current_embed_dim = self.expand_emb_dim_to # Update the dimension for subsequent layers
86
- elif self.expand_emb_dim_to > 0 and self.expand_emb_dim_to == self.embed_dim:
87
- warnings.warn(f"`expand_emb_dim_to` ({self.expand_emb_dim_to}) is the same as `embed_dim` ({self.embed_dim}). No expansion layer created.")
88
- self.expand_emb_dim_to = 0 # Treat as no expansion needed
89
-
90
- # --- Sub-modules ---
91
- # MHA operates on the potentially expanded dimension
92
- self.mha = nn.MultiheadAttention(
93
- embed_dim=self.current_embed_dim,
94
- num_heads=self.num_heads,
95
- dropout=self.dropout,
96
- bias=self.bias,
97
- add_bias_kv = False, # Keep False if CLS is query only
98
- batch_first=True
99
- )
100
-
101
- if self.use_MLP:
102
- self.MLP = nn.Sequential(
103
- SwiGLUBlock(self.current_embed_dim, self.MLP_h_size, bias=self.bias),
104
- nn.Dropout(self.dropout),
105
- nn.Linear(self.MLP_h_size, self.current_embed_dim, bias=self.bias)
106
- )
107
-
108
-
109
- if self.compress_output_dim_to > 0:
110
- self.compression_layer_input_dims = self.current_embed_dim if self.use_residual_MLP != 'concat' else self.current_embed_dim*2
111
- self.output_down_proj = nn.Linear(self.current_embed_dim, self.compress_output_dim_to, bias=self.bias)
112
-
113
-
114
- if self.use_layernorm:
115
- if self.compress_output_dim_to != 0:
116
- self.LayerNorm_input_dims = self.compress_output_dim_to
117
- elif self.use_residual_MLP != 'concat':
118
- self.LayerNorm_input_dims = self.current_embed_dim
119
- else:
120
- self.LayerNorm_input_dims = self.current_embed_dim*2
121
-
122
- self.layernorm = nn.LayerNorm(self.LayerNorm_input_dims, eps=1e-05, elementwise_affine=True)
123
-
124
-
125
- # --- Configuration for Saving/Loading ---
126
- # Keep 'embed_dim' here as it refers to the initial config parameter
127
- self.config_keys = ['embed_dim', 'num_heads', 'dropout', 'bias', 'use_layernorm', 'use_MLP', 'MLP_h_size', 'use_residual_MLP', 'ignore_cls_as_kv', 'expand_emb_dim_to', 'compress_output_dim_to']
128
-
129
- def _masked_mean_pooling(self, token_embeddings, attention_mask):
130
- """Helper function for masked mean pooling."""
131
- # Ensure mask is expanded correctly for broadcasting
132
- input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
133
- sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
134
- # Clamp sum_mask after summing to avoid division by zero
135
- sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
136
- return sum_embeddings / sum_mask
137
-
138
- def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
139
- token_embeddings_all = features['token_embeddings'] # Shape: (batch, seq_len, initial_dim)
140
- attention_mask = features.get('attention_mask')
141
- if attention_mask is None:
142
- attention_mask = torch.ones(token_embeddings_all.shape[:2], device=token_embeddings_all.device, dtype=torch.long)
143
- else:
144
- attention_mask = attention_mask.long()
145
-
146
- # --- Prepare MHA Inputs ---
147
- cls_embedding = token_embeddings_all[:, 0:1, :] # Shape: (batch, 1, initial_dim)
148
-
149
- # Decide which embeddings to use as K/V based on ignore_cls_as_kv
150
- if self.ignore_cls_as_kv:
151
- token_embeddings_kv = token_embeddings_all[:, 1:, :] # Exclude CLS
152
- # Adjust attention mask for K/V if CLS is ignored
153
- sequence_attention_mask = attention_mask[:, 1:]
154
- else:
155
- token_embeddings_kv = token_embeddings_all # Include CLS
156
- sequence_attention_mask = attention_mask
157
-
158
- # --- Optional Expansion ---
159
- if self.expand_emb_dim_to > 0:
160
- # Apply expansion to both CLS (query) and the K/V tokens
161
- cls_embedding = self.cls_up_proj(cls_embedding) # Shape: (batch, 1, current_embed_dim)
162
- token_embeddings_kv = self.tokens_up_proj(token_embeddings_kv) # Shape: (batch, kv_seq_len, current_embed_dim)
163
-
164
- # Check for empty sequence after slicing (if ignore_cls_as_kv is True)
165
- if self.ignore_cls_as_kv and token_embeddings_kv.shape[1] == 0:
166
- warnings.warn("Input sequence only contains [CLS] token after slicing when ignore_cls_as_kv=True. "
167
- "Attention pooling cannot be performed. Returning CLS embedding (potentially processed).")
168
- # Process the CLS embedding as if it were the pooled output
169
- pooled_embedding = cls_embedding.squeeze(1) # Shape: (batch, current_embed_dim)
170
-
171
- # Apply subsequent layers if configured
172
- if self.use_MLP:
173
- mlp_input = pooled_embedding
174
- post_MLP_embedding = self.MLP(mlp_input)
175
- if self.mlp_combination_mode == 'concat':
176
- pooled_embedding = torch.cat([mlp_input, post_MLP_embedding], dim=-1)
177
- elif self.mlp_combination_mode == 'add':
178
- pooled_embedding = mlp_input + post_MLP_embedding
179
- else:
180
- pooled_embedding = post_MLP_embedding
181
-
182
- if self.use_layernorm:
183
- pooled_embedding = self.layernorm(pooled_embedding) # Apply LN before potential compression
184
-
185
- if self.compress_output_dim_to > 0:
186
- pooled_embedding = self.output_down_proj(pooled_embedding) # Apply final compression
187
-
188
- return {'sentence_embedding': pooled_embedding}
189
-
190
-
191
- # --- Multi-Head Attention ---
192
- query = cls_embedding # Shape: (batch, 1, current_embed_dim)
193
- key = token_embeddings_kv # Shape: (batch, kv_seq_len, current_embed_dim)
194
- value = token_embeddings_kv # Shape: (batch, kv_seq_len, current_embed_dim)
195
-
196
- # Create boolean mask: True for padding (0), False for real tokens (1)
197
- # Mask shape should match (batch, kv_seq_len)
198
- key_padding_mask = (sequence_attention_mask == 0)
199
-
200
- attn_output, _ = self.mha(
201
- query=query,
202
- key=key,
203
- value=value,
204
- key_padding_mask=key_padding_mask,
205
- need_weights=False
206
- )
207
- # attn_output shape: (batch, query_len=1, current_embed_dim)
208
- pooled_embedding = attn_output.squeeze(1) # Shape: (batch, current_embed_dim)
209
-
210
-
211
- # --- Optional MLP ---
212
- if self.use_MLP:
213
- mlp_input = pooled_embedding # Input to MLP
214
- post_MLP_embedding = self.MLP(mlp_input)
215
- if self.use_residual_MLP:
216
- pooled_embedding = mlp_input + post_MLP_embedding # residual
217
- else:
218
- pooled_embedding = post_MLP_embedding
219
-
220
- # --- Optional Output Compression ---
221
- if self.compress_output_dim_to > 0:
222
- pooled_embedding = self.output_down_proj(pooled_embedding)
223
-
224
- # --- Optional LayerNorm ---
225
- if self.use_layernorm:
226
- pooled_embedding = self.layernorm(pooled_embedding)
227
-
228
-
229
- return {'sentence_embedding': pooled_embedding}
230
-
231
- def get_sentence_embedding_dimension(self) -> int:
232
- """Returns the final output dimension of the pooling layer."""
233
- # Start with the dimension after potential expansion
234
- final_dim = self.current_embed_dim
235
-
236
- # Account for MLP concatenation if used
237
- if self.use_MLP and self.use_residual_MLP == 'concat':
238
- final_dim *= 2
239
-
240
- # If compression is applied, that's the final dimension
241
- if self.compress_output_dim_to > 0:
242
- final_dim = self.compress_output_dim_to
243
-
244
- return final_dim
245
-
246
- def get_config_dict(self) -> Dict[str, Any]:
247
- # Now self.embed_dim exists and matches the key in config_keys
248
- return {key: getattr(self, key) for key in self.config_keys}
249
-
250
- def save(self, output_path: str, safe_serialization: bool = True) -> None: # Default to safe serialization
251
- os.makedirs(output_path, exist_ok=True)
252
- # Save config using the initial parameters
253
- with open(os.path.join(output_path, 'config.json'), 'w') as fOut:
254
- json.dump(self.get_config_dict(), fOut, indent=2)
255
-
256
- model_path_st = os.path.join(output_path, 'model.safetensors')
257
- model_path_bin = os.path.join(output_path, 'pytorch_model.bin')
258
-
259
- state_dict = self.state_dict()
260
- if safe_serialization:
261
- try:
262
- from safetensors.torch import save_file
263
- # Need to ensure state_dict keys match what load_state_dict expects
264
- save_file(state_dict, model_path_st)
265
- print(f"Saved state dict to {model_path_st}")
266
- # Remove old bin file if it exists and we successfully saved safetensors
267
- if os.path.exists(model_path_bin):
268
- os.remove(model_path_bin)
269
- except ImportError:
270
- warnings.warn("safetensors not available. Falling back to regular PyTorch serialization (pytorch_model.bin).", UserWarning)
271
- torch.save(state_dict, model_path_bin)
272
- print(f"Saved state dict to {model_path_bin}")
273
- except Exception as e: # Catch potential errors during saving
274
- warnings.warn(f"Error saving safetensors file: {e}. Falling back to pytorch_model.bin", UserWarning)
275
- torch.save(state_dict, model_path_bin)
276
- print(f"Saved state dict to {model_path_bin}")
277
- else:
278
- torch.save(state_dict, model_path_bin)
279
- print(f"Saved state dict to {model_path_bin}")
280
- # Remove old safetensors file if it exists
281
- if os.path.exists(model_path_st):
282
- os.remove(model_path_st)
283
-
284
-
285
- @staticmethod
286
- def load(input_path: str) -> 'AdvancedWeightedPooling':
287
- # Load config first to initialize the model structure
288
- config_path = os.path.join(input_path, 'config.json')
289
- if not os.path.exists(config_path):
290
- raise OSError(f"config.json not found in {input_path}")
291
- with open(config_path) as fIn:
292
- config = json.load(fIn)
293
-
294
- # Instantiate the model using the loaded config
295
- # This ensures all layers (like up/down projections, MLP, LN) are created
296
- # based on the *saved* configuration before loading weights.
297
- model = AdvancedWeightedPooling(**config)
298
-
299
- # Determine paths for weights files
300
- safetensors_path = os.path.join(input_path, 'model.safetensors')
301
- pytorch_path = os.path.join(input_path, 'pytorch_model.bin')
302
-
303
- loaded_state_dict = None
304
- load_success = False
305
- # Prioritize safetensors
306
- if os.path.exists(safetensors_path):
307
- try:
308
- from safetensors.torch import load_file
309
- loaded_state_dict = load_file(safetensors_path, device='cpu')
310
- print(f"Loaded state dict from {safetensors_path}")
311
- load_success = True
312
- except ImportError:
313
- warnings.warn("safetensors not available or error loading. Falling back to pytorch_model.bin if exists.", UserWarning)
314
- except Exception as e:
315
- warnings.warn(f"Error loading safetensors file: {e}. Falling back to pytorch_model.bin if exists.", UserWarning)
316
-
317
- # Fallback to pytorch_model.bin if safetensors failed or doesn't exist
318
- if not load_success and os.path.exists(pytorch_path):
319
- try:
320
- loaded_state_dict = torch.load(pytorch_path, map_location=torch.device('cpu'))
321
- print(f"Loaded state dict from {pytorch_path}")
322
- load_success = True
323
- except Exception as e:
324
- warnings.warn(f"Error loading pytorch_model.bin: {e}", UserWarning)
325
-
326
-
327
- if loaded_state_dict:
328
- # Use strict=True for debugging missing/unexpected keys during development
329
- # Can be set to strict=False for more flexibility if needed, but True is safer
330
- load_result = model.load_state_dict(loaded_state_dict, strict=True)
331
- print(f"Model state loaded. Result: {load_result}")
332
- elif not load_success: # Only warn if neither file could be loaded
333
- warnings.warn(f"Warning: No model weights file found or loaded successfully at {safetensors_path} or {pytorch_path}. Model initialized randomly.", UserWarning)
334
-
335
  return model
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import torch.nn as nn
5
+ from typing import Dict, Any
6
+ import torch.nn.functional as F
7
+ import warnings
8
+
9
+ class SwiGLUBlock(nn.Module):
10
+ """
11
+ SwiGLU activation using two separate linear layers.
12
+ Input -> Linear (w1) -> Swish \
13
+ * -> Output
14
+ Input -> Linear (w3) -> Gate /
15
+ """
16
+ def __init__(self, input_dim, hidden_dim, bias=True):
17
+ super().__init__()
18
+ self.input_dim = input_dim
19
+ self.hidden_dim = hidden_dim
20
+ self.bias = bias
21
+
22
+ # Layer 1: Input -> Hidden (for the main Swish path)
23
+ self.in_proj_swish = nn.Linear(self.input_dim, self.hidden_dim, bias=self.bias)
24
+ # Layer 3: Input -> Hidden (for the gate path)
25
+ self.in_proj_gate = nn.Linear(self.input_dim, self.hidden_dim, bias=self.bias)
26
+
27
+ def forward(self, x):
28
+ # x shape: [..., input_dim]
29
+ hidden_states = self.in_proj_swish(x) # Output shape: [..., hidden_dim]
30
+ gate = self.in_proj_gate(x) # Output shape: [..., hidden_dim]
31
+
32
+ # Apply SwiGLU activation: Swish(hidden_states) * gate
33
+ activated_hidden = F.silu(hidden_states) * gate # Output shape: [..., hidden_dim]
34
+ return activated_hidden
35
+
36
+
37
+ class AdvancedWeightedPooling(nn.Module):
38
+ """
39
+ Performs Attention Pooling using the [CLS] token as the query.
40
+
41
+ Args:
42
+ embed_dim (int): The hidden dimension of the embeddings.
43
+ num_heads (int): The number of attention heads.
44
+ dropout (float, optional): Dropout probability for MHA. Defaults to 0.0.
45
+ bias (bool, optional): Whether to use bias in linear layers (MHA internal, MLP). Defaults to True.
46
+ use_layernorm (bool, optional): Apply Layer Normalization after pooling (and potential MLP/residual). Defaults to False.
47
+ use_MLP (bool, optional): Apply an MLP layer after attention pooling. Defaults to False.
48
+ MLP_h_size (int, optional): Hidden size for the MLP. Defaults to embed_dim if use_MLP is True.
49
+ use_residual_mean (bool, optional): Add a masked mean-pooled representation to the attention output. Defaults to False.
50
+ use_residual_MLP (bool, optional): Add the input of the MLP back to its output (residual connection). Defaults to True.
51
+ ignore_cls_as_kv (bool, optional): Exclude the [CLS] token from the key/value pairs in MHA. Defaults to True.
52
+ expand_emb_dim_to (int, optional): Expand the embedding dimension before MHA/MLP. Defaults to 0 (no expansion).
53
+ compress_output_dim_to (int, optional): Compress the final output dimension after all other steps. Defaults to 0 (no compression).
54
+ """
55
+ def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True, use_layernorm: bool = False, use_MLP: bool = False, MLP_h_size: int = -1, use_residual_MLP: str = 'add', ignore_cls_as_kv: bool = True, expand_emb_dim_to: int = 0, compress_output_dim_to: int = 0):
56
+ super(AdvancedWeightedPooling, self).__init__()
57
+ # --- Store initial embed_dim consistently ---
58
+ self.embed_dim = embed_dim # <-- Use self.embed_dim consistently
59
+ # --- Store other config parameters ---
60
+ self.num_heads = num_heads
61
+ self.dropout = dropout
62
+ self.bias = bias
63
+ self.use_layernorm = use_layernorm
64
+ self.use_MLP = use_MLP
65
+ self.MLP_h_size = MLP_h_size
66
+ self.use_residual_MLP = use_residual_MLP
67
+ self.ignore_cls_as_kv = ignore_cls_as_kv
68
+ self.expand_emb_dim_to = expand_emb_dim_to
69
+ self.compress_output_dim_to = compress_output_dim_to
70
+
71
+ self.current_embed_dim = self.embed_dim if self.expand_emb_dim_to == 0 else self.expand_emb_dim_to
72
+
73
+
74
+ if self.MLP_h_size == -1:
75
+ self.MLP_h_size = self.current_embed_dim
76
+
77
+ if self.compress_output_dim_to > 0 and (self.expand_emb_dim_to == 0 and self.compress_output_dim_to == self.embed_dim and not self.use_residual_MLP != 'concat'):
78
+ warnings.warn(f"input dim ({self.embed_dim}) == compress_output_dim_to ({self.compress_output_dim_to}) without any valid expand_emb_dim_to. Disabling compression.")
79
+ self.compress_output_dim_to = 0
80
+
81
+ if self.expand_emb_dim_to > 0 and self.expand_emb_dim_to != self.embed_dim:
82
+ print(f"INFO: Expanding embedding dimension from {self.embed_dim} to {self.expand_emb_dim_to}")
83
+ self.tokens_up_proj = nn.Linear(self.embed_dim, self.expand_emb_dim_to, bias=self.bias)
84
+ self.cls_up_proj = nn.Linear(self.embed_dim, self.expand_emb_dim_to, bias=self.bias)
85
+ self.current_embed_dim = self.expand_emb_dim_to # Update the dimension for subsequent layers
86
+ elif self.expand_emb_dim_to > 0 and self.expand_emb_dim_to == self.embed_dim:
87
+ warnings.warn(f"`expand_emb_dim_to` ({self.expand_emb_dim_to}) is the same as `embed_dim` ({self.embed_dim}). No expansion layer created.")
88
+ self.expand_emb_dim_to = 0 # Treat as no expansion needed
89
+
90
+ # --- Sub-modules ---
91
+ # MHA operates on the potentially expanded dimension
92
+ self.mha = nn.MultiheadAttention(
93
+ embed_dim=self.current_embed_dim,
94
+ num_heads=self.num_heads,
95
+ dropout=self.dropout,
96
+ bias=self.bias,
97
+ add_bias_kv = False, # Keep False if CLS is query only
98
+ batch_first=True
99
+ )
100
+
101
+ if self.use_MLP:
102
+ self.MLP = nn.Sequential(
103
+ SwiGLUBlock(self.current_embed_dim, self.MLP_h_size, bias=self.bias),
104
+ nn.Dropout(self.dropout),
105
+ nn.Linear(self.MLP_h_size, self.current_embed_dim, bias=self.bias)
106
+ )
107
+
108
+
109
+ if self.compress_output_dim_to > 0:
110
+ self.compression_layer_input_dims = self.current_embed_dim if self.use_residual_MLP != 'concat' else self.current_embed_dim*2
111
+ self.output_down_proj = nn.Linear(self.current_embed_dim, self.compress_output_dim_to, bias=self.bias)
112
+
113
+
114
+ if self.use_layernorm:
115
+ if self.compress_output_dim_to != 0:
116
+ self.LayerNorm_input_dims = self.compress_output_dim_to
117
+ elif self.use_residual_MLP != 'concat':
118
+ self.LayerNorm_input_dims = self.current_embed_dim
119
+ else:
120
+ self.LayerNorm_input_dims = self.current_embed_dim*2
121
+
122
+ self.layernorm = nn.LayerNorm(self.LayerNorm_input_dims, eps=1e-05, elementwise_affine=True)
123
+
124
+
125
+ # --- Configuration for Saving/Loading ---
126
+ # Keep 'embed_dim' here as it refers to the initial config parameter
127
+ self.config_keys = ['embed_dim', 'num_heads', 'dropout', 'bias', 'use_layernorm', 'use_MLP', 'MLP_h_size', 'use_residual_MLP', 'ignore_cls_as_kv', 'expand_emb_dim_to', 'compress_output_dim_to']
128
+
129
+ def _masked_mean_pooling(self, token_embeddings, attention_mask):
130
+ """Helper function for masked mean pooling."""
131
+ # Ensure mask is expanded correctly for broadcasting
132
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
133
+ sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
134
+ # Clamp sum_mask after summing to avoid division by zero
135
+ sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
136
+ return sum_embeddings / sum_mask
137
+
138
+ def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
139
+ token_embeddings_all = features['token_embeddings'] # Shape: (batch, seq_len, initial_dim)
140
+ attention_mask = features.get('attention_mask')
141
+ if attention_mask is None:
142
+ attention_mask = torch.ones(token_embeddings_all.shape[:2], device=token_embeddings_all.device, dtype=torch.long)
143
+ else:
144
+ attention_mask = attention_mask.long()
145
+
146
+ # --- Prepare MHA Inputs ---
147
+ cls_embedding = token_embeddings_all[:, 0:1, :] # Shape: (batch, 1, initial_dim)
148
+
149
+ # Decide which embeddings to use as K/V based on ignore_cls_as_kv
150
+ if self.ignore_cls_as_kv:
151
+ token_embeddings_kv = token_embeddings_all[:, 1:, :] # Exclude CLS
152
+ # Adjust attention mask for K/V if CLS is ignored
153
+ sequence_attention_mask = attention_mask[:, 1:]
154
+ else:
155
+ token_embeddings_kv = token_embeddings_all # Include CLS
156
+ sequence_attention_mask = attention_mask
157
+
158
+ # --- Optional Expansion ---
159
+ if self.expand_emb_dim_to > 0:
160
+ # Apply expansion to both CLS (query) and the K/V tokens
161
+ cls_embedding = self.cls_up_proj(cls_embedding) # Shape: (batch, 1, current_embed_dim)
162
+ token_embeddings_kv = self.tokens_up_proj(token_embeddings_kv) # Shape: (batch, kv_seq_len, current_embed_dim)
163
+
164
+ # Check for empty sequence after slicing (if ignore_cls_as_kv is True)
165
+ if self.ignore_cls_as_kv and token_embeddings_kv.shape[1] == 0:
166
+ warnings.warn("Input sequence only contains [CLS] token after slicing when ignore_cls_as_kv=True. "
167
+ "Attention pooling cannot be performed. Returning CLS embedding (potentially processed).")
168
+ # Process the CLS embedding as if it were the pooled output
169
+ pooled_embedding = cls_embedding.squeeze(1) # Shape: (batch, current_embed_dim)
170
+
171
+ # Apply subsequent layers if configured
172
+ if self.use_MLP:
173
+ mlp_input = pooled_embedding
174
+ post_MLP_embedding = self.MLP(mlp_input)
175
+ if self.mlp_combination_mode == 'concat':
176
+ pooled_embedding = torch.cat([mlp_input, post_MLP_embedding], dim=-1)
177
+ elif self.mlp_combination_mode == 'add':
178
+ pooled_embedding = mlp_input + post_MLP_embedding
179
+ else:
180
+ pooled_embedding = post_MLP_embedding
181
+
182
+ if self.use_layernorm:
183
+ pooled_embedding = self.layernorm(pooled_embedding) # Apply LN before potential compression
184
+
185
+ if self.compress_output_dim_to > 0:
186
+ pooled_embedding = self.output_down_proj(pooled_embedding) # Apply final compression
187
+
188
+ return {'sentence_embedding': pooled_embedding}
189
+
190
+
191
+ # --- Multi-Head Attention ---
192
+ query = cls_embedding # Shape: (batch, 1, current_embed_dim)
193
+ key = token_embeddings_kv # Shape: (batch, kv_seq_len, current_embed_dim)
194
+ value = token_embeddings_kv # Shape: (batch, kv_seq_len, current_embed_dim)
195
+
196
+ # Create boolean mask: True for padding (0), False for real tokens (1)
197
+ # Mask shape should match (batch, kv_seq_len)
198
+ key_padding_mask = (sequence_attention_mask == 0)
199
+
200
+ attn_output, _ = self.mha(
201
+ query=query,
202
+ key=key,
203
+ value=value,
204
+ key_padding_mask=key_padding_mask,
205
+ need_weights=False
206
+ )
207
+ # attn_output shape: (batch, query_len=1, current_embed_dim)
208
+ pooled_embedding = attn_output.squeeze(1) # Shape: (batch, current_embed_dim)
209
+
210
+
211
+ # --- Optional MLP ---
212
+ if self.use_MLP:
213
+ mlp_input = pooled_embedding # Input to MLP
214
+ post_MLP_embedding = self.MLP(mlp_input)
215
+ if self.use_residual_MLP:
216
+ pooled_embedding = mlp_input + post_MLP_embedding # residual
217
+ else:
218
+ pooled_embedding = post_MLP_embedding
219
+
220
+ # --- Optional Output Compression ---
221
+ if self.compress_output_dim_to > 0:
222
+ pooled_embedding = self.output_down_proj(pooled_embedding)
223
+
224
+ # --- Optional LayerNorm ---
225
+ if self.use_layernorm:
226
+ pooled_embedding = self.layernorm(pooled_embedding)
227
+
228
+
229
+ return {'sentence_embedding': pooled_embedding}
230
+
231
+ def get_sentence_embedding_dimension(self) -> int:
232
+ """Returns the final output dimension of the pooling layer."""
233
+ # Start with the dimension after potential expansion
234
+ final_dim = self.current_embed_dim
235
+
236
+ # Account for MLP concatenation if used
237
+ if self.use_MLP and self.use_residual_MLP == 'concat':
238
+ final_dim *= 2
239
+
240
+ # If compression is applied, that's the final dimension
241
+ if self.compress_output_dim_to > 0:
242
+ final_dim = self.compress_output_dim_to
243
+
244
+ return final_dim
245
+
246
+ def get_config_dict(self) -> Dict[str, Any]:
247
+ # Now self.embed_dim exists and matches the key in config_keys
248
+ return {key: getattr(self, key) for key in self.config_keys}
249
+
250
+ def save(self, output_path: str, safe_serialization: bool = True) -> None: # Default to safe serialization
251
+ os.makedirs(output_path, exist_ok=True)
252
+ # Save config using the initial parameters
253
+ with open(os.path.join(output_path, 'config.json'), 'w') as fOut:
254
+ json.dump(self.get_config_dict(), fOut, indent=2)
255
+
256
+ model_path_st = os.path.join(output_path, 'model.safetensors')
257
+ model_path_bin = os.path.join(output_path, 'pytorch_model.bin')
258
+
259
+ state_dict = self.state_dict()
260
+ if safe_serialization:
261
+ try:
262
+ from safetensors.torch import save_file
263
+ # Need to ensure state_dict keys match what load_state_dict expects
264
+ save_file(state_dict, model_path_st)
265
+ print(f"Saved state dict to {model_path_st}")
266
+ # Remove old bin file if it exists and we successfully saved safetensors
267
+ if os.path.exists(model_path_bin):
268
+ os.remove(model_path_bin)
269
+ except ImportError:
270
+ warnings.warn("safetensors not available. Falling back to regular PyTorch serialization (pytorch_model.bin).", UserWarning)
271
+ torch.save(state_dict, model_path_bin)
272
+ print(f"Saved state dict to {model_path_bin}")
273
+ except Exception as e: # Catch potential errors during saving
274
+ warnings.warn(f"Error saving safetensors file: {e}. Falling back to pytorch_model.bin", UserWarning)
275
+ torch.save(state_dict, model_path_bin)
276
+ print(f"Saved state dict to {model_path_bin}")
277
+ else:
278
+ torch.save(state_dict, model_path_bin)
279
+ print(f"Saved state dict to {model_path_bin}")
280
+ # Remove old safetensors file if it exists
281
+ if os.path.exists(model_path_st):
282
+ os.remove(model_path_st)
283
+
284
+
285
+ @staticmethod
286
+ def load(input_path: str) -> 'AdvancedWeightedPooling':
287
+ # Load config first to initialize the model structure
288
+ config_path = os.path.join(input_path, 'config.json')
289
+ if not os.path.exists(config_path):
290
+ raise OSError(f"config.json not found in {input_path}")
291
+ with open(config_path) as fIn:
292
+ config = json.load(fIn)
293
+
294
+ # Instantiate the model using the loaded config
295
+ # This ensures all layers (like up/down projections, MLP, LN) are created
296
+ # based on the *saved* configuration before loading weights.
297
+ model = AdvancedWeightedPooling(**config)
298
+
299
+ # Determine paths for weights files
300
+ safetensors_path = os.path.join(input_path, 'model.safetensors')
301
+ pytorch_path = os.path.join(input_path, 'pytorch_model.bin')
302
+
303
+ loaded_state_dict = None
304
+ load_success = False
305
+ # Prioritize safetensors
306
+ if os.path.exists(safetensors_path):
307
+ try:
308
+ from safetensors.torch import load_file
309
+ loaded_state_dict = load_file(safetensors_path, device='cpu')
310
+ print(f"Loaded state dict from {safetensors_path}")
311
+ load_success = True
312
+ except ImportError:
313
+ warnings.warn("safetensors not available or error loading. Falling back to pytorch_model.bin if exists.", UserWarning)
314
+ except Exception as e:
315
+ warnings.warn(f"Error loading safetensors file: {e}. Falling back to pytorch_model.bin if exists.", UserWarning)
316
+
317
+ # Fallback to pytorch_model.bin if safetensors failed or doesn't exist
318
+ if not load_success and os.path.exists(pytorch_path):
319
+ try:
320
+ loaded_state_dict = torch.load(pytorch_path, map_location=torch.device('cpu'))
321
+ print(f"Loaded state dict from {pytorch_path}")
322
+ load_success = True
323
+ except Exception as e:
324
+ warnings.warn(f"Error loading pytorch_model.bin: {e}", UserWarning)
325
+
326
+
327
+ if loaded_state_dict:
328
+ # Use strict=True for debugging missing/unexpected keys during development
329
+ # Can be set to strict=False for more flexibility if needed, but True is safer
330
+ load_result = model.load_state_dict(loaded_state_dict, strict=True)
331
+ print(f"Model state loaded. Result: {load_result}")
332
+ elif not load_success: # Only warn if neither file could be loaded
333
+ warnings.warn(f"Warning: No model weights file found or loaded successfully at {safetensors_path} or {pytorch_path}. Model initialized randomly.", UserWarning)
334
+
335
  return model