Update model/openllama.py
Browse files- model/openllama.py +1 -8
model/openllama.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
from header import *
|
2 |
import os
|
3 |
import torch.nn.functional as F
|
4 |
-
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
|
5 |
from .ImageBind import *
|
6 |
from .ImageBind import data
|
7 |
from .modeling_llama import LlamaForCausalLM
|
@@ -103,13 +102,7 @@ class OpenLLAMAPEFTModel(nn.Module):
|
|
103 |
target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']
|
104 |
)
|
105 |
|
106 |
-
|
107 |
-
config = LlamaConfig.from_pretrained(vicuna_ckpt_path, use_auth_token=os.environ['API_TOKEN'])
|
108 |
-
self.llama_model = LlamaForCausalLM(config)
|
109 |
-
|
110 |
-
self.llama_model = load_checkpoint_and_dispatch(self.llama_model, vicuna_ckpt_path, device_map='sequential')
|
111 |
-
|
112 |
-
# self.llama_model = LlamaForCausalLM.from_pretrained(vicuna_ckpt_path, use_auth_token=os.environ['API_TOKEN'])
|
113 |
self.llama_model = get_peft_model(self.llama_model, peft_config)
|
114 |
self.llama_model.print_trainable_parameters()
|
115 |
|
|
|
1 |
from header import *
|
2 |
import os
|
3 |
import torch.nn.functional as F
|
|
|
4 |
from .ImageBind import *
|
5 |
from .ImageBind import data
|
6 |
from .modeling_llama import LlamaForCausalLM
|
|
|
102 |
target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']
|
103 |
)
|
104 |
|
105 |
+
self.llama_model = LlamaForCausalLM.from_pretrained(vicuna_ckpt_path, use_auth_token=os.environ['API_TOKEN'])
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
self.llama_model = get_peft_model(self.llama_model, peft_config)
|
107 |
self.llama_model.print_trainable_parameters()
|
108 |
|