OrionZheng commited on
Commit
d9ba731
·
1 Parent(s): 1536202

Update modeling_openmoe.py

Browse files
Files changed (1) hide show
  1. modeling_openmoe.py +5 -2
modeling_openmoe.py CHANGED
@@ -379,8 +379,11 @@ class OpenMoeAttention(nn.Module):
379
  value_states = repeat_kv(value_states, self.num_key_value_groups)
380
 
381
  if HAS_FLASH_ATTN and use_kernel:
382
- from flash_attn import flash_attn_func
383
-
 
 
 
384
  query_states = query_states.transpose(1, 2)
385
  key_states = key_states.transpose(1, 2)
386
  value_states = value_states.transpose(1, 2)
 
379
  value_states = repeat_kv(value_states, self.num_key_value_groups)
380
 
381
  if HAS_FLASH_ATTN and use_kernel:
382
+ # If we use `from flash_attn import flash_attn_func` directly,
383
+ # AutoModelForCausalLM.from_pretrained will treat flash_attn as a compulsory dependency and raise error if cannot find.
384
+ # Here is a workaround to avoid the error.
385
+ exec("from flash_attn import flash_attn_func")
386
+
387
  query_states = query_states.transpose(1, 2)
388
  key_states = key_states.transpose(1, 2)
389
  value_states = value_states.transpose(1, 2)