voidful commited on
Commit
19649d5
Β·
verified Β·
1 Parent(s): 14223c1

Upload 7 files

Browse files
NOTES.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ ## add to config.json to support trust remote code
2
+ {
3
+ "model_type": "gemma_3_omni",
4
+ "architectures": ["Gemma3OmniForConditionalGeneration"],
5
+ "auto_map": {
6
+ "AutoConfig": "modeling_gemma_3_omni.Gemma3Config",
7
+ "AutoModel": "modeling_gemma_3_omni.Gemma3OmniForConditionalGeneration"
8
+ },
9
+ }
chat_template.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "chat_template": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set first_user_prefix = messages[0]['content'][0]['text'] + '\\n\\n' %}{% set loop_messages = messages[1:] %}{% else %}{% set first_user_prefix = '' %}{% set loop_messages = messages %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set role = 'model' if message['role'] == 'assistant' else message['role'] %}{{ '<start_of_turn>' + role + '\\n' + (first_user_prefix if loop.first else '') }}{% if message['content'] is string %}{{ message['content'] | trim }}{% elif message['content'] is iterable %}{% for item in message['content'] %}{{ '<start_of_image>' if item['type']=='image' else '<start_of_audio>' if item['type']=='audio' else item['text']|trim if item['type']=='text' else '' }}{% endfor %}{% else %}{{ raise_exception('Invalid content type') }}{% endif %}{{ '<end_of_turn>\\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<start_of_turn>model\\n' }}{% endif %}"
3
+ }
map_phi_audio_encoder.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM
2
+
3
+ from modeling_gemma_3_omni import Gemma3OmniForConditionalGeneration
4
+
5
+ phi_audio_encoder = AutoModelForCausalLM.from_pretrained("microsoft/Phi-4-multimodal-instruct", trust_remote_code=True)
6
+ phi_state_dict = phi_audio_encoder.model.embed_tokens_extend.audio_embed.encoder.state_dict()
7
+ model = Gemma3OmniForConditionalGeneration.from_pretrained("google/gemma-3-4b-it")
8
+ model.eval()
9
+ model.audio_projector.encoder.load_state_dict(phi_state_dict, strict=False)
10
+ model.push_to_hub('voidful/gemma-3-omni-4b-it')
modeling_gemma_3_omni.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ from __future__ import annotations
3
+
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ import torchaudio
8
+ from torch import nn
9
+ from transformers import (
10
+ AutoModel,
11
+ AutoModelForCausalLM,
12
+ Cache,
13
+ Gemma3Config,
14
+ PreTrainedModel,
15
+ PretrainedConfig, StaticCache, HybridCache,
16
+ )
17
+ from transformers.modeling_outputs import CausalLMOutputWithPast
18
+ from transformers.models.gemma3.modeling_gemma3 import (
19
+ Gemma3CausalLMOutputWithPast,
20
+ Gemma3ForConditionalGeneration,
21
+ Gemma3RMSNorm,
22
+ )
23
+ from transformers.utils import is_torchdynamo_compiling, logging
24
+
25
+ from .speech_conformer_encoder import ConformerEncoder
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+ class Gemma3AudioProjectorConfig(PretrainedConfig):
31
+ model_type = "gemma3_audio"
32
+
33
+ def __init__(
34
+ self,
35
+ hidden_size: int = 1024,
36
+ num_hidden_layers: int = 24,
37
+ sample_rate: int = 16_000,
38
+ n_mels: int = 80,
39
+ audio_token_id: int = 0,
40
+ **kwargs,
41
+ ):
42
+ super().__init__(**kwargs)
43
+ self.hidden_size = hidden_size
44
+ self.num_hidden_layers = num_hidden_layers
45
+ self.sample_rate = sample_rate
46
+ self.n_mels = n_mels
47
+ self.audio_token_id = audio_token_id
48
+
49
+
50
+ class Gemma3AudioProjector(PreTrainedModel):
51
+ """Conformer-based audio encoder β†’ project to LM hidden-dim."""
52
+
53
+ config_class = Gemma3AudioProjectorConfig
54
+ base_model_prefix = "audio_projector"
55
+
56
+ def __init__(self, config: Gemma3AudioProjectorConfig):
57
+ super().__init__(config)
58
+ # encoder_config = config.audio_processor.get("config", None)
59
+ encoder_config = {
60
+ "activation": "swish",
61
+ "activation_checkpointing": {
62
+ "interval": 1,
63
+ "module": "transformer",
64
+ "offload": False
65
+ },
66
+ "attention_dim": 1024,
67
+ "attention_heads": 16,
68
+ "batch_norm": False,
69
+ "bias_in_glu": True,
70
+ "causal": True,
71
+ "chunk_size": -1,
72
+ "cnn_layer_norm": True,
73
+ "conv_activation": "swish",
74
+ "conv_glu_type": "swish",
75
+ "depthwise_multiplier": 1,
76
+ "depthwise_seperable_out_channel": 1024,
77
+ "dropout_rate": 0.0,
78
+ "encoder_embedding_config": {
79
+ "input_size": 80
80
+ },
81
+ "ext_pw_kernel_size": 1,
82
+ "ext_pw_out_channel": 1024,
83
+ "input_layer": "nemo_conv",
84
+ "input_size": 80,
85
+ "kernel_size": 3,
86
+ "left_chunk": 18,
87
+ "linear_units": 1536,
88
+ "nemo_conv_settings": {
89
+ "conv_channels": 1024
90
+ },
91
+ "num_blocks": 24,
92
+ "relative_attention_bias_args": {
93
+ "t5_bias_max_distance": 500,
94
+ "type": "t5"
95
+ },
96
+ "time_reduction": 8
97
+ }
98
+ self.encoder = ConformerEncoder(**encoder_config)
99
+ self.mel = torchaudio.transforms.MelSpectrogram(
100
+ sample_rate=config.sample_rate, n_mels=config.n_mels
101
+ )
102
+ self.proj = nn.Linear(1024, config.hidden_size, bias=False)
103
+ self.layer_norm = nn.LayerNorm(config.hidden_size)
104
+ self.post_init()
105
+
106
+ # ---------- helpers ----------
107
+ def wav2mel(self, wav: torch.Tensor) -> torch.Tensor:
108
+ return self.mel(wav).clamp(min=1e-5).log().transpose(1, 2)
109
+
110
+ # ---------- forward ----------
111
+ @torch.no_grad()
112
+ def forward(self, wav: torch.Tensor) -> torch.Tensor: # (B,T) or (B,1,T)
113
+ if wav.dim() == 3:
114
+ wav = wav.squeeze(1)
115
+ mel = self.wav2mel(wav)
116
+ lengths = torch.full(
117
+ (mel.size(0),), mel.size(1), dtype=torch.long, device=mel.device
118
+ )
119
+ hidden = self.encoder(mel, lengths)
120
+ hidden = self.proj(hidden)
121
+ return self.layer_norm(hidden)
122
+
123
+
124
+ # ──────────────────────────────────────────────────────────────────────────────
125
+ # Vision projector (θˆ‡εŽŸη‰ˆδΈ€θ‡΄οΌŒεͺζ”Ή dtype)
126
+ # ──────────────────────────────────────────────────────────────────────────────
127
+ class Gemma3VisionProjector(nn.Module):
128
+ def __init__(self, config: Gemma3Config):
129
+ super().__init__()
130
+ self.mm_input_projection_weight = nn.Parameter(
131
+ torch.zeros(config.vision_config.hidden_size, config.text_config.hidden_size)
132
+ )
133
+ self.mm_soft_emb_norm = Gemma3RMSNorm(
134
+ config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps
135
+ )
136
+ self.patches_per_image = config.vision_config.image_size // config.vision_config.patch_size
137
+ self.tokens_per_side = int(config.mm_tokens_per_image ** 0.5)
138
+ self.kernel_size = self.patches_per_image // self.tokens_per_side
139
+ self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size)
140
+
141
+ def forward(self, vision_outputs: torch.Tensor):
142
+ b, _, seq_len = vision_outputs.shape
143
+ x = vision_outputs.transpose(1, 2).reshape(
144
+ b, seq_len, self.patches_per_image, self.patches_per_image
145
+ )
146
+ x = self.avg_pool(x).flatten(2).transpose(1, 2)
147
+ x = self.mm_soft_emb_norm(x)
148
+ return torch.matmul(x, self.mm_input_projection_weight).type_as(vision_outputs)
149
+
150
+
151
+ # ──────────────────────────────────────────────────────────────────────────────
152
+ # Gemma-3 Multimodal wrapper
153
+ # ──────────────────────────────────────────────────────────────────────────────
154
+ class Gemma3OmniForConditionalGeneration(Gemma3ForConditionalGeneration):
155
+ """Gemma-3 Omni:vision + audio + text causal LM."""
156
+
157
+ def __init__(self, config: Gemma3Config):
158
+ super().__init__(config)
159
+
160
+ # ---- sub-modules
161
+ self.vision_tower = AutoModel.from_config(config=config.vision_config)
162
+ self.multi_modal_projector = Gemma3VisionProjector(config)
163
+ self.audio_projector = Gemma3AudioProjector(
164
+ Gemma3AudioProjectorConfig(hidden_size=config.text_config.hidden_size)
165
+ )
166
+ self.vocab_size = config.text_config.vocab_size
167
+
168
+ language_model = AutoModelForCausalLM.from_config(config=config.text_config)
169
+ if language_model._tied_weights_keys is not None:
170
+ self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
171
+ self.language_model = language_model
172
+
173
+ self.pad_token_id = (
174
+ self.config.pad_token_id if self.config.pad_token_id is not None else -1
175
+ )
176
+ self.post_init()
177
+
178
+ # ---------- helper ----------
179
+ def get_audio_features(self, audio_values: torch.Tensor) -> torch.Tensor:
180
+ return self.audio_projector(audio_values)
181
+
182
+ def _update_causal_mask(
183
+ self,
184
+ attention_mask,
185
+ token_type_ids,
186
+ past_key_values,
187
+ cache_position,
188
+ input_tensor,
189
+ is_training: bool = False,
190
+ ):
191
+ if self.config.text_config._attn_implementation == "flash_attention_2":
192
+ return attention_mask
193
+
194
+ if attention_mask is not None and attention_mask.dim() == 4:
195
+ # In this case we assume that the mask comes already in inverted
196
+ # form and requires no inversion or slicing.
197
+ return attention_mask
198
+
199
+ using_static_cache = isinstance(past_key_values, StaticCache)
200
+ min_dtype = torch.finfo(self.dtype).min
201
+ inputs_lead_dim, sequence_length = input_tensor.shape[:2]
202
+ if using_static_cache:
203
+ target_length = past_key_values.get_max_cache_shape()
204
+ elif isinstance(past_key_values, HybridCache):
205
+ target_length = past_key_values.get_max_cache_shape()
206
+ else:
207
+ target_length = (
208
+ attention_mask.shape[-1]
209
+ if isinstance(attention_mask, torch.Tensor)
210
+ else cache_position[0] + sequence_length + 1
211
+ )
212
+
213
+ if attention_mask is not None and attention_mask.dim() == 4:
214
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
215
+ return attention_mask
216
+
217
+ causal_mask = torch.full(
218
+ (sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device
219
+ )
220
+
221
+ # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
222
+ if sequence_length != 1:
223
+ causal_mask = torch.triu(causal_mask, diagonal=1)
224
+
225
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
226
+ causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
227
+
228
+ # Apply bidirectional mask on images if token type ids are provided
229
+ if token_type_ids is not None and sequence_length != 1:
230
+ token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2)
231
+ token_type_mask[token_type_ids == 0] = False # if text token do not change anything
232
+ token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool)
233
+ causal_mask = causal_mask.clone()
234
+ causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill(
235
+ token_type_mask, 0.0
236
+ )
237
+
238
+ if attention_mask is not None:
239
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
240
+ mask_length = attention_mask.shape[-1]
241
+
242
+ # Then apply padding mask (will mask pad tokens)
243
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
244
+ padding_mask = padding_mask == 0
245
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
246
+ padding_mask, min_dtype
247
+ )
248
+
249
+ return causal_mask
250
+
251
+ # ---------- forward ----------
252
+ def forward(
253
+ self,
254
+ input_ids: Optional[torch.LongTensor] = None,
255
+ pixel_values: Optional[torch.FloatTensor] = None,
256
+ audio_values: Optional[torch.FloatTensor] = None,
257
+ attention_mask: Optional[torch.Tensor] = None,
258
+ position_ids: Optional[torch.LongTensor] = None,
259
+ past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
260
+ token_type_ids: Optional[torch.LongTensor] = None,
261
+ cache_position: Optional[torch.LongTensor] = None,
262
+ inputs_embeds: Optional[torch.FloatTensor] = None,
263
+ labels: Optional[torch.LongTensor] = None,
264
+ use_cache: Optional[bool] = None,
265
+ output_attentions: Optional[bool] = None,
266
+ output_hidden_states: Optional[bool] = None,
267
+ logits_to_keep: Union[int, torch.Tensor] = 0,
268
+ **lm_kwargs,
269
+ ) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
270
+
271
+ # === input validation ===
272
+ if (input_ids is None) ^ (inputs_embeds is not None):
273
+ raise ValueError("Exactly one of input_ids or inputs_embeds must be provided")
274
+
275
+ output_attentions = (
276
+ output_attentions if output_attentions is not None else self.config.output_attentions
277
+ )
278
+ output_hidden_states = (
279
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
280
+ )
281
+
282
+ is_training = token_type_ids is not None and labels is not None
283
+
284
+ # OOV image token β†’ pad
285
+ if input_ids is not None and self.config.image_token_id >= self.vocab_size:
286
+ special_image_mask = input_ids == self.config.image_token_id
287
+ llm_input_ids = input_ids.clone()
288
+ llm_input_ids[special_image_mask] = 0
289
+ else:
290
+ llm_input_ids = input_ids
291
+
292
+ if inputs_embeds is None:
293
+ inputs_embeds = self.get_input_embeddings()(llm_input_ids)
294
+
295
+ # cache_position
296
+ if cache_position is None:
297
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
298
+ cache_position = torch.arange(
299
+ past_seen_tokens,
300
+ past_seen_tokens + inputs_embeds.shape[1],
301
+ device=inputs_embeds.device,
302
+ )
303
+
304
+ # === merge image ===
305
+ if pixel_values is not None:
306
+ image_feat = self.get_image_features(pixel_values)
307
+ special_image_mask = (
308
+ (
309
+ inputs_embeds
310
+ == self.get_input_embeddings()(
311
+ torch.tensor(self.config.image_token_id, device=inputs_embeds.device)
312
+ )
313
+ )
314
+ if input_ids is None
315
+ else (
316
+ input_ids == self.config.image_token_id
317
+ ).unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
318
+ )
319
+ if (
320
+ not is_torchdynamo_compiling()
321
+ and inputs_embeds[special_image_mask].numel() != image_feat.numel()
322
+ ):
323
+ raise ValueError("#image tokens β‰  #embedding slots")
324
+ inputs_embeds = inputs_embeds.masked_scatter(
325
+ special_image_mask, image_feat.to(inputs_embeds)
326
+ )
327
+
328
+ # === merge audio ===
329
+ if audio_values is not None:
330
+ audio_feat = self.get_audio_features(audio_values)
331
+ # special_audio_mask = (
332
+ # (
333
+ # inputs_embeds
334
+ # == self.get_input_embeddings()(
335
+ # torch.tensor(self.config.audio_token_id, device=inputs_embeds.device)
336
+ # )
337
+ # )
338
+ # if input_ids is None
339
+ # else (
340
+ # input_ids == self.config.audio_token_id
341
+ # ).unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
342
+ # )
343
+ # if (
344
+ # not is_torchdynamo_compiling()
345
+ # and inputs_embeds[special_audio_mask].numel() != audio_feat.numel()
346
+ # ):
347
+ # raise ValueError("#audio tokens β‰  #embedding slots")
348
+ # inputs_embeds = inputs_embeds.masked_scatter(
349
+ # special_audio_mask, audio_feat.to(inputs_embeds)
350
+ # )
351
+ print(audio_feat.shape, inputs_embeds.shape)
352
+ inputs_embeds = torch.cat([audio_feat, inputs_embeds], dim=1)
353
+
354
+ # === label masking ===
355
+ if labels is not None and self.pad_token_id in labels:
356
+ logger.warning_once(
357
+ "`labels` contains `pad_token_id`; they will be masked out at loss computation."
358
+ )
359
+ labels = torch.where(
360
+ input_ids == self.pad_token_id, self.config.ignore_index, labels
361
+ )
362
+
363
+ causal_mask = self._update_causal_mask(
364
+ attention_mask,
365
+ token_type_ids,
366
+ past_key_values,
367
+ cache_position,
368
+ inputs_embeds,
369
+ is_training,
370
+ )
371
+
372
+ outputs: CausalLMOutputWithPast = self.language_model(
373
+ attention_mask=causal_mask,
374
+ position_ids=position_ids,
375
+ past_key_values=past_key_values,
376
+ inputs_embeds=inputs_embeds,
377
+ use_cache=use_cache,
378
+ output_attentions=output_attentions,
379
+ output_hidden_states=output_hidden_states,
380
+ cache_position=cache_position,
381
+ logits_to_keep=logits_to_keep,
382
+ **lm_kwargs,
383
+ )
384
+
385
+ # === loss ===
386
+ logits = outputs.logits
387
+ loss = None
388
+ if labels is not None:
389
+ logits = logits.float()
390
+ shift_logits = logits[..., :-1, :]
391
+ shift_labels = labels[..., 1:]
392
+ if attention_mask is not None:
393
+ shift_attention_mask = attention_mask[:, -shift_logits.shape[1]:].to(
394
+ logits.device
395
+ )
396
+ shift_logits = shift_logits[shift_attention_mask != 0].contiguous()
397
+ shift_labels = shift_labels[shift_attention_mask != 0].contiguous()
398
+ loss = nn.CrossEntropyLoss()(
399
+ shift_logits.view(-1, self.config.text_config.vocab_size),
400
+ shift_labels.view(-1),
401
+ )
402
+
403
+ return Gemma3CausalLMOutputWithPast(
404
+ loss=loss,
405
+ logits=logits,
406
+ past_key_values=outputs.past_key_values,
407
+ hidden_states=outputs.hidden_states,
408
+ attentions=outputs.attentions,
409
+ image_hidden_states=image_feat if pixel_values is not None else None,
410
+ )
411
+
412
+
413
+ # ──────────────────────────────────────────────────────────────────────────────
414
+ # exports
415
+ # ──────────────────────────────────────────────────────────────────────────────
416
+ __all__ = [
417
+ "Gemma3AudioProjectorConfig",
418
+ "Gemma3AudioProjector",
419
+ "Gemma3VisionProjector",
420
+ "Gemma3OmniForConditionalGeneration",
421
+ ]
processing_gemma3_omni.py CHANGED
@@ -1,6 +1,7 @@
1
- import math
2
  import re
3
  from typing import List, Optional, Union, Dict, Any
 
 
4
  import numpy as np
5
  import scipy.signal
6
  import torch
@@ -9,7 +10,7 @@ from transformers.audio_utils import AudioInput
9
  from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
10
  from transformers.feature_extraction_utils import BatchFeature
11
  from transformers.image_utils import make_nested_list_of_images
12
- from transformers.processing_utils import ProcessorMixin, ProcessingKwargs, ImagesKwargs, AudioKwargs, Unpack
13
  from transformers.utils import TensorType, to_py_obj, logging
14
 
15
  # Constants
@@ -342,4 +343,13 @@ class Gemma3OmniProcessor(ProcessorMixin):
342
  def model_input_names(self):
343
  tokenizer_inputs = self.tokenizer.model_input_names + ["token_type_ids"]
344
  image_processor_inputs = self.image_processor.model_input_names
345
- return list(dict.fromkeys(tokenizer_inputs + image_processor_inputs))
 
 
 
 
 
 
 
 
 
 
 
1
  import re
2
  from typing import List, Optional, Union, Dict, Any
3
+
4
+ import math
5
  import numpy as np
6
  import scipy.signal
7
  import torch
 
10
  from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
11
  from transformers.feature_extraction_utils import BatchFeature
12
  from transformers.image_utils import make_nested_list_of_images
13
+ from transformers.processing_utils import ProcessorMixin, ProcessingKwargs, ImagesKwargs, Unpack
14
  from transformers.utils import TensorType, to_py_obj, logging
15
 
16
  # Constants
 
343
  def model_input_names(self):
344
  tokenizer_inputs = self.tokenizer.model_input_names + ["token_type_ids"]
345
  image_processor_inputs = self.image_processor.model_input_names
346
+ return list(dict.fromkeys(tokenizer_inputs + image_processor_inputs))
347
+
348
+
349
+ # ──────────────────────────────────────────────────────────────────────────────
350
+ # exports
351
+ # ──────────────────────────────────────────────────────────────────────────────
352
+ __all__ = [
353
+ "Gemma3OmniProcessor",
354
+ "Gemma3AudioFeatureExtractor"
355
+ ]
speech_conformer_encoder.py ADDED
The diff for this file is too large to render. See raw diff