eyupipler commited on
Commit
293b179
·
verified ·
1 Parent(s): 4d39eda

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +0 -27
model.py CHANGED
@@ -10,22 +10,13 @@ from thop import profile
10
  from transformers import T5ForConditionalGeneration, T5Tokenizer
11
  from huggingface_hub import hf_hub_download, snapshot_download
12
 
13
- # ------------------------------------------------------------
14
- # 1) MODEL HUB PATHS
15
- # ------------------------------------------------------------
16
- # CNN modellerinin bulunduğu repo:
17
  CLASSIFICATION_MODEL_REPO = "Neurazum/Vbai-DPA-2.3"
18
  CLASSIFICATION_MODEL_FILENAME_F = "Vbai-DPA 2.3f.pt"
19
  CLASSIFICATION_MODEL_FILENAME_C = "Vbai-DPA 2.3c.pt"
20
  CLASSIFICATION_MODEL_FILENAME_Q = "Vbai-DPA 2.3q.pt"
21
 
22
- # T5 modelinin bulunduğu repo (içinde “Tbai-DPA 1.0/model/…” ağaç yapısı var):
23
  T5_MODEL_REPO = "Neurazum/Tbai-DPA-1.0"
24
- # ------------------------------------------------------------
25
 
26
- # ------------------------------------------------------------
27
- # 2) SIMPLECNN TANIMI
28
- # ------------------------------------------------------------
29
  class SimpleCNN(nn.Module):
30
  def __init__(self, model_type="f", num_classes=6):
31
  super(SimpleCNN, self).__init__()
@@ -65,14 +56,7 @@ class SimpleCNN(nn.Module):
65
  x = self.fc2(x)
66
  return x
67
 
68
- # ------------------------------------------------------------
69
- # 3) MODEL YÜKLEME FONKSİYONLARI
70
- # ------------------------------------------------------------
71
  def load_classification_model(device, model_type="f", num_classes=6):
72
- """
73
- model_type: "f", "c" veya "q"
74
- Vbai-DPA-2.3 repo’sundan ilgili .pt dosyasını indirip SimpleCNN yapısına yükler.
75
- """
76
  if model_type == "f":
77
  filename = CLASSIFICATION_MODEL_FILENAME_F
78
  elif model_type == "c":
@@ -100,19 +84,12 @@ def load_classification_model(device, model_type="f", num_classes=6):
100
 
101
 
102
  def load_t5_model(device):
103
- """
104
- Tbai-DPA-1-0 repo’sunu indirir (snapshot_download).
105
- İndirilen tüm ağacı tarar, içinde “[.]model” uzantılı dosya barındıran
106
- ilk dizini bulur. O dizin, T5 tokenizer/model için kullanılır.
107
- """
108
  local_dir = snapshot_download(repo_id=T5_MODEL_REPO)
109
 
110
- # 1) “local_dir” içinde recursive olarak .model uzantılı bir dosya arayalım
111
  t5_path = None
112
  for root, dirs, files in os.walk(local_dir):
113
  for fname in files:
114
  if fname.endswith(".model"):
115
- # bulunduğu klasörü t5_path olarak al
116
  t5_path = root
117
  break
118
  if t5_path:
@@ -121,7 +98,6 @@ def load_t5_model(device):
121
  if t5_path is None:
122
  raise FileNotFoundError(f"Hiçbir '.model' dosyası bulunamadı: {local_dir} içinde.")
123
 
124
- # 2) Bulunan klasörü, T5Tokenizer ve T5ForConditionalGeneration için kullan
125
  tokenizer = T5Tokenizer.from_pretrained(
126
  t5_path,
127
  local_files_only=True
@@ -133,9 +109,6 @@ def load_t5_model(device):
133
  model.eval()
134
  return tokenizer, model
135
 
136
- # ------------------------------------------------------------
137
- # 4) GÖRÜNTÜ → TAHMİN ve YORUM İŞLEME FONKSİYONLARI
138
- # ------------------------------------------------------------
139
  transform = transforms.Compose([
140
  transforms.Resize((224, 224)),
141
  transforms.ToTensor(),
 
10
  from transformers import T5ForConditionalGeneration, T5Tokenizer
11
  from huggingface_hub import hf_hub_download, snapshot_download
12
 
 
 
 
 
13
  CLASSIFICATION_MODEL_REPO = "Neurazum/Vbai-DPA-2.3"
14
  CLASSIFICATION_MODEL_FILENAME_F = "Vbai-DPA 2.3f.pt"
15
  CLASSIFICATION_MODEL_FILENAME_C = "Vbai-DPA 2.3c.pt"
16
  CLASSIFICATION_MODEL_FILENAME_Q = "Vbai-DPA 2.3q.pt"
17
 
 
18
  T5_MODEL_REPO = "Neurazum/Tbai-DPA-1.0"
 
19
 
 
 
 
20
  class SimpleCNN(nn.Module):
21
  def __init__(self, model_type="f", num_classes=6):
22
  super(SimpleCNN, self).__init__()
 
56
  x = self.fc2(x)
57
  return x
58
 
 
 
 
59
  def load_classification_model(device, model_type="f", num_classes=6):
 
 
 
 
60
  if model_type == "f":
61
  filename = CLASSIFICATION_MODEL_FILENAME_F
62
  elif model_type == "c":
 
84
 
85
 
86
  def load_t5_model(device):
 
 
 
 
 
87
  local_dir = snapshot_download(repo_id=T5_MODEL_REPO)
88
 
 
89
  t5_path = None
90
  for root, dirs, files in os.walk(local_dir):
91
  for fname in files:
92
  if fname.endswith(".model"):
 
93
  t5_path = root
94
  break
95
  if t5_path:
 
98
  if t5_path is None:
99
  raise FileNotFoundError(f"Hiçbir '.model' dosyası bulunamadı: {local_dir} içinde.")
100
 
 
101
  tokenizer = T5Tokenizer.from_pretrained(
102
  t5_path,
103
  local_files_only=True
 
109
  model.eval()
110
  return tokenizer, model
111
 
 
 
 
112
  transform = transforms.Compose([
113
  transforms.Resize((224, 224)),
114
  transforms.ToTensor(),