Safetensors
custom_code
gheinrich commited on
Commit
3819d57
·
verified ·
1 Parent(s): d661928

Update vit_patch_generator.py

Browse files
Files changed (1) hide show
  1. vit_patch_generator.py +0 -19
vit_patch_generator.py CHANGED
@@ -119,10 +119,6 @@ class ViTPatchGenerator(nn.Module):
119
  'pos_embed',
120
  ]
121
 
122
- def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
123
- if self.abs_pos:
124
- self._load_embed(state_dict[f'{prefix}pos_embed'], self.pos_embed)
125
-
126
  def _load_embed(self, src_embed: torch.Tensor, targ_embed: nn.Parameter):
127
  if src_embed.shape != targ_embed.shape:
128
  src_size = int(math.sqrt(src_embed.shape[1]))
@@ -285,18 +281,3 @@ class ViTPatchLinear(nn.Linear):
285
  **factory
286
  )
287
  self.patch_size = patch_size
288
-
289
- def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
290
- if self.bias is not None:
291
- self.bias.data.copy_(state_dict[f'{prefix}bias'])
292
-
293
- chk_weight = state_dict[f'{prefix}weight']
294
- if chk_weight.shape != self.weight.shape:
295
- src_patch_size = int(math.sqrt(chk_weight.shape[1] // 3))
296
-
297
- assert (src_patch_size ** 2) * 3 == chk_weight.shape[1], 'Unable to interpolate non-square patch size'
298
-
299
- chk_weight = rearrange(chk_weight, 'b (c h w) -> b c h w', c=3, h=src_patch_size, w=src_patch_size)
300
- chk_weight = F.interpolate(chk_weight, size=(self.patch_size, self.patch_size), mode='bicubic', align_corners=True, antialias=False)
301
- chk_weight = rearrange(chk_weight, 'b c h w -> b (c h w)')
302
- self.weight.data.copy_(chk_weight)
 
119
  'pos_embed',
120
  ]
121
 
 
 
 
 
122
  def _load_embed(self, src_embed: torch.Tensor, targ_embed: nn.Parameter):
123
  if src_embed.shape != targ_embed.shape:
124
  src_size = int(math.sqrt(src_embed.shape[1]))
 
281
  **factory
282
  )
283
  self.patch_size = patch_size