JonasGeiping commited on
Commit
2a364bd
·
verified ·
1 Parent(s): 7f3b64f

Update raven_modeling_minimal.py

Browse files
Files changed (1) hide show
  1. raven_modeling_minimal.py +6 -4
raven_modeling_minimal.py CHANGED
@@ -660,10 +660,12 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
660
 
661
  @torch.no_grad()
662
  def generate(self, *args, **kwargs):
663
- """Dispatcher - use HF generate in all normal cases.
664
- If BOTH `criterion` AND `exit_threshold` are provided as not None, we use adaptive compute.
665
- """
666
- if kwargs.get("criterion", None) is not None and kwargs.get("exit_threshold", None) is not None:
 
 
667
  print("Dispatching to custom generate function call")
668
  return self.generate_with_adaptive_compute(*args, **kwargs)
669
  else:
 
660
 
661
  @torch.no_grad()
662
  def generate(self, *args, **kwargs):
663
+ """Dispatcher - use HF generate in all normal cases."""
664
+ self.generation_config = args[1] if len(args) > 1 else self.generation_config
665
+ if any(
666
+ k in kwargs
667
+ for k in ("continuous_compute", "latent_dampening", "criterion", "exit_threshold", "cache_kwargs")
668
+ ):
669
  print("Dispatching to custom generate function call")
670
  return self.generate_with_adaptive_compute(*args, **kwargs)
671
  else: