Update model.py
Browse files
model.py
CHANGED
@@ -83,16 +83,30 @@ def load_classification_model(device, model_type="f", num_classes=6):
|
|
83 |
model.eval()
|
84 |
return model
|
85 |
|
|
|
86 |
def load_t5_model(device):
|
87 |
local_dir = snapshot_download(repo_id=T5_MODEL_REPO)
|
88 |
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
|
|
|
|
91 |
tokenizer = T5Tokenizer.from_pretrained(
|
92 |
-
|
|
|
93 |
)
|
94 |
model = T5ForConditionalGeneration.from_pretrained(
|
95 |
-
|
|
|
96 |
).to(device)
|
97 |
model.eval()
|
98 |
return tokenizer, model
|
@@ -147,4 +161,4 @@ def calculate_performance_metrics(model, device):
|
|
147 |
"speed_v100_b1": cpu_time / 2,
|
148 |
"params_million": params / 1e6,
|
149 |
"flops_billion": flops / 1e9
|
150 |
-
}
|
|
|
83 |
model.eval()
|
84 |
return model
|
85 |
|
86 |
+
|
87 |
def load_t5_model(device):
|
88 |
local_dir = snapshot_download(repo_id=T5_MODEL_REPO)
|
89 |
|
90 |
+
base_model_dir = os.path.join(local_dir, "model")
|
91 |
+
|
92 |
+
t5_path = None
|
93 |
+
if os.path.exists(os.path.join(base_model_dir, "spiece.model")):
|
94 |
+
t5_path = base_model_dir
|
95 |
+
else:
|
96 |
+
for root, dirs, files in os.walk(base_model_dir):
|
97 |
+
if "spiece.model" in files:
|
98 |
+
t5_path = root
|
99 |
+
break
|
100 |
|
101 |
+
if t5_path is None:
|
102 |
+
raise FileNotFoundError(f"Spiece model dosyası bulunamadı: {base_model_dir} içinde.")
|
103 |
tokenizer = T5Tokenizer.from_pretrained(
|
104 |
+
t5_path,
|
105 |
+
local_files_only=True
|
106 |
)
|
107 |
model = T5ForConditionalGeneration.from_pretrained(
|
108 |
+
t5_path,
|
109 |
+
local_files_only=True
|
110 |
).to(device)
|
111 |
model.eval()
|
112 |
return tokenizer, model
|
|
|
161 |
"speed_v100_b1": cpu_time / 2,
|
162 |
"params_million": params / 1e6,
|
163 |
"flops_billion": flops / 1e9
|
164 |
+
}
|