Update raven_modeling_minimal.py
Browse files
raven_modeling_minimal.py
CHANGED
@@ -660,10 +660,12 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
|
|
660 |
|
661 |
@torch.no_grad()
|
662 |
def generate(self, *args, **kwargs):
|
663 |
-
"""Dispatcher - use HF generate in all normal cases.
|
664 |
-
|
665 |
-
|
666 |
-
|
|
|
|
|
667 |
print("Dispatching to custom generate function call")
|
668 |
return self.generate_with_adaptive_compute(*args, **kwargs)
|
669 |
else:
|
|
|
660 |
|
661 |
@torch.no_grad()
|
662 |
def generate(self, *args, **kwargs):
|
663 |
+
"""Dispatcher - use HF generate in all normal cases."""
|
664 |
+
self.generation_config = args[1] if len(args) > 1 else self.generation_config
|
665 |
+
if any(
|
666 |
+
k in kwargs
|
667 |
+
for k in ("continuous_compute", "latent_dampening", "criterion", "exit_threshold", "cache_kwargs")
|
668 |
+
):
|
669 |
print("Dispatching to custom generate function call")
|
670 |
return self.generate_with_adaptive_compute(*args, **kwargs)
|
671 |
else:
|