def forward(self, input): quant, diff, [_, _, img_toks] = self.encode(input) batch_size, height, width, n_channel = ( input.shape[0], quant.shape[-1], quant.shape[-2], quant.shape[-3], ) codebook_entry = self.quantize.get_codebook_entry( img_toks, (batch_size, n_channel, height, width) ) pixels = self.decode(codebook_entry) return pixels, img_toks, quant