Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
- README.md +135 -0
- config.json +44 -0
- generation_config.json +7 -0
- model-00001-of-00004.safetensors +3 -0
- model-00002-of-00004.safetensors +3 -0
- model-00003-of-00004.safetensors +3 -0
- model-00004-of-00004.safetensors +3 -0
- model.safetensors.index.json +441 -0
- modeling_plamo.py +1707 -0
- special_tokens_map.json +24 -0
- tokenization_plamo.py +392 -0
- tokenizer.jsonl +3 -0
- tokenizer_config.json +56 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
tokenizer.jsonl filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: other
|
3 |
+
license_name: plamo-community-license
|
4 |
+
license_link: https://huggingface.co/pfnet/plamo-2-8b/blob/main/LICENSE/ja
|
5 |
+
language:
|
6 |
+
- en
|
7 |
+
- ja
|
8 |
+
pipeline_tag: text-generation
|
9 |
+
library_name: transformers
|
10 |
+
extra_gated_heading: PLaMo community license to download PLaMo 2 8B
|
11 |
+
extra_gated_description: To download PLaMo 2 8B, you have to agree to our license.
|
12 |
+
PLaMo 2 8B is released PLaMo community license. For non-commerical use, please contact
|
13 |
+
us via this [form](https://forms.gle/mTL8tBLrMYXKNZD56).
|
14 |
+
extra_gated_button_content: agree to PLaMo community license
|
15 |
+
extra_gated_prompt: "(English version is under construction. We apologize for the\
|
16 |
+
\ inconvenience.)\n### PLaMoコミュニティライセンス契約\nPLaMoコミュニティライセンス契約には、株式会社Preferred Networksが提供する別途定める大規模言語基盤モデルPLaMo及びその派生物を利用するためのライセンスの内容及びユーザーが遵守する事項等が定められている。ユーザーのPLaMo及びその派生物の利用には本契約が適用され、本契約に同意又は本モデル等を利用することにより、ユーザーは本契約に拘束される。\n\
|
17 |
+
#### 第1条(定義)\n(1) 「本契約」とは、PLaMoコミュニティライセンス契約を意味する。\n(2) 「PFN」とは、株式会社Preferred Networksを意味する。\n\
|
18 |
+
(3) 「本モデル」とは、別途定める「PLaMo」という名称のモデルの重み、モデルコード、トークナイザー、学習スクリプト及びこれらに付随してPFNが提供するものを意味する。\n\
|
19 |
+
(4) 「ユーザー」とは、本モデルを利用する個人又は法人を意味する。\n(5) 「派生モデル」とは、本モデルを改変又は利用し作成されるモデルの重み、モデルコード及びその他作成されたモデルの付随物を意味する。\n\
|
20 |
+
(6) 「生成物」とは、本モデル又は派生モデルの出力結果を意味する。\n(7) 「本モデル等」とは、本モデル、派生モデル及び生成物の総称を意味する。\n(8)\
|
21 |
+
\ 「本ライセンス」とは、PFNがユーザーに対して本契約に基づき本モデル等を利用することを許諾することを意味する。\n(9) 「商業目的」とは、 私的使用又は学術用途の範囲を超える、事業での利用又は営利を目的とする利用を意味する。なお、商業目的にはユーザーの製品、サービス又は事業の開発、変更又は提供(ホスティングサービスやAPI経由での提供を含む。)を目的とした使用及びユーザーの組織内部における利用も含まれる。\n\
|
22 |
+
#### 第2条(ユーザー)\nユーザーは、18歳以上又はその居住国で単独で契約を締結できる年齢に達していなければならない。但し、ユーザーの親権者又は法定代理人が本契約をユーザーが締結することに同意している場合はこの限りではない。\n\
|
23 |
+
#### 第3条(本ライセンス)\n(1) PFNは、ユーザーが本契約に同意しかつ本契約を遵守することを条件に、ユーザーに対して、本モデル等を本契約に定める条件及び範囲内で利用することを許諾する。\n\
|
24 |
+
(2) 本ライセンスは非独占、世界的、譲渡不可及びロイヤリティ無料とする。\n(3) ユーザーは、以下の条件をいずれも満たす場合に限り、商業目的を含む形で本モデル等を利用することができる。なお、ユーザーがこれらの条件のいずれかを満たさなくなった場合は、ユーザーはその時点で本モデル等を商業目的で利用することはできず、商業目的で本モデル等を利用したい場合は、新たにPFNから商業用のライセンスを取得しなければならない。\n\
|
25 |
+
\n (i) PFNの公式登録ページ https://forms.gle/mTL8tBLrMYXKNZD56 に事前に登録すること。\n\n (ii) ユーザー又はその関係会社の直近事業年度の収入又は売上が10億円(ユーザーの現地通貨換算額)を超えないこと。\n\
|
26 |
+
\n#### 第4条(再配布及び表示義務)\n(1) ユーザーが本モデル等(派生モデルやその生成物を含む)を第三者に提供する場合、以下の条件を満たさなければならない。\n\
|
27 |
+
\n (i) 本契約のコピーを提供し、本契約の条件を遵守させること。\n\n (ii) 「Built with PLaMo」と明示し、関連ウェブサイト、ユーザーインターフェース、ブログ記事、製品情報ページ又は製品ドキュメントに記載すること。\n\
|
28 |
+
\n (iii) 本モデル等を利用して作成した AI モデルの名称に「PLaMo」を含めること。\n\n#### 第5条(生成物の利用)\n(1) ユーザーは、生成物を本モデル又は派生モデルの生成物であることを明示することを条件に、公表することができる。\n\
|
29 |
+
(2) 生成物を利用してモデルを学習した場合、そのモデルは派生モデルとして本契約の条件が適用され、本契約のライセンス条件の下でのみ利用、配布及び商業化することができる。\n\
|
30 |
+
#### 第6条(その他利用条件)\nユーザーは、本モデル等の利用に関して、以下に定める行為をしてはならない。\n(1) 法令又は公序良俗に違反する行為\n(2)\
|
31 |
+
\ 犯罪行為又はこれを予告、関与、助長その他これらに関連する行為\n(3) PFN又は第三者の権利又は利益を侵害する行為\n(4) PFN又は第三者の名誉若しくは信用を毀損する行為\n\
|
32 |
+
(5) 生成物がPFNの公式見解等であるものという錯誤を生む情報を流布する行為\n(6) 虚偽の情報を発信する行為\n(7) 上記の他、PFNが不適切と合理的に判断する行為\n\
|
33 |
+
#### 第7条(保証の否認)\n(1) 本モデル及び生成物は、「現状有姿」で提供され、PFNは、これらに対して、正確性、真実性、商品性、品質、性能、特定目的への適合性、権利の非侵害など一切の保証をしない。\n\
|
34 |
+
(2) ユーザーは、法律、医療、金融又は人物評価その他重要な事項の決定に関して、生成物を唯一の証拠、評価又は意見として使用してはならない。\n(3) ユーザーは、本モデル等の使用及びその結果に関して全ての責任を負う。\n\
|
35 |
+
#### 第8条(責任の制限)\n(1) 契約責任、不法行為又は製造物責任その他の法的責任のいずれかであるかを問わず、PFNが本契約及び本モデル等に関してユーザーに対して負う損害賠償の責任は、通常かつ直接の損害に限り(逸失利益、特別損害、間接損害その他の損害については、その予見可能性の有無に関わらず、責任を負わない。)、損害賠償額の上限は、500円とする。但し、PFNに故意又は重過失が認められる場合はこの限りではない。\n\
|
36 |
+
(2) 前項に関わらず、ユーザーが本モデル等を事業のために利用する場合は、PFNは本契約及び本モデル等に関してユーザーに対して一切の損害賠償責任及びその他の責任を負わない。\n\
|
37 |
+
#### 第9条(ユーザーの責任)\n(1) ユーザーは、本モデル等の取得及び利用に関して、適用される法令(輸出入及び貿易に関連する法令を含む。)及び本契約を遵守する。\n\
|
38 |
+
(2) ユーザーは、本契約違反又は本モデル等の使用によって、PFNに損害を与えた場合は、その損害を賠償する。\n(3) ユーザーの本モデル等の使用に起因して、PFNが第三者から損害賠償請求その他請求を受けた場合、ユーザーは、当該請求からPFNを免責し、PFNに損害を与えないようにする。\n\
|
39 |
+
#### 第10条(権利の帰属)\n(1) 本モデルの一切の権利は、PFN又はPFNに本モデルのライセンスをしている第三者に帰属する。\n(2) 派生モデルのうち、ユーザーが本モデルを改変した部分の権利はユーザーに帰属し、その他の部分の権利はPFNに帰属する。\n\
|
40 |
+
(3) 生成物の一切の権利はユーザーに帰属する。\n#### 第11条(契約期間及び終了)\n(1) 本契約は、ユーザーが本契約に同意したとき又は本モデルにアクセスしたときから、本契約が解約されたときまでとする。\n\
|
41 |
+
(2) ユーザーが本契約のいずれかの条項に違反した場合、PFNは直ちに本契約を解除することができ、ユーザーは本モデル等のすべてのコピーを削除し、利用を即時に停止しなければならない。\n\
|
42 |
+
#### 第12条(契約の変更)\nPFNは、本契約(本モデル等に関するルールや諸規定等を含む。以下本条において同じ。)を変更できるものとする。PFNは、本契約を変更する場合には、変更の内容及び変更の効力発生時期を、当該効力発生時期までにPFN所定の方法で告知するものとする。\n\
|
43 |
+
#### 第13条(準拠法及び管轄裁判所)\n(1) 本契約の準拠法は日本法とする。\n(2) 本モデル等及び本契約に起因する紛争については、東京地方裁判所が専属的合意管轄裁判所とする。"
|
44 |
+
base_model: pfnet/plamo-2-8b
|
45 |
+
tags:
|
46 |
+
- plamo
|
47 |
+
- translation
|
48 |
+
---
|
49 |
+
|
50 |
+
# PLaMo Translation Model
|
51 |
+
|
52 |
+
PLaMo翻訳モデルはPreferred Networksによって開発された翻訳向け特化型大規模言語モデルです。
|
53 |
+
詳しくは[ブログ記事](https://tech.preferred.jp/ja/blog/plamo-translate/)および[プレスリリース](https://www.preferred.jp/ja/news/pr20250527/)を参照してください。
|
54 |
+
|
55 |
+
PLaMo Translation Model is a specialized large-scale language model developed by Preferred Networks for translation tasks.
|
56 |
+
For details, please refer to the [blog post](https://tech.preferred.jp/ja/blog/plamo-translate/) and [press release](https://www.preferred.jp/ja/news/pr20250527/).
|
57 |
+
|
58 |
+
List of models:
|
59 |
+
- [plamo-2-translate](http://huggingface.co/pfnet/plamo-2-translate) ... Post-trained model for translation
|
60 |
+
- [plamo-2-translate-base](http://huggingface.co/pfnet/plamo-2-translate-base) ... Base model for translation
|
61 |
+
- [plamo-2-translate-eval](http://huggingface.co/pfnet/plamo-2-translate-eval) ... Pair-wise evaluation model
|
62 |
+
|
63 |
+
PLaMo Translation Model is released under PLaMo community license. Please check the following license and agree to this before downloading.
|
64 |
+
|
65 |
+
- (EN) under construction: we apologize for the inconvenience
|
66 |
+
- (JA) https://www.preferred.jp/ja/plamo-community-license/
|
67 |
+
|
68 |
+
**NOTE**: This model has **NOT** been instruction-tuned for chat dialog or other downstream tasks.
|
69 |
+
|
70 |
+
### For *commercial* users
|
71 |
+
|
72 |
+
Please check the PLaMo community license and contact us via the following form to use commercial purpose.
|
73 |
+
|
74 |
+
- (EN/JA) https://forms.gle/mTL8tBLrMYXKNZD56
|
75 |
+
|
76 |
+
## Usage
|
77 |
+
|
78 |
+
### main/base model
|
79 |
+
|
80 |
+
```py
|
81 |
+
import vllm
|
82 |
+
|
83 |
+
# max_model_len/max_num_batched_tokens can be increased when running on a GPU with substantial memory.
|
84 |
+
# NOTE: Switch to "pfnet/plamo-2-translate-base" to try the base model.
|
85 |
+
llm = vllm.LLM(model="pfnet/plamo-2-translate", trust_remote_code=True, max_model_len=2000, max_num_batched_tokens=2000)
|
86 |
+
|
87 |
+
prompt = r'''<|plamo:op|>dataset
|
88 |
+
translation
|
89 |
+
<|plamo:op|>input lang=English
|
90 |
+
Write the text to be translated here.
|
91 |
+
<|plamo:op|>output lang=Japanese
|
92 |
+
'''
|
93 |
+
|
94 |
+
responses = llm.generate([prompt] * 1, sampling_params=vllm.SamplingParams(temperature=0, max_tokens=1024, stop=["<|plamo:op|>"]))
|
95 |
+
# NOTE: This outputs "ここに翻訳するテキストを入力してください。".
|
96 |
+
print(responses[0].outputs[0].text)
|
97 |
+
```
|
98 |
+
|
99 |
+
### evaluation model
|
100 |
+
|
101 |
+
```py
|
102 |
+
import vllm
|
103 |
+
|
104 |
+
# max_model_len/max_num_batched_tokens can be increased when running on a GPU with substantial memory.
|
105 |
+
llm = vllm.LLM(model="pfnet/plamo-2-translate-eval", trust_remote_code=True, max_model_len=2000, max_num_batched_tokens=2000)
|
106 |
+
|
107 |
+
prompt = r'''<|plamo:op|>dataset
|
108 |
+
translation evaluation
|
109 |
+
<|plamo:op|>input lang=English
|
110 |
+
This is an apple.
|
111 |
+
<|plamo:op|>output id=A lang=Japanese
|
112 |
+
これはりんごです。
|
113 |
+
<|plamo:op|>output id=B lang=Japanese
|
114 |
+
これはリンゴです。
|
115 |
+
<|plamo:op|>best
|
116 |
+
id='''
|
117 |
+
|
118 |
+
responses = llm.generate([prompt] * 1, sampling_params=vllm.SamplingParams(temperature=0, max_tokens=1, stop=["<|plamo:op|>"]))
|
119 |
+
# NOTE: This outputs "A".
|
120 |
+
print(responses[0].outputs[0].text)
|
121 |
+
```
|
122 |
+
|
123 |
+
## Bias, Risks, and Limitations
|
124 |
+
|
125 |
+
PLaMo Translation Model is a new technology that carries risks with use. Testing conducted to date has been in English and Japanese, and has not covered, nor could it cover all scenarios. For these reasons, as with all LLMs, PLaMo Translation Model’s potential outputs cannot be predicted in advance, and the model may in some instances produce inaccurate, biased or other objectionable responses to user prompts. Therefore, before deploying any applications of PLaMo Translation Model, developers should perform safety testing and tuning tailored to their specific applications of the model.
|
126 |
+
|
127 |
+
## Acknowledgement
|
128 |
+
|
129 |
+
This model is trained under the project, “Research and Development Project of the Enhanced Infrastructures for Post 5G Information and Communication System” (JPNP 20017), subsidized by the New Energy and Industrial Technology Development Organization (NEDO).
|
130 |
+
|
131 |
+
|
132 |
+
## AI policies for Preferred Networks, Inc. group
|
133 |
+
|
134 |
+
- (EN) https://www.preferred.jp/en/company/aipolicy/
|
135 |
+
- (JA) https://www.preferred.jp/ja/company/aipolicy/
|
config.json
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"Plamo2ForCausalLM"
|
4 |
+
],
|
5 |
+
"attention_window_size": 32768,
|
6 |
+
"auto_map": {
|
7 |
+
"AutoConfig": "modeling_plamo.Plamo2Config",
|
8 |
+
"AutoModelForCausalLM": "modeling_plamo.Plamo2ForCausalLM"
|
9 |
+
},
|
10 |
+
"bos_token_id": 1,
|
11 |
+
"eos_token_id": 1,
|
12 |
+
"eval_attention_n_bit": null,
|
13 |
+
"eval_mlp_n_bit": null,
|
14 |
+
"fp8_accum_dtype": "bfloat16",
|
15 |
+
"full_attention_idx": [],
|
16 |
+
"hidden_size": 4096,
|
17 |
+
"hidden_size_per_head": 128,
|
18 |
+
"image_feature_size": null,
|
19 |
+
"image_proj_type": "linear",
|
20 |
+
"image_token_id": null,
|
21 |
+
"intermediate_size": 16384,
|
22 |
+
"linear_type": "normal",
|
23 |
+
"mamba_chunk_size": 256,
|
24 |
+
"mamba_d_conv": 4,
|
25 |
+
"mamba_d_state": 64,
|
26 |
+
"mamba_enabled": true,
|
27 |
+
"mamba_num_heads": 64,
|
28 |
+
"mamba_step": 2,
|
29 |
+
"max_position_embeddings": 10485760,
|
30 |
+
"model_type": "plamo2",
|
31 |
+
"num_attention_heads": 32,
|
32 |
+
"num_hidden_layers": 32,
|
33 |
+
"num_key_value_heads": 4,
|
34 |
+
"pad_token_id": 3,
|
35 |
+
"rms_norm_eps": 1e-06,
|
36 |
+
"rope_local_theta": 1000000.0,
|
37 |
+
"rope_theta": 1000000.0,
|
38 |
+
"sliding_window": 32768,
|
39 |
+
"tokenizer_class": "Plamo2Tokenizer",
|
40 |
+
"torch_dtype": "bfloat16",
|
41 |
+
"transformers_version": "4.46.2",
|
42 |
+
"use_cache": false,
|
43 |
+
"vocab_size": 100032
|
44 |
+
}
|
generation_config.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"bos_token_id": 1,
|
4 |
+
"eos_token_id": 1,
|
5 |
+
"pad_token_id": 3,
|
6 |
+
"transformers_version": "4.46.2"
|
7 |
+
}
|
model-00001-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:707a7c1f9bb4982b098864d7215a3985fa3f026f2ebd0a82799fea72e807349f
|
3 |
+
size 4771172576
|
model-00002-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ab65beeb0731f66b6542cbfb3b6967f72e12386f55bd2234d690cdcda991ac4f
|
3 |
+
size 4964802416
|
model-00003-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:88a4bb84ec3e982f16497c375013e7efdc07890a8cd7126be0bcdefe6193a9f7
|
3 |
+
size 4832590896
|
model-00004-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a90408e0d462b199de23f1e29730d1608c3c425d384c6a3c0dc34fbeea9b335f
|
3 |
+
size 3668492768
|
model.safetensors.index.json
ADDED
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"metadata": {
|
3 |
+
"total_size": 18237007872
|
4 |
+
},
|
5 |
+
"weight_map": {
|
6 |
+
"model.embed_tokens.weight": "model-00001-of-00004.safetensors",
|
7 |
+
"model.layers.layers.0.mixer.A_log": "model-00001-of-00004.safetensors",
|
8 |
+
"model.layers.layers.0.mixer.B_norm_weight": "model-00001-of-00004.safetensors",
|
9 |
+
"model.layers.layers.0.mixer.C_norm_weight": "model-00001-of-00004.safetensors",
|
10 |
+
"model.layers.layers.0.mixer.D": "model-00001-of-00004.safetensors",
|
11 |
+
"model.layers.layers.0.mixer.bcdt_proj.weight": "model-00001-of-00004.safetensors",
|
12 |
+
"model.layers.layers.0.mixer.conv1d.weight": "model-00001-of-00004.safetensors",
|
13 |
+
"model.layers.layers.0.mixer.dt_bias": "model-00001-of-00004.safetensors",
|
14 |
+
"model.layers.layers.0.mixer.dt_norm_weight": "model-00001-of-00004.safetensors",
|
15 |
+
"model.layers.layers.0.mixer.dt_proj.weight": "model-00001-of-00004.safetensors",
|
16 |
+
"model.layers.layers.0.mixer.in_proj.weight": "model-00001-of-00004.safetensors",
|
17 |
+
"model.layers.layers.0.mixer.out_proj.weight": "model-00001-of-00004.safetensors",
|
18 |
+
"model.layers.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
19 |
+
"model.layers.layers.0.mlp.gate_up_proj.weight": "model-00001-of-00004.safetensors",
|
20 |
+
"model.layers.layers.0.post_mixer_norm.weight": "model-00001-of-00004.safetensors",
|
21 |
+
"model.layers.layers.0.post_mlp_norm.weight": "model-00001-of-00004.safetensors",
|
22 |
+
"model.layers.layers.0.pre_mixer_norm.weight": "model-00001-of-00004.safetensors",
|
23 |
+
"model.layers.layers.0.pre_mlp_norm.weight": "model-00001-of-00004.safetensors",
|
24 |
+
"model.layers.layers.1.mixer.k_weight": "model-00001-of-00004.safetensors",
|
25 |
+
"model.layers.layers.1.mixer.o_proj.weight": "model-00001-of-00004.safetensors",
|
26 |
+
"model.layers.layers.1.mixer.q_weight": "model-00001-of-00004.safetensors",
|
27 |
+
"model.layers.layers.1.mixer.qkv_proj.weight": "model-00001-of-00004.safetensors",
|
28 |
+
"model.layers.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
29 |
+
"model.layers.layers.1.mlp.gate_up_proj.weight": "model-00001-of-00004.safetensors",
|
30 |
+
"model.layers.layers.1.post_mixer_norm.weight": "model-00001-of-00004.safetensors",
|
31 |
+
"model.layers.layers.1.post_mlp_norm.weight": "model-00001-of-00004.safetensors",
|
32 |
+
"model.layers.layers.1.pre_mixer_norm.weight": "model-00001-of-00004.safetensors",
|
33 |
+
"model.layers.layers.1.pre_mlp_norm.weight": "model-00001-of-00004.safetensors",
|
34 |
+
"model.layers.layers.10.mixer.A_log": "model-00002-of-00004.safetensors",
|
35 |
+
"model.layers.layers.10.mixer.B_norm_weight": "model-00002-of-00004.safetensors",
|
36 |
+
"model.layers.layers.10.mixer.C_norm_weight": "model-00002-of-00004.safetensors",
|
37 |
+
"model.layers.layers.10.mixer.D": "model-00002-of-00004.safetensors",
|
38 |
+
"model.layers.layers.10.mixer.bcdt_proj.weight": "model-00002-of-00004.safetensors",
|
39 |
+
"model.layers.layers.10.mixer.conv1d.weight": "model-00002-of-00004.safetensors",
|
40 |
+
"model.layers.layers.10.mixer.dt_bias": "model-00002-of-00004.safetensors",
|
41 |
+
"model.layers.layers.10.mixer.dt_norm_weight": "model-00002-of-00004.safetensors",
|
42 |
+
"model.layers.layers.10.mixer.dt_proj.weight": "model-00002-of-00004.safetensors",
|
43 |
+
"model.layers.layers.10.mixer.in_proj.weight": "model-00002-of-00004.safetensors",
|
44 |
+
"model.layers.layers.10.mixer.out_proj.weight": "model-00002-of-00004.safetensors",
|
45 |
+
"model.layers.layers.10.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
46 |
+
"model.layers.layers.10.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
|
47 |
+
"model.layers.layers.10.post_mixer_norm.weight": "model-00002-of-00004.safetensors",
|
48 |
+
"model.layers.layers.10.post_mlp_norm.weight": "model-00002-of-00004.safetensors",
|
49 |
+
"model.layers.layers.10.pre_mixer_norm.weight": "model-00002-of-00004.safetensors",
|
50 |
+
"model.layers.layers.10.pre_mlp_norm.weight": "model-00002-of-00004.safetensors",
|
51 |
+
"model.layers.layers.11.mixer.k_weight": "model-00002-of-00004.safetensors",
|
52 |
+
"model.layers.layers.11.mixer.o_proj.weight": "model-00002-of-00004.safetensors",
|
53 |
+
"model.layers.layers.11.mixer.q_weight": "model-00002-of-00004.safetensors",
|
54 |
+
"model.layers.layers.11.mixer.qkv_proj.weight": "model-00002-of-00004.safetensors",
|
55 |
+
"model.layers.layers.11.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
56 |
+
"model.layers.layers.11.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
|
57 |
+
"model.layers.layers.11.post_mixer_norm.weight": "model-00002-of-00004.safetensors",
|
58 |
+
"model.layers.layers.11.post_mlp_norm.weight": "model-00002-of-00004.safetensors",
|
59 |
+
"model.layers.layers.11.pre_mixer_norm.weight": "model-00002-of-00004.safetensors",
|
60 |
+
"model.layers.layers.11.pre_mlp_norm.weight": "model-00002-of-00004.safetensors",
|
61 |
+
"model.layers.layers.12.mixer.A_log": "model-00002-of-00004.safetensors",
|
62 |
+
"model.layers.layers.12.mixer.B_norm_weight": "model-00002-of-00004.safetensors",
|
63 |
+
"model.layers.layers.12.mixer.C_norm_weight": "model-00002-of-00004.safetensors",
|
64 |
+
"model.layers.layers.12.mixer.D": "model-00002-of-00004.safetensors",
|
65 |
+
"model.layers.layers.12.mixer.bcdt_proj.weight": "model-00002-of-00004.safetensors",
|
66 |
+
"model.layers.layers.12.mixer.conv1d.weight": "model-00002-of-00004.safetensors",
|
67 |
+
"model.layers.layers.12.mixer.dt_bias": "model-00002-of-00004.safetensors",
|
68 |
+
"model.layers.layers.12.mixer.dt_norm_weight": "model-00002-of-00004.safetensors",
|
69 |
+
"model.layers.layers.12.mixer.dt_proj.weight": "model-00002-of-00004.safetensors",
|
70 |
+
"model.layers.layers.12.mixer.in_proj.weight": "model-00002-of-00004.safetensors",
|
71 |
+
"model.layers.layers.12.mixer.out_proj.weight": "model-00002-of-00004.safetensors",
|
72 |
+
"model.layers.layers.12.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
73 |
+
"model.layers.layers.12.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
|
74 |
+
"model.layers.layers.12.post_mixer_norm.weight": "model-00002-of-00004.safetensors",
|
75 |
+
"model.layers.layers.12.post_mlp_norm.weight": "model-00002-of-00004.safetensors",
|
76 |
+
"model.layers.layers.12.pre_mixer_norm.weight": "model-00002-of-00004.safetensors",
|
77 |
+
"model.layers.layers.12.pre_mlp_norm.weight": "model-00002-of-00004.safetensors",
|
78 |
+
"model.layers.layers.13.mixer.k_weight": "model-00002-of-00004.safetensors",
|
79 |
+
"model.layers.layers.13.mixer.o_proj.weight": "model-00002-of-00004.safetensors",
|
80 |
+
"model.layers.layers.13.mixer.q_weight": "model-00002-of-00004.safetensors",
|
81 |
+
"model.layers.layers.13.mixer.qkv_proj.weight": "model-00002-of-00004.safetensors",
|
82 |
+
"model.layers.layers.13.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
83 |
+
"model.layers.layers.13.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
|
84 |
+
"model.layers.layers.13.post_mixer_norm.weight": "model-00002-of-00004.safetensors",
|
85 |
+
"model.layers.layers.13.post_mlp_norm.weight": "model-00002-of-00004.safetensors",
|
86 |
+
"model.layers.layers.13.pre_mixer_norm.weight": "model-00002-of-00004.safetensors",
|
87 |
+
"model.layers.layers.13.pre_mlp_norm.weight": "model-00002-of-00004.safetensors",
|
88 |
+
"model.layers.layers.14.mixer.A_log": "model-00002-of-00004.safetensors",
|
89 |
+
"model.layers.layers.14.mixer.B_norm_weight": "model-00002-of-00004.safetensors",
|
90 |
+
"model.layers.layers.14.mixer.C_norm_weight": "model-00002-of-00004.safetensors",
|
91 |
+
"model.layers.layers.14.mixer.D": "model-00002-of-00004.safetensors",
|
92 |
+
"model.layers.layers.14.mixer.bcdt_proj.weight": "model-00002-of-00004.safetensors",
|
93 |
+
"model.layers.layers.14.mixer.conv1d.weight": "model-00002-of-00004.safetensors",
|
94 |
+
"model.layers.layers.14.mixer.dt_bias": "model-00002-of-00004.safetensors",
|
95 |
+
"model.layers.layers.14.mixer.dt_norm_weight": "model-00002-of-00004.safetensors",
|
96 |
+
"model.layers.layers.14.mixer.dt_proj.weight": "model-00002-of-00004.safetensors",
|
97 |
+
"model.layers.layers.14.mixer.in_proj.weight": "model-00002-of-00004.safetensors",
|
98 |
+
"model.layers.layers.14.mixer.out_proj.weight": "model-00002-of-00004.safetensors",
|
99 |
+
"model.layers.layers.14.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
100 |
+
"model.layers.layers.14.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
|
101 |
+
"model.layers.layers.14.post_mixer_norm.weight": "model-00002-of-00004.safetensors",
|
102 |
+
"model.layers.layers.14.post_mlp_norm.weight": "model-00002-of-00004.safetensors",
|
103 |
+
"model.layers.layers.14.pre_mixer_norm.weight": "model-00002-of-00004.safetensors",
|
104 |
+
"model.layers.layers.14.pre_mlp_norm.weight": "model-00002-of-00004.safetensors",
|
105 |
+
"model.layers.layers.15.mixer.k_weight": "model-00002-of-00004.safetensors",
|
106 |
+
"model.layers.layers.15.mixer.o_proj.weight": "model-00002-of-00004.safetensors",
|
107 |
+
"model.layers.layers.15.mixer.q_weight": "model-00002-of-00004.safetensors",
|
108 |
+
"model.layers.layers.15.mixer.qkv_proj.weight": "model-00002-of-00004.safetensors",
|
109 |
+
"model.layers.layers.15.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
110 |
+
"model.layers.layers.15.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
|
111 |
+
"model.layers.layers.15.post_mixer_norm.weight": "model-00002-of-00004.safetensors",
|
112 |
+
"model.layers.layers.15.post_mlp_norm.weight": "model-00002-of-00004.safetensors",
|
113 |
+
"model.layers.layers.15.pre_mixer_norm.weight": "model-00002-of-00004.safetensors",
|
114 |
+
"model.layers.layers.15.pre_mlp_norm.weight": "model-00002-of-00004.safetensors",
|
115 |
+
"model.layers.layers.16.mixer.A_log": "model-00002-of-00004.safetensors",
|
116 |
+
"model.layers.layers.16.mixer.B_norm_weight": "model-00002-of-00004.safetensors",
|
117 |
+
"model.layers.layers.16.mixer.C_norm_weight": "model-00002-of-00004.safetensors",
|
118 |
+
"model.layers.layers.16.mixer.D": "model-00002-of-00004.safetensors",
|
119 |
+
"model.layers.layers.16.mixer.bcdt_proj.weight": "model-00002-of-00004.safetensors",
|
120 |
+
"model.layers.layers.16.mixer.conv1d.weight": "model-00002-of-00004.safetensors",
|
121 |
+
"model.layers.layers.16.mixer.dt_bias": "model-00002-of-00004.safetensors",
|
122 |
+
"model.layers.layers.16.mixer.dt_norm_weight": "model-00002-of-00004.safetensors",
|
123 |
+
"model.layers.layers.16.mixer.dt_proj.weight": "model-00002-of-00004.safetensors",
|
124 |
+
"model.layers.layers.16.mixer.in_proj.weight": "model-00002-of-00004.safetensors",
|
125 |
+
"model.layers.layers.16.mixer.out_proj.weight": "model-00002-of-00004.safetensors",
|
126 |
+
"model.layers.layers.16.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
127 |
+
"model.layers.layers.16.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
|
128 |
+
"model.layers.layers.16.post_mixer_norm.weight": "model-00003-of-00004.safetensors",
|
129 |
+
"model.layers.layers.16.post_mlp_norm.weight": "model-00003-of-00004.safetensors",
|
130 |
+
"model.layers.layers.16.pre_mixer_norm.weight": "model-00003-of-00004.safetensors",
|
131 |
+
"model.layers.layers.16.pre_mlp_norm.weight": "model-00003-of-00004.safetensors",
|
132 |
+
"model.layers.layers.17.mixer.k_weight": "model-00003-of-00004.safetensors",
|
133 |
+
"model.layers.layers.17.mixer.o_proj.weight": "model-00003-of-00004.safetensors",
|
134 |
+
"model.layers.layers.17.mixer.q_weight": "model-00003-of-00004.safetensors",
|
135 |
+
"model.layers.layers.17.mixer.qkv_proj.weight": "model-00003-of-00004.safetensors",
|
136 |
+
"model.layers.layers.17.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
137 |
+
"model.layers.layers.17.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
|
138 |
+
"model.layers.layers.17.post_mixer_norm.weight": "model-00003-of-00004.safetensors",
|
139 |
+
"model.layers.layers.17.post_mlp_norm.weight": "model-00003-of-00004.safetensors",
|
140 |
+
"model.layers.layers.17.pre_mixer_norm.weight": "model-00003-of-00004.safetensors",
|
141 |
+
"model.layers.layers.17.pre_mlp_norm.weight": "model-00003-of-00004.safetensors",
|
142 |
+
"model.layers.layers.18.mixer.A_log": "model-00003-of-00004.safetensors",
|
143 |
+
"model.layers.layers.18.mixer.B_norm_weight": "model-00003-of-00004.safetensors",
|
144 |
+
"model.layers.layers.18.mixer.C_norm_weight": "model-00003-of-00004.safetensors",
|
145 |
+
"model.layers.layers.18.mixer.D": "model-00003-of-00004.safetensors",
|
146 |
+
"model.layers.layers.18.mixer.bcdt_proj.weight": "model-00003-of-00004.safetensors",
|
147 |
+
"model.layers.layers.18.mixer.conv1d.weight": "model-00003-of-00004.safetensors",
|
148 |
+
"model.layers.layers.18.mixer.dt_bias": "model-00003-of-00004.safetensors",
|
149 |
+
"model.layers.layers.18.mixer.dt_norm_weight": "model-00003-of-00004.safetensors",
|
150 |
+
"model.layers.layers.18.mixer.dt_proj.weight": "model-00003-of-00004.safetensors",
|
151 |
+
"model.layers.layers.18.mixer.in_proj.weight": "model-00003-of-00004.safetensors",
|
152 |
+
"model.layers.layers.18.mixer.out_proj.weight": "model-00003-of-00004.safetensors",
|
153 |
+
"model.layers.layers.18.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
154 |
+
"model.layers.layers.18.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
|
155 |
+
"model.layers.layers.18.post_mixer_norm.weight": "model-00003-of-00004.safetensors",
|
156 |
+
"model.layers.layers.18.post_mlp_norm.weight": "model-00003-of-00004.safetensors",
|
157 |
+
"model.layers.layers.18.pre_mixer_norm.weight": "model-00003-of-00004.safetensors",
|
158 |
+
"model.layers.layers.18.pre_mlp_norm.weight": "model-00003-of-00004.safetensors",
|
159 |
+
"model.layers.layers.19.mixer.k_weight": "model-00003-of-00004.safetensors",
|
160 |
+
"model.layers.layers.19.mixer.o_proj.weight": "model-00003-of-00004.safetensors",
|
161 |
+
"model.layers.layers.19.mixer.q_weight": "model-00003-of-00004.safetensors",
|
162 |
+
"model.layers.layers.19.mixer.qkv_proj.weight": "model-00003-of-00004.safetensors",
|
163 |
+
"model.layers.layers.19.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
164 |
+
"model.layers.layers.19.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
|
165 |
+
"model.layers.layers.19.post_mixer_norm.weight": "model-00003-of-00004.safetensors",
|
166 |
+
"model.layers.layers.19.post_mlp_norm.weight": "model-00003-of-00004.safetensors",
|
167 |
+
"model.layers.layers.19.pre_mixer_norm.weight": "model-00003-of-00004.safetensors",
|
168 |
+
"model.layers.layers.19.pre_mlp_norm.weight": "model-00003-of-00004.safetensors",
|
169 |
+
"model.layers.layers.2.mixer.A_log": "model-00001-of-00004.safetensors",
|
170 |
+
"model.layers.layers.2.mixer.B_norm_weight": "model-00001-of-00004.safetensors",
|
171 |
+
"model.layers.layers.2.mixer.C_norm_weight": "model-00001-of-00004.safetensors",
|
172 |
+
"model.layers.layers.2.mixer.D": "model-00001-of-00004.safetensors",
|
173 |
+
"model.layers.layers.2.mixer.bcdt_proj.weight": "model-00001-of-00004.safetensors",
|
174 |
+
"model.layers.layers.2.mixer.conv1d.weight": "model-00001-of-00004.safetensors",
|
175 |
+
"model.layers.layers.2.mixer.dt_bias": "model-00001-of-00004.safetensors",
|
176 |
+
"model.layers.layers.2.mixer.dt_norm_weight": "model-00001-of-00004.safetensors",
|
177 |
+
"model.layers.layers.2.mixer.dt_proj.weight": "model-00001-of-00004.safetensors",
|
178 |
+
"model.layers.layers.2.mixer.in_proj.weight": "model-00001-of-00004.safetensors",
|
179 |
+
"model.layers.layers.2.mixer.out_proj.weight": "model-00001-of-00004.safetensors",
|
180 |
+
"model.layers.layers.2.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
181 |
+
"model.layers.layers.2.mlp.gate_up_proj.weight": "model-00001-of-00004.safetensors",
|
182 |
+
"model.layers.layers.2.post_mixer_norm.weight": "model-00001-of-00004.safetensors",
|
183 |
+
"model.layers.layers.2.post_mlp_norm.weight": "model-00001-of-00004.safetensors",
|
184 |
+
"model.layers.layers.2.pre_mixer_norm.weight": "model-00001-of-00004.safetensors",
|
185 |
+
"model.layers.layers.2.pre_mlp_norm.weight": "model-00001-of-00004.safetensors",
|
186 |
+
"model.layers.layers.20.mixer.A_log": "model-00003-of-00004.safetensors",
|
187 |
+
"model.layers.layers.20.mixer.B_norm_weight": "model-00003-of-00004.safetensors",
|
188 |
+
"model.layers.layers.20.mixer.C_norm_weight": "model-00003-of-00004.safetensors",
|
189 |
+
"model.layers.layers.20.mixer.D": "model-00003-of-00004.safetensors",
|
190 |
+
"model.layers.layers.20.mixer.bcdt_proj.weight": "model-00003-of-00004.safetensors",
|
191 |
+
"model.layers.layers.20.mixer.conv1d.weight": "model-00003-of-00004.safetensors",
|
192 |
+
"model.layers.layers.20.mixer.dt_bias": "model-00003-of-00004.safetensors",
|
193 |
+
"model.layers.layers.20.mixer.dt_norm_weight": "model-00003-of-00004.safetensors",
|
194 |
+
"model.layers.layers.20.mixer.dt_proj.weight": "model-00003-of-00004.safetensors",
|
195 |
+
"model.layers.layers.20.mixer.in_proj.weight": "model-00003-of-00004.safetensors",
|
196 |
+
"model.layers.layers.20.mixer.out_proj.weight": "model-00003-of-00004.safetensors",
|
197 |
+
"model.layers.layers.20.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
198 |
+
"model.layers.layers.20.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
|
199 |
+
"model.layers.layers.20.post_mixer_norm.weight": "model-00003-of-00004.safetensors",
|
200 |
+
"model.layers.layers.20.post_mlp_norm.weight": "model-00003-of-00004.safetensors",
|
201 |
+
"model.layers.layers.20.pre_mixer_norm.weight": "model-00003-of-00004.safetensors",
|
202 |
+
"model.layers.layers.20.pre_mlp_norm.weight": "model-00003-of-00004.safetensors",
|
203 |
+
"model.layers.layers.21.mixer.k_weight": "model-00003-of-00004.safetensors",
|
204 |
+
"model.layers.layers.21.mixer.o_proj.weight": "model-00003-of-00004.safetensors",
|
205 |
+
"model.layers.layers.21.mixer.q_weight": "model-00003-of-00004.safetensors",
|
206 |
+
"model.layers.layers.21.mixer.qkv_proj.weight": "model-00003-of-00004.safetensors",
|
207 |
+
"model.layers.layers.21.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
208 |
+
"model.layers.layers.21.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
|
209 |
+
"model.layers.layers.21.post_mixer_norm.weight": "model-00003-of-00004.safetensors",
|
210 |
+
"model.layers.layers.21.post_mlp_norm.weight": "model-00003-of-00004.safetensors",
|
211 |
+
"model.layers.layers.21.pre_mixer_norm.weight": "model-00003-of-00004.safetensors",
|
212 |
+
"model.layers.layers.21.pre_mlp_norm.weight": "model-00003-of-00004.safetensors",
|
213 |
+
"model.layers.layers.22.mixer.A_log": "model-00003-of-00004.safetensors",
|
214 |
+
"model.layers.layers.22.mixer.B_norm_weight": "model-00003-of-00004.safetensors",
|
215 |
+
"model.layers.layers.22.mixer.C_norm_weight": "model-00003-of-00004.safetensors",
|
216 |
+
"model.layers.layers.22.mixer.D": "model-00003-of-00004.safetensors",
|
217 |
+
"model.layers.layers.22.mixer.bcdt_proj.weight": "model-00003-of-00004.safetensors",
|
218 |
+
"model.layers.layers.22.mixer.conv1d.weight": "model-00003-of-00004.safetensors",
|
219 |
+
"model.layers.layers.22.mixer.dt_bias": "model-00003-of-00004.safetensors",
|
220 |
+
"model.layers.layers.22.mixer.dt_norm_weight": "model-00003-of-00004.safetensors",
|
221 |
+
"model.layers.layers.22.mixer.dt_proj.weight": "model-00003-of-00004.safetensors",
|
222 |
+
"model.layers.layers.22.mixer.in_proj.weight": "model-00003-of-00004.safetensors",
|
223 |
+
"model.layers.layers.22.mixer.out_proj.weight": "model-00003-of-00004.safetensors",
|
224 |
+
"model.layers.layers.22.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
225 |
+
"model.layers.layers.22.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
|
226 |
+
"model.layers.layers.22.post_mixer_norm.weight": "model-00003-of-00004.safetensors",
|
227 |
+
"model.layers.layers.22.post_mlp_norm.weight": "model-00003-of-00004.safetensors",
|
228 |
+
"model.layers.layers.22.pre_mixer_norm.weight": "model-00003-of-00004.safetensors",
|
229 |
+
"model.layers.layers.22.pre_mlp_norm.weight": "model-00003-of-00004.safetensors",
|
230 |
+
"model.layers.layers.23.mixer.k_weight": "model-00003-of-00004.safetensors",
|
231 |
+
"model.layers.layers.23.mixer.o_proj.weight": "model-00003-of-00004.safetensors",
|
232 |
+
"model.layers.layers.23.mixer.q_weight": "model-00003-of-00004.safetensors",
|
233 |
+
"model.layers.layers.23.mixer.qkv_proj.weight": "model-00003-of-00004.safetensors",
|
234 |
+
"model.layers.layers.23.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
235 |
+
"model.layers.layers.23.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
|
236 |
+
"model.layers.layers.23.post_mixer_norm.weight": "model-00003-of-00004.safetensors",
|
237 |
+
"model.layers.layers.23.post_mlp_norm.weight": "model-00003-of-00004.safetensors",
|
238 |
+
"model.layers.layers.23.pre_mixer_norm.weight": "model-00003-of-00004.safetensors",
|
239 |
+
"model.layers.layers.23.pre_mlp_norm.weight": "model-00003-of-00004.safetensors",
|
240 |
+
"model.layers.layers.24.mixer.A_log": "model-00003-of-00004.safetensors",
|
241 |
+
"model.layers.layers.24.mixer.B_norm_weight": "model-00003-of-00004.safetensors",
|
242 |
+
"model.layers.layers.24.mixer.C_norm_weight": "model-00003-of-00004.safetensors",
|
243 |
+
"model.layers.layers.24.mixer.D": "model-00003-of-00004.safetensors",
|
244 |
+
"model.layers.layers.24.mixer.bcdt_proj.weight": "model-00003-of-00004.safetensors",
|
245 |
+
"model.layers.layers.24.mixer.conv1d.weight": "model-00003-of-00004.safetensors",
|
246 |
+
"model.layers.layers.24.mixer.dt_bias": "model-00003-of-00004.safetensors",
|
247 |
+
"model.layers.layers.24.mixer.dt_norm_weight": "model-00003-of-00004.safetensors",
|
248 |
+
"model.layers.layers.24.mixer.dt_proj.weight": "model-00003-of-00004.safetensors",
|
249 |
+
"model.layers.layers.24.mixer.in_proj.weight": "model-00003-of-00004.safetensors",
|
250 |
+
"model.layers.layers.24.mixer.out_proj.weight": "model-00003-of-00004.safetensors",
|
251 |
+
"model.layers.layers.24.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
252 |
+
"model.layers.layers.24.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
|
253 |
+
"model.layers.layers.24.post_mixer_norm.weight": "model-00003-of-00004.safetensors",
|
254 |
+
"model.layers.layers.24.post_mlp_norm.weight": "model-00003-of-00004.safetensors",
|
255 |
+
"model.layers.layers.24.pre_mixer_norm.weight": "model-00003-of-00004.safetensors",
|
256 |
+
"model.layers.layers.24.pre_mlp_norm.weight": "model-00003-of-00004.safetensors",
|
257 |
+
"model.layers.layers.25.mixer.k_weight": "model-00003-of-00004.safetensors",
|
258 |
+
"model.layers.layers.25.mixer.o_proj.weight": "model-00003-of-00004.safetensors",
|
259 |
+
"model.layers.layers.25.mixer.q_weight": "model-00003-of-00004.safetensors",
|
260 |
+
"model.layers.layers.25.mixer.qkv_proj.weight": "model-00003-of-00004.safetensors",
|
261 |
+
"model.layers.layers.25.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
262 |
+
"model.layers.layers.25.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
|
263 |
+
"model.layers.layers.25.post_mixer_norm.weight": "model-00004-of-00004.safetensors",
|
264 |
+
"model.layers.layers.25.post_mlp_norm.weight": "model-00004-of-00004.safetensors",
|
265 |
+
"model.layers.layers.25.pre_mixer_norm.weight": "model-00004-of-00004.safetensors",
|
266 |
+
"model.layers.layers.25.pre_mlp_norm.weight": "model-00004-of-00004.safetensors",
|
267 |
+
"model.layers.layers.26.mixer.A_log": "model-00004-of-00004.safetensors",
|
268 |
+
"model.layers.layers.26.mixer.B_norm_weight": "model-00004-of-00004.safetensors",
|
269 |
+
"model.layers.layers.26.mixer.C_norm_weight": "model-00004-of-00004.safetensors",
|
270 |
+
"model.layers.layers.26.mixer.D": "model-00004-of-00004.safetensors",
|
271 |
+
"model.layers.layers.26.mixer.bcdt_proj.weight": "model-00004-of-00004.safetensors",
|
272 |
+
"model.layers.layers.26.mixer.conv1d.weight": "model-00004-of-00004.safetensors",
|
273 |
+
"model.layers.layers.26.mixer.dt_bias": "model-00004-of-00004.safetensors",
|
274 |
+
"model.layers.layers.26.mixer.dt_norm_weight": "model-00004-of-00004.safetensors",
|
275 |
+
"model.layers.layers.26.mixer.dt_proj.weight": "model-00004-of-00004.safetensors",
|
276 |
+
"model.layers.layers.26.mixer.in_proj.weight": "model-00004-of-00004.safetensors",
|
277 |
+
"model.layers.layers.26.mixer.out_proj.weight": "model-00004-of-00004.safetensors",
|
278 |
+
"model.layers.layers.26.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
279 |
+
"model.layers.layers.26.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
|
280 |
+
"model.layers.layers.26.post_mixer_norm.weight": "model-00004-of-00004.safetensors",
|
281 |
+
"model.layers.layers.26.post_mlp_norm.weight": "model-00004-of-00004.safetensors",
|
282 |
+
"model.layers.layers.26.pre_mixer_norm.weight": "model-00004-of-00004.safetensors",
|
283 |
+
"model.layers.layers.26.pre_mlp_norm.weight": "model-00004-of-00004.safetensors",
|
284 |
+
"model.layers.layers.27.mixer.k_weight": "model-00004-of-00004.safetensors",
|
285 |
+
"model.layers.layers.27.mixer.o_proj.weight": "model-00004-of-00004.safetensors",
|
286 |
+
"model.layers.layers.27.mixer.q_weight": "model-00004-of-00004.safetensors",
|
287 |
+
"model.layers.layers.27.mixer.qkv_proj.weight": "model-00004-of-00004.safetensors",
|
288 |
+
"model.layers.layers.27.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
289 |
+
"model.layers.layers.27.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
|
290 |
+
"model.layers.layers.27.post_mixer_norm.weight": "model-00004-of-00004.safetensors",
|
291 |
+
"model.layers.layers.27.post_mlp_norm.weight": "model-00004-of-00004.safetensors",
|
292 |
+
"model.layers.layers.27.pre_mixer_norm.weight": "model-00004-of-00004.safetensors",
|
293 |
+
"model.layers.layers.27.pre_mlp_norm.weight": "model-00004-of-00004.safetensors",
|
294 |
+
"model.layers.layers.28.mixer.A_log": "model-00004-of-00004.safetensors",
|
295 |
+
"model.layers.layers.28.mixer.B_norm_weight": "model-00004-of-00004.safetensors",
|
296 |
+
"model.layers.layers.28.mixer.C_norm_weight": "model-00004-of-00004.safetensors",
|
297 |
+
"model.layers.layers.28.mixer.D": "model-00004-of-00004.safetensors",
|
298 |
+
"model.layers.layers.28.mixer.bcdt_proj.weight": "model-00004-of-00004.safetensors",
|
299 |
+
"model.layers.layers.28.mixer.conv1d.weight": "model-00004-of-00004.safetensors",
|
300 |
+
"model.layers.layers.28.mixer.dt_bias": "model-00004-of-00004.safetensors",
|
301 |
+
"model.layers.layers.28.mixer.dt_norm_weight": "model-00004-of-00004.safetensors",
|
302 |
+
"model.layers.layers.28.mixer.dt_proj.weight": "model-00004-of-00004.safetensors",
|
303 |
+
"model.layers.layers.28.mixer.in_proj.weight": "model-00004-of-00004.safetensors",
|
304 |
+
"model.layers.layers.28.mixer.out_proj.weight": "model-00004-of-00004.safetensors",
|
305 |
+
"model.layers.layers.28.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
306 |
+
"model.layers.layers.28.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
|
307 |
+
"model.layers.layers.28.post_mixer_norm.weight": "model-00004-of-00004.safetensors",
|
308 |
+
"model.layers.layers.28.post_mlp_norm.weight": "model-00004-of-00004.safetensors",
|
309 |
+
"model.layers.layers.28.pre_mixer_norm.weight": "model-00004-of-00004.safetensors",
|
310 |
+
"model.layers.layers.28.pre_mlp_norm.weight": "model-00004-of-00004.safetensors",
|
311 |
+
"model.layers.layers.29.mixer.k_weight": "model-00004-of-00004.safetensors",
|
312 |
+
"model.layers.layers.29.mixer.o_proj.weight": "model-00004-of-00004.safetensors",
|
313 |
+
"model.layers.layers.29.mixer.q_weight": "model-00004-of-00004.safetensors",
|
314 |
+
"model.layers.layers.29.mixer.qkv_proj.weight": "model-00004-of-00004.safetensors",
|
315 |
+
"model.layers.layers.29.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
316 |
+
"model.layers.layers.29.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
|
317 |
+
"model.layers.layers.29.post_mixer_norm.weight": "model-00004-of-00004.safetensors",
|
318 |
+
"model.layers.layers.29.post_mlp_norm.weight": "model-00004-of-00004.safetensors",
|
319 |
+
"model.layers.layers.29.pre_mixer_norm.weight": "model-00004-of-00004.safetensors",
|
320 |
+
"model.layers.layers.29.pre_mlp_norm.weight": "model-00004-of-00004.safetensors",
|
321 |
+
"model.layers.layers.3.mixer.k_weight": "model-00001-of-00004.safetensors",
|
322 |
+
"model.layers.layers.3.mixer.o_proj.weight": "model-00001-of-00004.safetensors",
|
323 |
+
"model.layers.layers.3.mixer.q_weight": "model-00001-of-00004.safetensors",
|
324 |
+
"model.layers.layers.3.mixer.qkv_proj.weight": "model-00001-of-00004.safetensors",
|
325 |
+
"model.layers.layers.3.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
326 |
+
"model.layers.layers.3.mlp.gate_up_proj.weight": "model-00001-of-00004.safetensors",
|
327 |
+
"model.layers.layers.3.post_mixer_norm.weight": "model-00001-of-00004.safetensors",
|
328 |
+
"model.layers.layers.3.post_mlp_norm.weight": "model-00001-of-00004.safetensors",
|
329 |
+
"model.layers.layers.3.pre_mixer_norm.weight": "model-00001-of-00004.safetensors",
|
330 |
+
"model.layers.layers.3.pre_mlp_norm.weight": "model-00001-of-00004.safetensors",
|
331 |
+
"model.layers.layers.30.mixer.A_log": "model-00004-of-00004.safetensors",
|
332 |
+
"model.layers.layers.30.mixer.B_norm_weight": "model-00004-of-00004.safetensors",
|
333 |
+
"model.layers.layers.30.mixer.C_norm_weight": "model-00004-of-00004.safetensors",
|
334 |
+
"model.layers.layers.30.mixer.D": "model-00004-of-00004.safetensors",
|
335 |
+
"model.layers.layers.30.mixer.bcdt_proj.weight": "model-00004-of-00004.safetensors",
|
336 |
+
"model.layers.layers.30.mixer.conv1d.weight": "model-00004-of-00004.safetensors",
|
337 |
+
"model.layers.layers.30.mixer.dt_bias": "model-00004-of-00004.safetensors",
|
338 |
+
"model.layers.layers.30.mixer.dt_norm_weight": "model-00004-of-00004.safetensors",
|
339 |
+
"model.layers.layers.30.mixer.dt_proj.weight": "model-00004-of-00004.safetensors",
|
340 |
+
"model.layers.layers.30.mixer.in_proj.weight": "model-00004-of-00004.safetensors",
|
341 |
+
"model.layers.layers.30.mixer.out_proj.weight": "model-00004-of-00004.safetensors",
|
342 |
+
"model.layers.layers.30.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
343 |
+
"model.layers.layers.30.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
|
344 |
+
"model.layers.layers.30.post_mixer_norm.weight": "model-00004-of-00004.safetensors",
|
345 |
+
"model.layers.layers.30.post_mlp_norm.weight": "model-00004-of-00004.safetensors",
|
346 |
+
"model.layers.layers.30.pre_mixer_norm.weight": "model-00004-of-00004.safetensors",
|
347 |
+
"model.layers.layers.30.pre_mlp_norm.weight": "model-00004-of-00004.safetensors",
|
348 |
+
"model.layers.layers.31.mixer.k_weight": "model-00004-of-00004.safetensors",
|
349 |
+
"model.layers.layers.31.mixer.o_proj.weight": "model-00004-of-00004.safetensors",
|
350 |
+
"model.layers.layers.31.mixer.q_weight": "model-00004-of-00004.safetensors",
|
351 |
+
"model.layers.layers.31.mixer.qkv_proj.weight": "model-00004-of-00004.safetensors",
|
352 |
+
"model.layers.layers.31.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
353 |
+
"model.layers.layers.31.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
|
354 |
+
"model.layers.layers.31.post_mixer_norm.weight": "model-00004-of-00004.safetensors",
|
355 |
+
"model.layers.layers.31.post_mlp_norm.weight": "model-00004-of-00004.safetensors",
|
356 |
+
"model.layers.layers.31.pre_mixer_norm.weight": "model-00004-of-00004.safetensors",
|
357 |
+
"model.layers.layers.31.pre_mlp_norm.weight": "model-00004-of-00004.safetensors",
|
358 |
+
"model.layers.layers.4.mixer.A_log": "model-00001-of-00004.safetensors",
|
359 |
+
"model.layers.layers.4.mixer.B_norm_weight": "model-00001-of-00004.safetensors",
|
360 |
+
"model.layers.layers.4.mixer.C_norm_weight": "model-00001-of-00004.safetensors",
|
361 |
+
"model.layers.layers.4.mixer.D": "model-00001-of-00004.safetensors",
|
362 |
+
"model.layers.layers.4.mixer.bcdt_proj.weight": "model-00001-of-00004.safetensors",
|
363 |
+
"model.layers.layers.4.mixer.conv1d.weight": "model-00001-of-00004.safetensors",
|
364 |
+
"model.layers.layers.4.mixer.dt_bias": "model-00001-of-00004.safetensors",
|
365 |
+
"model.layers.layers.4.mixer.dt_norm_weight": "model-00001-of-00004.safetensors",
|
366 |
+
"model.layers.layers.4.mixer.dt_proj.weight": "model-00001-of-00004.safetensors",
|
367 |
+
"model.layers.layers.4.mixer.in_proj.weight": "model-00001-of-00004.safetensors",
|
368 |
+
"model.layers.layers.4.mixer.out_proj.weight": "model-00001-of-00004.safetensors",
|
369 |
+
"model.layers.layers.4.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
370 |
+
"model.layers.layers.4.mlp.gate_up_proj.weight": "model-00001-of-00004.safetensors",
|
371 |
+
"model.layers.layers.4.post_mixer_norm.weight": "model-00001-of-00004.safetensors",
|
372 |
+
"model.layers.layers.4.post_mlp_norm.weight": "model-00001-of-00004.safetensors",
|
373 |
+
"model.layers.layers.4.pre_mixer_norm.weight": "model-00001-of-00004.safetensors",
|
374 |
+
"model.layers.layers.4.pre_mlp_norm.weight": "model-00001-of-00004.safetensors",
|
375 |
+
"model.layers.layers.5.mixer.k_weight": "model-00001-of-00004.safetensors",
|
376 |
+
"model.layers.layers.5.mixer.o_proj.weight": "model-00001-of-00004.safetensors",
|
377 |
+
"model.layers.layers.5.mixer.q_weight": "model-00001-of-00004.safetensors",
|
378 |
+
"model.layers.layers.5.mixer.qkv_proj.weight": "model-00001-of-00004.safetensors",
|
379 |
+
"model.layers.layers.5.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
380 |
+
"model.layers.layers.5.mlp.gate_up_proj.weight": "model-00001-of-00004.safetensors",
|
381 |
+
"model.layers.layers.5.post_mixer_norm.weight": "model-00001-of-00004.safetensors",
|
382 |
+
"model.layers.layers.5.post_mlp_norm.weight": "model-00001-of-00004.safetensors",
|
383 |
+
"model.layers.layers.5.pre_mixer_norm.weight": "model-00001-of-00004.safetensors",
|
384 |
+
"model.layers.layers.5.pre_mlp_norm.weight": "model-00001-of-00004.safetensors",
|
385 |
+
"model.layers.layers.6.mixer.A_log": "model-00001-of-00004.safetensors",
|
386 |
+
"model.layers.layers.6.mixer.B_norm_weight": "model-00001-of-00004.safetensors",
|
387 |
+
"model.layers.layers.6.mixer.C_norm_weight": "model-00001-of-00004.safetensors",
|
388 |
+
"model.layers.layers.6.mixer.D": "model-00001-of-00004.safetensors",
|
389 |
+
"model.layers.layers.6.mixer.bcdt_proj.weight": "model-00001-of-00004.safetensors",
|
390 |
+
"model.layers.layers.6.mixer.conv1d.weight": "model-00001-of-00004.safetensors",
|
391 |
+
"model.layers.layers.6.mixer.dt_bias": "model-00001-of-00004.safetensors",
|
392 |
+
"model.layers.layers.6.mixer.dt_norm_weight": "model-00001-of-00004.safetensors",
|
393 |
+
"model.layers.layers.6.mixer.dt_proj.weight": "model-00001-of-00004.safetensors",
|
394 |
+
"model.layers.layers.6.mixer.in_proj.weight": "model-00001-of-00004.safetensors",
|
395 |
+
"model.layers.layers.6.mixer.out_proj.weight": "model-00001-of-00004.safetensors",
|
396 |
+
"model.layers.layers.6.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
397 |
+
"model.layers.layers.6.mlp.gate_up_proj.weight": "model-00001-of-00004.safetensors",
|
398 |
+
"model.layers.layers.6.post_mixer_norm.weight": "model-00001-of-00004.safetensors",
|
399 |
+
"model.layers.layers.6.post_mlp_norm.weight": "model-00001-of-00004.safetensors",
|
400 |
+
"model.layers.layers.6.pre_mixer_norm.weight": "model-00001-of-00004.safetensors",
|
401 |
+
"model.layers.layers.6.pre_mlp_norm.weight": "model-00001-of-00004.safetensors",
|
402 |
+
"model.layers.layers.7.mixer.k_weight": "model-00001-of-00004.safetensors",
|
403 |
+
"model.layers.layers.7.mixer.o_proj.weight": "model-00001-of-00004.safetensors",
|
404 |
+
"model.layers.layers.7.mixer.q_weight": "model-00001-of-00004.safetensors",
|
405 |
+
"model.layers.layers.7.mixer.qkv_proj.weight": "model-00001-of-00004.safetensors",
|
406 |
+
"model.layers.layers.7.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
407 |
+
"model.layers.layers.7.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
|
408 |
+
"model.layers.layers.7.post_mixer_norm.weight": "model-00002-of-00004.safetensors",
|
409 |
+
"model.layers.layers.7.post_mlp_norm.weight": "model-00002-of-00004.safetensors",
|
410 |
+
"model.layers.layers.7.pre_mixer_norm.weight": "model-00002-of-00004.safetensors",
|
411 |
+
"model.layers.layers.7.pre_mlp_norm.weight": "model-00002-of-00004.safetensors",
|
412 |
+
"model.layers.layers.8.mixer.A_log": "model-00002-of-00004.safetensors",
|
413 |
+
"model.layers.layers.8.mixer.B_norm_weight": "model-00002-of-00004.safetensors",
|
414 |
+
"model.layers.layers.8.mixer.C_norm_weight": "model-00002-of-00004.safetensors",
|
415 |
+
"model.layers.layers.8.mixer.D": "model-00002-of-00004.safetensors",
|
416 |
+
"model.layers.layers.8.mixer.bcdt_proj.weight": "model-00002-of-00004.safetensors",
|
417 |
+
"model.layers.layers.8.mixer.conv1d.weight": "model-00002-of-00004.safetensors",
|
418 |
+
"model.layers.layers.8.mixer.dt_bias": "model-00002-of-00004.safetensors",
|
419 |
+
"model.layers.layers.8.mixer.dt_norm_weight": "model-00002-of-00004.safetensors",
|
420 |
+
"model.layers.layers.8.mixer.dt_proj.weight": "model-00002-of-00004.safetensors",
|
421 |
+
"model.layers.layers.8.mixer.in_proj.weight": "model-00002-of-00004.safetensors",
|
422 |
+
"model.layers.layers.8.mixer.out_proj.weight": "model-00002-of-00004.safetensors",
|
423 |
+
"model.layers.layers.8.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
424 |
+
"model.layers.layers.8.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
|
425 |
+
"model.layers.layers.8.post_mixer_norm.weight": "model-00002-of-00004.safetensors",
|
426 |
+
"model.layers.layers.8.post_mlp_norm.weight": "model-00002-of-00004.safetensors",
|
427 |
+
"model.layers.layers.8.pre_mixer_norm.weight": "model-00002-of-00004.safetensors",
|
428 |
+
"model.layers.layers.8.pre_mlp_norm.weight": "model-00002-of-00004.safetensors",
|
429 |
+
"model.layers.layers.9.mixer.k_weight": "model-00002-of-00004.safetensors",
|
430 |
+
"model.layers.layers.9.mixer.o_proj.weight": "model-00002-of-00004.safetensors",
|
431 |
+
"model.layers.layers.9.mixer.q_weight": "model-00002-of-00004.safetensors",
|
432 |
+
"model.layers.layers.9.mixer.qkv_proj.weight": "model-00002-of-00004.safetensors",
|
433 |
+
"model.layers.layers.9.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
434 |
+
"model.layers.layers.9.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
|
435 |
+
"model.layers.layers.9.post_mixer_norm.weight": "model-00002-of-00004.safetensors",
|
436 |
+
"model.layers.layers.9.post_mlp_norm.weight": "model-00002-of-00004.safetensors",
|
437 |
+
"model.layers.layers.9.pre_mixer_norm.weight": "model-00002-of-00004.safetensors",
|
438 |
+
"model.layers.layers.9.pre_mlp_norm.weight": "model-00002-of-00004.safetensors",
|
439 |
+
"model.norm.weight": "model-00004-of-00004.safetensors"
|
440 |
+
}
|
441 |
+
}
|
modeling_plamo.py
ADDED
@@ -0,0 +1,1707 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import enum
|
2 |
+
import math
|
3 |
+
import warnings
|
4 |
+
from typing import Any, Dict, List, Literal, NamedTuple, Optional, Tuple, Union
|
5 |
+
|
6 |
+
try:
|
7 |
+
# It is difficult to install mamba_ssm in login node because
|
8 |
+
# it requires GPU for installation
|
9 |
+
import mamba_ssm
|
10 |
+
except ModuleNotFoundError:
|
11 |
+
warnings.warn("mamba_ssm could not be imported", stacklevel=2)
|
12 |
+
try:
|
13 |
+
# It is difficult to install causal_conv1d in login node because
|
14 |
+
# it requires GPU for installation
|
15 |
+
import causal_conv1d.causal_conv1d_interface as causal_conv1d
|
16 |
+
except ModuleNotFoundError:
|
17 |
+
warnings.warn("causal_conv1d could not be imported", stacklevel=2)
|
18 |
+
import torch
|
19 |
+
from torch import nn
|
20 |
+
from torch.nn import functional as F
|
21 |
+
from transformers import PretrainedConfig, PreTrainedModel
|
22 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
23 |
+
|
24 |
+
|
25 |
+
def _is_first_token(mask: torch.Tensor) -> torch.Tensor:
|
26 |
+
assert mask.dtype == torch.bool
|
27 |
+
B, Nh, q_len, kv_len = mask.shape
|
28 |
+
mask = mask[:, :, :, -q_len:]
|
29 |
+
cont = q_len != kv_len
|
30 |
+
v = False if cont else True
|
31 |
+
out = torch.logical_not(torch.diagonal(mask, offset=-1, dim1=-2, dim2=-1).bool())
|
32 |
+
out = torch.cat(
|
33 |
+
[
|
34 |
+
torch.full(size=(B, Nh, 1), dtype=torch.bool, device=out.device, fill_value=v),
|
35 |
+
out,
|
36 |
+
],
|
37 |
+
dim=-1,
|
38 |
+
)
|
39 |
+
return out
|
40 |
+
|
41 |
+
|
42 |
+
def _swiglu(h: torch.Tensor) -> torch.Tensor:
|
43 |
+
h0, h1 = h.chunk(2, dim=-1)
|
44 |
+
return torch.nn.functional.silu(h0) * h1
|
45 |
+
|
46 |
+
|
47 |
+
class RotaryEmbedding(torch.nn.Module):
|
48 |
+
def __init__(
|
49 |
+
self, dim: int, max_position_embeddings: int = 2048, base: int = 10000, device: Optional[torch.device] = None
|
50 |
+
) -> None:
|
51 |
+
super().__init__()
|
52 |
+
|
53 |
+
self.dim = dim
|
54 |
+
self.max_position_embeddings = max_position_embeddings
|
55 |
+
self.base = base
|
56 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
57 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
58 |
+
|
59 |
+
# Build here to make `torch.jit.trace` work.
|
60 |
+
self._set_cos_sin_cache(
|
61 |
+
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
62 |
+
)
|
63 |
+
|
64 |
+
def _set_cos_sin_cache(self, seq_len: int, device: Any, dtype: Any) -> None:
|
65 |
+
self.max_seq_len_cached = seq_len
|
66 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) # type: ignore
|
67 |
+
|
68 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
69 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
70 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
71 |
+
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
|
72 |
+
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
|
73 |
+
|
74 |
+
def forward(self, x: torch.Tensor, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
75 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
76 |
+
if seq_len > self.max_seq_len_cached:
|
77 |
+
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
78 |
+
|
79 |
+
return (
|
80 |
+
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), # type: ignore
|
81 |
+
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), # type: ignore
|
82 |
+
)
|
83 |
+
|
84 |
+
|
85 |
+
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
|
86 |
+
"""Rotates half the hidden dims of the input."""
|
87 |
+
x1 = x[..., : x.shape[-1] // 2]
|
88 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
89 |
+
return torch.cat((-x2, x1), dim=-1)
|
90 |
+
|
91 |
+
|
92 |
+
def _rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor:
|
93 |
+
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
94 |
+
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
95 |
+
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
96 |
+
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
97 |
+
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
98 |
+
x_embed = (x * cos) + (_rotate_half(x) * sin)
|
99 |
+
return x_embed
|
100 |
+
|
101 |
+
|
102 |
+
class LinearType(str, enum.Enum):
|
103 |
+
Normal = "normal"
|
104 |
+
Fp8 = "fp8"
|
105 |
+
Fp8Retain = "fp8-retain"
|
106 |
+
|
107 |
+
|
108 |
+
class Plamo2Config(PretrainedConfig): # type: ignore
|
109 |
+
model_type: str = "plamo2"
|
110 |
+
|
111 |
+
def __init__(
|
112 |
+
self,
|
113 |
+
hidden_size: int = 4096,
|
114 |
+
num_hidden_layers: int = 32,
|
115 |
+
rms_norm_eps: float = 1e-6,
|
116 |
+
tie_word_embeddings: bool = True,
|
117 |
+
# Attention
|
118 |
+
num_attention_heads: int = 32,
|
119 |
+
num_key_value_heads: int = 4,
|
120 |
+
hidden_size_per_head: int = 128,
|
121 |
+
max_position_embeddings: int = 2048,
|
122 |
+
attention_window_size: int = 2048,
|
123 |
+
full_attention_idx: list[int] | None = None,
|
124 |
+
rope_theta: int = 10000,
|
125 |
+
rope_local_theta: int = 10000,
|
126 |
+
# Mamba
|
127 |
+
mamba_d_state: int = 64,
|
128 |
+
mamba_d_conv: int = 4,
|
129 |
+
mamba_num_heads: int = 64,
|
130 |
+
mamba_step: int = 2,
|
131 |
+
mamba_chunk_size: int = 256,
|
132 |
+
mamba_enabled: bool = True,
|
133 |
+
# MLP
|
134 |
+
intermediate_size: int = 13312,
|
135 |
+
# Tokenizer
|
136 |
+
vocab_size: int = 32000,
|
137 |
+
tokenizer_class: str = "Plamo2Tokenizer",
|
138 |
+
pad_token_id: Optional[int] = None,
|
139 |
+
bos_token_id: int = 1,
|
140 |
+
eos_token_id: int = 2,
|
141 |
+
# Multimodal
|
142 |
+
image_token_id: Optional[int] = None,
|
143 |
+
image_feature_size: Optional[int] = None,
|
144 |
+
image_proj_type: Literal["linear", "mlp"] = "linear",
|
145 |
+
# FP8
|
146 |
+
linear_type: LinearType = LinearType.Normal,
|
147 |
+
fp8_accum_dtype: Optional[str] = None,
|
148 |
+
# Evaluation
|
149 |
+
eval_attention_n_bit: Optional[int] = None,
|
150 |
+
eval_mlp_n_bit: Optional[int] = None,
|
151 |
+
use_cache: bool = True,
|
152 |
+
**kwargs: Any,
|
153 |
+
) -> None:
|
154 |
+
# max_position_embeddings is often used to determine the max length during inference,
|
155 |
+
# but samba should have extrapolation abilities
|
156 |
+
self.max_position_embeddings = max(10 * 1024 * 1024, max_position_embeddings)
|
157 |
+
self.hidden_size = hidden_size
|
158 |
+
self.rms_norm_eps = rms_norm_eps
|
159 |
+
|
160 |
+
self.num_hidden_layers = num_hidden_layers
|
161 |
+
self.num_attention_heads = num_attention_heads
|
162 |
+
self.hidden_size_per_head = hidden_size_per_head
|
163 |
+
self.num_key_value_heads = num_key_value_heads
|
164 |
+
self.attention_window_size = attention_window_size
|
165 |
+
self.full_attention_idx = full_attention_idx if full_attention_idx is not None else []
|
166 |
+
self.rope_theta = rope_theta
|
167 |
+
self.rope_local_theta = rope_local_theta
|
168 |
+
|
169 |
+
self.mamba_d_state = mamba_d_state
|
170 |
+
self.mamba_d_conv = mamba_d_conv
|
171 |
+
self.mamba_num_heads = mamba_num_heads
|
172 |
+
self.mamba_step = mamba_step
|
173 |
+
self.mamba_chunk_size = mamba_chunk_size
|
174 |
+
self.mamba_enabled = mamba_enabled
|
175 |
+
|
176 |
+
self.intermediate_size = intermediate_size
|
177 |
+
|
178 |
+
self.vocab_size = vocab_size
|
179 |
+
|
180 |
+
self.image_token_id = image_token_id
|
181 |
+
self.image_feature_size = image_feature_size
|
182 |
+
self.image_proj_type = image_proj_type
|
183 |
+
|
184 |
+
self.linear_type = linear_type
|
185 |
+
self.fp8_accum_dtype = fp8_accum_dtype
|
186 |
+
|
187 |
+
self.eval_attention_n_bit = eval_attention_n_bit
|
188 |
+
self.eval_mlp_n_bit = eval_mlp_n_bit
|
189 |
+
self.use_cache = use_cache
|
190 |
+
|
191 |
+
# fields for vLLM
|
192 |
+
self.sliding_window = attention_window_size
|
193 |
+
|
194 |
+
super().__init__(
|
195 |
+
tokenizer_class=tokenizer_class,
|
196 |
+
pad_token_id=pad_token_id,
|
197 |
+
bos_token_id=bos_token_id,
|
198 |
+
eos_token_id=eos_token_id,
|
199 |
+
tie_word_embeddings=tie_word_embeddings,
|
200 |
+
**kwargs,
|
201 |
+
)
|
202 |
+
|
203 |
+
@property
|
204 |
+
def layers_block_type(self) -> list[str]:
|
205 |
+
return ["mamba" if is_mamba(self, i) else "attention" for i in range(self.num_hidden_layers)]
|
206 |
+
|
207 |
+
@property
|
208 |
+
def rope_local_base_freq(self) -> int:
|
209 |
+
return self.rope_local_theta
|
210 |
+
|
211 |
+
|
212 |
+
class Plamo2AttentionCache(torch.nn.Module):
|
213 |
+
def __init__(self, key: torch.Tensor, value: torch.Tensor) -> None:
|
214 |
+
super().__init__()
|
215 |
+
B, nh, L, c = key.shape
|
216 |
+
assert len(value.shape) == 4
|
217 |
+
assert value.shape[0] == B
|
218 |
+
assert value.shape[2] == L
|
219 |
+
self.register_parameter("key", torch.nn.Parameter(key, requires_grad=False))
|
220 |
+
self.register_parameter("value", torch.nn.Parameter(value, requires_grad=False))
|
221 |
+
|
222 |
+
|
223 |
+
class Plamo2MambaCache(torch.nn.Module):
|
224 |
+
def __init__(self, conv_state: torch.Tensor, ssm_state: torch.Tensor) -> None:
|
225 |
+
super().__init__()
|
226 |
+
# conv_state: [B, C, d_conv]
|
227 |
+
# ssm_state: [B, nhead, nchanel_per_head, d_state]
|
228 |
+
assert len(conv_state.shape) == 3
|
229 |
+
assert len(ssm_state.shape) == 4
|
230 |
+
assert conv_state.shape[0] == ssm_state.shape[0]
|
231 |
+
self.register_parameter("conv_state", torch.nn.Parameter(conv_state, requires_grad=False))
|
232 |
+
self.register_parameter("ssm_state", torch.nn.Parameter(ssm_state, requires_grad=False))
|
233 |
+
|
234 |
+
|
235 |
+
Plamo2LayerCache = Plamo2AttentionCache | Plamo2MambaCache
|
236 |
+
|
237 |
+
|
238 |
+
class Plamo2Cache(torch.nn.Module):
|
239 |
+
"""
|
240 |
+
stores states of the model for fast decoding.
|
241 |
+
`transformers` uses `transformers.Cache` for this purpose, but the interface and variable names are
|
242 |
+
deeply dependent on Transformers architecture (e.g., `key_states`) and it is difficult to use
|
243 |
+
other architectures (e.g., Mamba).
|
244 |
+
This class provides a similar interface to `transformers.Cache`, but is designed to also handle
|
245 |
+
the state of Mamba properly.
|
246 |
+
"""
|
247 |
+
|
248 |
+
def __init__(self, config: Plamo2Config) -> None:
|
249 |
+
super().__init__()
|
250 |
+
self.config = config
|
251 |
+
self.cache = torch.nn.ModuleList([None for _ in range(config.num_hidden_layers)]) # type: ignore
|
252 |
+
|
253 |
+
def append_kv(self, key: torch.Tensor, value: torch.Tensor, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
|
254 |
+
c = self.cache[layer_idx]
|
255 |
+
if c is None:
|
256 |
+
return key, value
|
257 |
+
assert isinstance(c, Plamo2AttentionCache)
|
258 |
+
|
259 |
+
def _validate(cache: torch.Tensor, new_tensor: torch.Tensor) -> None:
|
260 |
+
assert len(cache.shape) == 4
|
261 |
+
assert len(new_tensor.shape) == 4
|
262 |
+
assert cache.shape[0] == new_tensor.shape[0]
|
263 |
+
assert cache.shape[1] == new_tensor.shape[1]
|
264 |
+
assert cache.shape[3] == new_tensor.shape[3]
|
265 |
+
|
266 |
+
_validate(c.key, key)
|
267 |
+
_validate(c.value, value)
|
268 |
+
assert key.shape[2] == value.shape[2]
|
269 |
+
return torch.cat([c.key, key], dim=2), torch.cat([c.value, value], dim=2)
|
270 |
+
|
271 |
+
def update_attention(
|
272 |
+
self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int
|
273 |
+
) -> Plamo2AttentionCache:
|
274 |
+
full_attn = layer_idx in self.config.full_attention_idx
|
275 |
+
window_size = self.config.attention_window_size
|
276 |
+
|
277 |
+
if self.cache[layer_idx] is None:
|
278 |
+
if full_attn:
|
279 |
+
self.cache[layer_idx] = Plamo2AttentionCache(key_states, value_states)
|
280 |
+
else:
|
281 |
+
self.cache[layer_idx] = Plamo2AttentionCache(
|
282 |
+
key_states[:, :, -window_size:, :], value_states[:, :, -window_size:, :]
|
283 |
+
)
|
284 |
+
else:
|
285 |
+
c = self.cache[layer_idx]
|
286 |
+
assert isinstance(c, Plamo2AttentionCache)
|
287 |
+
k, v = self.append_kv(key_states, value_states, layer_idx)
|
288 |
+
if full_attn:
|
289 |
+
c.key.data = k
|
290 |
+
c.value.data = v
|
291 |
+
else:
|
292 |
+
c.key.data = k[:, :, -window_size:, :]
|
293 |
+
c.value.data = v[:, :, -window_size:, :]
|
294 |
+
return self.cache[layer_idx] # type: ignore
|
295 |
+
|
296 |
+
def update_mamba(self, conv_state: torch.Tensor, ssm_state: torch.Tensor, layer_idx: int) -> Plamo2MambaCache:
|
297 |
+
if self.cache[layer_idx] is None:
|
298 |
+
self.cache[layer_idx] = Plamo2MambaCache(conv_state, ssm_state)
|
299 |
+
else:
|
300 |
+
c = self.cache[layer_idx]
|
301 |
+
assert isinstance(c, Plamo2MambaCache)
|
302 |
+
assert c.conv_state.shape == conv_state.shape
|
303 |
+
assert c.ssm_state.shape == ssm_state.shape
|
304 |
+
c.conv_state.data = conv_state
|
305 |
+
c.ssm_state.data = ssm_state
|
306 |
+
return self.cache[layer_idx] # type: ignore
|
307 |
+
|
308 |
+
def __getitem__(self, layer_idx: int) -> Plamo2LayerCache | None:
|
309 |
+
assert layer_idx < len(self.cache)
|
310 |
+
layer_cache = self.cache[layer_idx]
|
311 |
+
return layer_cache # type: ignore
|
312 |
+
|
313 |
+
def __len__(self) -> int:
|
314 |
+
return len(self.cache)
|
315 |
+
|
316 |
+
def get_seq_length(self, layer_idx: Optional[int] = None) -> int:
|
317 |
+
if layer_idx is not None:
|
318 |
+
c = self.cache[layer_idx]
|
319 |
+
assert isinstance(c, Plamo2AttentionCache)
|
320 |
+
return c.key.shape[2] # type: ignore
|
321 |
+
|
322 |
+
sequence_length: int | None = None
|
323 |
+
for layer_cache in self.cache:
|
324 |
+
if isinstance(layer_cache, Plamo2AttentionCache):
|
325 |
+
sequence_length = (
|
326 |
+
max(layer_cache.key.shape[2], sequence_length)
|
327 |
+
if sequence_length is not None
|
328 |
+
else layer_cache.key.shape[2]
|
329 |
+
)
|
330 |
+
assert sequence_length is not None
|
331 |
+
return sequence_length
|
332 |
+
|
333 |
+
def get_max_length(self) -> int | None:
|
334 |
+
return None
|
335 |
+
|
336 |
+
def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
|
337 |
+
"""Given the sequence length of the new inputs, returns the usable length of the cache."""
|
338 |
+
# Cache without size limit -> all cache is usable
|
339 |
+
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
|
340 |
+
# length, we will need to evict part of the cache (and thus not all cache is usable)
|
341 |
+
max_length = self.get_max_length()
|
342 |
+
previous_seq_length = self.get_seq_length(layer_idx)
|
343 |
+
if max_length is not None and previous_seq_length + new_seq_length > max_length:
|
344 |
+
return max_length - new_seq_length
|
345 |
+
return previous_seq_length
|
346 |
+
|
347 |
+
def reorder_cache(self, beam_idx: torch.Tensor) -> None:
|
348 |
+
def _mamba(cache: Plamo2MambaCache) -> Plamo2MambaCache:
|
349 |
+
return Plamo2MambaCache(
|
350 |
+
conv_state=cache.conv_state.index_select(0, beam_idx),
|
351 |
+
ssm_state=cache.ssm_state.index_select(0, beam_idx),
|
352 |
+
)
|
353 |
+
|
354 |
+
def _attention(cache: Plamo2AttentionCache) -> Plamo2AttentionCache:
|
355 |
+
return Plamo2AttentionCache(
|
356 |
+
key=cache.key.index_select(0, beam_idx),
|
357 |
+
value=cache.value.index_select(0, beam_idx),
|
358 |
+
)
|
359 |
+
|
360 |
+
for i in range(len(self.cache)):
|
361 |
+
if self.cache[i] is None:
|
362 |
+
continue
|
363 |
+
layer_cache = self.cache[i]
|
364 |
+
if isinstance(layer_cache, Plamo2MambaCache):
|
365 |
+
self.cache[i] = _mamba(layer_cache)
|
366 |
+
else:
|
367 |
+
assert isinstance(layer_cache, Plamo2AttentionCache)
|
368 |
+
self.cache[i] = _attention(layer_cache)
|
369 |
+
|
370 |
+
@property
|
371 |
+
def seen_tokens(self) -> int | None:
|
372 |
+
return None
|
373 |
+
|
374 |
+
|
375 |
+
class DecoderInput(NamedTuple):
|
376 |
+
hidden_states: torch.Tensor
|
377 |
+
attention_mask: Optional[torch.Tensor] = None
|
378 |
+
past_states: Optional[Plamo2Cache] = None
|
379 |
+
output_hidden_states: Optional[bool] = False
|
380 |
+
output_attentions: Optional[bool] = False
|
381 |
+
gradient_checkpointing: bool = False
|
382 |
+
input_ids: Optional[torch.Tensor] = None
|
383 |
+
|
384 |
+
|
385 |
+
class DecoderOutput(NamedTuple):
|
386 |
+
hidden_states: torch.Tensor
|
387 |
+
all_hidden_states: Optional[Tuple[torch.Tensor, ...]]
|
388 |
+
all_self_attns: Optional[Tuple[torch.Tensor, ...]]
|
389 |
+
|
390 |
+
|
391 |
+
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
392 |
+
def _make_causal_mask(
|
393 |
+
input_ids_shape: Tuple[int, int], dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|
394 |
+
) -> torch.Tensor:
|
395 |
+
"""
|
396 |
+
Make causal mask used for bi-directional self-attention.
|
397 |
+
"""
|
398 |
+
bsz, tgt_len = input_ids_shape
|
399 |
+
mask = torch.full((tgt_len, tgt_len), float("-inf"), device=device)
|
400 |
+
mask_cond = torch.arange(mask.size(-1), device=device)
|
401 |
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
402 |
+
mask = mask.to(dtype)
|
403 |
+
|
404 |
+
if past_key_values_length > 0:
|
405 |
+
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
|
406 |
+
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
407 |
+
|
408 |
+
|
409 |
+
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
410 |
+
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None) -> torch.Tensor:
|
411 |
+
"""
|
412 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
413 |
+
"""
|
414 |
+
bsz, src_len = mask.size()
|
415 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
416 |
+
|
417 |
+
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
418 |
+
|
419 |
+
inverted_mask = 1.0 - expanded_mask
|
420 |
+
|
421 |
+
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), float("-inf")) # type: ignore
|
422 |
+
|
423 |
+
|
424 |
+
def _rms_norm(
|
425 |
+
hidden_states: torch.Tensor, weight: Optional[torch.Tensor], eps: float, offset: float = 1.0
|
426 |
+
) -> torch.Tensor:
|
427 |
+
input_dtype = hidden_states.dtype
|
428 |
+
hidden_states = hidden_states.to(torch.float32)
|
429 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
430 |
+
hidden_states = hidden_states * torch.rsqrt(variance + eps)
|
431 |
+
hidden_states = hidden_states.to(input_dtype)
|
432 |
+
if weight is not None:
|
433 |
+
hidden_states = (offset + weight) * hidden_states
|
434 |
+
return hidden_states
|
435 |
+
|
436 |
+
|
437 |
+
class RMSNorm(nn.Module):
|
438 |
+
def __init__(
|
439 |
+
self,
|
440 |
+
hidden_size: int,
|
441 |
+
eps: float = 1e-6,
|
442 |
+
offset: float = 1.0,
|
443 |
+
device: Optional[Union[torch.device, str]] = None,
|
444 |
+
) -> None:
|
445 |
+
super().__init__()
|
446 |
+
self.weight = nn.Parameter(torch.zeros(hidden_size, device=device))
|
447 |
+
self.variance_epsilon = eps
|
448 |
+
self.offset = offset
|
449 |
+
|
450 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
451 |
+
return _rms_norm(hidden_states, self.weight, self.variance_epsilon, offset=self.offset)
|
452 |
+
|
453 |
+
|
454 |
+
def get_initial_dt_bias(num_heads: int) -> torch.Tensor:
|
455 |
+
dt_min = 0.001
|
456 |
+
dt_max = 0.1
|
457 |
+
dt = torch.exp(torch.rand(num_heads) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min))
|
458 |
+
dt = torch.clamp(dt, 1e-4)
|
459 |
+
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
460 |
+
return inv_dt
|
461 |
+
|
462 |
+
|
463 |
+
def get_initial_A(num_heads: int) -> torch.Tensor:
|
464 |
+
A = torch.arange(1, num_heads + 1, dtype=torch.float32)
|
465 |
+
return torch.log(A)
|
466 |
+
|
467 |
+
|
468 |
+
def _bf16_supported_in_triton() -> bool:
|
469 |
+
# newer torch (2.2.0 and later?) supports bfloat16 even when using Voltas
|
470 |
+
# but triton cannot compile bf16 kernels for Volta
|
471 |
+
major, _ = torch.cuda.get_device_capability()
|
472 |
+
return major >= 8
|
473 |
+
|
474 |
+
|
475 |
+
def _get_trition_dtype(dtype: torch.dtype) -> torch.dtype:
|
476 |
+
if dtype != torch.bfloat16:
|
477 |
+
return dtype
|
478 |
+
if _bf16_supported_in_triton():
|
479 |
+
return dtype
|
480 |
+
return torch.float32
|
481 |
+
|
482 |
+
|
483 |
+
def ssd_update_state(
|
484 |
+
ssm_state: torch.Tensor,
|
485 |
+
x: torch.Tensor,
|
486 |
+
dt: torch.Tensor,
|
487 |
+
A: torch.Tensor,
|
488 |
+
B: torch.Tensor,
|
489 |
+
C: torch.Tensor,
|
490 |
+
D: torch.Tensor,
|
491 |
+
z: torch.Tensor,
|
492 |
+
dt_bias: torch.Tensor,
|
493 |
+
dt_softplus: bool,
|
494 |
+
) -> torch.Tensor:
|
495 |
+
assert ssm_state.dtype == torch.float32
|
496 |
+
if dt.is_cuda:
|
497 |
+
dtype = _get_trition_dtype(x.dtype)
|
498 |
+
else:
|
499 |
+
dtype = x.dtype
|
500 |
+
if dt.is_cuda:
|
501 |
+
f = mamba_ssm.ops.triton.selective_state_update.selective_state_update
|
502 |
+
else:
|
503 |
+
f = mamba_ssm.ops.triton.selective_state_update.selective_state_update_ref
|
504 |
+
|
505 |
+
hidden_size_per_head = x.shape[-1]
|
506 |
+
d_state = B.shape[-1]
|
507 |
+
A = A[:, None, None].expand(-1, hidden_size_per_head, d_state).float()
|
508 |
+
dt = dt[..., None].expand(-1, -1, hidden_size_per_head)
|
509 |
+
dt_bias = dt_bias[:, None].expand(-1, hidden_size_per_head)
|
510 |
+
D = D[:, None].expand(-1, hidden_size_per_head)
|
511 |
+
assert ssm_state.dtype == torch.float32
|
512 |
+
out = f(
|
513 |
+
ssm_state,
|
514 |
+
x.to(dtype),
|
515 |
+
dt.to(dtype),
|
516 |
+
A.float(),
|
517 |
+
B.to(dtype),
|
518 |
+
C.to(dtype),
|
519 |
+
D.float(),
|
520 |
+
z.to(dtype),
|
521 |
+
dt_bias.float(),
|
522 |
+
dt_softplus=dt_softplus,
|
523 |
+
)
|
524 |
+
return out[:, None] # type: ignore
|
525 |
+
|
526 |
+
|
527 |
+
def _ssd_chunk_scan_combined_naive(
|
528 |
+
x: torch.Tensor,
|
529 |
+
dt: torch.Tensor,
|
530 |
+
A: torch.Tensor,
|
531 |
+
B: torch.Tensor,
|
532 |
+
C: torch.Tensor,
|
533 |
+
D: torch.Tensor,
|
534 |
+
z: torch.Tensor,
|
535 |
+
dt_bias: torch.Tensor,
|
536 |
+
dt_softplus: bool,
|
537 |
+
seq_idx: torch.Tensor | None,
|
538 |
+
ssm_state: torch.Tensor,
|
539 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
540 |
+
assert ssm_state.dtype == torch.float32
|
541 |
+
length = x.shape[1]
|
542 |
+
ys = []
|
543 |
+
for i in range(length):
|
544 |
+
if i != 0 and seq_idx is not None:
|
545 |
+
ssm_state = torch.where(
|
546 |
+
(seq_idx[:, i - 1] != seq_idx[:, i])[:, None, None, None],
|
547 |
+
torch.zeros_like(ssm_state),
|
548 |
+
ssm_state,
|
549 |
+
)
|
550 |
+
y = ssd_update_state(
|
551 |
+
ssm_state,
|
552 |
+
x[:, i],
|
553 |
+
dt[:, i],
|
554 |
+
A,
|
555 |
+
B[:, i],
|
556 |
+
C[:, i],
|
557 |
+
D,
|
558 |
+
z=z[:, i],
|
559 |
+
dt_bias=dt_bias,
|
560 |
+
dt_softplus=dt_softplus,
|
561 |
+
)
|
562 |
+
ys.append(y)
|
563 |
+
return torch.cat(ys, dim=1), ssm_state
|
564 |
+
|
565 |
+
|
566 |
+
def _ssd_chunk_scan_combined_cpu(
|
567 |
+
x: torch.Tensor,
|
568 |
+
dt: torch.Tensor,
|
569 |
+
A: torch.Tensor,
|
570 |
+
B: torch.Tensor,
|
571 |
+
C: torch.Tensor,
|
572 |
+
chunk_size: int,
|
573 |
+
D: torch.Tensor,
|
574 |
+
z: torch.Tensor,
|
575 |
+
dt_bias: torch.Tensor,
|
576 |
+
dt_softplus: bool,
|
577 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
578 |
+
# (bsize, nhead, nchunk, chunk_size)
|
579 |
+
dt = dt.float() # We want high precision for this before cumsum
|
580 |
+
dt = dt.permute(0, 2, 1).unflatten(2, (-1, chunk_size)) # type: ignore
|
581 |
+
if dt_bias is not None:
|
582 |
+
dt = dt + dt_bias[None, :, None, None]
|
583 |
+
if dt_softplus:
|
584 |
+
dt = F.softplus(dt)
|
585 |
+
dA = dt * A[None, :, None, None]
|
586 |
+
dA_cumsum = torch.cumsum(dA, dim=-1)
|
587 |
+
|
588 |
+
_, _, nheads, _ = x.shape
|
589 |
+
dstate = B.shape[-1]
|
590 |
+
_ = dt.shape[2]
|
591 |
+
|
592 |
+
with torch.profiler.record_function("ssd_chunk_scan_combined_cpu_chunk_state"):
|
593 |
+
# Following is equivalent to `mamba_ssm.ops.triton.ssd_combined.chunk_state_ref(B, x, dt, dA_cumsum)`
|
594 |
+
# But `einsum` in the above function is too slow in CPU.
|
595 |
+
x_ = torch.unflatten(x, 1, (-1, chunk_size))
|
596 |
+
assert B.shape[2] == nheads # B should be already expanded
|
597 |
+
B_ = torch.unflatten(B, 1, (-1, chunk_size)).to(x.dtype) # (bsize, nchunk, chunk_size, nheads, dstate)
|
598 |
+
decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum)).to(x.dtype)
|
599 |
+
dt_ = dt.to(x.dtype)
|
600 |
+
|
601 |
+
# einsum("bclhn,bhcl,bhcl,bclhp->bchpn", B_, decay_states, dt_, x_)
|
602 |
+
B_ = B_.permute(0, 1, 3, 4, 2) # bchnl
|
603 |
+
tmp = dt_ * decay_states # bhcl
|
604 |
+
tmp = tmp.permute(0, 2, 1, 3)[:, :, :, None] # bch1l
|
605 |
+
tmp = B_ * tmp # bchnl
|
606 |
+
x_ = x_.permute(0, 1, 3, 2, 4) # bchlp
|
607 |
+
tmp = tmp @ x_ # bchnp
|
608 |
+
states = tmp.permute(0, 1, 2, 4, 3) # bchpn
|
609 |
+
|
610 |
+
states_dtype = states.dtype
|
611 |
+
if states.dtype not in [torch.float32, torch.float64]:
|
612 |
+
states = states.to(torch.float32)
|
613 |
+
with torch.profiler.record_function("ssd_chunk_scan_combined_cpu_state_passing"):
|
614 |
+
out, last_state = mamba_ssm.ops.triton.ssd_combined.state_passing_ref(
|
615 |
+
states.flatten(start_dim=-2, end_dim=-1),
|
616 |
+
dA_cumsum[:, :, :, -1],
|
617 |
+
)
|
618 |
+
states = torch.unflatten(out, -1, (-1, dstate))
|
619 |
+
last_state = torch.unflatten(last_state, -1, (-1, dstate))
|
620 |
+
states = states.to(states_dtype)
|
621 |
+
with torch.profiler.record_function("ssd_chunk_scan_combined_cpu_chunk_scan"):
|
622 |
+
out = mamba_ssm.ops.triton.ssd_combined.chunk_scan_ref(B, C, x, dt, dA_cumsum, states, D=D, z=z)
|
623 |
+
|
624 |
+
return out, last_state
|
625 |
+
|
626 |
+
|
627 |
+
@torch.profiler.record_function("ssd_chunk_scan_combined")
|
628 |
+
def ssd_chunk_scan_combined(
|
629 |
+
x: torch.Tensor,
|
630 |
+
dt: torch.Tensor,
|
631 |
+
A: torch.Tensor,
|
632 |
+
B: torch.Tensor,
|
633 |
+
C: torch.Tensor,
|
634 |
+
chunk_size: int,
|
635 |
+
D: torch.Tensor,
|
636 |
+
z: torch.Tensor,
|
637 |
+
dt_bias: torch.Tensor,
|
638 |
+
dt_softplus: bool,
|
639 |
+
return_final_states: bool,
|
640 |
+
seq_idx: torch.Tensor | None,
|
641 |
+
ssm_state: torch.Tensor | None,
|
642 |
+
) -> tuple[torch.Tensor, torch.Tensor] | torch.Tensor:
|
643 |
+
if seq_idx is not None:
|
644 |
+
assert seq_idx.dtype == torch.int32
|
645 |
+
assert ssm_state is None
|
646 |
+
assert not return_final_states
|
647 |
+
if ssm_state is not None:
|
648 |
+
assert ssm_state.dtype == torch.float32
|
649 |
+
assert seq_idx is None
|
650 |
+
|
651 |
+
length = x.shape[1]
|
652 |
+
|
653 |
+
"""
|
654 |
+
state will be updates by following:
|
655 |
+
```
|
656 |
+
dt = softplus(dt)
|
657 |
+
dA = exp(dt * A)
|
658 |
+
state_next = state * dA + dB * x
|
659 |
+
```
|
660 |
+
|
661 |
+
To avoid updating state, we set dt to -inf and x to 0
|
662 |
+
because `softplus(-inf) = 0` and `exp(0) = 1`
|
663 |
+
"""
|
664 |
+
pad = (chunk_size - length % chunk_size) % chunk_size
|
665 |
+
x = torch.nn.functional.pad(x, pad=[0, 0, 0, 0, pad, 0], value=0.0)
|
666 |
+
dt = torch.nn.functional.pad(dt, pad=[0, 0, pad, 0], value=float("-inf"))
|
667 |
+
B = torch.nn.functional.pad(B, pad=[0, 0, 0, 0, pad, 0], value=0.0)
|
668 |
+
C = torch.nn.functional.pad(C, pad=[0, 0, 0, 0, pad, 0], value=0.0)
|
669 |
+
z = torch.nn.functional.pad(z, pad=[0, 0, 0, 0, pad, 0], value=0.0)
|
670 |
+
if seq_idx is not None:
|
671 |
+
seq_idx = torch.nn.functional.pad(seq_idx, pad=[pad, 0], value=0)
|
672 |
+
|
673 |
+
length = x.shape[1]
|
674 |
+
assert length % chunk_size == 0, (length, chunk_size)
|
675 |
+
|
676 |
+
if dt.is_cuda:
|
677 |
+
dtype = _get_trition_dtype(x.dtype)
|
678 |
+
out = mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined( # type: ignore
|
679 |
+
x.to(dtype),
|
680 |
+
dt.to(dtype),
|
681 |
+
A.float(),
|
682 |
+
B.to(dtype),
|
683 |
+
C.to(dtype),
|
684 |
+
chunk_size,
|
685 |
+
D=D.float(),
|
686 |
+
z=z.to(dtype),
|
687 |
+
initial_states=ssm_state,
|
688 |
+
dt_bias=dt_bias.float(),
|
689 |
+
dt_softplus=dt_softplus,
|
690 |
+
seq_idx=seq_idx,
|
691 |
+
return_final_states=return_final_states,
|
692 |
+
)
|
693 |
+
if return_final_states:
|
694 |
+
return out[0][:, pad:], out[1]
|
695 |
+
else:
|
696 |
+
assert isinstance(out, torch.Tensor)
|
697 |
+
return out[:, pad:]
|
698 |
+
else:
|
699 |
+
if ssm_state is None and seq_idx is None:
|
700 |
+
tmp = _ssd_chunk_scan_combined_cpu(
|
701 |
+
x,
|
702 |
+
dt,
|
703 |
+
A,
|
704 |
+
B,
|
705 |
+
C,
|
706 |
+
chunk_size,
|
707 |
+
D=D,
|
708 |
+
z=z,
|
709 |
+
dt_bias=dt_bias.float(),
|
710 |
+
dt_softplus=dt_softplus,
|
711 |
+
)
|
712 |
+
else:
|
713 |
+
if ssm_state is None:
|
714 |
+
bsize, _, num_heads, channel = x.shape
|
715 |
+
state = B.shape[-1]
|
716 |
+
ssm_state = torch.zeros(bsize, num_heads, channel, state, dtype=torch.float32, device=x.device)
|
717 |
+
tmp = _ssd_chunk_scan_combined_naive(
|
718 |
+
x, dt, A, B, C, D, z=z, dt_bias=dt_bias, dt_softplus=dt_softplus, seq_idx=seq_idx, ssm_state=ssm_state
|
719 |
+
)
|
720 |
+
tmp = (tmp[0][:, pad:], tmp[1])
|
721 |
+
if return_final_states:
|
722 |
+
return tmp
|
723 |
+
else:
|
724 |
+
return tmp[0]
|
725 |
+
|
726 |
+
|
727 |
+
def _causal_conv1d_update(
|
728 |
+
conv_state: torch.Tensor, weight: torch.Tensor, xBC: torch.Tensor
|
729 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
730 |
+
dtype = conv_state.dtype
|
731 |
+
xBC = xBC.to(dtype)
|
732 |
+
weight = weight.to(dtype)
|
733 |
+
if conv_state.is_cuda:
|
734 |
+
x = causal_conv1d.causal_conv1d_update(
|
735 |
+
x=xBC,
|
736 |
+
conv_state=conv_state,
|
737 |
+
weight=weight[:, 0, :],
|
738 |
+
activation="silu",
|
739 |
+
)
|
740 |
+
return x, conv_state
|
741 |
+
else:
|
742 |
+
x = causal_conv1d.causal_conv1d_update_ref(
|
743 |
+
x=xBC,
|
744 |
+
conv_state=conv_state,
|
745 |
+
weight=weight[:, 0, :],
|
746 |
+
activation="silu",
|
747 |
+
)
|
748 |
+
return x, conv_state
|
749 |
+
|
750 |
+
|
751 |
+
def _causal_conv1d_naive(
|
752 |
+
conv_state: torch.Tensor, weight: torch.Tensor, x: torch.Tensor, seq_idx: torch.Tensor | None
|
753 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
754 |
+
length = x.shape[-1]
|
755 |
+
out = torch.zeros_like(x)
|
756 |
+
for i in range(length):
|
757 |
+
if i != 0 and seq_idx is not None:
|
758 |
+
conv_state = torch.where(
|
759 |
+
(seq_idx[:, i - 1] != seq_idx[:, i])[:, None, None],
|
760 |
+
torch.zeros_like(conv_state),
|
761 |
+
conv_state,
|
762 |
+
)
|
763 |
+
out[:, :, i : i + 1], conv_state = _causal_conv1d_update(conv_state, weight, x[:, :, i : i + 1])
|
764 |
+
return out, conv_state
|
765 |
+
|
766 |
+
|
767 |
+
@torch.profiler.record_function("causal_conv1d")
|
768 |
+
def _causal_conv1d(
|
769 |
+
conv_state: torch.Tensor | None, weight: torch.Tensor, x: torch.Tensor, seq_idx: torch.Tensor | None
|
770 |
+
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
771 |
+
dtype = x.dtype
|
772 |
+
if conv_state is not None:
|
773 |
+
dtype = conv_state.dtype
|
774 |
+
assert seq_idx is None
|
775 |
+
if seq_idx is not None:
|
776 |
+
assert seq_idx.dtype == torch.int32
|
777 |
+
assert conv_state is None
|
778 |
+
weight = weight.to(dtype)
|
779 |
+
x = x.to(dtype)
|
780 |
+
|
781 |
+
return_final_states = conv_state is not None
|
782 |
+
if weight.is_cuda:
|
783 |
+
if x.stride(1) != 1:
|
784 |
+
# to channel-last format
|
785 |
+
x = x.transpose(-1, -2).contiguous().transpose(-1, -2)
|
786 |
+
if conv_state is not None:
|
787 |
+
if conv_state.stride(1) != 1:
|
788 |
+
# to channel-last format
|
789 |
+
conv_state = conv_state.transpose(-1, -2).contiguous().transpose(-1, -2)
|
790 |
+
tmp = causal_conv1d.causal_conv1d_fn(
|
791 |
+
x=x,
|
792 |
+
weight=weight[:, 0, :],
|
793 |
+
initial_states=conv_state,
|
794 |
+
return_final_states=conv_state is not None,
|
795 |
+
activation="silu",
|
796 |
+
seq_idx=seq_idx,
|
797 |
+
)
|
798 |
+
if conv_state is not None:
|
799 |
+
x, conv_state = tmp
|
800 |
+
else:
|
801 |
+
x = tmp
|
802 |
+
else:
|
803 |
+
if seq_idx is None:
|
804 |
+
x, conv_state = causal_conv1d.causal_conv1d_ref(
|
805 |
+
x=x,
|
806 |
+
initial_states=conv_state,
|
807 |
+
return_final_states=True,
|
808 |
+
weight=weight[:, 0, :],
|
809 |
+
activation="silu",
|
810 |
+
)
|
811 |
+
else:
|
812 |
+
if conv_state is None:
|
813 |
+
bsize = x.shape[0]
|
814 |
+
dim = weight.shape[0]
|
815 |
+
d_conv = weight.shape[-1]
|
816 |
+
conv_state = torch.zeros(bsize, dim, d_conv - 1, dtype=x.dtype, device=x.device)
|
817 |
+
x, conv_state = _causal_conv1d_naive(conv_state, weight, x, seq_idx)
|
818 |
+
if return_final_states:
|
819 |
+
return x, conv_state
|
820 |
+
else:
|
821 |
+
return x, None
|
822 |
+
|
823 |
+
|
824 |
+
class Mamba(torch.nn.Module):
|
825 |
+
def __init__(self, config: Plamo2Config, layer_idx: int) -> None:
|
826 |
+
super().__init__()
|
827 |
+
self.config = config
|
828 |
+
self.layer_idx = layer_idx
|
829 |
+
self.hidden_size = config.hidden_size
|
830 |
+
self.d_state = config.mamba_d_state
|
831 |
+
self.d_conv = config.mamba_d_conv
|
832 |
+
self.chunk_size = config.mamba_chunk_size
|
833 |
+
self.num_heads = config.mamba_num_heads
|
834 |
+
# TODO add mamba_hidden_size_per_head config (?)
|
835 |
+
self.hidden_size_per_head = config.hidden_size_per_head
|
836 |
+
|
837 |
+
self.intermediate_size = self.num_heads * self.hidden_size_per_head
|
838 |
+
|
839 |
+
self.in_proj = torch.nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False)
|
840 |
+
self.conv1d = torch.nn.Conv1d(
|
841 |
+
in_channels=self.intermediate_size,
|
842 |
+
out_channels=self.intermediate_size,
|
843 |
+
bias=False, # TODO the original implementation uses bias
|
844 |
+
kernel_size=self.d_conv,
|
845 |
+
groups=self.intermediate_size,
|
846 |
+
padding=0,
|
847 |
+
)
|
848 |
+
self.dt_dim = max(64, self.hidden_size // 16)
|
849 |
+
# Notes:
|
850 |
+
# Mamba2 removes this linear projection for simplicity (Figure 6 in the paper),
|
851 |
+
# but it may degrade the ability of content-length extrapolation.
|
852 |
+
self.bcdt_proj = torch.nn.Linear(
|
853 |
+
self.intermediate_size,
|
854 |
+
self.dt_dim + 2 * self.d_state,
|
855 |
+
bias=False,
|
856 |
+
)
|
857 |
+
self.dt_proj = torch.nn.Linear(self.dt_dim, self.num_heads, bias=False)
|
858 |
+
|
859 |
+
self.dt_bias = torch.nn.Parameter(get_initial_dt_bias(self.num_heads))
|
860 |
+
self.A_log = torch.nn.Parameter(get_initial_A(self.num_heads))
|
861 |
+
self.D = torch.nn.Parameter(torch.ones(self.num_heads))
|
862 |
+
|
863 |
+
# TODO norm weight before gating like Mamba2
|
864 |
+
self.dt_norm_weight = torch.nn.Parameter(torch.ones(self.dt_dim))
|
865 |
+
self.B_norm_weight = torch.nn.Parameter(torch.ones(self.d_state))
|
866 |
+
self.C_norm_weight = torch.nn.Parameter(torch.ones(self.d_state))
|
867 |
+
|
868 |
+
self.out_proj = torch.nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
869 |
+
|
870 |
+
def _no_weight_decay_param_names(self) -> set[str]:
|
871 |
+
return set(["D", "dt_bias", "A_log"])
|
872 |
+
|
873 |
+
def forward(
|
874 |
+
self,
|
875 |
+
hidden_states: torch.Tensor,
|
876 |
+
attention_mask: Optional[torch.Tensor] = None,
|
877 |
+
past_states: Optional[Plamo2Cache] = None,
|
878 |
+
) -> Tuple[torch.Tensor, Optional[Plamo2Cache]]:
|
879 |
+
bsize, length, _ = hidden_states.shape
|
880 |
+
is_update = length == 1 and past_states is not None
|
881 |
+
|
882 |
+
bool_mask: torch.Tensor | None = None
|
883 |
+
seq_idx: torch.Tensor | None = None
|
884 |
+
if attention_mask is not None:
|
885 |
+
if len(attention_mask.shape) == 2:
|
886 |
+
attention_mask = attention_mask[None, None].expand(bsize, 1, -1, -1)
|
887 |
+
assert len(attention_mask.shape) == 4
|
888 |
+
|
889 |
+
if past_states is None:
|
890 |
+
# TODO: support seq_idx with cache
|
891 |
+
bool_mask_4d = attention_mask == 0
|
892 |
+
is_first_token = _is_first_token(bool_mask_4d)[:, 0, :]
|
893 |
+
seq_idx = torch.cumsum(is_first_token, dim=-1) - 1
|
894 |
+
seq_idx = seq_idx.to(torch.int32)
|
895 |
+
|
896 |
+
# `generate` function creates attention mask that contains past tokens,
|
897 |
+
# but mamba does not use them
|
898 |
+
attention_mask = attention_mask[:, 0, -length:, -length:]
|
899 |
+
bool_mask = torch.diagonal(attention_mask, dim1=-2, dim2=-1) == 0
|
900 |
+
|
901 |
+
conv_state: torch.Tensor | None
|
902 |
+
ssm_state: torch.Tensor | None
|
903 |
+
if past_states is None:
|
904 |
+
conv_state = None
|
905 |
+
ssm_state = None
|
906 |
+
elif past_states[self.layer_idx] is None:
|
907 |
+
conv_state = torch.zeros(
|
908 |
+
bsize, self.intermediate_size, self.d_conv - 1, dtype=hidden_states.dtype, device=hidden_states.device
|
909 |
+
)
|
910 |
+
ssm_state = torch.zeros(
|
911 |
+
bsize,
|
912 |
+
self.num_heads,
|
913 |
+
self.hidden_size_per_head,
|
914 |
+
self.d_state,
|
915 |
+
dtype=torch.float32,
|
916 |
+
device=hidden_states.device,
|
917 |
+
)
|
918 |
+
else:
|
919 |
+
c = past_states[self.layer_idx]
|
920 |
+
assert isinstance(c, Plamo2MambaCache)
|
921 |
+
conv_state = c.conv_state
|
922 |
+
ssm_state = c.ssm_state
|
923 |
+
|
924 |
+
zx = self.in_proj(hidden_states)
|
925 |
+
zx = zx.reshape(bsize, length, self.num_heads, -1)
|
926 |
+
# z: (bsize, length, num_heads, hidden_size_per_head)
|
927 |
+
# x: (bsize, length, num_heads, hidden_size_per_head)
|
928 |
+
z, x = torch.split(zx, [self.hidden_size_per_head, self.hidden_size_per_head], dim=-1)
|
929 |
+
|
930 |
+
# conv
|
931 |
+
x = x.reshape(bsize, length, -1).transpose(1, 2) # (bsize, intermediate_size, length)
|
932 |
+
if bool_mask is not None:
|
933 |
+
x = torch.where(bool_mask[:, None, :], x, 0.0)
|
934 |
+
if is_update:
|
935 |
+
assert conv_state is not None
|
936 |
+
x, conv_state = _causal_conv1d_update(conv_state, self.conv1d.weight, x)
|
937 |
+
else:
|
938 |
+
x, conv_state = _causal_conv1d(conv_state, self.conv1d.weight, x, seq_idx=seq_idx)
|
939 |
+
x = x.to(dtype=hidden_states.dtype)
|
940 |
+
x = x.transpose(1, 2) # (bsize, length, intermediate_size)
|
941 |
+
x = x.reshape(bsize, length, -1)
|
942 |
+
# x: (bsize, length, num_heads, hidden_size_per_head)
|
943 |
+
# B: (bsize, length, 1, d_state)
|
944 |
+
# C: (bsize, length, 1, d_state)
|
945 |
+
# dt: (bsize, length, dt_dim)
|
946 |
+
BCdt = self.bcdt_proj(x)
|
947 |
+
x = x.reshape(bsize, length, self.num_heads, -1)
|
948 |
+
B, C, dt = torch.split(BCdt, [self.d_state, self.d_state, self.dt_dim], dim=-1)
|
949 |
+
B = B[:, :, None, :]
|
950 |
+
C = C[:, :, None, :]
|
951 |
+
|
952 |
+
A = -torch.exp(self.A_log.float()) # (num_heads,)
|
953 |
+
dt = _rms_norm(dt, None, self.config.rms_norm_eps) * self.dt_norm_weight[None, None, :]
|
954 |
+
B = _rms_norm(B, None, self.config.rms_norm_eps) * self.B_norm_weight[None, None, None, :]
|
955 |
+
C = _rms_norm(C, None, self.config.rms_norm_eps) * self.C_norm_weight[None, None, None, :]
|
956 |
+
|
957 |
+
# (bsize, length, num_heads, 1)
|
958 |
+
dt = self.dt_proj(dt)[..., None]
|
959 |
+
|
960 |
+
# TODO it may not be required
|
961 |
+
B = B.expand(-1, -1, self.num_heads, -1)
|
962 |
+
C = C.expand(-1, -1, self.num_heads, -1)
|
963 |
+
|
964 |
+
if bool_mask is not None:
|
965 |
+
"""
|
966 |
+
state will be updates by following:
|
967 |
+
```
|
968 |
+
dt = softplus(dt)
|
969 |
+
dA = exp(dt * A)
|
970 |
+
state_next = state * dA + dB * x
|
971 |
+
```
|
972 |
+
|
973 |
+
To avoid updating state, we set dt to -inf and x to 0
|
974 |
+
because `softplus(-inf) = 0` and `exp(0) = 1`
|
975 |
+
"""
|
976 |
+
dt = torch.where(bool_mask[:, :, None, None], dt, float("-inf"))
|
977 |
+
x = torch.where(bool_mask[:, :, None, None], x, 0.0)
|
978 |
+
|
979 |
+
# ssm
|
980 |
+
if is_update:
|
981 |
+
assert ssm_state is not None
|
982 |
+
out = ssd_update_state(
|
983 |
+
ssm_state,
|
984 |
+
x[:, 0],
|
985 |
+
dt[:, 0].reshape(bsize, -1),
|
986 |
+
A,
|
987 |
+
B[:, 0],
|
988 |
+
C[:, 0],
|
989 |
+
D=self.D,
|
990 |
+
z=z[:, 0],
|
991 |
+
dt_bias=self.dt_bias,
|
992 |
+
dt_softplus=True,
|
993 |
+
)
|
994 |
+
else:
|
995 |
+
tmp = ssd_chunk_scan_combined(
|
996 |
+
x,
|
997 |
+
dt.reshape(bsize, length, -1),
|
998 |
+
A,
|
999 |
+
B,
|
1000 |
+
C,
|
1001 |
+
self.chunk_size,
|
1002 |
+
D=self.D,
|
1003 |
+
z=z,
|
1004 |
+
dt_bias=self.dt_bias,
|
1005 |
+
dt_softplus=True,
|
1006 |
+
return_final_states=past_states is not None,
|
1007 |
+
seq_idx=seq_idx,
|
1008 |
+
ssm_state=ssm_state,
|
1009 |
+
)
|
1010 |
+
if past_states is not None:
|
1011 |
+
out, ssm_state = tmp
|
1012 |
+
else:
|
1013 |
+
assert isinstance(tmp, torch.Tensor)
|
1014 |
+
out = tmp
|
1015 |
+
|
1016 |
+
y = self.out_proj(out.reshape(bsize, length, -1))
|
1017 |
+
|
1018 |
+
if past_states is not None:
|
1019 |
+
assert ssm_state is not None
|
1020 |
+
assert conv_state is not None
|
1021 |
+
past_states.update_mamba(conv_state, ssm_state, self.layer_idx)
|
1022 |
+
|
1023 |
+
return y, past_states
|
1024 |
+
|
1025 |
+
|
1026 |
+
def swa_mask(q_len: int, kv_len: int, device: torch.device, window_size: int) -> torch.Tensor:
|
1027 |
+
max_len = max(q_len, kv_len)
|
1028 |
+
mask = (
|
1029 |
+
torch.ones(max_len, max_len, dtype=torch.bool, device=device)
|
1030 |
+
.triu(diagonal=-window_size)
|
1031 |
+
.tril(diagonal=window_size)
|
1032 |
+
)
|
1033 |
+
return mask[-q_len:, -kv_len:]
|
1034 |
+
|
1035 |
+
|
1036 |
+
class Attention(torch.nn.Module):
|
1037 |
+
def __init__(self, config: Plamo2Config, layer_idx: int) -> None:
|
1038 |
+
super().__init__()
|
1039 |
+
self.config = config
|
1040 |
+
self.layer_idx = layer_idx
|
1041 |
+
self.hidden_size = config.hidden_size
|
1042 |
+
head_dim = config.hidden_size_per_head
|
1043 |
+
self.max_position_embeddings = config.max_position_embeddings
|
1044 |
+
|
1045 |
+
self.q_num_heads = config.num_attention_heads
|
1046 |
+
self.qk_dim = self.v_dim = head_dim
|
1047 |
+
self.k_num_heads = self.v_num_heads = config.num_key_value_heads
|
1048 |
+
assert self.q_num_heads % self.k_num_heads == 0
|
1049 |
+
self.n_group = self.q_num_heads // self.k_num_heads
|
1050 |
+
|
1051 |
+
self.q_proj_dim = self.q_num_heads * self.qk_dim
|
1052 |
+
self.k_proj_dim = self.k_num_heads * self.qk_dim
|
1053 |
+
self.v_proj_dim = self.k_num_heads * self.v_dim
|
1054 |
+
self.qkv_proj = nn.Linear(self.hidden_size, self.q_proj_dim + self.k_proj_dim + self.v_proj_dim, bias=False)
|
1055 |
+
self.o_proj = nn.Linear(self.q_num_heads * self.v_dim, self.hidden_size, bias=False)
|
1056 |
+
|
1057 |
+
self.q_weight = torch.nn.Parameter(torch.ones((self.q_num_heads, self.qk_dim)))
|
1058 |
+
self.k_weight = torch.nn.Parameter(torch.ones((self.k_num_heads, self.qk_dim)))
|
1059 |
+
|
1060 |
+
self.full_attn = self.layer_idx in self.config.full_attention_idx
|
1061 |
+
base = self.config.rope_theta if self.full_attn else self.config.rope_local_theta
|
1062 |
+
self.rotary_emb = RotaryEmbedding(
|
1063 |
+
self.qk_dim, max_position_embeddings=self.config.attention_window_size, base=base
|
1064 |
+
)
|
1065 |
+
|
1066 |
+
def forward(
|
1067 |
+
self,
|
1068 |
+
hidden_states: torch.Tensor,
|
1069 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1070 |
+
past_states: Optional[Plamo2Cache] = None,
|
1071 |
+
output_attentions: bool = False,
|
1072 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Plamo2Cache]]:
|
1073 |
+
bsz, q_len, _ = hidden_states.size()
|
1074 |
+
|
1075 |
+
qkv = self.qkv_proj(hidden_states)
|
1076 |
+
query_states, key_states, value_states = torch.split(
|
1077 |
+
qkv, [self.q_proj_dim, self.k_proj_dim, self.v_proj_dim], dim=-1
|
1078 |
+
)
|
1079 |
+
query_states = query_states.view(bsz, q_len, self.q_num_heads, self.qk_dim).transpose(1, 2)
|
1080 |
+
key_states = key_states.view(bsz, q_len, self.k_num_heads, self.qk_dim).transpose(1, 2)
|
1081 |
+
value_states = value_states.view(bsz, q_len, self.v_num_heads, self.v_dim).transpose(1, 2)
|
1082 |
+
|
1083 |
+
attn_dtype = query_states.dtype
|
1084 |
+
|
1085 |
+
query_states = _rms_norm(query_states, None, 1e-6) * self.q_weight[None, :, None]
|
1086 |
+
key_states = _rms_norm(key_states, None, 1e-6) * self.k_weight[None, :, None]
|
1087 |
+
|
1088 |
+
if past_states is not None:
|
1089 |
+
# reuse k, v, self_attention
|
1090 |
+
key_states_new = key_states
|
1091 |
+
value_states_new = value_states
|
1092 |
+
key_states, value_states = past_states.append_kv(key_states, value_states, self.layer_idx) # type: ignore
|
1093 |
+
past_states.update_attention(key_states_new, value_states_new, self.layer_idx)
|
1094 |
+
|
1095 |
+
kv_seq_len = key_states.shape[-2]
|
1096 |
+
device = hidden_states.device
|
1097 |
+
position_ids = torch.arange(kv_seq_len, dtype=torch.long, device=device)[None]
|
1098 |
+
q_position_ids = position_ids[:, -query_states.shape[2] :]
|
1099 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
1100 |
+
query_states = _rotary_pos_emb(query_states, cos, sin, q_position_ids)
|
1101 |
+
key_states = _rotary_pos_emb(key_states, cos, sin, position_ids)
|
1102 |
+
# [bsz, nh, t, hd]
|
1103 |
+
|
1104 |
+
def _expand_kv(t: torch.Tensor, repeat: int, target: int) -> torch.Tensor:
|
1105 |
+
t = torch.repeat_interleave(t, repeat, dim=1)
|
1106 |
+
return t[:, :target]
|
1107 |
+
|
1108 |
+
# expand shared kv
|
1109 |
+
assert self.k_num_heads == self.v_num_heads
|
1110 |
+
key_states = _expand_kv(key_states, self.n_group, self.q_num_heads)
|
1111 |
+
value_states = _expand_kv(value_states, self.n_group, self.q_num_heads)
|
1112 |
+
|
1113 |
+
query_states = query_states.to(attn_dtype)
|
1114 |
+
key_states = key_states.to(attn_dtype)
|
1115 |
+
value_states = value_states.to(attn_dtype)
|
1116 |
+
if attention_mask is not None and attention_mask.dtype != torch.bool:
|
1117 |
+
attention_mask = attention_mask.to(attn_dtype)
|
1118 |
+
if attention_mask is None:
|
1119 |
+
if not self.full_attn:
|
1120 |
+
assert key_states.shape[2] <= self.config.attention_window_size + 1
|
1121 |
+
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=True)
|
1122 |
+
else:
|
1123 |
+
if attention_mask.dtype == torch.bool:
|
1124 |
+
attention_mask = torch.where(attention_mask, torch.tensor(0.0, dtype=torch.float), float("-inf"))
|
1125 |
+
if len(attention_mask.shape) == 2:
|
1126 |
+
attention_mask = attention_mask[None, None]
|
1127 |
+
assert len(attention_mask.shape) == 4
|
1128 |
+
|
1129 |
+
if not self.full_attn:
|
1130 |
+
m_swa = swa_mask(
|
1131 |
+
query_states.shape[2], key_states.shape[2], query_states.device, self.config.attention_window_size
|
1132 |
+
)
|
1133 |
+
# `generate` function creates attention mask that does not consider sliding window
|
1134 |
+
m_swa = m_swa[None, None]
|
1135 |
+
attention_mask = attention_mask[:, :, -query_states.shape[2] :, -key_states.shape[2] :]
|
1136 |
+
attention_mask = torch.where(m_swa, attention_mask, float("-inf"))
|
1137 |
+
|
1138 |
+
# like AttentionMaskConverter._unmask_unattended in huggingface.transfoermers,
|
1139 |
+
# we need to attend to all tokens in masked rows for `scaled_dot_product_attention`
|
1140 |
+
bool_mask = torch.logical_not(torch.isneginf(attention_mask))
|
1141 |
+
valid_tokens = torch.sum(bool_mask, dim=-1).bool() # (..., q_len)
|
1142 |
+
attention_mask = torch.where(valid_tokens[..., None], attention_mask, float(0.0))
|
1143 |
+
attn_output = F.scaled_dot_product_attention(
|
1144 |
+
query_states, key_states, value_states, attn_mask=attention_mask
|
1145 |
+
)
|
1146 |
+
|
1147 |
+
attn_output = attn_output.transpose(1, 2)
|
1148 |
+
|
1149 |
+
attn_output = attn_output.reshape(bsz, q_len, self.q_num_heads * self.v_dim)
|
1150 |
+
attn_output = self.o_proj(attn_output)
|
1151 |
+
|
1152 |
+
if not output_attentions:
|
1153 |
+
attn_weights = None
|
1154 |
+
|
1155 |
+
return attn_output, attn_weights, past_states
|
1156 |
+
|
1157 |
+
|
1158 |
+
class MLP(nn.Module):
|
1159 |
+
def __init__(self, config: Plamo2Config) -> None:
|
1160 |
+
super().__init__()
|
1161 |
+
self.config = config
|
1162 |
+
self.hidden_size = config.hidden_size
|
1163 |
+
self.intermediate_size = config.intermediate_size
|
1164 |
+
self.gate_up_proj = torch.nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
|
1165 |
+
self.down_proj = torch.nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
1166 |
+
|
1167 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
1168 |
+
h = self.gate_up_proj(x)
|
1169 |
+
h = _swiglu(h)
|
1170 |
+
return self.down_proj(h) # type: ignore
|
1171 |
+
|
1172 |
+
|
1173 |
+
class Plamo2DecoderLayer(torch.nn.Module):
|
1174 |
+
def __init__(self, config: Plamo2Config, layer_idx: int) -> None:
|
1175 |
+
super().__init__()
|
1176 |
+
self.config = config
|
1177 |
+
self.hidden_size = config.hidden_size
|
1178 |
+
self.is_mamba = config.layers_block_type[layer_idx] == "mamba"
|
1179 |
+
self.mixer: torch.nn.Module
|
1180 |
+
if self.is_mamba:
|
1181 |
+
self.mixer = Mamba(config, layer_idx)
|
1182 |
+
else:
|
1183 |
+
self.mixer = Attention(config, layer_idx)
|
1184 |
+
self.mlp = MLP(config)
|
1185 |
+
"""
|
1186 |
+
Notes: The model performance was degraded when setting all offsets to 1.
|
1187 |
+
"""
|
1188 |
+
self.pre_mixer_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0)
|
1189 |
+
self.post_mixer_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / 5)
|
1190 |
+
self.pre_mlp_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0)
|
1191 |
+
self.post_mlp_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / (5**1.5))
|
1192 |
+
|
1193 |
+
def forward(
|
1194 |
+
self,
|
1195 |
+
hidden_states: torch.Tensor,
|
1196 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1197 |
+
past_state: Optional[Plamo2Cache] = None,
|
1198 |
+
output_attentions: Optional[bool] = False,
|
1199 |
+
) -> Tuple[Any, ...]:
|
1200 |
+
# from LlamaDecoder
|
1201 |
+
residual = hidden_states
|
1202 |
+
hidden_states = self.pre_mixer_norm(hidden_states)
|
1203 |
+
|
1204 |
+
# Self Attention
|
1205 |
+
if self.is_mamba:
|
1206 |
+
hidden_states_sa, present_key_value = self.mixer(
|
1207 |
+
hidden_states=hidden_states,
|
1208 |
+
attention_mask=attention_mask,
|
1209 |
+
past_states=past_state,
|
1210 |
+
)
|
1211 |
+
self_attn_weights = None
|
1212 |
+
else:
|
1213 |
+
hidden_states_sa, self_attn_weights, present_key_value = self.mixer(
|
1214 |
+
hidden_states=hidden_states,
|
1215 |
+
attention_mask=attention_mask,
|
1216 |
+
past_states=past_state,
|
1217 |
+
output_attentions=output_attentions,
|
1218 |
+
)
|
1219 |
+
|
1220 |
+
hidden_states_sa = self.post_mixer_norm(hidden_states_sa)
|
1221 |
+
hidden_states = residual + hidden_states_sa
|
1222 |
+
|
1223 |
+
residual = hidden_states
|
1224 |
+
hidden_states = self.pre_mlp_norm(hidden_states)
|
1225 |
+
|
1226 |
+
# Fully Connected
|
1227 |
+
hidden_states_mlp = self.mlp(hidden_states)
|
1228 |
+
|
1229 |
+
# Residual
|
1230 |
+
hidden_states_mlp = self.post_mlp_norm(hidden_states_mlp)
|
1231 |
+
hidden_states = residual + hidden_states_mlp
|
1232 |
+
|
1233 |
+
outputs: Any = (hidden_states,)
|
1234 |
+
|
1235 |
+
if output_attentions:
|
1236 |
+
outputs += (self_attn_weights,)
|
1237 |
+
|
1238 |
+
return outputs # type: ignore
|
1239 |
+
|
1240 |
+
|
1241 |
+
def is_mamba(config: Plamo2Config, i: int) -> bool:
|
1242 |
+
if not config.mamba_enabled:
|
1243 |
+
return False
|
1244 |
+
assert config.mamba_step > 1
|
1245 |
+
assert i < config.num_hidden_layers
|
1246 |
+
|
1247 |
+
if config.num_hidden_layers <= (config.mamba_step // 2):
|
1248 |
+
# use attention in last layer
|
1249 |
+
return i != config.num_hidden_layers - 1
|
1250 |
+
return (i % config.mamba_step) != (config.mamba_step // 2)
|
1251 |
+
|
1252 |
+
|
1253 |
+
class Plamo2Decoder(torch.nn.Module):
|
1254 |
+
def __init__(self, config: Plamo2Config) -> None:
|
1255 |
+
super().__init__()
|
1256 |
+
|
1257 |
+
self.layers = torch.nn.ModuleList(
|
1258 |
+
[Plamo2DecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
1259 |
+
)
|
1260 |
+
self.gradient_checkpointing = False
|
1261 |
+
|
1262 |
+
def forward(self, x: DecoderInput) -> DecoderOutput:
|
1263 |
+
all_hidden_states: Optional[Tuple[torch.Tensor, ...]] = () if x.output_hidden_states else None
|
1264 |
+
all_self_attns: Optional[Tuple[torch.Tensor, ...]] = () if x.output_attentions else None
|
1265 |
+
hidden_states = x.hidden_states
|
1266 |
+
|
1267 |
+
for decoder_layer in self.layers:
|
1268 |
+
if x.output_hidden_states:
|
1269 |
+
assert all_hidden_states is not None
|
1270 |
+
all_hidden_states += (hidden_states,)
|
1271 |
+
|
1272 |
+
if self.training and x.gradient_checkpointing:
|
1273 |
+
layer_outputs = self._gradient_checkpointing_func(
|
1274 |
+
decoder_layer.__call__,
|
1275 |
+
hidden_states,
|
1276 |
+
x.attention_mask,
|
1277 |
+
x.past_states,
|
1278 |
+
x.output_attentions,
|
1279 |
+
)
|
1280 |
+
else:
|
1281 |
+
layer_outputs = decoder_layer(
|
1282 |
+
hidden_states,
|
1283 |
+
attention_mask=x.attention_mask,
|
1284 |
+
past_state=x.past_states,
|
1285 |
+
output_attentions=x.output_attentions,
|
1286 |
+
)
|
1287 |
+
|
1288 |
+
hidden_states = layer_outputs[0]
|
1289 |
+
|
1290 |
+
if x.output_attentions:
|
1291 |
+
assert layer_outputs[1] is not None
|
1292 |
+
assert all_self_attns is not None
|
1293 |
+
all_self_attns += (layer_outputs[1],)
|
1294 |
+
return DecoderOutput(hidden_states, all_hidden_states, all_self_attns)
|
1295 |
+
|
1296 |
+
|
1297 |
+
class Plamo2PreTrainedModel(PreTrainedModel): # type: ignore
|
1298 |
+
config_class = Plamo2Config
|
1299 |
+
_no_split_modules: List[str]
|
1300 |
+
base_model_prefix = "model"
|
1301 |
+
supports_gradient_checkpointing = True
|
1302 |
+
_no_split_modules = ["PlamoDecoderLayer"]
|
1303 |
+
_skip_keys_device_placement = "past_key_values"
|
1304 |
+
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
|
1305 |
+
|
1306 |
+
def _init_weights(self, module: torch.nn.Module) -> None:
|
1307 |
+
std = 0.02
|
1308 |
+
if isinstance(module, nn.Linear):
|
1309 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
1310 |
+
if module.bias is not None:
|
1311 |
+
module.bias.data.zero_()
|
1312 |
+
elif isinstance(module, nn.Embedding):
|
1313 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
1314 |
+
if module.padding_idx is not None:
|
1315 |
+
module.weight.data[module.padding_idx].zero_()
|
1316 |
+
|
1317 |
+
|
1318 |
+
class Plamo2Model(Plamo2PreTrainedModel):
|
1319 |
+
def __init__(self, config: Plamo2Config):
|
1320 |
+
super().__init__(config)
|
1321 |
+
assert config.eval_attention_n_bit is None
|
1322 |
+
assert config.eval_mlp_n_bit is None
|
1323 |
+
|
1324 |
+
self.padding_idx = config.pad_token_id
|
1325 |
+
self.vocab_size = config.vocab_size
|
1326 |
+
|
1327 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
1328 |
+
if config.image_feature_size is not None:
|
1329 |
+
if config.image_proj_type == "mlp":
|
1330 |
+
self.image_proj = MLPImageProjector(config) # type: ignore
|
1331 |
+
elif config.image_proj_type == "linear":
|
1332 |
+
self.image_proj = nn.Linear(config.image_feature_size, config.hidden_size, bias=False) # type: ignore
|
1333 |
+
else:
|
1334 |
+
raise ValueError(f"Unknown image_proj_type: {config.image_proj_type}")
|
1335 |
+
self.layers = Plamo2Decoder(config) # type: ignore
|
1336 |
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
1337 |
+
|
1338 |
+
self.gradient_checkpointing = False
|
1339 |
+
# Initialize weights and apply final processing
|
1340 |
+
self.post_init()
|
1341 |
+
|
1342 |
+
def get_input_embeddings(self) -> torch.nn.Embedding:
|
1343 |
+
return self.embed_tokens
|
1344 |
+
|
1345 |
+
def set_input_embeddings(self, value: torch.nn.Embedding) -> None:
|
1346 |
+
self.embed_tokens = value
|
1347 |
+
|
1348 |
+
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
1349 |
+
def _prepare_decoder_attention_mask(
|
1350 |
+
self,
|
1351 |
+
attention_mask: torch.Tensor,
|
1352 |
+
input_shape: Tuple[int, int],
|
1353 |
+
inputs_embeds: Optional[torch.Tensor],
|
1354 |
+
past_key_values_length: int,
|
1355 |
+
) -> Optional[torch.Tensor]:
|
1356 |
+
# create causal mask
|
1357 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
1358 |
+
combined_attention_mask: Optional[torch.Tensor] = None
|
1359 |
+
if input_shape[-1] > 1:
|
1360 |
+
assert inputs_embeds is not None
|
1361 |
+
combined_attention_mask = _make_causal_mask(
|
1362 |
+
input_shape,
|
1363 |
+
inputs_embeds.dtype,
|
1364 |
+
device=inputs_embeds.device,
|
1365 |
+
past_key_values_length=past_key_values_length,
|
1366 |
+
)
|
1367 |
+
input_shape = (input_shape[0], combined_attention_mask.shape[2])
|
1368 |
+
|
1369 |
+
if attention_mask is not None:
|
1370 |
+
if attention_mask.dim() == 4:
|
1371 |
+
# Custom 4D attention mask
|
1372 |
+
expanded_attn_mask = attention_mask
|
1373 |
+
else:
|
1374 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
1375 |
+
assert inputs_embeds is not None
|
1376 |
+
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
|
1377 |
+
inputs_embeds.device
|
1378 |
+
)
|
1379 |
+
combined_attention_mask = (
|
1380 |
+
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
1381 |
+
)
|
1382 |
+
|
1383 |
+
return combined_attention_mask
|
1384 |
+
|
1385 |
+
def forward(
|
1386 |
+
self,
|
1387 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1388 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1389 |
+
position_ids: Optional[torch.Tensor] = None,
|
1390 |
+
past_key_values: Optional[Plamo2Cache] = None,
|
1391 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1392 |
+
image_features: Optional[torch.Tensor] = None,
|
1393 |
+
use_cache: Optional[bool] = None,
|
1394 |
+
output_attentions: Optional[bool] = None,
|
1395 |
+
output_hidden_states: Optional[bool] = None,
|
1396 |
+
return_dict: Optional[bool] = None,
|
1397 |
+
cache_position: Optional[torch.LongTensor] = None,
|
1398 |
+
**kwargs: Any,
|
1399 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
1400 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1401 |
+
output_hidden_states = (
|
1402 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1403 |
+
)
|
1404 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
1405 |
+
|
1406 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1407 |
+
|
1408 |
+
# retrieve input_ids and inputs_embeds
|
1409 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
1410 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
1411 |
+
|
1412 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
1413 |
+
use_cache = False
|
1414 |
+
|
1415 |
+
if inputs_embeds is None:
|
1416 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
1417 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
1418 |
+
|
1419 |
+
seq_length_with_past = seq_length
|
1420 |
+
past_key_values_length = 0
|
1421 |
+
if past_key_values is not None:
|
1422 |
+
past_key_values_length = past_key_values.get_seq_length()
|
1423 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
1424 |
+
assert cache_position is None, "cache_position is not supported yet"
|
1425 |
+
|
1426 |
+
if image_features is not None:
|
1427 |
+
assert self.config.image_token_id is not None
|
1428 |
+
image_embeds = self.image_proj(image_features)
|
1429 |
+
assert image_embeds.shape == inputs_embeds.shape, (image_embeds.shape, inputs_embeds.shape)
|
1430 |
+
mask = input_ids == self.config.image_token_id
|
1431 |
+
inputs_embeds[mask] = image_embeds[mask]
|
1432 |
+
|
1433 |
+
# embed positions
|
1434 |
+
require_attn_mask = False
|
1435 |
+
if not self.training or past_key_values is not None:
|
1436 |
+
require_attn_mask = True
|
1437 |
+
if seq_length_with_past >= self.config.attention_window_size:
|
1438 |
+
require_attn_mask = True
|
1439 |
+
if require_attn_mask and attention_mask is None:
|
1440 |
+
attention_mask = torch.ones(
|
1441 |
+
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
1442 |
+
)
|
1443 |
+
if attention_mask is not None:
|
1444 |
+
attention_mask = self._prepare_decoder_attention_mask(
|
1445 |
+
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
1446 |
+
)
|
1447 |
+
|
1448 |
+
hidden_states = inputs_embeds
|
1449 |
+
|
1450 |
+
if use_cache and past_key_values is None:
|
1451 |
+
past_key_values = Plamo2Cache(self.config)
|
1452 |
+
|
1453 |
+
# decoder layers
|
1454 |
+
out = self.layers(
|
1455 |
+
DecoderInput(
|
1456 |
+
hidden_states,
|
1457 |
+
attention_mask,
|
1458 |
+
past_key_values,
|
1459 |
+
output_hidden_states,
|
1460 |
+
output_attentions,
|
1461 |
+
self.gradient_checkpointing,
|
1462 |
+
)
|
1463 |
+
)
|
1464 |
+
assert isinstance(out, DecoderOutput)
|
1465 |
+
hidden_states = out.hidden_states
|
1466 |
+
all_hidden_states = out.all_hidden_states
|
1467 |
+
all_self_attns = out.all_self_attns
|
1468 |
+
|
1469 |
+
hidden_states = self.norm(hidden_states)
|
1470 |
+
|
1471 |
+
# add hidden states from the last decoder layer
|
1472 |
+
if output_hidden_states:
|
1473 |
+
assert all_hidden_states is not None
|
1474 |
+
all_hidden_states += (hidden_states,)
|
1475 |
+
|
1476 |
+
if not return_dict:
|
1477 |
+
return tuple(
|
1478 |
+
v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None
|
1479 |
+
)
|
1480 |
+
return BaseModelOutputWithPast(
|
1481 |
+
last_hidden_state=hidden_states,
|
1482 |
+
past_key_values=past_key_values,
|
1483 |
+
hidden_states=all_hidden_states,
|
1484 |
+
attentions=all_self_attns,
|
1485 |
+
)
|
1486 |
+
|
1487 |
+
|
1488 |
+
class Plamo2ForCausalLM(Plamo2PreTrainedModel):
|
1489 |
+
_tied_weights_keys = ["lm_head.weight"]
|
1490 |
+
|
1491 |
+
# Without this, the model cannot be loaded into a meta device.
|
1492 |
+
# Relevant code:
|
1493 |
+
# https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/modeling_utils.py#L4376-L4381
|
1494 |
+
# https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/modeling_utils.py#L356
|
1495 |
+
# https://github.com/pytorch/pytorch/blob/v2.4.1/torch/nn/modules/module.py#L2068
|
1496 |
+
_supports_param_buffer_assignment = False
|
1497 |
+
|
1498 |
+
def __init__(self, config: Plamo2Config) -> None:
|
1499 |
+
super().__init__(config)
|
1500 |
+
self.model = Plamo2Model(config)
|
1501 |
+
|
1502 |
+
self.vocab_size = config.vocab_size
|
1503 |
+
vocab_size = ((self.vocab_size + 15) // 16) * 16
|
1504 |
+
self.lm_head: torch.nn.Module = nn.Linear(config.hidden_size, vocab_size, bias=False)
|
1505 |
+
|
1506 |
+
# Initialize weights and apply final processing
|
1507 |
+
self.post_init()
|
1508 |
+
|
1509 |
+
def get_input_embeddings(self) -> torch.nn.Embedding:
|
1510 |
+
return self.model.embed_tokens
|
1511 |
+
|
1512 |
+
def set_input_embeddings(self, value: torch.nn.Embedding) -> None:
|
1513 |
+
self.model.embed_tokens = value
|
1514 |
+
|
1515 |
+
def get_output_embeddings(self) -> torch.nn.Module:
|
1516 |
+
return self.lm_head
|
1517 |
+
|
1518 |
+
def set_output_embeddings(self, new_embeddings: torch.nn.Module) -> None:
|
1519 |
+
self.lm_head = new_embeddings
|
1520 |
+
|
1521 |
+
def set_decoder(self, decoder: Plamo2Model) -> None:
|
1522 |
+
self.model = decoder
|
1523 |
+
|
1524 |
+
def get_decoder(self) -> Plamo2Model:
|
1525 |
+
return self.model
|
1526 |
+
|
1527 |
+
def forward( # type: ignore
|
1528 |
+
self,
|
1529 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1530 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1531 |
+
position_ids: Optional[torch.Tensor] = None,
|
1532 |
+
past_key_values: Optional[Plamo2Cache] = None,
|
1533 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1534 |
+
image_features: Optional[torch.Tensor] = None,
|
1535 |
+
labels: Optional[torch.LongTensor] = None,
|
1536 |
+
use_cache: Optional[bool] = None,
|
1537 |
+
output_attentions: Optional[bool] = None,
|
1538 |
+
output_hidden_states: Optional[bool] = None,
|
1539 |
+
return_dict: Optional[bool] = None,
|
1540 |
+
cache_position: Optional[torch.LongTensor] = None,
|
1541 |
+
logits_to_keep: int | torch.Tensor = 0,
|
1542 |
+
**kwargs: Any,
|
1543 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
1544 |
+
r"""
|
1545 |
+
Args:
|
1546 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1547 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
1548 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
1549 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
1550 |
+
|
1551 |
+
Returns:
|
1552 |
+
|
1553 |
+
Example:
|
1554 |
+
|
1555 |
+
```python
|
1556 |
+
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
1557 |
+
|
1558 |
+
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
1559 |
+
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
1560 |
+
|
1561 |
+
>>> prompt = "Hey, are you consciours? Can you talk to me?"
|
1562 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
1563 |
+
|
1564 |
+
>>> # Generate
|
1565 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
1566 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
1567 |
+
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
|
1568 |
+
```"""
|
1569 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1570 |
+
output_hidden_states = (
|
1571 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1572 |
+
)
|
1573 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1574 |
+
|
1575 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
1576 |
+
outputs = self.model(
|
1577 |
+
input_ids=input_ids,
|
1578 |
+
attention_mask=attention_mask,
|
1579 |
+
position_ids=position_ids,
|
1580 |
+
past_key_values=past_key_values,
|
1581 |
+
inputs_embeds=inputs_embeds,
|
1582 |
+
image_features=image_features,
|
1583 |
+
use_cache=use_cache,
|
1584 |
+
output_attentions=output_attentions,
|
1585 |
+
output_hidden_states=output_hidden_states,
|
1586 |
+
return_dict=return_dict,
|
1587 |
+
cache_position=cache_position,
|
1588 |
+
**kwargs,
|
1589 |
+
)
|
1590 |
+
|
1591 |
+
hidden_states = outputs[0]
|
1592 |
+
logits = self.lm_head(hidden_states)
|
1593 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
1594 |
+
logits = logits[:, slice_indices, : self.vocab_size]
|
1595 |
+
|
1596 |
+
loss = None
|
1597 |
+
if labels is not None:
|
1598 |
+
if len(kwargs) > 0 and set(kwargs.keys()) != set(["ignore_index"]):
|
1599 |
+
warnings.warn(
|
1600 |
+
f"The following kwargs may not be supported: {', '.join(kwargs.keys())}. ",
|
1601 |
+
stacklevel=2,
|
1602 |
+
)
|
1603 |
+
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
1604 |
+
|
1605 |
+
if not return_dict:
|
1606 |
+
output = (logits,) + outputs[1:]
|
1607 |
+
return (loss,) + output if loss is not None else output
|
1608 |
+
|
1609 |
+
return CausalLMOutputWithPast(
|
1610 |
+
loss=loss,
|
1611 |
+
logits=logits,
|
1612 |
+
past_key_values=outputs.past_key_values,
|
1613 |
+
hidden_states=outputs.hidden_states,
|
1614 |
+
attentions=outputs.attentions,
|
1615 |
+
)
|
1616 |
+
|
1617 |
+
def prepare_inputs_for_generation(
|
1618 |
+
self,
|
1619 |
+
input_ids: torch.Tensor,
|
1620 |
+
past_key_values: Optional[Plamo2Cache] = None,
|
1621 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1622 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1623 |
+
image_features: Optional[torch.Tensor] = None,
|
1624 |
+
**kwargs: Any,
|
1625 |
+
) -> Dict[str, Any]:
|
1626 |
+
if past_key_values:
|
1627 |
+
input_ids = input_ids[:, -1:]
|
1628 |
+
if image_features is not None:
|
1629 |
+
image_features = image_features[:, -1:, :]
|
1630 |
+
|
1631 |
+
position_ids = kwargs.get("position_ids", None)
|
1632 |
+
if attention_mask is not None and position_ids is None:
|
1633 |
+
# create position_ids on the fly for batch generation
|
1634 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
1635 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
1636 |
+
if past_key_values:
|
1637 |
+
position_ids = position_ids[:, -1].unsqueeze(-1)
|
1638 |
+
|
1639 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
1640 |
+
if inputs_embeds is not None and past_key_values is None:
|
1641 |
+
model_inputs: Dict[str, Any] = {"inputs_embeds": inputs_embeds}
|
1642 |
+
else:
|
1643 |
+
model_inputs = {"input_ids": input_ids}
|
1644 |
+
|
1645 |
+
model_inputs.update(
|
1646 |
+
{
|
1647 |
+
"position_ids": position_ids,
|
1648 |
+
"past_key_values": past_key_values,
|
1649 |
+
"use_cache": kwargs.get("use_cache"),
|
1650 |
+
"attention_mask": attention_mask,
|
1651 |
+
"image_features": image_features,
|
1652 |
+
}
|
1653 |
+
)
|
1654 |
+
return model_inputs
|
1655 |
+
|
1656 |
+
@staticmethod
|
1657 |
+
def _reorder_cache(past_key_values: Plamo2Cache, beam_idx: torch.Tensor) -> Plamo2Cache:
|
1658 |
+
past_key_values.reorder_cache(beam_idx)
|
1659 |
+
return past_key_values
|
1660 |
+
|
1661 |
+
|
1662 |
+
class MLPImageProjector(nn.Module):
|
1663 |
+
def __init__(self, config: Plamo2Config) -> None:
|
1664 |
+
super().__init__()
|
1665 |
+
self.config = config
|
1666 |
+
|
1667 |
+
assert config.image_feature_size is not None # for typing
|
1668 |
+
|
1669 |
+
# nn.LayerNorm is not supported by PFVM, so use RMSNorm + Bias instead to approximate this.
|
1670 |
+
self.norm0 = RMSNorm(config.image_feature_size, eps=config.rms_norm_eps)
|
1671 |
+
self.bias0 = Bias(config.image_feature_size)
|
1672 |
+
|
1673 |
+
# PFVM doesn't support Linear with bias, so add bias manually afterwards.
|
1674 |
+
self.linear1 = nn.Linear(config.image_feature_size, config.hidden_size, bias=False)
|
1675 |
+
self.bias1 = Bias(config.hidden_size)
|
1676 |
+
self.act1 = nn.GELU()
|
1677 |
+
|
1678 |
+
self.linear2 = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
|
1679 |
+
self.bias2 = Bias(config.hidden_size)
|
1680 |
+
|
1681 |
+
def forward(
|
1682 |
+
self,
|
1683 |
+
hidden_states: torch.Tensor,
|
1684 |
+
) -> torch.Tensor:
|
1685 |
+
hidden_states = self.norm0(hidden_states)
|
1686 |
+
hidden_states = self.bias0(hidden_states)
|
1687 |
+
|
1688 |
+
hidden_states = self.linear1(hidden_states)
|
1689 |
+
hidden_states = self.bias1(hidden_states)
|
1690 |
+
hidden_states = self.act1(hidden_states)
|
1691 |
+
|
1692 |
+
hidden_states = self.linear2(hidden_states)
|
1693 |
+
hidden_states = self.bias2(hidden_states)
|
1694 |
+
|
1695 |
+
return hidden_states
|
1696 |
+
|
1697 |
+
|
1698 |
+
class Bias(nn.Module):
|
1699 |
+
def __init__(self, num_features: int) -> None:
|
1700 |
+
super().__init__()
|
1701 |
+
self._bias = nn.Parameter(torch.zeros((num_features,)))
|
1702 |
+
|
1703 |
+
def forward(
|
1704 |
+
self,
|
1705 |
+
x: torch.Tensor,
|
1706 |
+
) -> torch.Tensor:
|
1707 |
+
return x + self._bias
|
special_tokens_map.json
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": {
|
3 |
+
"content": "<|plamo:bos|>",
|
4 |
+
"lstrip": false,
|
5 |
+
"normalized": false,
|
6 |
+
"rstrip": false,
|
7 |
+
"single_word": false
|
8 |
+
},
|
9 |
+
"eos_token": "<|plamo:bos|>",
|
10 |
+
"pad_token": {
|
11 |
+
"content": "<|plamo:pad|>",
|
12 |
+
"lstrip": false,
|
13 |
+
"normalized": false,
|
14 |
+
"rstrip": false,
|
15 |
+
"single_word": false
|
16 |
+
},
|
17 |
+
"unk_token": {
|
18 |
+
"content": "<|plamo:unk|>",
|
19 |
+
"lstrip": false,
|
20 |
+
"normalized": false,
|
21 |
+
"rstrip": false,
|
22 |
+
"single_word": false
|
23 |
+
}
|
24 |
+
}
|
tokenization_plamo.py
ADDED
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import math
|
3 |
+
import os
|
4 |
+
from shutil import copyfile
|
5 |
+
from typing import Any, Optional, Tuple
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
# NOTE: numba does not support type hints for njit: https://github.com/python/mypy/issues/16149
|
10 |
+
from numba import njit # type: ignore[attr-defined]
|
11 |
+
from numba.core import types
|
12 |
+
from numba.typed import Dict, List
|
13 |
+
from transformers.tokenization_utils import PreTrainedTokenizer
|
14 |
+
from transformers.utils import logging
|
15 |
+
|
16 |
+
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.jsonl"}
|
17 |
+
logger = logging.get_logger(__name__)
|
18 |
+
|
19 |
+
INVALID_SCORE = -20000000
|
20 |
+
UNKNOWN_SCORE = -10000000
|
21 |
+
|
22 |
+
TABLE_PIECE_LENGTH = 0
|
23 |
+
TABLE_TOKEN_ID = 1
|
24 |
+
TABLE_SCORE = 2
|
25 |
+
TABLE_PIECE_ID = 3
|
26 |
+
|
27 |
+
PATH_TOKEN_LENGTH = 0
|
28 |
+
PATH_TOKEN_ID = 1
|
29 |
+
PATH_NUM_TOKENS = 2
|
30 |
+
|
31 |
+
|
32 |
+
class AhoCorasick:
|
33 |
+
def __init__(self) -> None:
|
34 |
+
# List of tokens in the vocabulary.
|
35 |
+
self._tokens: list[str]
|
36 |
+
|
37 |
+
# A mapping from a byte code point to a token ID, used for byte fallback.
|
38 |
+
self._bytes: np.ndarray
|
39 |
+
|
40 |
+
# A mapping from a suffix's piece code to a suffix ID.
|
41 |
+
#
|
42 |
+
# Typically, the Aho-Corasick algorithm builds a Trie and adds suffix links between nodes
|
43 |
+
# of the Trie. In this implementation, a suffix ID corresponds to a node in the trie, and
|
44 |
+
# a piece code to an edge (in other words, a pair of a node and the next character).
|
45 |
+
#
|
46 |
+
# A piece code is a 64-bit integer:
|
47 |
+
# - The upper 32 bits store the Unicode code point of the first character.
|
48 |
+
# - The lower 32 bits store the suffix ID of the remaining suffix.
|
49 |
+
#
|
50 |
+
# A suffix ID is an integer indicating the starting position in the _table.
|
51 |
+
self._to_suffix_id: Dict[types.int64, types.int32]
|
52 |
+
|
53 |
+
# Flattened table representing the Trie structure for the Aho-Corasick algorithm.
|
54 |
+
# It stores information including scores for each piece (prefix) within each suffix.
|
55 |
+
# It is flattened for memory efficiency and performance. Suffixes are stored in
|
56 |
+
# lexicographical order of their reversed strings, which improves memory access locality
|
57 |
+
# when exploring new characters starting from the string's end. Pieces within a suffix are
|
58 |
+
# stored in the decreasing order of their lengths.
|
59 |
+
#
|
60 |
+
# Each piece (a prefix fo the suffix) contains four pieces of information:
|
61 |
+
# - TABLE_PIECE_LENGTH: Length of the piece.
|
62 |
+
# - TABLE_TOKEN_ID: Token ID (or -1 if the piece is not a valid token).
|
63 |
+
# - TABLE_SCORE: Score (or INVALID_SCORE if the piece is not a valid token).
|
64 |
+
# - TABLE_PIECE_ID: Piece ID of the suffix.
|
65 |
+
#
|
66 |
+
# Each suffix also includes a sentinel row with a length of 1, a score of UNKNOWN_SCORE,
|
67 |
+
# and a token ID of -1. Sentinel rows are identified by the score being UNKNOWN_SCORE.
|
68 |
+
self._table: np.ndarray
|
69 |
+
|
70 |
+
def build(self, vocab: list[Any]) -> None:
|
71 |
+
self._bytes = np.zeros(256, dtype=np.int32)
|
72 |
+
self._to_suffix_id = Dict.empty(key_type=types.int64, value_type=types.int32)
|
73 |
+
|
74 |
+
# Build suffix_to_score and token_to_token_id.
|
75 |
+
# The suffix_to_score dictionary maps a suffix to its score. It also includes all suffixes
|
76 |
+
# of the token for the Trie structure for the Aho-Corasick algorithm. If a suffix is not a
|
77 |
+
# valid token, its score is set to math.nan.
|
78 |
+
# The token_to_token_id dictionary maps a token to its token ID.
|
79 |
+
suffix_to_score: dict[str, float] = {}
|
80 |
+
token_to_token_id: dict[str, int] = {}
|
81 |
+
self._tokens = []
|
82 |
+
for token_id, row in enumerate(vocab):
|
83 |
+
assert isinstance(row[0], str), row
|
84 |
+
assert isinstance(row[1], (int, float)), row
|
85 |
+
|
86 |
+
token = str(row[0])
|
87 |
+
self._tokens.append(token)
|
88 |
+
token_to_token_id[token] = token_id
|
89 |
+
|
90 |
+
# Special handling for byte tokens.
|
91 |
+
if len(row) > 2 and row[2] == "BYTE":
|
92 |
+
assert len(token) == 6 and token.startswith("<0x") and token.endswith(">"), row[0]
|
93 |
+
self._bytes[int(row[0][3:5], 16)] = token_id
|
94 |
+
continue
|
95 |
+
|
96 |
+
suffix_to_score[token] = float(row[1])
|
97 |
+
# Ensure that all suffixes are included in suffix_to_score.
|
98 |
+
for i in range(1, len(token)):
|
99 |
+
suffix_to_score[token[i:]] = suffix_to_score.get(token[i:], math.nan)
|
100 |
+
|
101 |
+
# Ensure all byte tokens are set.
|
102 |
+
for i in range(256):
|
103 |
+
assert self._bytes[i] != 0, f"Byte token for <0x{i:02X}> is not set."
|
104 |
+
|
105 |
+
# List suffixes in lexicographical order of their reversed strings.
|
106 |
+
suffixes = list(suffix_to_score.keys())
|
107 |
+
suffixes.append("")
|
108 |
+
suffixes.sort(key=lambda x: x[::-1])
|
109 |
+
|
110 |
+
# Build suffix_to_id, which is a mapping from a suffix to a suffix ID, and _to_suffix_id,
|
111 |
+
# which is a mapping from a piece code to a suffix ID.
|
112 |
+
suffix_to_id: dict[str, int] = {}
|
113 |
+
num_pieces = 0
|
114 |
+
for s in suffixes:
|
115 |
+
suffix_to_id[s] = num_pieces
|
116 |
+
if s != "":
|
117 |
+
self._to_suffix_id[ord(s[0]) << 32 | suffix_to_id[s[1:]]] = np.int32(num_pieces)
|
118 |
+
num_pieces += 1 + sum(s[:i] in suffix_to_score for i in range(1, len(s) + 1))
|
119 |
+
assert suffix_to_id[""] == 0, suffix_to_id[""]
|
120 |
+
|
121 |
+
# Build _table, which is a flattened table representing the Trie structure for the Aho-Corasick.
|
122 |
+
self._table = np.zeros((num_pieces, 4), dtype=np.int32)
|
123 |
+
i = 0
|
124 |
+
for suffix in suffixes:
|
125 |
+
# Add all prefixes of the suffix to the table.
|
126 |
+
for piece_length in range(len(suffix), 0, -1):
|
127 |
+
piece = suffix[:piece_length]
|
128 |
+
score = suffix_to_score.get(piece, None)
|
129 |
+
if score is None:
|
130 |
+
continue
|
131 |
+
self._table[i, TABLE_PIECE_LENGTH] = piece_length
|
132 |
+
self._table[i, TABLE_TOKEN_ID] = token_to_token_id.get(piece, -1)
|
133 |
+
self._table[i, TABLE_SCORE] = round(score * 1e4) if math.isfinite(score) else INVALID_SCORE
|
134 |
+
self._table[i, TABLE_PIECE_ID] = suffix_to_id[piece]
|
135 |
+
i += 1
|
136 |
+
|
137 |
+
# Add a sentinel row.
|
138 |
+
self._table[i, TABLE_PIECE_LENGTH] = 1
|
139 |
+
self._table[i, TABLE_TOKEN_ID] = -1
|
140 |
+
self._table[i, TABLE_SCORE] = UNKNOWN_SCORE
|
141 |
+
i += 1
|
142 |
+
assert i == num_pieces, (i, num_pieces)
|
143 |
+
|
144 |
+
@staticmethod
|
145 |
+
@njit
|
146 |
+
def _encode(
|
147 |
+
to_suffix_id: Dict[types.int64, types.int32],
|
148 |
+
table: np.ndarray,
|
149 |
+
bytes: np.ndarray,
|
150 |
+
data: np.ndarray,
|
151 |
+
) -> np.ndarray:
|
152 |
+
# Initialize scores array with a high value and set the score at the end to 0.
|
153 |
+
# This array keeps track of the minimum cost (best score) to encode from each position to the end.
|
154 |
+
scores = np.full((len(data) + 1,), 2**60, dtype=np.int64)
|
155 |
+
scores[-1] = 0
|
156 |
+
|
157 |
+
# Path array to store the best path information.
|
158 |
+
# The path array keeps track of token length, token ID, and number of tokens needed to encode.
|
159 |
+
path = np.zeros((len(data) + 1, 3), dtype=np.int32)
|
160 |
+
|
161 |
+
# Initialize suffix_id to 0, which represents the root of the Trie.
|
162 |
+
suffix_id = 0
|
163 |
+
|
164 |
+
# Process the input data from the end to the beginning.
|
165 |
+
for i in range(len(data) - 1, -1, -1):
|
166 |
+
c = data[i]
|
167 |
+
|
168 |
+
# Find the next suffix ID by iterating the suffix IDs of prefixes of the current suffix.
|
169 |
+
# NOTE: If no suffix ID is found, suffix_id will be set to 0.
|
170 |
+
for p in range(suffix_id, len(table)):
|
171 |
+
suffix_id = to_suffix_id.get(c << 32 | table[p, TABLE_PIECE_ID], np.int32(0))
|
172 |
+
# If a next suffix ID is found or a sentinel row is reached, break the loop.
|
173 |
+
if suffix_id > 0 or table[p, TABLE_SCORE] == UNKNOWN_SCORE:
|
174 |
+
break
|
175 |
+
|
176 |
+
# Update the best path to the current position. If multiple paths have the same score,
|
177 |
+
# this chooses the longest prefix as the best path (table is sorted in the decreasing
|
178 |
+
# order of piece length).
|
179 |
+
for p in range(suffix_id, len(table)):
|
180 |
+
score = table[p, TABLE_SCORE]
|
181 |
+
if score > INVALID_SCORE:
|
182 |
+
piece_length = table[p, TABLE_PIECE_LENGTH]
|
183 |
+
s = scores[i + piece_length] - score
|
184 |
+
if s < scores[i]:
|
185 |
+
scores[i] = s
|
186 |
+
path[i, PATH_TOKEN_LENGTH] = piece_length
|
187 |
+
path[i, PATH_TOKEN_ID] = table[p, TABLE_TOKEN_ID]
|
188 |
+
path[i, PATH_NUM_TOKENS] = path[i + piece_length, PATH_NUM_TOKENS] + 1
|
189 |
+
if score == UNKNOWN_SCORE:
|
190 |
+
# Add number of bytes to represent `c` in UTF-8 (minus 1; 1 is already
|
191 |
+
# added above).
|
192 |
+
path[i, PATH_NUM_TOKENS] += (c >= 0x80) + (c >= 0x800) + (c >= 0x10000)
|
193 |
+
|
194 |
+
# If it reaches a sentinel row, break the loop.
|
195 |
+
if score == UNKNOWN_SCORE:
|
196 |
+
break
|
197 |
+
|
198 |
+
# Decode the best path from the beginning to get the token IDs.
|
199 |
+
pos = 0
|
200 |
+
token_ids = np.zeros(path[0, PATH_NUM_TOKENS], dtype=np.int32)
|
201 |
+
token_pos = 0
|
202 |
+
while pos < len(data):
|
203 |
+
if path[pos, PATH_TOKEN_ID] >= 0:
|
204 |
+
token_ids[token_pos] = path[pos, PATH_TOKEN_ID]
|
205 |
+
token_pos += 1
|
206 |
+
else:
|
207 |
+
# Fall back to byte tokens.
|
208 |
+
c = data[pos]
|
209 |
+
s = 1 + (c >= 0x80) + (c >= 0x800) + (c >= 0x10000)
|
210 |
+
# Add byte tokens representing UTF-8 bytes.
|
211 |
+
for i in range(s):
|
212 |
+
b = c if s == 1 else (0xF00 >> s) & 0xFF if i == 0 else 0x80
|
213 |
+
token_ids[token_pos] = bytes[b | ((c >> (s - i - 1) * 6) & 0x3F)]
|
214 |
+
token_pos += 1
|
215 |
+
|
216 |
+
# Ensure that pos should increase by at least 1.
|
217 |
+
assert path[pos, PATH_TOKEN_LENGTH] > 0, (pos, path[pos])
|
218 |
+
pos += path[pos, PATH_TOKEN_LENGTH]
|
219 |
+
|
220 |
+
return token_ids
|
221 |
+
|
222 |
+
def encode(self, data: str) -> np.ndarray:
|
223 |
+
"""Encodes a string into a sequence of token IDs."""
|
224 |
+
return np.asarray(
|
225 |
+
self._encode(
|
226 |
+
self._to_suffix_id,
|
227 |
+
self._table,
|
228 |
+
self._bytes,
|
229 |
+
# Convert a string into a numpy array of Unicode code points.
|
230 |
+
# NOTE: This skips UTF-32 BOM.
|
231 |
+
np.frombuffer(data.encode("utf-32"), dtype=np.int32)[1:],
|
232 |
+
)
|
233 |
+
)
|
234 |
+
|
235 |
+
def encode_as_tokens(self, data: str) -> list[str]:
|
236 |
+
"""Encodes a string into a sequence of tokens."""
|
237 |
+
return [self._tokens[token_id] for token_id in self.encode(data)]
|
238 |
+
|
239 |
+
|
240 |
+
class Plamo2Tokenizer(PreTrainedTokenizer): # type: ignore
|
241 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
242 |
+
model_input_names = ["input_ids", "attention_mask"]
|
243 |
+
|
244 |
+
_save_files = [
|
245 |
+
"special_tokens_map.json",
|
246 |
+
"tokenization_plamo.py",
|
247 |
+
"tokenizer.jsonl",
|
248 |
+
"tokenizer_config.json",
|
249 |
+
]
|
250 |
+
|
251 |
+
def __init__(
|
252 |
+
self,
|
253 |
+
vocab_file: str,
|
254 |
+
unk_token: str = "<|plamo:unk|>",
|
255 |
+
bos_token: str = "<|plamo:bos|>",
|
256 |
+
eos_token: str = "<|plamo:eos|>",
|
257 |
+
pad_token: str = "<|plamo:pad|>",
|
258 |
+
cls_token: Optional[str] = None,
|
259 |
+
sep_token: Optional[str] = None,
|
260 |
+
mask_token: Optional[str] = None,
|
261 |
+
clean_up_tokenization_spaces: bool = False,
|
262 |
+
**kwargs: Any,
|
263 |
+
) -> None:
|
264 |
+
"""Tokenizer for PLaMo.
|
265 |
+
|
266 |
+
Args:
|
267 |
+
vocab_file (str): Vocabrary file path.
|
268 |
+
unk_token (str): Unknown token.
|
269 |
+
bos_token (str): Beginning of sentence token.
|
270 |
+
eos_token (str): End of sentence token.
|
271 |
+
pad_token (str): Padding token.
|
272 |
+
cls_token (str):
|
273 |
+
Classification token, to extract a summary of an input sequence leveraging self-attention along the
|
274 |
+
full depth of the model.
|
275 |
+
sep_token (str): Separation token, to separate context and query in an input sequence.
|
276 |
+
mask_token (str): Mask token, to use when training a model with masked-language modeling.
|
277 |
+
clean_up_tokenization_spaces (bool): Whether or not to clean up the tokenization spaces.
|
278 |
+
num_threads (int):
|
279 |
+
Number of threads. This value will be ignored if one of `PLAMO_TOKENIZER_NUM_THREADS` or
|
280 |
+
`RAYON_NUM_THREADS` is set as an environment variable.
|
281 |
+
"""
|
282 |
+
if "add_bos_token" not in kwargs:
|
283 |
+
kwargs["add_bos_token"] = False
|
284 |
+
if "add_eos_token" not in kwargs:
|
285 |
+
kwargs["add_eos_token"] = False
|
286 |
+
self.data: list[Any] = [json.loads(line) for line in open(vocab_file, "r", encoding="utf-8")]
|
287 |
+
self.vocab: dict[str, int] = {v[0]: i for i, v in enumerate(self.data)}
|
288 |
+
self.aho_corasick = AhoCorasick()
|
289 |
+
self.aho_corasick.build(self.data)
|
290 |
+
self.vocab_file = vocab_file
|
291 |
+
self.add_bos_token = kwargs["add_bos_token"]
|
292 |
+
self.add_eos_token = kwargs["add_eos_token"]
|
293 |
+
|
294 |
+
super().__init__(
|
295 |
+
vocab_file=vocab_file,
|
296 |
+
unk_token=unk_token,
|
297 |
+
bos_token=bos_token,
|
298 |
+
eos_token=eos_token,
|
299 |
+
pad_token=pad_token,
|
300 |
+
cls_token=cls_token,
|
301 |
+
sep_token=sep_token,
|
302 |
+
mask_token=mask_token,
|
303 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
304 |
+
**kwargs,
|
305 |
+
)
|
306 |
+
|
307 |
+
# the functions below are copied from hf transformers LlamaTokenizer's implementation to fix the behaviour of the tokenizer
|
308 |
+
# https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/models/llama/tokenization_llama.py
|
309 |
+
|
310 |
+
def __getstate__(self) -> dict[str, Any]:
|
311 |
+
state = self.__dict__.copy()
|
312 |
+
state["aho_corasick"] = None
|
313 |
+
return state
|
314 |
+
|
315 |
+
def __setstate__(self, d: dict[str, Any]) -> None:
|
316 |
+
self.__dict__ = d
|
317 |
+
self.aho_corasick = AhoCorasick()
|
318 |
+
self.aho_corasick.build(self.data)
|
319 |
+
|
320 |
+
@property
|
321 |
+
def vocab_size(self) -> Any:
|
322 |
+
"""Returns vocab size"""
|
323 |
+
return len(self.data)
|
324 |
+
|
325 |
+
def token_to_score(self, token: str) -> Optional[float]:
|
326 |
+
"""Returns score of the token"""
|
327 |
+
token_id = self.vocab.get(token, None)
|
328 |
+
return None if token_id is None else self.data[token_id][1]
|
329 |
+
|
330 |
+
def get_vocab(self) -> dict[str, int]:
|
331 |
+
"""Returns vocab as a dict"""
|
332 |
+
vocab = self.vocab.copy()
|
333 |
+
vocab.update(self.added_tokens_encoder)
|
334 |
+
return vocab
|
335 |
+
|
336 |
+
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
337 |
+
"""Converts a sequence of tokens (string) in a single string."""
|
338 |
+
return b"".join(
|
339 |
+
[bytes([int(t[3:5], 16)]) if t.startswith("<0x") else t.encode("utf-8") for t in tokens]
|
340 |
+
).decode("utf-8", errors="replace")
|
341 |
+
|
342 |
+
def _tokenize(self, text: str) -> Any:
|
343 |
+
"""Returns a tokenized string."""
|
344 |
+
return self.aho_corasick.encode_as_tokens(text)
|
345 |
+
|
346 |
+
def _convert_token_to_id(self, token: str) -> Any:
|
347 |
+
"""Converts a token (str) in an id using the vocab."""
|
348 |
+
return self.vocab.get(token, 0)
|
349 |
+
|
350 |
+
def _convert_id_to_token(self, index: int) -> Any:
|
351 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
352 |
+
return self.data[index][0]
|
353 |
+
|
354 |
+
def build_inputs_with_special_tokens(
|
355 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
356 |
+
) -> List[int]:
|
357 |
+
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
|
358 |
+
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
|
359 |
+
|
360 |
+
output = bos_token_id + token_ids_0 + eos_token_id
|
361 |
+
|
362 |
+
if token_ids_1 is not None:
|
363 |
+
output = output + bos_token_id + token_ids_1 + eos_token_id
|
364 |
+
|
365 |
+
return output
|
366 |
+
|
367 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
368 |
+
"""
|
369 |
+
Save the vocabulary and special tokens file to a directory.
|
370 |
+
|
371 |
+
Args:
|
372 |
+
save_directory (`str`):
|
373 |
+
The directory in which to save the vocabulary.
|
374 |
+
|
375 |
+
Returns:
|
376 |
+
`Tuple(str)`: Paths to the files saved.
|
377 |
+
"""
|
378 |
+
if not os.path.isdir(save_directory):
|
379 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
380 |
+
return ("",)
|
381 |
+
out_vocab_file = os.path.join(
|
382 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
383 |
+
)
|
384 |
+
|
385 |
+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
386 |
+
copyfile(self.vocab_file, out_vocab_file)
|
387 |
+
elif not os.path.isfile(self.vocab_file):
|
388 |
+
with open(out_vocab_file, "w") as f:
|
389 |
+
for token in self.data:
|
390 |
+
print(json.dumps(token, ensure_ascii=False), file=f)
|
391 |
+
|
392 |
+
return (out_vocab_file,)
|
tokenizer.jsonl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3af94303d6c676e875e13165abf4e4a9d42c663ef247a8985cdeedb223cc0681
|
3 |
+
size 10573745
|
tokenizer_config.json
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_bos_token": true,
|
3 |
+
"add_eos_token": false,
|
4 |
+
"added_tokens_decoder": {
|
5 |
+
"0": {
|
6 |
+
"content": "<|plamo:unk|>",
|
7 |
+
"lstrip": false,
|
8 |
+
"normalized": false,
|
9 |
+
"rstrip": false,
|
10 |
+
"single_word": false,
|
11 |
+
"special": true
|
12 |
+
},
|
13 |
+
"1": {
|
14 |
+
"content": "<|plamo:bos|>",
|
15 |
+
"lstrip": false,
|
16 |
+
"normalized": false,
|
17 |
+
"rstrip": false,
|
18 |
+
"single_word": false,
|
19 |
+
"special": true
|
20 |
+
},
|
21 |
+
"2": {
|
22 |
+
"content": "<|plamo:eos|>",
|
23 |
+
"lstrip": false,
|
24 |
+
"normalized": false,
|
25 |
+
"rstrip": false,
|
26 |
+
"single_word": false,
|
27 |
+
"special": true
|
28 |
+
},
|
29 |
+
"3": {
|
30 |
+
"content": "<|plamo:pad|>",
|
31 |
+
"lstrip": false,
|
32 |
+
"normalized": false,
|
33 |
+
"rstrip": false,
|
34 |
+
"single_word": false,
|
35 |
+
"special": true
|
36 |
+
}
|
37 |
+
},
|
38 |
+
"auto_map": {
|
39 |
+
"AutoTokenizer": [
|
40 |
+
"tokenization_plamo.Plamo2Tokenizer",
|
41 |
+
null
|
42 |
+
]
|
43 |
+
},
|
44 |
+
"bos_token": "<|plamo:bos|>",
|
45 |
+
"clean_up_tokenization_spaces": false,
|
46 |
+
"cls_token": null,
|
47 |
+
"eos_token": "<|plamo:bos|>",
|
48 |
+
"extra_special_tokens": {},
|
49 |
+
"local_file_only": true,
|
50 |
+
"mask_token": null,
|
51 |
+
"model_max_length": 1000000000000000019884624838656,
|
52 |
+
"pad_token": "<|plamo:pad|>",
|
53 |
+
"sep_token": null,
|
54 |
+
"tokenizer_class": "Plamo2Tokenizer",
|
55 |
+
"unk_token": "<|plamo:unk|>"
|
56 |
+
}
|