Commit
·
d9ba731
1
Parent(s):
1536202
Update modeling_openmoe.py
Browse files- modeling_openmoe.py +5 -2
modeling_openmoe.py
CHANGED
@@ -379,8 +379,11 @@ class OpenMoeAttention(nn.Module):
|
|
379 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
380 |
|
381 |
if HAS_FLASH_ATTN and use_kernel:
|
382 |
-
from flash_attn import flash_attn_func
|
383 |
-
|
|
|
|
|
|
|
384 |
query_states = query_states.transpose(1, 2)
|
385 |
key_states = key_states.transpose(1, 2)
|
386 |
value_states = value_states.transpose(1, 2)
|
|
|
379 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
380 |
|
381 |
if HAS_FLASH_ATTN and use_kernel:
|
382 |
+
# If we use `from flash_attn import flash_attn_func` directly,
|
383 |
+
# AutoModelForCausalLM.from_pretrained will treat flash_attn as a compulsory dependency and raise error if cannot find.
|
384 |
+
# Here is a workaround to avoid the error.
|
385 |
+
exec("from flash_attn import flash_attn_func")
|
386 |
+
|
387 |
query_states = query_states.transpose(1, 2)
|
388 |
key_states = key_states.transpose(1, 2)
|
389 |
value_states = value_states.transpose(1, 2)
|