ArneBinder commited on
Commit
14a87b0
·
verified ·
1 Parent(s): 4beb39d

fix load_model_with_adapter location

Browse files
configs/retriever/related_span_retriever_with_relations_from_other_docs.yaml CHANGED
@@ -18,7 +18,7 @@ vectorstore:
18
  embedding:
19
  _target_: src.langchain_modules.HuggingFaceSpanEmbeddings
20
  model:
21
- _target_: src.models.utils.adapters.load_model_with_adapter
22
  model_kwargs:
23
  pretrained_model_name_or_path: allenai/specter2_base
24
  adapter_kwargs:
 
18
  embedding:
19
  _target_: src.langchain_modules.HuggingFaceSpanEmbeddings
20
  model:
21
+ _target_: src.models.utils.load_model_with_adapter
22
  model_kwargs:
23
  pretrained_model_name_or_path: allenai/specter2_base
24
  adapter_kwargs:
src/models/__init__.py ADDED
File without changes
src/models/utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .loading import load_model_from_pie_model, load_model_with_adapter, load_tokenizer_from_pie_taskmodule
src/models/utils/loading.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict
2
+
3
+ from pie_modules.models import * # noqa: F403
4
+ from pie_modules.taskmodules import * # noqa: F403
5
+ from pytorch_ie import AutoModel, AutoTaskModule, PyTorchIEModel, TaskModule
6
+ from pytorch_ie.models import * # noqa: F403
7
+ from pytorch_ie.taskmodules import * # noqa: F403
8
+ from transformers import PreTrainedModel, PreTrainedTokenizer
9
+
10
+
11
+ def load_model_from_pie_model(model_kwargs: Dict[str, Any]) -> PreTrainedModel:
12
+
13
+ pie_model: PyTorchIEModel = AutoModel.from_pretrained(**model_kwargs)
14
+
15
+ return pie_model.model.model
16
+
17
+
18
+ def load_tokenizer_from_pie_taskmodule(taskmodule_kwargs: Dict[str, Any]) -> PreTrainedTokenizer:
19
+
20
+ pie_taskmodule: TaskModule = AutoTaskModule.from_pretrained(**taskmodule_kwargs)
21
+
22
+ return pie_taskmodule.tokenizer
23
+
24
+
25
+ def load_model_with_adapter(
26
+ model_kwargs: Dict[str, Any], adapter_kwargs: Dict[str, Any]
27
+ ) -> "ModelAdaptersMixin":
28
+ from adapters import AutoAdapterModel, ModelAdaptersMixin
29
+
30
+ model = AutoAdapterModel.from_pretrained(**model_kwargs)
31
+ model.load_adapter(set_active=True, **adapter_kwargs)
32
+ return model