imos commited on
Commit
e10be72
·
verified ·
1 Parent(s): b3176ac

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.jsonl filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: other
3
+ license_name: plamo-community-license
4
+ license_link: https://huggingface.co/pfnet/plamo-2-8b/blob/main/LICENSE/ja
5
+ language:
6
+ - en
7
+ - ja
8
+ pipeline_tag: text-generation
9
+ library_name: transformers
10
+ extra_gated_heading: PLaMo community license to download PLaMo 2 8B
11
+ extra_gated_description: To download PLaMo 2 8B, you have to agree to our license.
12
+ PLaMo 2 8B is released PLaMo community license. For non-commerical use, please contact
13
+ us via this [form](https://forms.gle/mTL8tBLrMYXKNZD56).
14
+ extra_gated_button_content: agree to PLaMo community license
15
+ extra_gated_prompt: "(English version is under construction. We apologize for the\
16
+ \ inconvenience.)\n### PLaMoコミュニティライセンス契約\nPLaMoコミュニティライセンス契約には、株式会社Preferred Networksが提供する別途定める大規模言語基盤モデルPLaMo及びその派生物を利用するためのライセンスの内容及びユーザーが遵守する事項等が定められている。ユーザーのPLaMo及びその派生物の利用には本契約が適用され、本契約に同意又は本モデル等を利用することにより、ユーザーは本契約に拘束される。\n\
17
+ #### 第1条(定義)\n(1) 「本契約」とは、PLaMoコミュニティライセンス契約を意味する。\n(2) 「PFN」とは、株式会社Preferred Networksを意味する。\n\
18
+ (3) 「本モデル」とは、別途定める「PLaMo」という名称のモデルの重み、モデルコード、トークナイザー、学習スクリプト及びこれらに付随してPFNが提供するものを意味する。\n\
19
+ (4) 「ユーザー」とは、本モデルを利用する個人又は法人を意味する。\n(5) 「派生モデル」とは、本モデルを改変又は利用し作成されるモデルの重み、モデルコード及びその他作成されたモデルの付随物を意味する。\n\
20
+ (6) 「生成物」とは、本モデル又は派生モデルの出力結果を意味する。\n(7) 「本モデル等」とは、本モデル、派生モデル及び生成物の総称を意味する。\n(8)\
21
+ \ 「本ライセンス」とは、PFNがユーザーに対して本契約に基づき本モデル等を利用することを許諾することを意味する。\n(9) 「商業目的」とは、 私的使用又は学術用途の範囲を超える、事業での利用又は営利を目的とする利用を意味する。なお、商業目的にはユーザーの製品、サービス又は事業の開発、変更又は提供(ホスティングサービスやAPI経由での提供を含む。)を目的とした使用及びユーザーの組織内部における利用も含まれる。\n\
22
+ #### 第2条(ユーザー)\nユーザーは、18歳以上又はその居住国で単独で契約を締結できる年齢に達していなければならない。但し、ユーザーの親権者又は法定代理人が本契約をユーザーが締結することに同意している場合はこの限りではない。\n\
23
+ #### 第3条(本ライセンス)\n(1) PFNは、ユーザーが本契約に同意しかつ本契約を遵守することを条件に、ユーザーに対して、本モデル等を本契約に定める条件及び範囲内で利用することを許諾する。\n\
24
+ (2) 本ライセンスは非独占、世界的、譲渡不可及びロイヤリティ無料とする。\n(3) ユーザーは、以下の条件をいずれも満たす場合に限り、商業目的を含む形で本モデル等を利用することができる。なお、ユーザーがこれらの条件のいずれかを満たさなくなった場合は、ユーザーはその時点で本モデル等を商業目的で利用することはできず、商業目的で本モデル等を利用したい場合は、新たにPFNから商業用のライセンスを取得しなければならない。\n\
25
+ \n (i) PFNの公式登録ページ https://forms.gle/mTL8tBLrMYXKNZD56 に事前に登録すること。\n\n (ii) ユーザー又はその関係会社の直近事業年度の収入又は売上が10億円(ユーザーの現地通貨換算額)を超えないこと。\n\
26
+ \n#### 第4条(再配布及び表示義務)\n(1) ユーザーが本モデル等(派生モデルやその生成物を含む)を第三者に提供する場合、以下の条件を満たさなければならない。\n\
27
+ \n (i) 本契約のコピーを提供し、本契約の条件を遵守させること。\n\n (ii) 「Built with PLaMo」と明示し、関連ウェブサイト、ユーザーインターフェース、ブログ記事、製品情報ページ又は製品ドキュメントに記載すること。\n\
28
+ \n (iii) 本モデル等を利用して作成した AI モデルの名称に「PLaMo」を含めること。\n\n#### 第5条(生成物の利用)\n(1) ユーザーは、生成物を本モデル又は派生モデルの生成物であることを明示することを条件に、公表することができる。\n\
29
+ (2) 生成物を利用してモデルを学習した場合、そのモデルは派生モデルとして本契約の条件が適用され、本契約のライセンス条件の下でのみ利用、配布及び商業化することができる。\n\
30
+ #### 第6条(その他利用条件)\nユーザーは、本モデル等の利用に関して、以下に定める行為をしてはならない。\n(1) 法令又は公序良俗に違反する行為\n(2)\
31
+ \ 犯罪行為又はこれを予告、関与、助長その他これらに関連する行為\n(3) PFN又は第三者の権利又は利益を侵害する行為\n(4) PFN又は第三者の名誉若しくは信用を毀損する行為\n\
32
+ (5) 生成物がPFNの公式見解等であるものという錯誤を生む情報を流布する行為\n(6) 虚偽の情報を発信する行為\n(7) 上記の他、PFNが不適切と合理的に判断する行為\n\
33
+ #### 第7条(保証の否認)\n(1) 本モデル及び生成物は、「現状有姿」で提供され、PFNは、これらに対して、正確性、真実性、商品性、品質、性能、特定目的への適合性、権利の非侵害など一切の保証をしない。\n\
34
+ (2) ユーザーは、法律、医療、金融又は人物評価その他重要な事項の決定に関して、生成物を唯一の証拠、評価又は意見として使用してはならない。\n(3) ユーザーは、本モデル等の使用及びその結果に関して全ての責任を負う。\n\
35
+ #### 第8条(責任の制限)\n(1) 契約責任、不法行為又は製造物責任その他の法的責任のいずれかであるかを問わず、PFNが本契約及び本モデル等に関してユーザーに対して負う損害賠償の責任は、通常かつ直接の損害に限り(逸失利益、特別損害、間接損害その他の損害については、その予見可能性の有無に関わらず、責任を負わない。)、損害賠償額の上限は、500円とする。但し、PFNに故意又は重過失が認められる場合はこの限りではない。\n\
36
+ (2) 前項に関わらず、ユーザーが本モデル等を事業のために利用する場合は、PFNは本契約及び本モデル等に関してユーザーに対して一切の損害賠償責任及びその他の責任を負わない。\n\
37
+ #### 第9条(ユーザーの責任)\n(1) ユーザーは、本モデル等の取得及び利用に関して、適用される法令(輸出入及び貿易に関連する法令を含む。)及び本契約を遵守する。\n\
38
+ (2) ユーザーは、本契約違反又は本モデル等の使用によって、PFNに損害を与えた場合は、その損害を賠償する。\n(3) ユーザーの本モデル等の使用に起因して、PFNが第三者から損害賠償請求その他請求を受けた場合、ユーザーは、当該請求からPFNを免責し、PFNに損害を与えないようにする。\n\
39
+ #### 第10条(権利の帰属)\n(1) 本モデルの一切の権利は、PFN又はPFNに本モデルのライセンスをしている第三者に帰属する。\n(2) 派生モデルのうち、ユーザーが本モデルを改変した部分の権利はユーザーに帰属し、その他の部分の権利はPFNに帰属する。\n\
40
+ (3) 生成物の一切の権利はユーザーに帰属する。\n#### 第11条(契約期間及び終了)\n(1) 本契約は、ユーザーが本契約に同意したとき又は本モデルにアクセスしたときから、本契約が解約されたときまでとする。\n\
41
+ (2) ユーザーが本契約のいずれかの条項に違反した場合、PFNは直ちに本契約を解除することができ、ユーザーは本モデル等のすべてのコピーを削除し、利用を即時に停止しなければならない。\n\
42
+ #### 第12条(契約の変更)\nPFNは、本契約(本モデル等に関するルールや諸規定等を含む。以下本条において同じ。)を変更できるものとする。PFNは、本契約を変更する場合には、変更の内容及び変更の効力発生時期を、当該効力発生時期までにPFN所定の方法で告知するものとする。\n\
43
+ #### 第13条(準拠法及び管轄裁判所)\n(1) 本契約の準拠法は日本法とする。\n(2) 本モデル等及び本契約に起因する紛争については、東京地方裁判所が専属的合意管轄裁判所とする。"
44
+ base_model: pfnet/plamo-2-8b
45
+ tags:
46
+ - plamo
47
+ - translation
48
+ ---
49
+
50
+ # PLaMo Translation Model
51
+
52
+ PLaMo翻訳モデルはPreferred Networksによって開発された翻訳向け特化型大規模言語モデルです。
53
+ 詳しくは[ブログ記事](https://tech.preferred.jp/ja/blog/plamo-translate/)および[プレスリリース](https://www.preferred.jp/ja/news/pr20250527/)を参照してください。
54
+
55
+ PLaMo Translation Model is a specialized large-scale language model developed by Preferred Networks for translation tasks.
56
+ For details, please refer to the [blog post](https://tech.preferred.jp/ja/blog/plamo-translate/) and [press release](https://www.preferred.jp/ja/news/pr20250527/).
57
+
58
+ List of models:
59
+ - [plamo-2-translate](http://huggingface.co/pfnet/plamo-2-translate) ... Post-trained model for translation
60
+ - [plamo-2-translate-base](http://huggingface.co/pfnet/plamo-2-translate-base) ... Base model for translation
61
+ - [plamo-2-translate-eval](http://huggingface.co/pfnet/plamo-2-translate-eval) ... Pair-wise evaluation model
62
+
63
+ PLaMo Translation Model is released under PLaMo community license. Please check the following license and agree to this before downloading.
64
+
65
+ - (EN) under construction: we apologize for the inconvenience
66
+ - (JA) https://www.preferred.jp/ja/plamo-community-license/
67
+
68
+ **NOTE**: This model has **NOT** been instruction-tuned for chat dialog or other downstream tasks.
69
+
70
+ ### For *commercial* users
71
+
72
+ Please check the PLaMo community license and contact us via the following form to use commercial purpose.
73
+
74
+ - (EN/JA) https://forms.gle/mTL8tBLrMYXKNZD56
75
+
76
+ ## Usage
77
+
78
+ ### main/base model
79
+
80
+ ```py
81
+ import vllm
82
+
83
+ # max_model_len/max_num_batched_tokens can be increased when running on a GPU with substantial memory.
84
+ # NOTE: Switch to "pfnet/plamo-2-translate-base" to try the base model.
85
+ llm = vllm.LLM(model="pfnet/plamo-2-translate", trust_remote_code=True, max_model_len=2000, max_num_batched_tokens=2000)
86
+
87
+ prompt = r'''<|plamo:op|>dataset
88
+ translation
89
+ <|plamo:op|>input lang=English
90
+ Write the text to be translated here.
91
+ <|plamo:op|>output lang=Japanese
92
+ '''
93
+
94
+ responses = llm.generate([prompt] * 1, sampling_params=vllm.SamplingParams(temperature=0, max_tokens=1024, stop=["<|plamo:op|>"]))
95
+ # NOTE: This outputs "ここに翻訳するテキストを入力してください。".
96
+ print(responses[0].outputs[0].text)
97
+ ```
98
+
99
+ ### evaluation model
100
+
101
+ ```py
102
+ import vllm
103
+
104
+ # max_model_len/max_num_batched_tokens can be increased when running on a GPU with substantial memory.
105
+ llm = vllm.LLM(model="pfnet/plamo-2-translate-eval", trust_remote_code=True, max_model_len=2000, max_num_batched_tokens=2000)
106
+
107
+ prompt = r'''<|plamo:op|>dataset
108
+ translation evaluation
109
+ <|plamo:op|>input lang=English
110
+ This is an apple.
111
+ <|plamo:op|>output id=A lang=Japanese
112
+ これはりんごです。
113
+ <|plamo:op|>output id=B lang=Japanese
114
+ これはリンゴです。
115
+ <|plamo:op|>best
116
+ id='''
117
+
118
+ responses = llm.generate([prompt] * 1, sampling_params=vllm.SamplingParams(temperature=0, max_tokens=1, stop=["<|plamo:op|>"]))
119
+ # NOTE: This outputs "A".
120
+ print(responses[0].outputs[0].text)
121
+ ```
122
+
123
+ ## Bias, Risks, and Limitations
124
+
125
+ PLaMo Translation Model is a new technology that carries risks with use. Testing conducted to date has been in English and Japanese, and has not covered, nor could it cover all scenarios. For these reasons, as with all LLMs, PLaMo Translation Model’s potential outputs cannot be predicted in advance, and the model may in some instances produce inaccurate, biased or other objectionable responses to user prompts. Therefore, before deploying any applications of PLaMo Translation Model, developers should perform safety testing and tuning tailored to their specific applications of the model.
126
+
127
+ ## Acknowledgement
128
+
129
+ This model is trained under the project, “Research and Development Project of the Enhanced Infrastructures for Post 5G Information and Communication System” (JPNP 20017), subsidized by the New Energy and Industrial Technology Development Organization (NEDO).
130
+
131
+
132
+ ## AI policies for Preferred Networks, Inc. group
133
+
134
+ - (EN) https://www.preferred.jp/en/company/aipolicy/
135
+ - (JA) https://www.preferred.jp/ja/company/aipolicy/
config.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Plamo2ForCausalLM"
4
+ ],
5
+ "attention_window_size": 32768,
6
+ "auto_map": {
7
+ "AutoConfig": "modeling_plamo.Plamo2Config",
8
+ "AutoModelForCausalLM": "modeling_plamo.Plamo2ForCausalLM"
9
+ },
10
+ "bos_token_id": 1,
11
+ "eos_token_id": 1,
12
+ "eval_attention_n_bit": null,
13
+ "eval_mlp_n_bit": null,
14
+ "fp8_accum_dtype": "bfloat16",
15
+ "full_attention_idx": [],
16
+ "hidden_size": 4096,
17
+ "hidden_size_per_head": 128,
18
+ "image_feature_size": null,
19
+ "image_proj_type": "linear",
20
+ "image_token_id": null,
21
+ "intermediate_size": 16384,
22
+ "linear_type": "normal",
23
+ "mamba_chunk_size": 256,
24
+ "mamba_d_conv": 4,
25
+ "mamba_d_state": 64,
26
+ "mamba_enabled": true,
27
+ "mamba_num_heads": 64,
28
+ "mamba_step": 2,
29
+ "max_position_embeddings": 10485760,
30
+ "model_type": "plamo2",
31
+ "num_attention_heads": 32,
32
+ "num_hidden_layers": 32,
33
+ "num_key_value_heads": 4,
34
+ "pad_token_id": 3,
35
+ "rms_norm_eps": 1e-06,
36
+ "rope_local_theta": 1000000.0,
37
+ "rope_theta": 1000000.0,
38
+ "sliding_window": 32768,
39
+ "tokenizer_class": "Plamo2Tokenizer",
40
+ "torch_dtype": "bfloat16",
41
+ "transformers_version": "4.46.2",
42
+ "use_cache": false,
43
+ "vocab_size": 100032
44
+ }
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 1,
5
+ "pad_token_id": 3,
6
+ "transformers_version": "4.46.2"
7
+ }
model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:707a7c1f9bb4982b098864d7215a3985fa3f026f2ebd0a82799fea72e807349f
3
+ size 4771172576
model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ab65beeb0731f66b6542cbfb3b6967f72e12386f55bd2234d690cdcda991ac4f
3
+ size 4964802416
model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:88a4bb84ec3e982f16497c375013e7efdc07890a8cd7126be0bcdefe6193a9f7
3
+ size 4832590896
model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a90408e0d462b199de23f1e29730d1608c3c425d384c6a3c0dc34fbeea9b335f
3
+ size 3668492768
model.safetensors.index.json ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 18237007872
4
+ },
5
+ "weight_map": {
6
+ "model.embed_tokens.weight": "model-00001-of-00004.safetensors",
7
+ "model.layers.layers.0.mixer.A_log": "model-00001-of-00004.safetensors",
8
+ "model.layers.layers.0.mixer.B_norm_weight": "model-00001-of-00004.safetensors",
9
+ "model.layers.layers.0.mixer.C_norm_weight": "model-00001-of-00004.safetensors",
10
+ "model.layers.layers.0.mixer.D": "model-00001-of-00004.safetensors",
11
+ "model.layers.layers.0.mixer.bcdt_proj.weight": "model-00001-of-00004.safetensors",
12
+ "model.layers.layers.0.mixer.conv1d.weight": "model-00001-of-00004.safetensors",
13
+ "model.layers.layers.0.mixer.dt_bias": "model-00001-of-00004.safetensors",
14
+ "model.layers.layers.0.mixer.dt_norm_weight": "model-00001-of-00004.safetensors",
15
+ "model.layers.layers.0.mixer.dt_proj.weight": "model-00001-of-00004.safetensors",
16
+ "model.layers.layers.0.mixer.in_proj.weight": "model-00001-of-00004.safetensors",
17
+ "model.layers.layers.0.mixer.out_proj.weight": "model-00001-of-00004.safetensors",
18
+ "model.layers.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
19
+ "model.layers.layers.0.mlp.gate_up_proj.weight": "model-00001-of-00004.safetensors",
20
+ "model.layers.layers.0.post_mixer_norm.weight": "model-00001-of-00004.safetensors",
21
+ "model.layers.layers.0.post_mlp_norm.weight": "model-00001-of-00004.safetensors",
22
+ "model.layers.layers.0.pre_mixer_norm.weight": "model-00001-of-00004.safetensors",
23
+ "model.layers.layers.0.pre_mlp_norm.weight": "model-00001-of-00004.safetensors",
24
+ "model.layers.layers.1.mixer.k_weight": "model-00001-of-00004.safetensors",
25
+ "model.layers.layers.1.mixer.o_proj.weight": "model-00001-of-00004.safetensors",
26
+ "model.layers.layers.1.mixer.q_weight": "model-00001-of-00004.safetensors",
27
+ "model.layers.layers.1.mixer.qkv_proj.weight": "model-00001-of-00004.safetensors",
28
+ "model.layers.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
29
+ "model.layers.layers.1.mlp.gate_up_proj.weight": "model-00001-of-00004.safetensors",
30
+ "model.layers.layers.1.post_mixer_norm.weight": "model-00001-of-00004.safetensors",
31
+ "model.layers.layers.1.post_mlp_norm.weight": "model-00001-of-00004.safetensors",
32
+ "model.layers.layers.1.pre_mixer_norm.weight": "model-00001-of-00004.safetensors",
33
+ "model.layers.layers.1.pre_mlp_norm.weight": "model-00001-of-00004.safetensors",
34
+ "model.layers.layers.10.mixer.A_log": "model-00002-of-00004.safetensors",
35
+ "model.layers.layers.10.mixer.B_norm_weight": "model-00002-of-00004.safetensors",
36
+ "model.layers.layers.10.mixer.C_norm_weight": "model-00002-of-00004.safetensors",
37
+ "model.layers.layers.10.mixer.D": "model-00002-of-00004.safetensors",
38
+ "model.layers.layers.10.mixer.bcdt_proj.weight": "model-00002-of-00004.safetensors",
39
+ "model.layers.layers.10.mixer.conv1d.weight": "model-00002-of-00004.safetensors",
40
+ "model.layers.layers.10.mixer.dt_bias": "model-00002-of-00004.safetensors",
41
+ "model.layers.layers.10.mixer.dt_norm_weight": "model-00002-of-00004.safetensors",
42
+ "model.layers.layers.10.mixer.dt_proj.weight": "model-00002-of-00004.safetensors",
43
+ "model.layers.layers.10.mixer.in_proj.weight": "model-00002-of-00004.safetensors",
44
+ "model.layers.layers.10.mixer.out_proj.weight": "model-00002-of-00004.safetensors",
45
+ "model.layers.layers.10.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
46
+ "model.layers.layers.10.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
47
+ "model.layers.layers.10.post_mixer_norm.weight": "model-00002-of-00004.safetensors",
48
+ "model.layers.layers.10.post_mlp_norm.weight": "model-00002-of-00004.safetensors",
49
+ "model.layers.layers.10.pre_mixer_norm.weight": "model-00002-of-00004.safetensors",
50
+ "model.layers.layers.10.pre_mlp_norm.weight": "model-00002-of-00004.safetensors",
51
+ "model.layers.layers.11.mixer.k_weight": "model-00002-of-00004.safetensors",
52
+ "model.layers.layers.11.mixer.o_proj.weight": "model-00002-of-00004.safetensors",
53
+ "model.layers.layers.11.mixer.q_weight": "model-00002-of-00004.safetensors",
54
+ "model.layers.layers.11.mixer.qkv_proj.weight": "model-00002-of-00004.safetensors",
55
+ "model.layers.layers.11.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
56
+ "model.layers.layers.11.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
57
+ "model.layers.layers.11.post_mixer_norm.weight": "model-00002-of-00004.safetensors",
58
+ "model.layers.layers.11.post_mlp_norm.weight": "model-00002-of-00004.safetensors",
59
+ "model.layers.layers.11.pre_mixer_norm.weight": "model-00002-of-00004.safetensors",
60
+ "model.layers.layers.11.pre_mlp_norm.weight": "model-00002-of-00004.safetensors",
61
+ "model.layers.layers.12.mixer.A_log": "model-00002-of-00004.safetensors",
62
+ "model.layers.layers.12.mixer.B_norm_weight": "model-00002-of-00004.safetensors",
63
+ "model.layers.layers.12.mixer.C_norm_weight": "model-00002-of-00004.safetensors",
64
+ "model.layers.layers.12.mixer.D": "model-00002-of-00004.safetensors",
65
+ "model.layers.layers.12.mixer.bcdt_proj.weight": "model-00002-of-00004.safetensors",
66
+ "model.layers.layers.12.mixer.conv1d.weight": "model-00002-of-00004.safetensors",
67
+ "model.layers.layers.12.mixer.dt_bias": "model-00002-of-00004.safetensors",
68
+ "model.layers.layers.12.mixer.dt_norm_weight": "model-00002-of-00004.safetensors",
69
+ "model.layers.layers.12.mixer.dt_proj.weight": "model-00002-of-00004.safetensors",
70
+ "model.layers.layers.12.mixer.in_proj.weight": "model-00002-of-00004.safetensors",
71
+ "model.layers.layers.12.mixer.out_proj.weight": "model-00002-of-00004.safetensors",
72
+ "model.layers.layers.12.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
73
+ "model.layers.layers.12.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
74
+ "model.layers.layers.12.post_mixer_norm.weight": "model-00002-of-00004.safetensors",
75
+ "model.layers.layers.12.post_mlp_norm.weight": "model-00002-of-00004.safetensors",
76
+ "model.layers.layers.12.pre_mixer_norm.weight": "model-00002-of-00004.safetensors",
77
+ "model.layers.layers.12.pre_mlp_norm.weight": "model-00002-of-00004.safetensors",
78
+ "model.layers.layers.13.mixer.k_weight": "model-00002-of-00004.safetensors",
79
+ "model.layers.layers.13.mixer.o_proj.weight": "model-00002-of-00004.safetensors",
80
+ "model.layers.layers.13.mixer.q_weight": "model-00002-of-00004.safetensors",
81
+ "model.layers.layers.13.mixer.qkv_proj.weight": "model-00002-of-00004.safetensors",
82
+ "model.layers.layers.13.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
83
+ "model.layers.layers.13.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
84
+ "model.layers.layers.13.post_mixer_norm.weight": "model-00002-of-00004.safetensors",
85
+ "model.layers.layers.13.post_mlp_norm.weight": "model-00002-of-00004.safetensors",
86
+ "model.layers.layers.13.pre_mixer_norm.weight": "model-00002-of-00004.safetensors",
87
+ "model.layers.layers.13.pre_mlp_norm.weight": "model-00002-of-00004.safetensors",
88
+ "model.layers.layers.14.mixer.A_log": "model-00002-of-00004.safetensors",
89
+ "model.layers.layers.14.mixer.B_norm_weight": "model-00002-of-00004.safetensors",
90
+ "model.layers.layers.14.mixer.C_norm_weight": "model-00002-of-00004.safetensors",
91
+ "model.layers.layers.14.mixer.D": "model-00002-of-00004.safetensors",
92
+ "model.layers.layers.14.mixer.bcdt_proj.weight": "model-00002-of-00004.safetensors",
93
+ "model.layers.layers.14.mixer.conv1d.weight": "model-00002-of-00004.safetensors",
94
+ "model.layers.layers.14.mixer.dt_bias": "model-00002-of-00004.safetensors",
95
+ "model.layers.layers.14.mixer.dt_norm_weight": "model-00002-of-00004.safetensors",
96
+ "model.layers.layers.14.mixer.dt_proj.weight": "model-00002-of-00004.safetensors",
97
+ "model.layers.layers.14.mixer.in_proj.weight": "model-00002-of-00004.safetensors",
98
+ "model.layers.layers.14.mixer.out_proj.weight": "model-00002-of-00004.safetensors",
99
+ "model.layers.layers.14.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
100
+ "model.layers.layers.14.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
101
+ "model.layers.layers.14.post_mixer_norm.weight": "model-00002-of-00004.safetensors",
102
+ "model.layers.layers.14.post_mlp_norm.weight": "model-00002-of-00004.safetensors",
103
+ "model.layers.layers.14.pre_mixer_norm.weight": "model-00002-of-00004.safetensors",
104
+ "model.layers.layers.14.pre_mlp_norm.weight": "model-00002-of-00004.safetensors",
105
+ "model.layers.layers.15.mixer.k_weight": "model-00002-of-00004.safetensors",
106
+ "model.layers.layers.15.mixer.o_proj.weight": "model-00002-of-00004.safetensors",
107
+ "model.layers.layers.15.mixer.q_weight": "model-00002-of-00004.safetensors",
108
+ "model.layers.layers.15.mixer.qkv_proj.weight": "model-00002-of-00004.safetensors",
109
+ "model.layers.layers.15.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
110
+ "model.layers.layers.15.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
111
+ "model.layers.layers.15.post_mixer_norm.weight": "model-00002-of-00004.safetensors",
112
+ "model.layers.layers.15.post_mlp_norm.weight": "model-00002-of-00004.safetensors",
113
+ "model.layers.layers.15.pre_mixer_norm.weight": "model-00002-of-00004.safetensors",
114
+ "model.layers.layers.15.pre_mlp_norm.weight": "model-00002-of-00004.safetensors",
115
+ "model.layers.layers.16.mixer.A_log": "model-00002-of-00004.safetensors",
116
+ "model.layers.layers.16.mixer.B_norm_weight": "model-00002-of-00004.safetensors",
117
+ "model.layers.layers.16.mixer.C_norm_weight": "model-00002-of-00004.safetensors",
118
+ "model.layers.layers.16.mixer.D": "model-00002-of-00004.safetensors",
119
+ "model.layers.layers.16.mixer.bcdt_proj.weight": "model-00002-of-00004.safetensors",
120
+ "model.layers.layers.16.mixer.conv1d.weight": "model-00002-of-00004.safetensors",
121
+ "model.layers.layers.16.mixer.dt_bias": "model-00002-of-00004.safetensors",
122
+ "model.layers.layers.16.mixer.dt_norm_weight": "model-00002-of-00004.safetensors",
123
+ "model.layers.layers.16.mixer.dt_proj.weight": "model-00002-of-00004.safetensors",
124
+ "model.layers.layers.16.mixer.in_proj.weight": "model-00002-of-00004.safetensors",
125
+ "model.layers.layers.16.mixer.out_proj.weight": "model-00002-of-00004.safetensors",
126
+ "model.layers.layers.16.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
127
+ "model.layers.layers.16.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
128
+ "model.layers.layers.16.post_mixer_norm.weight": "model-00003-of-00004.safetensors",
129
+ "model.layers.layers.16.post_mlp_norm.weight": "model-00003-of-00004.safetensors",
130
+ "model.layers.layers.16.pre_mixer_norm.weight": "model-00003-of-00004.safetensors",
131
+ "model.layers.layers.16.pre_mlp_norm.weight": "model-00003-of-00004.safetensors",
132
+ "model.layers.layers.17.mixer.k_weight": "model-00003-of-00004.safetensors",
133
+ "model.layers.layers.17.mixer.o_proj.weight": "model-00003-of-00004.safetensors",
134
+ "model.layers.layers.17.mixer.q_weight": "model-00003-of-00004.safetensors",
135
+ "model.layers.layers.17.mixer.qkv_proj.weight": "model-00003-of-00004.safetensors",
136
+ "model.layers.layers.17.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
137
+ "model.layers.layers.17.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
138
+ "model.layers.layers.17.post_mixer_norm.weight": "model-00003-of-00004.safetensors",
139
+ "model.layers.layers.17.post_mlp_norm.weight": "model-00003-of-00004.safetensors",
140
+ "model.layers.layers.17.pre_mixer_norm.weight": "model-00003-of-00004.safetensors",
141
+ "model.layers.layers.17.pre_mlp_norm.weight": "model-00003-of-00004.safetensors",
142
+ "model.layers.layers.18.mixer.A_log": "model-00003-of-00004.safetensors",
143
+ "model.layers.layers.18.mixer.B_norm_weight": "model-00003-of-00004.safetensors",
144
+ "model.layers.layers.18.mixer.C_norm_weight": "model-00003-of-00004.safetensors",
145
+ "model.layers.layers.18.mixer.D": "model-00003-of-00004.safetensors",
146
+ "model.layers.layers.18.mixer.bcdt_proj.weight": "model-00003-of-00004.safetensors",
147
+ "model.layers.layers.18.mixer.conv1d.weight": "model-00003-of-00004.safetensors",
148
+ "model.layers.layers.18.mixer.dt_bias": "model-00003-of-00004.safetensors",
149
+ "model.layers.layers.18.mixer.dt_norm_weight": "model-00003-of-00004.safetensors",
150
+ "model.layers.layers.18.mixer.dt_proj.weight": "model-00003-of-00004.safetensors",
151
+ "model.layers.layers.18.mixer.in_proj.weight": "model-00003-of-00004.safetensors",
152
+ "model.layers.layers.18.mixer.out_proj.weight": "model-00003-of-00004.safetensors",
153
+ "model.layers.layers.18.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
154
+ "model.layers.layers.18.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
155
+ "model.layers.layers.18.post_mixer_norm.weight": "model-00003-of-00004.safetensors",
156
+ "model.layers.layers.18.post_mlp_norm.weight": "model-00003-of-00004.safetensors",
157
+ "model.layers.layers.18.pre_mixer_norm.weight": "model-00003-of-00004.safetensors",
158
+ "model.layers.layers.18.pre_mlp_norm.weight": "model-00003-of-00004.safetensors",
159
+ "model.layers.layers.19.mixer.k_weight": "model-00003-of-00004.safetensors",
160
+ "model.layers.layers.19.mixer.o_proj.weight": "model-00003-of-00004.safetensors",
161
+ "model.layers.layers.19.mixer.q_weight": "model-00003-of-00004.safetensors",
162
+ "model.layers.layers.19.mixer.qkv_proj.weight": "model-00003-of-00004.safetensors",
163
+ "model.layers.layers.19.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
164
+ "model.layers.layers.19.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
165
+ "model.layers.layers.19.post_mixer_norm.weight": "model-00003-of-00004.safetensors",
166
+ "model.layers.layers.19.post_mlp_norm.weight": "model-00003-of-00004.safetensors",
167
+ "model.layers.layers.19.pre_mixer_norm.weight": "model-00003-of-00004.safetensors",
168
+ "model.layers.layers.19.pre_mlp_norm.weight": "model-00003-of-00004.safetensors",
169
+ "model.layers.layers.2.mixer.A_log": "model-00001-of-00004.safetensors",
170
+ "model.layers.layers.2.mixer.B_norm_weight": "model-00001-of-00004.safetensors",
171
+ "model.layers.layers.2.mixer.C_norm_weight": "model-00001-of-00004.safetensors",
172
+ "model.layers.layers.2.mixer.D": "model-00001-of-00004.safetensors",
173
+ "model.layers.layers.2.mixer.bcdt_proj.weight": "model-00001-of-00004.safetensors",
174
+ "model.layers.layers.2.mixer.conv1d.weight": "model-00001-of-00004.safetensors",
175
+ "model.layers.layers.2.mixer.dt_bias": "model-00001-of-00004.safetensors",
176
+ "model.layers.layers.2.mixer.dt_norm_weight": "model-00001-of-00004.safetensors",
177
+ "model.layers.layers.2.mixer.dt_proj.weight": "model-00001-of-00004.safetensors",
178
+ "model.layers.layers.2.mixer.in_proj.weight": "model-00001-of-00004.safetensors",
179
+ "model.layers.layers.2.mixer.out_proj.weight": "model-00001-of-00004.safetensors",
180
+ "model.layers.layers.2.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
181
+ "model.layers.layers.2.mlp.gate_up_proj.weight": "model-00001-of-00004.safetensors",
182
+ "model.layers.layers.2.post_mixer_norm.weight": "model-00001-of-00004.safetensors",
183
+ "model.layers.layers.2.post_mlp_norm.weight": "model-00001-of-00004.safetensors",
184
+ "model.layers.layers.2.pre_mixer_norm.weight": "model-00001-of-00004.safetensors",
185
+ "model.layers.layers.2.pre_mlp_norm.weight": "model-00001-of-00004.safetensors",
186
+ "model.layers.layers.20.mixer.A_log": "model-00003-of-00004.safetensors",
187
+ "model.layers.layers.20.mixer.B_norm_weight": "model-00003-of-00004.safetensors",
188
+ "model.layers.layers.20.mixer.C_norm_weight": "model-00003-of-00004.safetensors",
189
+ "model.layers.layers.20.mixer.D": "model-00003-of-00004.safetensors",
190
+ "model.layers.layers.20.mixer.bcdt_proj.weight": "model-00003-of-00004.safetensors",
191
+ "model.layers.layers.20.mixer.conv1d.weight": "model-00003-of-00004.safetensors",
192
+ "model.layers.layers.20.mixer.dt_bias": "model-00003-of-00004.safetensors",
193
+ "model.layers.layers.20.mixer.dt_norm_weight": "model-00003-of-00004.safetensors",
194
+ "model.layers.layers.20.mixer.dt_proj.weight": "model-00003-of-00004.safetensors",
195
+ "model.layers.layers.20.mixer.in_proj.weight": "model-00003-of-00004.safetensors",
196
+ "model.layers.layers.20.mixer.out_proj.weight": "model-00003-of-00004.safetensors",
197
+ "model.layers.layers.20.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
198
+ "model.layers.layers.20.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
199
+ "model.layers.layers.20.post_mixer_norm.weight": "model-00003-of-00004.safetensors",
200
+ "model.layers.layers.20.post_mlp_norm.weight": "model-00003-of-00004.safetensors",
201
+ "model.layers.layers.20.pre_mixer_norm.weight": "model-00003-of-00004.safetensors",
202
+ "model.layers.layers.20.pre_mlp_norm.weight": "model-00003-of-00004.safetensors",
203
+ "model.layers.layers.21.mixer.k_weight": "model-00003-of-00004.safetensors",
204
+ "model.layers.layers.21.mixer.o_proj.weight": "model-00003-of-00004.safetensors",
205
+ "model.layers.layers.21.mixer.q_weight": "model-00003-of-00004.safetensors",
206
+ "model.layers.layers.21.mixer.qkv_proj.weight": "model-00003-of-00004.safetensors",
207
+ "model.layers.layers.21.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
208
+ "model.layers.layers.21.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
209
+ "model.layers.layers.21.post_mixer_norm.weight": "model-00003-of-00004.safetensors",
210
+ "model.layers.layers.21.post_mlp_norm.weight": "model-00003-of-00004.safetensors",
211
+ "model.layers.layers.21.pre_mixer_norm.weight": "model-00003-of-00004.safetensors",
212
+ "model.layers.layers.21.pre_mlp_norm.weight": "model-00003-of-00004.safetensors",
213
+ "model.layers.layers.22.mixer.A_log": "model-00003-of-00004.safetensors",
214
+ "model.layers.layers.22.mixer.B_norm_weight": "model-00003-of-00004.safetensors",
215
+ "model.layers.layers.22.mixer.C_norm_weight": "model-00003-of-00004.safetensors",
216
+ "model.layers.layers.22.mixer.D": "model-00003-of-00004.safetensors",
217
+ "model.layers.layers.22.mixer.bcdt_proj.weight": "model-00003-of-00004.safetensors",
218
+ "model.layers.layers.22.mixer.conv1d.weight": "model-00003-of-00004.safetensors",
219
+ "model.layers.layers.22.mixer.dt_bias": "model-00003-of-00004.safetensors",
220
+ "model.layers.layers.22.mixer.dt_norm_weight": "model-00003-of-00004.safetensors",
221
+ "model.layers.layers.22.mixer.dt_proj.weight": "model-00003-of-00004.safetensors",
222
+ "model.layers.layers.22.mixer.in_proj.weight": "model-00003-of-00004.safetensors",
223
+ "model.layers.layers.22.mixer.out_proj.weight": "model-00003-of-00004.safetensors",
224
+ "model.layers.layers.22.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
225
+ "model.layers.layers.22.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
226
+ "model.layers.layers.22.post_mixer_norm.weight": "model-00003-of-00004.safetensors",
227
+ "model.layers.layers.22.post_mlp_norm.weight": "model-00003-of-00004.safetensors",
228
+ "model.layers.layers.22.pre_mixer_norm.weight": "model-00003-of-00004.safetensors",
229
+ "model.layers.layers.22.pre_mlp_norm.weight": "model-00003-of-00004.safetensors",
230
+ "model.layers.layers.23.mixer.k_weight": "model-00003-of-00004.safetensors",
231
+ "model.layers.layers.23.mixer.o_proj.weight": "model-00003-of-00004.safetensors",
232
+ "model.layers.layers.23.mixer.q_weight": "model-00003-of-00004.safetensors",
233
+ "model.layers.layers.23.mixer.qkv_proj.weight": "model-00003-of-00004.safetensors",
234
+ "model.layers.layers.23.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
235
+ "model.layers.layers.23.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
236
+ "model.layers.layers.23.post_mixer_norm.weight": "model-00003-of-00004.safetensors",
237
+ "model.layers.layers.23.post_mlp_norm.weight": "model-00003-of-00004.safetensors",
238
+ "model.layers.layers.23.pre_mixer_norm.weight": "model-00003-of-00004.safetensors",
239
+ "model.layers.layers.23.pre_mlp_norm.weight": "model-00003-of-00004.safetensors",
240
+ "model.layers.layers.24.mixer.A_log": "model-00003-of-00004.safetensors",
241
+ "model.layers.layers.24.mixer.B_norm_weight": "model-00003-of-00004.safetensors",
242
+ "model.layers.layers.24.mixer.C_norm_weight": "model-00003-of-00004.safetensors",
243
+ "model.layers.layers.24.mixer.D": "model-00003-of-00004.safetensors",
244
+ "model.layers.layers.24.mixer.bcdt_proj.weight": "model-00003-of-00004.safetensors",
245
+ "model.layers.layers.24.mixer.conv1d.weight": "model-00003-of-00004.safetensors",
246
+ "model.layers.layers.24.mixer.dt_bias": "model-00003-of-00004.safetensors",
247
+ "model.layers.layers.24.mixer.dt_norm_weight": "model-00003-of-00004.safetensors",
248
+ "model.layers.layers.24.mixer.dt_proj.weight": "model-00003-of-00004.safetensors",
249
+ "model.layers.layers.24.mixer.in_proj.weight": "model-00003-of-00004.safetensors",
250
+ "model.layers.layers.24.mixer.out_proj.weight": "model-00003-of-00004.safetensors",
251
+ "model.layers.layers.24.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
252
+ "model.layers.layers.24.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
253
+ "model.layers.layers.24.post_mixer_norm.weight": "model-00003-of-00004.safetensors",
254
+ "model.layers.layers.24.post_mlp_norm.weight": "model-00003-of-00004.safetensors",
255
+ "model.layers.layers.24.pre_mixer_norm.weight": "model-00003-of-00004.safetensors",
256
+ "model.layers.layers.24.pre_mlp_norm.weight": "model-00003-of-00004.safetensors",
257
+ "model.layers.layers.25.mixer.k_weight": "model-00003-of-00004.safetensors",
258
+ "model.layers.layers.25.mixer.o_proj.weight": "model-00003-of-00004.safetensors",
259
+ "model.layers.layers.25.mixer.q_weight": "model-00003-of-00004.safetensors",
260
+ "model.layers.layers.25.mixer.qkv_proj.weight": "model-00003-of-00004.safetensors",
261
+ "model.layers.layers.25.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
262
+ "model.layers.layers.25.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
263
+ "model.layers.layers.25.post_mixer_norm.weight": "model-00004-of-00004.safetensors",
264
+ "model.layers.layers.25.post_mlp_norm.weight": "model-00004-of-00004.safetensors",
265
+ "model.layers.layers.25.pre_mixer_norm.weight": "model-00004-of-00004.safetensors",
266
+ "model.layers.layers.25.pre_mlp_norm.weight": "model-00004-of-00004.safetensors",
267
+ "model.layers.layers.26.mixer.A_log": "model-00004-of-00004.safetensors",
268
+ "model.layers.layers.26.mixer.B_norm_weight": "model-00004-of-00004.safetensors",
269
+ "model.layers.layers.26.mixer.C_norm_weight": "model-00004-of-00004.safetensors",
270
+ "model.layers.layers.26.mixer.D": "model-00004-of-00004.safetensors",
271
+ "model.layers.layers.26.mixer.bcdt_proj.weight": "model-00004-of-00004.safetensors",
272
+ "model.layers.layers.26.mixer.conv1d.weight": "model-00004-of-00004.safetensors",
273
+ "model.layers.layers.26.mixer.dt_bias": "model-00004-of-00004.safetensors",
274
+ "model.layers.layers.26.mixer.dt_norm_weight": "model-00004-of-00004.safetensors",
275
+ "model.layers.layers.26.mixer.dt_proj.weight": "model-00004-of-00004.safetensors",
276
+ "model.layers.layers.26.mixer.in_proj.weight": "model-00004-of-00004.safetensors",
277
+ "model.layers.layers.26.mixer.out_proj.weight": "model-00004-of-00004.safetensors",
278
+ "model.layers.layers.26.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
279
+ "model.layers.layers.26.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
280
+ "model.layers.layers.26.post_mixer_norm.weight": "model-00004-of-00004.safetensors",
281
+ "model.layers.layers.26.post_mlp_norm.weight": "model-00004-of-00004.safetensors",
282
+ "model.layers.layers.26.pre_mixer_norm.weight": "model-00004-of-00004.safetensors",
283
+ "model.layers.layers.26.pre_mlp_norm.weight": "model-00004-of-00004.safetensors",
284
+ "model.layers.layers.27.mixer.k_weight": "model-00004-of-00004.safetensors",
285
+ "model.layers.layers.27.mixer.o_proj.weight": "model-00004-of-00004.safetensors",
286
+ "model.layers.layers.27.mixer.q_weight": "model-00004-of-00004.safetensors",
287
+ "model.layers.layers.27.mixer.qkv_proj.weight": "model-00004-of-00004.safetensors",
288
+ "model.layers.layers.27.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
289
+ "model.layers.layers.27.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
290
+ "model.layers.layers.27.post_mixer_norm.weight": "model-00004-of-00004.safetensors",
291
+ "model.layers.layers.27.post_mlp_norm.weight": "model-00004-of-00004.safetensors",
292
+ "model.layers.layers.27.pre_mixer_norm.weight": "model-00004-of-00004.safetensors",
293
+ "model.layers.layers.27.pre_mlp_norm.weight": "model-00004-of-00004.safetensors",
294
+ "model.layers.layers.28.mixer.A_log": "model-00004-of-00004.safetensors",
295
+ "model.layers.layers.28.mixer.B_norm_weight": "model-00004-of-00004.safetensors",
296
+ "model.layers.layers.28.mixer.C_norm_weight": "model-00004-of-00004.safetensors",
297
+ "model.layers.layers.28.mixer.D": "model-00004-of-00004.safetensors",
298
+ "model.layers.layers.28.mixer.bcdt_proj.weight": "model-00004-of-00004.safetensors",
299
+ "model.layers.layers.28.mixer.conv1d.weight": "model-00004-of-00004.safetensors",
300
+ "model.layers.layers.28.mixer.dt_bias": "model-00004-of-00004.safetensors",
301
+ "model.layers.layers.28.mixer.dt_norm_weight": "model-00004-of-00004.safetensors",
302
+ "model.layers.layers.28.mixer.dt_proj.weight": "model-00004-of-00004.safetensors",
303
+ "model.layers.layers.28.mixer.in_proj.weight": "model-00004-of-00004.safetensors",
304
+ "model.layers.layers.28.mixer.out_proj.weight": "model-00004-of-00004.safetensors",
305
+ "model.layers.layers.28.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
306
+ "model.layers.layers.28.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
307
+ "model.layers.layers.28.post_mixer_norm.weight": "model-00004-of-00004.safetensors",
308
+ "model.layers.layers.28.post_mlp_norm.weight": "model-00004-of-00004.safetensors",
309
+ "model.layers.layers.28.pre_mixer_norm.weight": "model-00004-of-00004.safetensors",
310
+ "model.layers.layers.28.pre_mlp_norm.weight": "model-00004-of-00004.safetensors",
311
+ "model.layers.layers.29.mixer.k_weight": "model-00004-of-00004.safetensors",
312
+ "model.layers.layers.29.mixer.o_proj.weight": "model-00004-of-00004.safetensors",
313
+ "model.layers.layers.29.mixer.q_weight": "model-00004-of-00004.safetensors",
314
+ "model.layers.layers.29.mixer.qkv_proj.weight": "model-00004-of-00004.safetensors",
315
+ "model.layers.layers.29.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
316
+ "model.layers.layers.29.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
317
+ "model.layers.layers.29.post_mixer_norm.weight": "model-00004-of-00004.safetensors",
318
+ "model.layers.layers.29.post_mlp_norm.weight": "model-00004-of-00004.safetensors",
319
+ "model.layers.layers.29.pre_mixer_norm.weight": "model-00004-of-00004.safetensors",
320
+ "model.layers.layers.29.pre_mlp_norm.weight": "model-00004-of-00004.safetensors",
321
+ "model.layers.layers.3.mixer.k_weight": "model-00001-of-00004.safetensors",
322
+ "model.layers.layers.3.mixer.o_proj.weight": "model-00001-of-00004.safetensors",
323
+ "model.layers.layers.3.mixer.q_weight": "model-00001-of-00004.safetensors",
324
+ "model.layers.layers.3.mixer.qkv_proj.weight": "model-00001-of-00004.safetensors",
325
+ "model.layers.layers.3.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
326
+ "model.layers.layers.3.mlp.gate_up_proj.weight": "model-00001-of-00004.safetensors",
327
+ "model.layers.layers.3.post_mixer_norm.weight": "model-00001-of-00004.safetensors",
328
+ "model.layers.layers.3.post_mlp_norm.weight": "model-00001-of-00004.safetensors",
329
+ "model.layers.layers.3.pre_mixer_norm.weight": "model-00001-of-00004.safetensors",
330
+ "model.layers.layers.3.pre_mlp_norm.weight": "model-00001-of-00004.safetensors",
331
+ "model.layers.layers.30.mixer.A_log": "model-00004-of-00004.safetensors",
332
+ "model.layers.layers.30.mixer.B_norm_weight": "model-00004-of-00004.safetensors",
333
+ "model.layers.layers.30.mixer.C_norm_weight": "model-00004-of-00004.safetensors",
334
+ "model.layers.layers.30.mixer.D": "model-00004-of-00004.safetensors",
335
+ "model.layers.layers.30.mixer.bcdt_proj.weight": "model-00004-of-00004.safetensors",
336
+ "model.layers.layers.30.mixer.conv1d.weight": "model-00004-of-00004.safetensors",
337
+ "model.layers.layers.30.mixer.dt_bias": "model-00004-of-00004.safetensors",
338
+ "model.layers.layers.30.mixer.dt_norm_weight": "model-00004-of-00004.safetensors",
339
+ "model.layers.layers.30.mixer.dt_proj.weight": "model-00004-of-00004.safetensors",
340
+ "model.layers.layers.30.mixer.in_proj.weight": "model-00004-of-00004.safetensors",
341
+ "model.layers.layers.30.mixer.out_proj.weight": "model-00004-of-00004.safetensors",
342
+ "model.layers.layers.30.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
343
+ "model.layers.layers.30.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
344
+ "model.layers.layers.30.post_mixer_norm.weight": "model-00004-of-00004.safetensors",
345
+ "model.layers.layers.30.post_mlp_norm.weight": "model-00004-of-00004.safetensors",
346
+ "model.layers.layers.30.pre_mixer_norm.weight": "model-00004-of-00004.safetensors",
347
+ "model.layers.layers.30.pre_mlp_norm.weight": "model-00004-of-00004.safetensors",
348
+ "model.layers.layers.31.mixer.k_weight": "model-00004-of-00004.safetensors",
349
+ "model.layers.layers.31.mixer.o_proj.weight": "model-00004-of-00004.safetensors",
350
+ "model.layers.layers.31.mixer.q_weight": "model-00004-of-00004.safetensors",
351
+ "model.layers.layers.31.mixer.qkv_proj.weight": "model-00004-of-00004.safetensors",
352
+ "model.layers.layers.31.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
353
+ "model.layers.layers.31.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
354
+ "model.layers.layers.31.post_mixer_norm.weight": "model-00004-of-00004.safetensors",
355
+ "model.layers.layers.31.post_mlp_norm.weight": "model-00004-of-00004.safetensors",
356
+ "model.layers.layers.31.pre_mixer_norm.weight": "model-00004-of-00004.safetensors",
357
+ "model.layers.layers.31.pre_mlp_norm.weight": "model-00004-of-00004.safetensors",
358
+ "model.layers.layers.4.mixer.A_log": "model-00001-of-00004.safetensors",
359
+ "model.layers.layers.4.mixer.B_norm_weight": "model-00001-of-00004.safetensors",
360
+ "model.layers.layers.4.mixer.C_norm_weight": "model-00001-of-00004.safetensors",
361
+ "model.layers.layers.4.mixer.D": "model-00001-of-00004.safetensors",
362
+ "model.layers.layers.4.mixer.bcdt_proj.weight": "model-00001-of-00004.safetensors",
363
+ "model.layers.layers.4.mixer.conv1d.weight": "model-00001-of-00004.safetensors",
364
+ "model.layers.layers.4.mixer.dt_bias": "model-00001-of-00004.safetensors",
365
+ "model.layers.layers.4.mixer.dt_norm_weight": "model-00001-of-00004.safetensors",
366
+ "model.layers.layers.4.mixer.dt_proj.weight": "model-00001-of-00004.safetensors",
367
+ "model.layers.layers.4.mixer.in_proj.weight": "model-00001-of-00004.safetensors",
368
+ "model.layers.layers.4.mixer.out_proj.weight": "model-00001-of-00004.safetensors",
369
+ "model.layers.layers.4.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
370
+ "model.layers.layers.4.mlp.gate_up_proj.weight": "model-00001-of-00004.safetensors",
371
+ "model.layers.layers.4.post_mixer_norm.weight": "model-00001-of-00004.safetensors",
372
+ "model.layers.layers.4.post_mlp_norm.weight": "model-00001-of-00004.safetensors",
373
+ "model.layers.layers.4.pre_mixer_norm.weight": "model-00001-of-00004.safetensors",
374
+ "model.layers.layers.4.pre_mlp_norm.weight": "model-00001-of-00004.safetensors",
375
+ "model.layers.layers.5.mixer.k_weight": "model-00001-of-00004.safetensors",
376
+ "model.layers.layers.5.mixer.o_proj.weight": "model-00001-of-00004.safetensors",
377
+ "model.layers.layers.5.mixer.q_weight": "model-00001-of-00004.safetensors",
378
+ "model.layers.layers.5.mixer.qkv_proj.weight": "model-00001-of-00004.safetensors",
379
+ "model.layers.layers.5.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
380
+ "model.layers.layers.5.mlp.gate_up_proj.weight": "model-00001-of-00004.safetensors",
381
+ "model.layers.layers.5.post_mixer_norm.weight": "model-00001-of-00004.safetensors",
382
+ "model.layers.layers.5.post_mlp_norm.weight": "model-00001-of-00004.safetensors",
383
+ "model.layers.layers.5.pre_mixer_norm.weight": "model-00001-of-00004.safetensors",
384
+ "model.layers.layers.5.pre_mlp_norm.weight": "model-00001-of-00004.safetensors",
385
+ "model.layers.layers.6.mixer.A_log": "model-00001-of-00004.safetensors",
386
+ "model.layers.layers.6.mixer.B_norm_weight": "model-00001-of-00004.safetensors",
387
+ "model.layers.layers.6.mixer.C_norm_weight": "model-00001-of-00004.safetensors",
388
+ "model.layers.layers.6.mixer.D": "model-00001-of-00004.safetensors",
389
+ "model.layers.layers.6.mixer.bcdt_proj.weight": "model-00001-of-00004.safetensors",
390
+ "model.layers.layers.6.mixer.conv1d.weight": "model-00001-of-00004.safetensors",
391
+ "model.layers.layers.6.mixer.dt_bias": "model-00001-of-00004.safetensors",
392
+ "model.layers.layers.6.mixer.dt_norm_weight": "model-00001-of-00004.safetensors",
393
+ "model.layers.layers.6.mixer.dt_proj.weight": "model-00001-of-00004.safetensors",
394
+ "model.layers.layers.6.mixer.in_proj.weight": "model-00001-of-00004.safetensors",
395
+ "model.layers.layers.6.mixer.out_proj.weight": "model-00001-of-00004.safetensors",
396
+ "model.layers.layers.6.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
397
+ "model.layers.layers.6.mlp.gate_up_proj.weight": "model-00001-of-00004.safetensors",
398
+ "model.layers.layers.6.post_mixer_norm.weight": "model-00001-of-00004.safetensors",
399
+ "model.layers.layers.6.post_mlp_norm.weight": "model-00001-of-00004.safetensors",
400
+ "model.layers.layers.6.pre_mixer_norm.weight": "model-00001-of-00004.safetensors",
401
+ "model.layers.layers.6.pre_mlp_norm.weight": "model-00001-of-00004.safetensors",
402
+ "model.layers.layers.7.mixer.k_weight": "model-00001-of-00004.safetensors",
403
+ "model.layers.layers.7.mixer.o_proj.weight": "model-00001-of-00004.safetensors",
404
+ "model.layers.layers.7.mixer.q_weight": "model-00001-of-00004.safetensors",
405
+ "model.layers.layers.7.mixer.qkv_proj.weight": "model-00001-of-00004.safetensors",
406
+ "model.layers.layers.7.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
407
+ "model.layers.layers.7.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
408
+ "model.layers.layers.7.post_mixer_norm.weight": "model-00002-of-00004.safetensors",
409
+ "model.layers.layers.7.post_mlp_norm.weight": "model-00002-of-00004.safetensors",
410
+ "model.layers.layers.7.pre_mixer_norm.weight": "model-00002-of-00004.safetensors",
411
+ "model.layers.layers.7.pre_mlp_norm.weight": "model-00002-of-00004.safetensors",
412
+ "model.layers.layers.8.mixer.A_log": "model-00002-of-00004.safetensors",
413
+ "model.layers.layers.8.mixer.B_norm_weight": "model-00002-of-00004.safetensors",
414
+ "model.layers.layers.8.mixer.C_norm_weight": "model-00002-of-00004.safetensors",
415
+ "model.layers.layers.8.mixer.D": "model-00002-of-00004.safetensors",
416
+ "model.layers.layers.8.mixer.bcdt_proj.weight": "model-00002-of-00004.safetensors",
417
+ "model.layers.layers.8.mixer.conv1d.weight": "model-00002-of-00004.safetensors",
418
+ "model.layers.layers.8.mixer.dt_bias": "model-00002-of-00004.safetensors",
419
+ "model.layers.layers.8.mixer.dt_norm_weight": "model-00002-of-00004.safetensors",
420
+ "model.layers.layers.8.mixer.dt_proj.weight": "model-00002-of-00004.safetensors",
421
+ "model.layers.layers.8.mixer.in_proj.weight": "model-00002-of-00004.safetensors",
422
+ "model.layers.layers.8.mixer.out_proj.weight": "model-00002-of-00004.safetensors",
423
+ "model.layers.layers.8.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
424
+ "model.layers.layers.8.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
425
+ "model.layers.layers.8.post_mixer_norm.weight": "model-00002-of-00004.safetensors",
426
+ "model.layers.layers.8.post_mlp_norm.weight": "model-00002-of-00004.safetensors",
427
+ "model.layers.layers.8.pre_mixer_norm.weight": "model-00002-of-00004.safetensors",
428
+ "model.layers.layers.8.pre_mlp_norm.weight": "model-00002-of-00004.safetensors",
429
+ "model.layers.layers.9.mixer.k_weight": "model-00002-of-00004.safetensors",
430
+ "model.layers.layers.9.mixer.o_proj.weight": "model-00002-of-00004.safetensors",
431
+ "model.layers.layers.9.mixer.q_weight": "model-00002-of-00004.safetensors",
432
+ "model.layers.layers.9.mixer.qkv_proj.weight": "model-00002-of-00004.safetensors",
433
+ "model.layers.layers.9.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
434
+ "model.layers.layers.9.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
435
+ "model.layers.layers.9.post_mixer_norm.weight": "model-00002-of-00004.safetensors",
436
+ "model.layers.layers.9.post_mlp_norm.weight": "model-00002-of-00004.safetensors",
437
+ "model.layers.layers.9.pre_mixer_norm.weight": "model-00002-of-00004.safetensors",
438
+ "model.layers.layers.9.pre_mlp_norm.weight": "model-00002-of-00004.safetensors",
439
+ "model.norm.weight": "model-00004-of-00004.safetensors"
440
+ }
441
+ }
modeling_plamo.py ADDED
@@ -0,0 +1,1707 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+ import math
3
+ import warnings
4
+ from typing import Any, Dict, List, Literal, NamedTuple, Optional, Tuple, Union
5
+
6
+ try:
7
+ # It is difficult to install mamba_ssm in login node because
8
+ # it requires GPU for installation
9
+ import mamba_ssm
10
+ except ModuleNotFoundError:
11
+ warnings.warn("mamba_ssm could not be imported", stacklevel=2)
12
+ try:
13
+ # It is difficult to install causal_conv1d in login node because
14
+ # it requires GPU for installation
15
+ import causal_conv1d.causal_conv1d_interface as causal_conv1d
16
+ except ModuleNotFoundError:
17
+ warnings.warn("causal_conv1d could not be imported", stacklevel=2)
18
+ import torch
19
+ from torch import nn
20
+ from torch.nn import functional as F
21
+ from transformers import PretrainedConfig, PreTrainedModel
22
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
23
+
24
+
25
+ def _is_first_token(mask: torch.Tensor) -> torch.Tensor:
26
+ assert mask.dtype == torch.bool
27
+ B, Nh, q_len, kv_len = mask.shape
28
+ mask = mask[:, :, :, -q_len:]
29
+ cont = q_len != kv_len
30
+ v = False if cont else True
31
+ out = torch.logical_not(torch.diagonal(mask, offset=-1, dim1=-2, dim2=-1).bool())
32
+ out = torch.cat(
33
+ [
34
+ torch.full(size=(B, Nh, 1), dtype=torch.bool, device=out.device, fill_value=v),
35
+ out,
36
+ ],
37
+ dim=-1,
38
+ )
39
+ return out
40
+
41
+
42
+ def _swiglu(h: torch.Tensor) -> torch.Tensor:
43
+ h0, h1 = h.chunk(2, dim=-1)
44
+ return torch.nn.functional.silu(h0) * h1
45
+
46
+
47
+ class RotaryEmbedding(torch.nn.Module):
48
+ def __init__(
49
+ self, dim: int, max_position_embeddings: int = 2048, base: int = 10000, device: Optional[torch.device] = None
50
+ ) -> None:
51
+ super().__init__()
52
+
53
+ self.dim = dim
54
+ self.max_position_embeddings = max_position_embeddings
55
+ self.base = base
56
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
57
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
58
+
59
+ # Build here to make `torch.jit.trace` work.
60
+ self._set_cos_sin_cache(
61
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
62
+ )
63
+
64
+ def _set_cos_sin_cache(self, seq_len: int, device: Any, dtype: Any) -> None:
65
+ self.max_seq_len_cached = seq_len
66
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) # type: ignore
67
+
68
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
69
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
70
+ emb = torch.cat((freqs, freqs), dim=-1)
71
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
72
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
73
+
74
+ def forward(self, x: torch.Tensor, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
75
+ # x: [bs, num_attention_heads, seq_len, head_size]
76
+ if seq_len > self.max_seq_len_cached:
77
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
78
+
79
+ return (
80
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), # type: ignore
81
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), # type: ignore
82
+ )
83
+
84
+
85
+ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
86
+ """Rotates half the hidden dims of the input."""
87
+ x1 = x[..., : x.shape[-1] // 2]
88
+ x2 = x[..., x.shape[-1] // 2 :]
89
+ return torch.cat((-x2, x1), dim=-1)
90
+
91
+
92
+ def _rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor:
93
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
94
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
95
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
96
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
97
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
98
+ x_embed = (x * cos) + (_rotate_half(x) * sin)
99
+ return x_embed
100
+
101
+
102
+ class LinearType(str, enum.Enum):
103
+ Normal = "normal"
104
+ Fp8 = "fp8"
105
+ Fp8Retain = "fp8-retain"
106
+
107
+
108
+ class Plamo2Config(PretrainedConfig): # type: ignore
109
+ model_type: str = "plamo2"
110
+
111
+ def __init__(
112
+ self,
113
+ hidden_size: int = 4096,
114
+ num_hidden_layers: int = 32,
115
+ rms_norm_eps: float = 1e-6,
116
+ tie_word_embeddings: bool = True,
117
+ # Attention
118
+ num_attention_heads: int = 32,
119
+ num_key_value_heads: int = 4,
120
+ hidden_size_per_head: int = 128,
121
+ max_position_embeddings: int = 2048,
122
+ attention_window_size: int = 2048,
123
+ full_attention_idx: list[int] | None = None,
124
+ rope_theta: int = 10000,
125
+ rope_local_theta: int = 10000,
126
+ # Mamba
127
+ mamba_d_state: int = 64,
128
+ mamba_d_conv: int = 4,
129
+ mamba_num_heads: int = 64,
130
+ mamba_step: int = 2,
131
+ mamba_chunk_size: int = 256,
132
+ mamba_enabled: bool = True,
133
+ # MLP
134
+ intermediate_size: int = 13312,
135
+ # Tokenizer
136
+ vocab_size: int = 32000,
137
+ tokenizer_class: str = "Plamo2Tokenizer",
138
+ pad_token_id: Optional[int] = None,
139
+ bos_token_id: int = 1,
140
+ eos_token_id: int = 2,
141
+ # Multimodal
142
+ image_token_id: Optional[int] = None,
143
+ image_feature_size: Optional[int] = None,
144
+ image_proj_type: Literal["linear", "mlp"] = "linear",
145
+ # FP8
146
+ linear_type: LinearType = LinearType.Normal,
147
+ fp8_accum_dtype: Optional[str] = None,
148
+ # Evaluation
149
+ eval_attention_n_bit: Optional[int] = None,
150
+ eval_mlp_n_bit: Optional[int] = None,
151
+ use_cache: bool = True,
152
+ **kwargs: Any,
153
+ ) -> None:
154
+ # max_position_embeddings is often used to determine the max length during inference,
155
+ # but samba should have extrapolation abilities
156
+ self.max_position_embeddings = max(10 * 1024 * 1024, max_position_embeddings)
157
+ self.hidden_size = hidden_size
158
+ self.rms_norm_eps = rms_norm_eps
159
+
160
+ self.num_hidden_layers = num_hidden_layers
161
+ self.num_attention_heads = num_attention_heads
162
+ self.hidden_size_per_head = hidden_size_per_head
163
+ self.num_key_value_heads = num_key_value_heads
164
+ self.attention_window_size = attention_window_size
165
+ self.full_attention_idx = full_attention_idx if full_attention_idx is not None else []
166
+ self.rope_theta = rope_theta
167
+ self.rope_local_theta = rope_local_theta
168
+
169
+ self.mamba_d_state = mamba_d_state
170
+ self.mamba_d_conv = mamba_d_conv
171
+ self.mamba_num_heads = mamba_num_heads
172
+ self.mamba_step = mamba_step
173
+ self.mamba_chunk_size = mamba_chunk_size
174
+ self.mamba_enabled = mamba_enabled
175
+
176
+ self.intermediate_size = intermediate_size
177
+
178
+ self.vocab_size = vocab_size
179
+
180
+ self.image_token_id = image_token_id
181
+ self.image_feature_size = image_feature_size
182
+ self.image_proj_type = image_proj_type
183
+
184
+ self.linear_type = linear_type
185
+ self.fp8_accum_dtype = fp8_accum_dtype
186
+
187
+ self.eval_attention_n_bit = eval_attention_n_bit
188
+ self.eval_mlp_n_bit = eval_mlp_n_bit
189
+ self.use_cache = use_cache
190
+
191
+ # fields for vLLM
192
+ self.sliding_window = attention_window_size
193
+
194
+ super().__init__(
195
+ tokenizer_class=tokenizer_class,
196
+ pad_token_id=pad_token_id,
197
+ bos_token_id=bos_token_id,
198
+ eos_token_id=eos_token_id,
199
+ tie_word_embeddings=tie_word_embeddings,
200
+ **kwargs,
201
+ )
202
+
203
+ @property
204
+ def layers_block_type(self) -> list[str]:
205
+ return ["mamba" if is_mamba(self, i) else "attention" for i in range(self.num_hidden_layers)]
206
+
207
+ @property
208
+ def rope_local_base_freq(self) -> int:
209
+ return self.rope_local_theta
210
+
211
+
212
+ class Plamo2AttentionCache(torch.nn.Module):
213
+ def __init__(self, key: torch.Tensor, value: torch.Tensor) -> None:
214
+ super().__init__()
215
+ B, nh, L, c = key.shape
216
+ assert len(value.shape) == 4
217
+ assert value.shape[0] == B
218
+ assert value.shape[2] == L
219
+ self.register_parameter("key", torch.nn.Parameter(key, requires_grad=False))
220
+ self.register_parameter("value", torch.nn.Parameter(value, requires_grad=False))
221
+
222
+
223
+ class Plamo2MambaCache(torch.nn.Module):
224
+ def __init__(self, conv_state: torch.Tensor, ssm_state: torch.Tensor) -> None:
225
+ super().__init__()
226
+ # conv_state: [B, C, d_conv]
227
+ # ssm_state: [B, nhead, nchanel_per_head, d_state]
228
+ assert len(conv_state.shape) == 3
229
+ assert len(ssm_state.shape) == 4
230
+ assert conv_state.shape[0] == ssm_state.shape[0]
231
+ self.register_parameter("conv_state", torch.nn.Parameter(conv_state, requires_grad=False))
232
+ self.register_parameter("ssm_state", torch.nn.Parameter(ssm_state, requires_grad=False))
233
+
234
+
235
+ Plamo2LayerCache = Plamo2AttentionCache | Plamo2MambaCache
236
+
237
+
238
+ class Plamo2Cache(torch.nn.Module):
239
+ """
240
+ stores states of the model for fast decoding.
241
+ `transformers` uses `transformers.Cache` for this purpose, but the interface and variable names are
242
+ deeply dependent on Transformers architecture (e.g., `key_states`) and it is difficult to use
243
+ other architectures (e.g., Mamba).
244
+ This class provides a similar interface to `transformers.Cache`, but is designed to also handle
245
+ the state of Mamba properly.
246
+ """
247
+
248
+ def __init__(self, config: Plamo2Config) -> None:
249
+ super().__init__()
250
+ self.config = config
251
+ self.cache = torch.nn.ModuleList([None for _ in range(config.num_hidden_layers)]) # type: ignore
252
+
253
+ def append_kv(self, key: torch.Tensor, value: torch.Tensor, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
254
+ c = self.cache[layer_idx]
255
+ if c is None:
256
+ return key, value
257
+ assert isinstance(c, Plamo2AttentionCache)
258
+
259
+ def _validate(cache: torch.Tensor, new_tensor: torch.Tensor) -> None:
260
+ assert len(cache.shape) == 4
261
+ assert len(new_tensor.shape) == 4
262
+ assert cache.shape[0] == new_tensor.shape[0]
263
+ assert cache.shape[1] == new_tensor.shape[1]
264
+ assert cache.shape[3] == new_tensor.shape[3]
265
+
266
+ _validate(c.key, key)
267
+ _validate(c.value, value)
268
+ assert key.shape[2] == value.shape[2]
269
+ return torch.cat([c.key, key], dim=2), torch.cat([c.value, value], dim=2)
270
+
271
+ def update_attention(
272
+ self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int
273
+ ) -> Plamo2AttentionCache:
274
+ full_attn = layer_idx in self.config.full_attention_idx
275
+ window_size = self.config.attention_window_size
276
+
277
+ if self.cache[layer_idx] is None:
278
+ if full_attn:
279
+ self.cache[layer_idx] = Plamo2AttentionCache(key_states, value_states)
280
+ else:
281
+ self.cache[layer_idx] = Plamo2AttentionCache(
282
+ key_states[:, :, -window_size:, :], value_states[:, :, -window_size:, :]
283
+ )
284
+ else:
285
+ c = self.cache[layer_idx]
286
+ assert isinstance(c, Plamo2AttentionCache)
287
+ k, v = self.append_kv(key_states, value_states, layer_idx)
288
+ if full_attn:
289
+ c.key.data = k
290
+ c.value.data = v
291
+ else:
292
+ c.key.data = k[:, :, -window_size:, :]
293
+ c.value.data = v[:, :, -window_size:, :]
294
+ return self.cache[layer_idx] # type: ignore
295
+
296
+ def update_mamba(self, conv_state: torch.Tensor, ssm_state: torch.Tensor, layer_idx: int) -> Plamo2MambaCache:
297
+ if self.cache[layer_idx] is None:
298
+ self.cache[layer_idx] = Plamo2MambaCache(conv_state, ssm_state)
299
+ else:
300
+ c = self.cache[layer_idx]
301
+ assert isinstance(c, Plamo2MambaCache)
302
+ assert c.conv_state.shape == conv_state.shape
303
+ assert c.ssm_state.shape == ssm_state.shape
304
+ c.conv_state.data = conv_state
305
+ c.ssm_state.data = ssm_state
306
+ return self.cache[layer_idx] # type: ignore
307
+
308
+ def __getitem__(self, layer_idx: int) -> Plamo2LayerCache | None:
309
+ assert layer_idx < len(self.cache)
310
+ layer_cache = self.cache[layer_idx]
311
+ return layer_cache # type: ignore
312
+
313
+ def __len__(self) -> int:
314
+ return len(self.cache)
315
+
316
+ def get_seq_length(self, layer_idx: Optional[int] = None) -> int:
317
+ if layer_idx is not None:
318
+ c = self.cache[layer_idx]
319
+ assert isinstance(c, Plamo2AttentionCache)
320
+ return c.key.shape[2] # type: ignore
321
+
322
+ sequence_length: int | None = None
323
+ for layer_cache in self.cache:
324
+ if isinstance(layer_cache, Plamo2AttentionCache):
325
+ sequence_length = (
326
+ max(layer_cache.key.shape[2], sequence_length)
327
+ if sequence_length is not None
328
+ else layer_cache.key.shape[2]
329
+ )
330
+ assert sequence_length is not None
331
+ return sequence_length
332
+
333
+ def get_max_length(self) -> int | None:
334
+ return None
335
+
336
+ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
337
+ """Given the sequence length of the new inputs, returns the usable length of the cache."""
338
+ # Cache without size limit -> all cache is usable
339
+ # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
340
+ # length, we will need to evict part of the cache (and thus not all cache is usable)
341
+ max_length = self.get_max_length()
342
+ previous_seq_length = self.get_seq_length(layer_idx)
343
+ if max_length is not None and previous_seq_length + new_seq_length > max_length:
344
+ return max_length - new_seq_length
345
+ return previous_seq_length
346
+
347
+ def reorder_cache(self, beam_idx: torch.Tensor) -> None:
348
+ def _mamba(cache: Plamo2MambaCache) -> Plamo2MambaCache:
349
+ return Plamo2MambaCache(
350
+ conv_state=cache.conv_state.index_select(0, beam_idx),
351
+ ssm_state=cache.ssm_state.index_select(0, beam_idx),
352
+ )
353
+
354
+ def _attention(cache: Plamo2AttentionCache) -> Plamo2AttentionCache:
355
+ return Plamo2AttentionCache(
356
+ key=cache.key.index_select(0, beam_idx),
357
+ value=cache.value.index_select(0, beam_idx),
358
+ )
359
+
360
+ for i in range(len(self.cache)):
361
+ if self.cache[i] is None:
362
+ continue
363
+ layer_cache = self.cache[i]
364
+ if isinstance(layer_cache, Plamo2MambaCache):
365
+ self.cache[i] = _mamba(layer_cache)
366
+ else:
367
+ assert isinstance(layer_cache, Plamo2AttentionCache)
368
+ self.cache[i] = _attention(layer_cache)
369
+
370
+ @property
371
+ def seen_tokens(self) -> int | None:
372
+ return None
373
+
374
+
375
+ class DecoderInput(NamedTuple):
376
+ hidden_states: torch.Tensor
377
+ attention_mask: Optional[torch.Tensor] = None
378
+ past_states: Optional[Plamo2Cache] = None
379
+ output_hidden_states: Optional[bool] = False
380
+ output_attentions: Optional[bool] = False
381
+ gradient_checkpointing: bool = False
382
+ input_ids: Optional[torch.Tensor] = None
383
+
384
+
385
+ class DecoderOutput(NamedTuple):
386
+ hidden_states: torch.Tensor
387
+ all_hidden_states: Optional[Tuple[torch.Tensor, ...]]
388
+ all_self_attns: Optional[Tuple[torch.Tensor, ...]]
389
+
390
+
391
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
392
+ def _make_causal_mask(
393
+ input_ids_shape: Tuple[int, int], dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
394
+ ) -> torch.Tensor:
395
+ """
396
+ Make causal mask used for bi-directional self-attention.
397
+ """
398
+ bsz, tgt_len = input_ids_shape
399
+ mask = torch.full((tgt_len, tgt_len), float("-inf"), device=device)
400
+ mask_cond = torch.arange(mask.size(-1), device=device)
401
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
402
+ mask = mask.to(dtype)
403
+
404
+ if past_key_values_length > 0:
405
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
406
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
407
+
408
+
409
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
410
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None) -> torch.Tensor:
411
+ """
412
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
413
+ """
414
+ bsz, src_len = mask.size()
415
+ tgt_len = tgt_len if tgt_len is not None else src_len
416
+
417
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
418
+
419
+ inverted_mask = 1.0 - expanded_mask
420
+
421
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), float("-inf")) # type: ignore
422
+
423
+
424
+ def _rms_norm(
425
+ hidden_states: torch.Tensor, weight: Optional[torch.Tensor], eps: float, offset: float = 1.0
426
+ ) -> torch.Tensor:
427
+ input_dtype = hidden_states.dtype
428
+ hidden_states = hidden_states.to(torch.float32)
429
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
430
+ hidden_states = hidden_states * torch.rsqrt(variance + eps)
431
+ hidden_states = hidden_states.to(input_dtype)
432
+ if weight is not None:
433
+ hidden_states = (offset + weight) * hidden_states
434
+ return hidden_states
435
+
436
+
437
+ class RMSNorm(nn.Module):
438
+ def __init__(
439
+ self,
440
+ hidden_size: int,
441
+ eps: float = 1e-6,
442
+ offset: float = 1.0,
443
+ device: Optional[Union[torch.device, str]] = None,
444
+ ) -> None:
445
+ super().__init__()
446
+ self.weight = nn.Parameter(torch.zeros(hidden_size, device=device))
447
+ self.variance_epsilon = eps
448
+ self.offset = offset
449
+
450
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
451
+ return _rms_norm(hidden_states, self.weight, self.variance_epsilon, offset=self.offset)
452
+
453
+
454
+ def get_initial_dt_bias(num_heads: int) -> torch.Tensor:
455
+ dt_min = 0.001
456
+ dt_max = 0.1
457
+ dt = torch.exp(torch.rand(num_heads) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min))
458
+ dt = torch.clamp(dt, 1e-4)
459
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
460
+ return inv_dt
461
+
462
+
463
+ def get_initial_A(num_heads: int) -> torch.Tensor:
464
+ A = torch.arange(1, num_heads + 1, dtype=torch.float32)
465
+ return torch.log(A)
466
+
467
+
468
+ def _bf16_supported_in_triton() -> bool:
469
+ # newer torch (2.2.0 and later?) supports bfloat16 even when using Voltas
470
+ # but triton cannot compile bf16 kernels for Volta
471
+ major, _ = torch.cuda.get_device_capability()
472
+ return major >= 8
473
+
474
+
475
+ def _get_trition_dtype(dtype: torch.dtype) -> torch.dtype:
476
+ if dtype != torch.bfloat16:
477
+ return dtype
478
+ if _bf16_supported_in_triton():
479
+ return dtype
480
+ return torch.float32
481
+
482
+
483
+ def ssd_update_state(
484
+ ssm_state: torch.Tensor,
485
+ x: torch.Tensor,
486
+ dt: torch.Tensor,
487
+ A: torch.Tensor,
488
+ B: torch.Tensor,
489
+ C: torch.Tensor,
490
+ D: torch.Tensor,
491
+ z: torch.Tensor,
492
+ dt_bias: torch.Tensor,
493
+ dt_softplus: bool,
494
+ ) -> torch.Tensor:
495
+ assert ssm_state.dtype == torch.float32
496
+ if dt.is_cuda:
497
+ dtype = _get_trition_dtype(x.dtype)
498
+ else:
499
+ dtype = x.dtype
500
+ if dt.is_cuda:
501
+ f = mamba_ssm.ops.triton.selective_state_update.selective_state_update
502
+ else:
503
+ f = mamba_ssm.ops.triton.selective_state_update.selective_state_update_ref
504
+
505
+ hidden_size_per_head = x.shape[-1]
506
+ d_state = B.shape[-1]
507
+ A = A[:, None, None].expand(-1, hidden_size_per_head, d_state).float()
508
+ dt = dt[..., None].expand(-1, -1, hidden_size_per_head)
509
+ dt_bias = dt_bias[:, None].expand(-1, hidden_size_per_head)
510
+ D = D[:, None].expand(-1, hidden_size_per_head)
511
+ assert ssm_state.dtype == torch.float32
512
+ out = f(
513
+ ssm_state,
514
+ x.to(dtype),
515
+ dt.to(dtype),
516
+ A.float(),
517
+ B.to(dtype),
518
+ C.to(dtype),
519
+ D.float(),
520
+ z.to(dtype),
521
+ dt_bias.float(),
522
+ dt_softplus=dt_softplus,
523
+ )
524
+ return out[:, None] # type: ignore
525
+
526
+
527
+ def _ssd_chunk_scan_combined_naive(
528
+ x: torch.Tensor,
529
+ dt: torch.Tensor,
530
+ A: torch.Tensor,
531
+ B: torch.Tensor,
532
+ C: torch.Tensor,
533
+ D: torch.Tensor,
534
+ z: torch.Tensor,
535
+ dt_bias: torch.Tensor,
536
+ dt_softplus: bool,
537
+ seq_idx: torch.Tensor | None,
538
+ ssm_state: torch.Tensor,
539
+ ) -> tuple[torch.Tensor, torch.Tensor]:
540
+ assert ssm_state.dtype == torch.float32
541
+ length = x.shape[1]
542
+ ys = []
543
+ for i in range(length):
544
+ if i != 0 and seq_idx is not None:
545
+ ssm_state = torch.where(
546
+ (seq_idx[:, i - 1] != seq_idx[:, i])[:, None, None, None],
547
+ torch.zeros_like(ssm_state),
548
+ ssm_state,
549
+ )
550
+ y = ssd_update_state(
551
+ ssm_state,
552
+ x[:, i],
553
+ dt[:, i],
554
+ A,
555
+ B[:, i],
556
+ C[:, i],
557
+ D,
558
+ z=z[:, i],
559
+ dt_bias=dt_bias,
560
+ dt_softplus=dt_softplus,
561
+ )
562
+ ys.append(y)
563
+ return torch.cat(ys, dim=1), ssm_state
564
+
565
+
566
+ def _ssd_chunk_scan_combined_cpu(
567
+ x: torch.Tensor,
568
+ dt: torch.Tensor,
569
+ A: torch.Tensor,
570
+ B: torch.Tensor,
571
+ C: torch.Tensor,
572
+ chunk_size: int,
573
+ D: torch.Tensor,
574
+ z: torch.Tensor,
575
+ dt_bias: torch.Tensor,
576
+ dt_softplus: bool,
577
+ ) -> tuple[torch.Tensor, torch.Tensor]:
578
+ # (bsize, nhead, nchunk, chunk_size)
579
+ dt = dt.float() # We want high precision for this before cumsum
580
+ dt = dt.permute(0, 2, 1).unflatten(2, (-1, chunk_size)) # type: ignore
581
+ if dt_bias is not None:
582
+ dt = dt + dt_bias[None, :, None, None]
583
+ if dt_softplus:
584
+ dt = F.softplus(dt)
585
+ dA = dt * A[None, :, None, None]
586
+ dA_cumsum = torch.cumsum(dA, dim=-1)
587
+
588
+ _, _, nheads, _ = x.shape
589
+ dstate = B.shape[-1]
590
+ _ = dt.shape[2]
591
+
592
+ with torch.profiler.record_function("ssd_chunk_scan_combined_cpu_chunk_state"):
593
+ # Following is equivalent to `mamba_ssm.ops.triton.ssd_combined.chunk_state_ref(B, x, dt, dA_cumsum)`
594
+ # But `einsum` in the above function is too slow in CPU.
595
+ x_ = torch.unflatten(x, 1, (-1, chunk_size))
596
+ assert B.shape[2] == nheads # B should be already expanded
597
+ B_ = torch.unflatten(B, 1, (-1, chunk_size)).to(x.dtype) # (bsize, nchunk, chunk_size, nheads, dstate)
598
+ decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum)).to(x.dtype)
599
+ dt_ = dt.to(x.dtype)
600
+
601
+ # einsum("bclhn,bhcl,bhcl,bclhp->bchpn", B_, decay_states, dt_, x_)
602
+ B_ = B_.permute(0, 1, 3, 4, 2) # bchnl
603
+ tmp = dt_ * decay_states # bhcl
604
+ tmp = tmp.permute(0, 2, 1, 3)[:, :, :, None] # bch1l
605
+ tmp = B_ * tmp # bchnl
606
+ x_ = x_.permute(0, 1, 3, 2, 4) # bchlp
607
+ tmp = tmp @ x_ # bchnp
608
+ states = tmp.permute(0, 1, 2, 4, 3) # bchpn
609
+
610
+ states_dtype = states.dtype
611
+ if states.dtype not in [torch.float32, torch.float64]:
612
+ states = states.to(torch.float32)
613
+ with torch.profiler.record_function("ssd_chunk_scan_combined_cpu_state_passing"):
614
+ out, last_state = mamba_ssm.ops.triton.ssd_combined.state_passing_ref(
615
+ states.flatten(start_dim=-2, end_dim=-1),
616
+ dA_cumsum[:, :, :, -1],
617
+ )
618
+ states = torch.unflatten(out, -1, (-1, dstate))
619
+ last_state = torch.unflatten(last_state, -1, (-1, dstate))
620
+ states = states.to(states_dtype)
621
+ with torch.profiler.record_function("ssd_chunk_scan_combined_cpu_chunk_scan"):
622
+ out = mamba_ssm.ops.triton.ssd_combined.chunk_scan_ref(B, C, x, dt, dA_cumsum, states, D=D, z=z)
623
+
624
+ return out, last_state
625
+
626
+
627
+ @torch.profiler.record_function("ssd_chunk_scan_combined")
628
+ def ssd_chunk_scan_combined(
629
+ x: torch.Tensor,
630
+ dt: torch.Tensor,
631
+ A: torch.Tensor,
632
+ B: torch.Tensor,
633
+ C: torch.Tensor,
634
+ chunk_size: int,
635
+ D: torch.Tensor,
636
+ z: torch.Tensor,
637
+ dt_bias: torch.Tensor,
638
+ dt_softplus: bool,
639
+ return_final_states: bool,
640
+ seq_idx: torch.Tensor | None,
641
+ ssm_state: torch.Tensor | None,
642
+ ) -> tuple[torch.Tensor, torch.Tensor] | torch.Tensor:
643
+ if seq_idx is not None:
644
+ assert seq_idx.dtype == torch.int32
645
+ assert ssm_state is None
646
+ assert not return_final_states
647
+ if ssm_state is not None:
648
+ assert ssm_state.dtype == torch.float32
649
+ assert seq_idx is None
650
+
651
+ length = x.shape[1]
652
+
653
+ """
654
+ state will be updates by following:
655
+ ```
656
+ dt = softplus(dt)
657
+ dA = exp(dt * A)
658
+ state_next = state * dA + dB * x
659
+ ```
660
+
661
+ To avoid updating state, we set dt to -inf and x to 0
662
+ because `softplus(-inf) = 0` and `exp(0) = 1`
663
+ """
664
+ pad = (chunk_size - length % chunk_size) % chunk_size
665
+ x = torch.nn.functional.pad(x, pad=[0, 0, 0, 0, pad, 0], value=0.0)
666
+ dt = torch.nn.functional.pad(dt, pad=[0, 0, pad, 0], value=float("-inf"))
667
+ B = torch.nn.functional.pad(B, pad=[0, 0, 0, 0, pad, 0], value=0.0)
668
+ C = torch.nn.functional.pad(C, pad=[0, 0, 0, 0, pad, 0], value=0.0)
669
+ z = torch.nn.functional.pad(z, pad=[0, 0, 0, 0, pad, 0], value=0.0)
670
+ if seq_idx is not None:
671
+ seq_idx = torch.nn.functional.pad(seq_idx, pad=[pad, 0], value=0)
672
+
673
+ length = x.shape[1]
674
+ assert length % chunk_size == 0, (length, chunk_size)
675
+
676
+ if dt.is_cuda:
677
+ dtype = _get_trition_dtype(x.dtype)
678
+ out = mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined( # type: ignore
679
+ x.to(dtype),
680
+ dt.to(dtype),
681
+ A.float(),
682
+ B.to(dtype),
683
+ C.to(dtype),
684
+ chunk_size,
685
+ D=D.float(),
686
+ z=z.to(dtype),
687
+ initial_states=ssm_state,
688
+ dt_bias=dt_bias.float(),
689
+ dt_softplus=dt_softplus,
690
+ seq_idx=seq_idx,
691
+ return_final_states=return_final_states,
692
+ )
693
+ if return_final_states:
694
+ return out[0][:, pad:], out[1]
695
+ else:
696
+ assert isinstance(out, torch.Tensor)
697
+ return out[:, pad:]
698
+ else:
699
+ if ssm_state is None and seq_idx is None:
700
+ tmp = _ssd_chunk_scan_combined_cpu(
701
+ x,
702
+ dt,
703
+ A,
704
+ B,
705
+ C,
706
+ chunk_size,
707
+ D=D,
708
+ z=z,
709
+ dt_bias=dt_bias.float(),
710
+ dt_softplus=dt_softplus,
711
+ )
712
+ else:
713
+ if ssm_state is None:
714
+ bsize, _, num_heads, channel = x.shape
715
+ state = B.shape[-1]
716
+ ssm_state = torch.zeros(bsize, num_heads, channel, state, dtype=torch.float32, device=x.device)
717
+ tmp = _ssd_chunk_scan_combined_naive(
718
+ x, dt, A, B, C, D, z=z, dt_bias=dt_bias, dt_softplus=dt_softplus, seq_idx=seq_idx, ssm_state=ssm_state
719
+ )
720
+ tmp = (tmp[0][:, pad:], tmp[1])
721
+ if return_final_states:
722
+ return tmp
723
+ else:
724
+ return tmp[0]
725
+
726
+
727
+ def _causal_conv1d_update(
728
+ conv_state: torch.Tensor, weight: torch.Tensor, xBC: torch.Tensor
729
+ ) -> tuple[torch.Tensor, torch.Tensor]:
730
+ dtype = conv_state.dtype
731
+ xBC = xBC.to(dtype)
732
+ weight = weight.to(dtype)
733
+ if conv_state.is_cuda:
734
+ x = causal_conv1d.causal_conv1d_update(
735
+ x=xBC,
736
+ conv_state=conv_state,
737
+ weight=weight[:, 0, :],
738
+ activation="silu",
739
+ )
740
+ return x, conv_state
741
+ else:
742
+ x = causal_conv1d.causal_conv1d_update_ref(
743
+ x=xBC,
744
+ conv_state=conv_state,
745
+ weight=weight[:, 0, :],
746
+ activation="silu",
747
+ )
748
+ return x, conv_state
749
+
750
+
751
+ def _causal_conv1d_naive(
752
+ conv_state: torch.Tensor, weight: torch.Tensor, x: torch.Tensor, seq_idx: torch.Tensor | None
753
+ ) -> tuple[torch.Tensor, torch.Tensor]:
754
+ length = x.shape[-1]
755
+ out = torch.zeros_like(x)
756
+ for i in range(length):
757
+ if i != 0 and seq_idx is not None:
758
+ conv_state = torch.where(
759
+ (seq_idx[:, i - 1] != seq_idx[:, i])[:, None, None],
760
+ torch.zeros_like(conv_state),
761
+ conv_state,
762
+ )
763
+ out[:, :, i : i + 1], conv_state = _causal_conv1d_update(conv_state, weight, x[:, :, i : i + 1])
764
+ return out, conv_state
765
+
766
+
767
+ @torch.profiler.record_function("causal_conv1d")
768
+ def _causal_conv1d(
769
+ conv_state: torch.Tensor | None, weight: torch.Tensor, x: torch.Tensor, seq_idx: torch.Tensor | None
770
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
771
+ dtype = x.dtype
772
+ if conv_state is not None:
773
+ dtype = conv_state.dtype
774
+ assert seq_idx is None
775
+ if seq_idx is not None:
776
+ assert seq_idx.dtype == torch.int32
777
+ assert conv_state is None
778
+ weight = weight.to(dtype)
779
+ x = x.to(dtype)
780
+
781
+ return_final_states = conv_state is not None
782
+ if weight.is_cuda:
783
+ if x.stride(1) != 1:
784
+ # to channel-last format
785
+ x = x.transpose(-1, -2).contiguous().transpose(-1, -2)
786
+ if conv_state is not None:
787
+ if conv_state.stride(1) != 1:
788
+ # to channel-last format
789
+ conv_state = conv_state.transpose(-1, -2).contiguous().transpose(-1, -2)
790
+ tmp = causal_conv1d.causal_conv1d_fn(
791
+ x=x,
792
+ weight=weight[:, 0, :],
793
+ initial_states=conv_state,
794
+ return_final_states=conv_state is not None,
795
+ activation="silu",
796
+ seq_idx=seq_idx,
797
+ )
798
+ if conv_state is not None:
799
+ x, conv_state = tmp
800
+ else:
801
+ x = tmp
802
+ else:
803
+ if seq_idx is None:
804
+ x, conv_state = causal_conv1d.causal_conv1d_ref(
805
+ x=x,
806
+ initial_states=conv_state,
807
+ return_final_states=True,
808
+ weight=weight[:, 0, :],
809
+ activation="silu",
810
+ )
811
+ else:
812
+ if conv_state is None:
813
+ bsize = x.shape[0]
814
+ dim = weight.shape[0]
815
+ d_conv = weight.shape[-1]
816
+ conv_state = torch.zeros(bsize, dim, d_conv - 1, dtype=x.dtype, device=x.device)
817
+ x, conv_state = _causal_conv1d_naive(conv_state, weight, x, seq_idx)
818
+ if return_final_states:
819
+ return x, conv_state
820
+ else:
821
+ return x, None
822
+
823
+
824
+ class Mamba(torch.nn.Module):
825
+ def __init__(self, config: Plamo2Config, layer_idx: int) -> None:
826
+ super().__init__()
827
+ self.config = config
828
+ self.layer_idx = layer_idx
829
+ self.hidden_size = config.hidden_size
830
+ self.d_state = config.mamba_d_state
831
+ self.d_conv = config.mamba_d_conv
832
+ self.chunk_size = config.mamba_chunk_size
833
+ self.num_heads = config.mamba_num_heads
834
+ # TODO add mamba_hidden_size_per_head config (?)
835
+ self.hidden_size_per_head = config.hidden_size_per_head
836
+
837
+ self.intermediate_size = self.num_heads * self.hidden_size_per_head
838
+
839
+ self.in_proj = torch.nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False)
840
+ self.conv1d = torch.nn.Conv1d(
841
+ in_channels=self.intermediate_size,
842
+ out_channels=self.intermediate_size,
843
+ bias=False, # TODO the original implementation uses bias
844
+ kernel_size=self.d_conv,
845
+ groups=self.intermediate_size,
846
+ padding=0,
847
+ )
848
+ self.dt_dim = max(64, self.hidden_size // 16)
849
+ # Notes:
850
+ # Mamba2 removes this linear projection for simplicity (Figure 6 in the paper),
851
+ # but it may degrade the ability of content-length extrapolation.
852
+ self.bcdt_proj = torch.nn.Linear(
853
+ self.intermediate_size,
854
+ self.dt_dim + 2 * self.d_state,
855
+ bias=False,
856
+ )
857
+ self.dt_proj = torch.nn.Linear(self.dt_dim, self.num_heads, bias=False)
858
+
859
+ self.dt_bias = torch.nn.Parameter(get_initial_dt_bias(self.num_heads))
860
+ self.A_log = torch.nn.Parameter(get_initial_A(self.num_heads))
861
+ self.D = torch.nn.Parameter(torch.ones(self.num_heads))
862
+
863
+ # TODO norm weight before gating like Mamba2
864
+ self.dt_norm_weight = torch.nn.Parameter(torch.ones(self.dt_dim))
865
+ self.B_norm_weight = torch.nn.Parameter(torch.ones(self.d_state))
866
+ self.C_norm_weight = torch.nn.Parameter(torch.ones(self.d_state))
867
+
868
+ self.out_proj = torch.nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
869
+
870
+ def _no_weight_decay_param_names(self) -> set[str]:
871
+ return set(["D", "dt_bias", "A_log"])
872
+
873
+ def forward(
874
+ self,
875
+ hidden_states: torch.Tensor,
876
+ attention_mask: Optional[torch.Tensor] = None,
877
+ past_states: Optional[Plamo2Cache] = None,
878
+ ) -> Tuple[torch.Tensor, Optional[Plamo2Cache]]:
879
+ bsize, length, _ = hidden_states.shape
880
+ is_update = length == 1 and past_states is not None
881
+
882
+ bool_mask: torch.Tensor | None = None
883
+ seq_idx: torch.Tensor | None = None
884
+ if attention_mask is not None:
885
+ if len(attention_mask.shape) == 2:
886
+ attention_mask = attention_mask[None, None].expand(bsize, 1, -1, -1)
887
+ assert len(attention_mask.shape) == 4
888
+
889
+ if past_states is None:
890
+ # TODO: support seq_idx with cache
891
+ bool_mask_4d = attention_mask == 0
892
+ is_first_token = _is_first_token(bool_mask_4d)[:, 0, :]
893
+ seq_idx = torch.cumsum(is_first_token, dim=-1) - 1
894
+ seq_idx = seq_idx.to(torch.int32)
895
+
896
+ # `generate` function creates attention mask that contains past tokens,
897
+ # but mamba does not use them
898
+ attention_mask = attention_mask[:, 0, -length:, -length:]
899
+ bool_mask = torch.diagonal(attention_mask, dim1=-2, dim2=-1) == 0
900
+
901
+ conv_state: torch.Tensor | None
902
+ ssm_state: torch.Tensor | None
903
+ if past_states is None:
904
+ conv_state = None
905
+ ssm_state = None
906
+ elif past_states[self.layer_idx] is None:
907
+ conv_state = torch.zeros(
908
+ bsize, self.intermediate_size, self.d_conv - 1, dtype=hidden_states.dtype, device=hidden_states.device
909
+ )
910
+ ssm_state = torch.zeros(
911
+ bsize,
912
+ self.num_heads,
913
+ self.hidden_size_per_head,
914
+ self.d_state,
915
+ dtype=torch.float32,
916
+ device=hidden_states.device,
917
+ )
918
+ else:
919
+ c = past_states[self.layer_idx]
920
+ assert isinstance(c, Plamo2MambaCache)
921
+ conv_state = c.conv_state
922
+ ssm_state = c.ssm_state
923
+
924
+ zx = self.in_proj(hidden_states)
925
+ zx = zx.reshape(bsize, length, self.num_heads, -1)
926
+ # z: (bsize, length, num_heads, hidden_size_per_head)
927
+ # x: (bsize, length, num_heads, hidden_size_per_head)
928
+ z, x = torch.split(zx, [self.hidden_size_per_head, self.hidden_size_per_head], dim=-1)
929
+
930
+ # conv
931
+ x = x.reshape(bsize, length, -1).transpose(1, 2) # (bsize, intermediate_size, length)
932
+ if bool_mask is not None:
933
+ x = torch.where(bool_mask[:, None, :], x, 0.0)
934
+ if is_update:
935
+ assert conv_state is not None
936
+ x, conv_state = _causal_conv1d_update(conv_state, self.conv1d.weight, x)
937
+ else:
938
+ x, conv_state = _causal_conv1d(conv_state, self.conv1d.weight, x, seq_idx=seq_idx)
939
+ x = x.to(dtype=hidden_states.dtype)
940
+ x = x.transpose(1, 2) # (bsize, length, intermediate_size)
941
+ x = x.reshape(bsize, length, -1)
942
+ # x: (bsize, length, num_heads, hidden_size_per_head)
943
+ # B: (bsize, length, 1, d_state)
944
+ # C: (bsize, length, 1, d_state)
945
+ # dt: (bsize, length, dt_dim)
946
+ BCdt = self.bcdt_proj(x)
947
+ x = x.reshape(bsize, length, self.num_heads, -1)
948
+ B, C, dt = torch.split(BCdt, [self.d_state, self.d_state, self.dt_dim], dim=-1)
949
+ B = B[:, :, None, :]
950
+ C = C[:, :, None, :]
951
+
952
+ A = -torch.exp(self.A_log.float()) # (num_heads,)
953
+ dt = _rms_norm(dt, None, self.config.rms_norm_eps) * self.dt_norm_weight[None, None, :]
954
+ B = _rms_norm(B, None, self.config.rms_norm_eps) * self.B_norm_weight[None, None, None, :]
955
+ C = _rms_norm(C, None, self.config.rms_norm_eps) * self.C_norm_weight[None, None, None, :]
956
+
957
+ # (bsize, length, num_heads, 1)
958
+ dt = self.dt_proj(dt)[..., None]
959
+
960
+ # TODO it may not be required
961
+ B = B.expand(-1, -1, self.num_heads, -1)
962
+ C = C.expand(-1, -1, self.num_heads, -1)
963
+
964
+ if bool_mask is not None:
965
+ """
966
+ state will be updates by following:
967
+ ```
968
+ dt = softplus(dt)
969
+ dA = exp(dt * A)
970
+ state_next = state * dA + dB * x
971
+ ```
972
+
973
+ To avoid updating state, we set dt to -inf and x to 0
974
+ because `softplus(-inf) = 0` and `exp(0) = 1`
975
+ """
976
+ dt = torch.where(bool_mask[:, :, None, None], dt, float("-inf"))
977
+ x = torch.where(bool_mask[:, :, None, None], x, 0.0)
978
+
979
+ # ssm
980
+ if is_update:
981
+ assert ssm_state is not None
982
+ out = ssd_update_state(
983
+ ssm_state,
984
+ x[:, 0],
985
+ dt[:, 0].reshape(bsize, -1),
986
+ A,
987
+ B[:, 0],
988
+ C[:, 0],
989
+ D=self.D,
990
+ z=z[:, 0],
991
+ dt_bias=self.dt_bias,
992
+ dt_softplus=True,
993
+ )
994
+ else:
995
+ tmp = ssd_chunk_scan_combined(
996
+ x,
997
+ dt.reshape(bsize, length, -1),
998
+ A,
999
+ B,
1000
+ C,
1001
+ self.chunk_size,
1002
+ D=self.D,
1003
+ z=z,
1004
+ dt_bias=self.dt_bias,
1005
+ dt_softplus=True,
1006
+ return_final_states=past_states is not None,
1007
+ seq_idx=seq_idx,
1008
+ ssm_state=ssm_state,
1009
+ )
1010
+ if past_states is not None:
1011
+ out, ssm_state = tmp
1012
+ else:
1013
+ assert isinstance(tmp, torch.Tensor)
1014
+ out = tmp
1015
+
1016
+ y = self.out_proj(out.reshape(bsize, length, -1))
1017
+
1018
+ if past_states is not None:
1019
+ assert ssm_state is not None
1020
+ assert conv_state is not None
1021
+ past_states.update_mamba(conv_state, ssm_state, self.layer_idx)
1022
+
1023
+ return y, past_states
1024
+
1025
+
1026
+ def swa_mask(q_len: int, kv_len: int, device: torch.device, window_size: int) -> torch.Tensor:
1027
+ max_len = max(q_len, kv_len)
1028
+ mask = (
1029
+ torch.ones(max_len, max_len, dtype=torch.bool, device=device)
1030
+ .triu(diagonal=-window_size)
1031
+ .tril(diagonal=window_size)
1032
+ )
1033
+ return mask[-q_len:, -kv_len:]
1034
+
1035
+
1036
+ class Attention(torch.nn.Module):
1037
+ def __init__(self, config: Plamo2Config, layer_idx: int) -> None:
1038
+ super().__init__()
1039
+ self.config = config
1040
+ self.layer_idx = layer_idx
1041
+ self.hidden_size = config.hidden_size
1042
+ head_dim = config.hidden_size_per_head
1043
+ self.max_position_embeddings = config.max_position_embeddings
1044
+
1045
+ self.q_num_heads = config.num_attention_heads
1046
+ self.qk_dim = self.v_dim = head_dim
1047
+ self.k_num_heads = self.v_num_heads = config.num_key_value_heads
1048
+ assert self.q_num_heads % self.k_num_heads == 0
1049
+ self.n_group = self.q_num_heads // self.k_num_heads
1050
+
1051
+ self.q_proj_dim = self.q_num_heads * self.qk_dim
1052
+ self.k_proj_dim = self.k_num_heads * self.qk_dim
1053
+ self.v_proj_dim = self.k_num_heads * self.v_dim
1054
+ self.qkv_proj = nn.Linear(self.hidden_size, self.q_proj_dim + self.k_proj_dim + self.v_proj_dim, bias=False)
1055
+ self.o_proj = nn.Linear(self.q_num_heads * self.v_dim, self.hidden_size, bias=False)
1056
+
1057
+ self.q_weight = torch.nn.Parameter(torch.ones((self.q_num_heads, self.qk_dim)))
1058
+ self.k_weight = torch.nn.Parameter(torch.ones((self.k_num_heads, self.qk_dim)))
1059
+
1060
+ self.full_attn = self.layer_idx in self.config.full_attention_idx
1061
+ base = self.config.rope_theta if self.full_attn else self.config.rope_local_theta
1062
+ self.rotary_emb = RotaryEmbedding(
1063
+ self.qk_dim, max_position_embeddings=self.config.attention_window_size, base=base
1064
+ )
1065
+
1066
+ def forward(
1067
+ self,
1068
+ hidden_states: torch.Tensor,
1069
+ attention_mask: Optional[torch.Tensor] = None,
1070
+ past_states: Optional[Plamo2Cache] = None,
1071
+ output_attentions: bool = False,
1072
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Plamo2Cache]]:
1073
+ bsz, q_len, _ = hidden_states.size()
1074
+
1075
+ qkv = self.qkv_proj(hidden_states)
1076
+ query_states, key_states, value_states = torch.split(
1077
+ qkv, [self.q_proj_dim, self.k_proj_dim, self.v_proj_dim], dim=-1
1078
+ )
1079
+ query_states = query_states.view(bsz, q_len, self.q_num_heads, self.qk_dim).transpose(1, 2)
1080
+ key_states = key_states.view(bsz, q_len, self.k_num_heads, self.qk_dim).transpose(1, 2)
1081
+ value_states = value_states.view(bsz, q_len, self.v_num_heads, self.v_dim).transpose(1, 2)
1082
+
1083
+ attn_dtype = query_states.dtype
1084
+
1085
+ query_states = _rms_norm(query_states, None, 1e-6) * self.q_weight[None, :, None]
1086
+ key_states = _rms_norm(key_states, None, 1e-6) * self.k_weight[None, :, None]
1087
+
1088
+ if past_states is not None:
1089
+ # reuse k, v, self_attention
1090
+ key_states_new = key_states
1091
+ value_states_new = value_states
1092
+ key_states, value_states = past_states.append_kv(key_states, value_states, self.layer_idx) # type: ignore
1093
+ past_states.update_attention(key_states_new, value_states_new, self.layer_idx)
1094
+
1095
+ kv_seq_len = key_states.shape[-2]
1096
+ device = hidden_states.device
1097
+ position_ids = torch.arange(kv_seq_len, dtype=torch.long, device=device)[None]
1098
+ q_position_ids = position_ids[:, -query_states.shape[2] :]
1099
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
1100
+ query_states = _rotary_pos_emb(query_states, cos, sin, q_position_ids)
1101
+ key_states = _rotary_pos_emb(key_states, cos, sin, position_ids)
1102
+ # [bsz, nh, t, hd]
1103
+
1104
+ def _expand_kv(t: torch.Tensor, repeat: int, target: int) -> torch.Tensor:
1105
+ t = torch.repeat_interleave(t, repeat, dim=1)
1106
+ return t[:, :target]
1107
+
1108
+ # expand shared kv
1109
+ assert self.k_num_heads == self.v_num_heads
1110
+ key_states = _expand_kv(key_states, self.n_group, self.q_num_heads)
1111
+ value_states = _expand_kv(value_states, self.n_group, self.q_num_heads)
1112
+
1113
+ query_states = query_states.to(attn_dtype)
1114
+ key_states = key_states.to(attn_dtype)
1115
+ value_states = value_states.to(attn_dtype)
1116
+ if attention_mask is not None and attention_mask.dtype != torch.bool:
1117
+ attention_mask = attention_mask.to(attn_dtype)
1118
+ if attention_mask is None:
1119
+ if not self.full_attn:
1120
+ assert key_states.shape[2] <= self.config.attention_window_size + 1
1121
+ attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=True)
1122
+ else:
1123
+ if attention_mask.dtype == torch.bool:
1124
+ attention_mask = torch.where(attention_mask, torch.tensor(0.0, dtype=torch.float), float("-inf"))
1125
+ if len(attention_mask.shape) == 2:
1126
+ attention_mask = attention_mask[None, None]
1127
+ assert len(attention_mask.shape) == 4
1128
+
1129
+ if not self.full_attn:
1130
+ m_swa = swa_mask(
1131
+ query_states.shape[2], key_states.shape[2], query_states.device, self.config.attention_window_size
1132
+ )
1133
+ # `generate` function creates attention mask that does not consider sliding window
1134
+ m_swa = m_swa[None, None]
1135
+ attention_mask = attention_mask[:, :, -query_states.shape[2] :, -key_states.shape[2] :]
1136
+ attention_mask = torch.where(m_swa, attention_mask, float("-inf"))
1137
+
1138
+ # like AttentionMaskConverter._unmask_unattended in huggingface.transfoermers,
1139
+ # we need to attend to all tokens in masked rows for `scaled_dot_product_attention`
1140
+ bool_mask = torch.logical_not(torch.isneginf(attention_mask))
1141
+ valid_tokens = torch.sum(bool_mask, dim=-1).bool() # (..., q_len)
1142
+ attention_mask = torch.where(valid_tokens[..., None], attention_mask, float(0.0))
1143
+ attn_output = F.scaled_dot_product_attention(
1144
+ query_states, key_states, value_states, attn_mask=attention_mask
1145
+ )
1146
+
1147
+ attn_output = attn_output.transpose(1, 2)
1148
+
1149
+ attn_output = attn_output.reshape(bsz, q_len, self.q_num_heads * self.v_dim)
1150
+ attn_output = self.o_proj(attn_output)
1151
+
1152
+ if not output_attentions:
1153
+ attn_weights = None
1154
+
1155
+ return attn_output, attn_weights, past_states
1156
+
1157
+
1158
+ class MLP(nn.Module):
1159
+ def __init__(self, config: Plamo2Config) -> None:
1160
+ super().__init__()
1161
+ self.config = config
1162
+ self.hidden_size = config.hidden_size
1163
+ self.intermediate_size = config.intermediate_size
1164
+ self.gate_up_proj = torch.nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
1165
+ self.down_proj = torch.nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
1166
+
1167
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1168
+ h = self.gate_up_proj(x)
1169
+ h = _swiglu(h)
1170
+ return self.down_proj(h) # type: ignore
1171
+
1172
+
1173
+ class Plamo2DecoderLayer(torch.nn.Module):
1174
+ def __init__(self, config: Plamo2Config, layer_idx: int) -> None:
1175
+ super().__init__()
1176
+ self.config = config
1177
+ self.hidden_size = config.hidden_size
1178
+ self.is_mamba = config.layers_block_type[layer_idx] == "mamba"
1179
+ self.mixer: torch.nn.Module
1180
+ if self.is_mamba:
1181
+ self.mixer = Mamba(config, layer_idx)
1182
+ else:
1183
+ self.mixer = Attention(config, layer_idx)
1184
+ self.mlp = MLP(config)
1185
+ """
1186
+ Notes: The model performance was degraded when setting all offsets to 1.
1187
+ """
1188
+ self.pre_mixer_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0)
1189
+ self.post_mixer_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / 5)
1190
+ self.pre_mlp_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0)
1191
+ self.post_mlp_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / (5**1.5))
1192
+
1193
+ def forward(
1194
+ self,
1195
+ hidden_states: torch.Tensor,
1196
+ attention_mask: Optional[torch.Tensor] = None,
1197
+ past_state: Optional[Plamo2Cache] = None,
1198
+ output_attentions: Optional[bool] = False,
1199
+ ) -> Tuple[Any, ...]:
1200
+ # from LlamaDecoder
1201
+ residual = hidden_states
1202
+ hidden_states = self.pre_mixer_norm(hidden_states)
1203
+
1204
+ # Self Attention
1205
+ if self.is_mamba:
1206
+ hidden_states_sa, present_key_value = self.mixer(
1207
+ hidden_states=hidden_states,
1208
+ attention_mask=attention_mask,
1209
+ past_states=past_state,
1210
+ )
1211
+ self_attn_weights = None
1212
+ else:
1213
+ hidden_states_sa, self_attn_weights, present_key_value = self.mixer(
1214
+ hidden_states=hidden_states,
1215
+ attention_mask=attention_mask,
1216
+ past_states=past_state,
1217
+ output_attentions=output_attentions,
1218
+ )
1219
+
1220
+ hidden_states_sa = self.post_mixer_norm(hidden_states_sa)
1221
+ hidden_states = residual + hidden_states_sa
1222
+
1223
+ residual = hidden_states
1224
+ hidden_states = self.pre_mlp_norm(hidden_states)
1225
+
1226
+ # Fully Connected
1227
+ hidden_states_mlp = self.mlp(hidden_states)
1228
+
1229
+ # Residual
1230
+ hidden_states_mlp = self.post_mlp_norm(hidden_states_mlp)
1231
+ hidden_states = residual + hidden_states_mlp
1232
+
1233
+ outputs: Any = (hidden_states,)
1234
+
1235
+ if output_attentions:
1236
+ outputs += (self_attn_weights,)
1237
+
1238
+ return outputs # type: ignore
1239
+
1240
+
1241
+ def is_mamba(config: Plamo2Config, i: int) -> bool:
1242
+ if not config.mamba_enabled:
1243
+ return False
1244
+ assert config.mamba_step > 1
1245
+ assert i < config.num_hidden_layers
1246
+
1247
+ if config.num_hidden_layers <= (config.mamba_step // 2):
1248
+ # use attention in last layer
1249
+ return i != config.num_hidden_layers - 1
1250
+ return (i % config.mamba_step) != (config.mamba_step // 2)
1251
+
1252
+
1253
+ class Plamo2Decoder(torch.nn.Module):
1254
+ def __init__(self, config: Plamo2Config) -> None:
1255
+ super().__init__()
1256
+
1257
+ self.layers = torch.nn.ModuleList(
1258
+ [Plamo2DecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]
1259
+ )
1260
+ self.gradient_checkpointing = False
1261
+
1262
+ def forward(self, x: DecoderInput) -> DecoderOutput:
1263
+ all_hidden_states: Optional[Tuple[torch.Tensor, ...]] = () if x.output_hidden_states else None
1264
+ all_self_attns: Optional[Tuple[torch.Tensor, ...]] = () if x.output_attentions else None
1265
+ hidden_states = x.hidden_states
1266
+
1267
+ for decoder_layer in self.layers:
1268
+ if x.output_hidden_states:
1269
+ assert all_hidden_states is not None
1270
+ all_hidden_states += (hidden_states,)
1271
+
1272
+ if self.training and x.gradient_checkpointing:
1273
+ layer_outputs = self._gradient_checkpointing_func(
1274
+ decoder_layer.__call__,
1275
+ hidden_states,
1276
+ x.attention_mask,
1277
+ x.past_states,
1278
+ x.output_attentions,
1279
+ )
1280
+ else:
1281
+ layer_outputs = decoder_layer(
1282
+ hidden_states,
1283
+ attention_mask=x.attention_mask,
1284
+ past_state=x.past_states,
1285
+ output_attentions=x.output_attentions,
1286
+ )
1287
+
1288
+ hidden_states = layer_outputs[0]
1289
+
1290
+ if x.output_attentions:
1291
+ assert layer_outputs[1] is not None
1292
+ assert all_self_attns is not None
1293
+ all_self_attns += (layer_outputs[1],)
1294
+ return DecoderOutput(hidden_states, all_hidden_states, all_self_attns)
1295
+
1296
+
1297
+ class Plamo2PreTrainedModel(PreTrainedModel): # type: ignore
1298
+ config_class = Plamo2Config
1299
+ _no_split_modules: List[str]
1300
+ base_model_prefix = "model"
1301
+ supports_gradient_checkpointing = True
1302
+ _no_split_modules = ["PlamoDecoderLayer"]
1303
+ _skip_keys_device_placement = "past_key_values"
1304
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
1305
+
1306
+ def _init_weights(self, module: torch.nn.Module) -> None:
1307
+ std = 0.02
1308
+ if isinstance(module, nn.Linear):
1309
+ module.weight.data.normal_(mean=0.0, std=std)
1310
+ if module.bias is not None:
1311
+ module.bias.data.zero_()
1312
+ elif isinstance(module, nn.Embedding):
1313
+ module.weight.data.normal_(mean=0.0, std=std)
1314
+ if module.padding_idx is not None:
1315
+ module.weight.data[module.padding_idx].zero_()
1316
+
1317
+
1318
+ class Plamo2Model(Plamo2PreTrainedModel):
1319
+ def __init__(self, config: Plamo2Config):
1320
+ super().__init__(config)
1321
+ assert config.eval_attention_n_bit is None
1322
+ assert config.eval_mlp_n_bit is None
1323
+
1324
+ self.padding_idx = config.pad_token_id
1325
+ self.vocab_size = config.vocab_size
1326
+
1327
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1328
+ if config.image_feature_size is not None:
1329
+ if config.image_proj_type == "mlp":
1330
+ self.image_proj = MLPImageProjector(config) # type: ignore
1331
+ elif config.image_proj_type == "linear":
1332
+ self.image_proj = nn.Linear(config.image_feature_size, config.hidden_size, bias=False) # type: ignore
1333
+ else:
1334
+ raise ValueError(f"Unknown image_proj_type: {config.image_proj_type}")
1335
+ self.layers = Plamo2Decoder(config) # type: ignore
1336
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1337
+
1338
+ self.gradient_checkpointing = False
1339
+ # Initialize weights and apply final processing
1340
+ self.post_init()
1341
+
1342
+ def get_input_embeddings(self) -> torch.nn.Embedding:
1343
+ return self.embed_tokens
1344
+
1345
+ def set_input_embeddings(self, value: torch.nn.Embedding) -> None:
1346
+ self.embed_tokens = value
1347
+
1348
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
1349
+ def _prepare_decoder_attention_mask(
1350
+ self,
1351
+ attention_mask: torch.Tensor,
1352
+ input_shape: Tuple[int, int],
1353
+ inputs_embeds: Optional[torch.Tensor],
1354
+ past_key_values_length: int,
1355
+ ) -> Optional[torch.Tensor]:
1356
+ # create causal mask
1357
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1358
+ combined_attention_mask: Optional[torch.Tensor] = None
1359
+ if input_shape[-1] > 1:
1360
+ assert inputs_embeds is not None
1361
+ combined_attention_mask = _make_causal_mask(
1362
+ input_shape,
1363
+ inputs_embeds.dtype,
1364
+ device=inputs_embeds.device,
1365
+ past_key_values_length=past_key_values_length,
1366
+ )
1367
+ input_shape = (input_shape[0], combined_attention_mask.shape[2])
1368
+
1369
+ if attention_mask is not None:
1370
+ if attention_mask.dim() == 4:
1371
+ # Custom 4D attention mask
1372
+ expanded_attn_mask = attention_mask
1373
+ else:
1374
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1375
+ assert inputs_embeds is not None
1376
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
1377
+ inputs_embeds.device
1378
+ )
1379
+ combined_attention_mask = (
1380
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
1381
+ )
1382
+
1383
+ return combined_attention_mask
1384
+
1385
+ def forward(
1386
+ self,
1387
+ input_ids: Optional[torch.LongTensor] = None,
1388
+ attention_mask: Optional[torch.Tensor] = None,
1389
+ position_ids: Optional[torch.Tensor] = None,
1390
+ past_key_values: Optional[Plamo2Cache] = None,
1391
+ inputs_embeds: Optional[torch.Tensor] = None,
1392
+ image_features: Optional[torch.Tensor] = None,
1393
+ use_cache: Optional[bool] = None,
1394
+ output_attentions: Optional[bool] = None,
1395
+ output_hidden_states: Optional[bool] = None,
1396
+ return_dict: Optional[bool] = None,
1397
+ cache_position: Optional[torch.LongTensor] = None,
1398
+ **kwargs: Any,
1399
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1400
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1401
+ output_hidden_states = (
1402
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1403
+ )
1404
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1405
+
1406
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1407
+
1408
+ # retrieve input_ids and inputs_embeds
1409
+ if (input_ids is None) ^ (inputs_embeds is not None):
1410
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1411
+
1412
+ if self.gradient_checkpointing and self.training and use_cache:
1413
+ use_cache = False
1414
+
1415
+ if inputs_embeds is None:
1416
+ inputs_embeds = self.embed_tokens(input_ids)
1417
+ batch_size, seq_length, _ = inputs_embeds.shape
1418
+
1419
+ seq_length_with_past = seq_length
1420
+ past_key_values_length = 0
1421
+ if past_key_values is not None:
1422
+ past_key_values_length = past_key_values.get_seq_length()
1423
+ seq_length_with_past = seq_length_with_past + past_key_values_length
1424
+ assert cache_position is None, "cache_position is not supported yet"
1425
+
1426
+ if image_features is not None:
1427
+ assert self.config.image_token_id is not None
1428
+ image_embeds = self.image_proj(image_features)
1429
+ assert image_embeds.shape == inputs_embeds.shape, (image_embeds.shape, inputs_embeds.shape)
1430
+ mask = input_ids == self.config.image_token_id
1431
+ inputs_embeds[mask] = image_embeds[mask]
1432
+
1433
+ # embed positions
1434
+ require_attn_mask = False
1435
+ if not self.training or past_key_values is not None:
1436
+ require_attn_mask = True
1437
+ if seq_length_with_past >= self.config.attention_window_size:
1438
+ require_attn_mask = True
1439
+ if require_attn_mask and attention_mask is None:
1440
+ attention_mask = torch.ones(
1441
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
1442
+ )
1443
+ if attention_mask is not None:
1444
+ attention_mask = self._prepare_decoder_attention_mask(
1445
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
1446
+ )
1447
+
1448
+ hidden_states = inputs_embeds
1449
+
1450
+ if use_cache and past_key_values is None:
1451
+ past_key_values = Plamo2Cache(self.config)
1452
+
1453
+ # decoder layers
1454
+ out = self.layers(
1455
+ DecoderInput(
1456
+ hidden_states,
1457
+ attention_mask,
1458
+ past_key_values,
1459
+ output_hidden_states,
1460
+ output_attentions,
1461
+ self.gradient_checkpointing,
1462
+ )
1463
+ )
1464
+ assert isinstance(out, DecoderOutput)
1465
+ hidden_states = out.hidden_states
1466
+ all_hidden_states = out.all_hidden_states
1467
+ all_self_attns = out.all_self_attns
1468
+
1469
+ hidden_states = self.norm(hidden_states)
1470
+
1471
+ # add hidden states from the last decoder layer
1472
+ if output_hidden_states:
1473
+ assert all_hidden_states is not None
1474
+ all_hidden_states += (hidden_states,)
1475
+
1476
+ if not return_dict:
1477
+ return tuple(
1478
+ v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None
1479
+ )
1480
+ return BaseModelOutputWithPast(
1481
+ last_hidden_state=hidden_states,
1482
+ past_key_values=past_key_values,
1483
+ hidden_states=all_hidden_states,
1484
+ attentions=all_self_attns,
1485
+ )
1486
+
1487
+
1488
+ class Plamo2ForCausalLM(Plamo2PreTrainedModel):
1489
+ _tied_weights_keys = ["lm_head.weight"]
1490
+
1491
+ # Without this, the model cannot be loaded into a meta device.
1492
+ # Relevant code:
1493
+ # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/modeling_utils.py#L4376-L4381
1494
+ # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/modeling_utils.py#L356
1495
+ # https://github.com/pytorch/pytorch/blob/v2.4.1/torch/nn/modules/module.py#L2068
1496
+ _supports_param_buffer_assignment = False
1497
+
1498
+ def __init__(self, config: Plamo2Config) -> None:
1499
+ super().__init__(config)
1500
+ self.model = Plamo2Model(config)
1501
+
1502
+ self.vocab_size = config.vocab_size
1503
+ vocab_size = ((self.vocab_size + 15) // 16) * 16
1504
+ self.lm_head: torch.nn.Module = nn.Linear(config.hidden_size, vocab_size, bias=False)
1505
+
1506
+ # Initialize weights and apply final processing
1507
+ self.post_init()
1508
+
1509
+ def get_input_embeddings(self) -> torch.nn.Embedding:
1510
+ return self.model.embed_tokens
1511
+
1512
+ def set_input_embeddings(self, value: torch.nn.Embedding) -> None:
1513
+ self.model.embed_tokens = value
1514
+
1515
+ def get_output_embeddings(self) -> torch.nn.Module:
1516
+ return self.lm_head
1517
+
1518
+ def set_output_embeddings(self, new_embeddings: torch.nn.Module) -> None:
1519
+ self.lm_head = new_embeddings
1520
+
1521
+ def set_decoder(self, decoder: Plamo2Model) -> None:
1522
+ self.model = decoder
1523
+
1524
+ def get_decoder(self) -> Plamo2Model:
1525
+ return self.model
1526
+
1527
+ def forward( # type: ignore
1528
+ self,
1529
+ input_ids: Optional[torch.LongTensor] = None,
1530
+ attention_mask: Optional[torch.Tensor] = None,
1531
+ position_ids: Optional[torch.Tensor] = None,
1532
+ past_key_values: Optional[Plamo2Cache] = None,
1533
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1534
+ image_features: Optional[torch.Tensor] = None,
1535
+ labels: Optional[torch.LongTensor] = None,
1536
+ use_cache: Optional[bool] = None,
1537
+ output_attentions: Optional[bool] = None,
1538
+ output_hidden_states: Optional[bool] = None,
1539
+ return_dict: Optional[bool] = None,
1540
+ cache_position: Optional[torch.LongTensor] = None,
1541
+ logits_to_keep: int | torch.Tensor = 0,
1542
+ **kwargs: Any,
1543
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1544
+ r"""
1545
+ Args:
1546
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1547
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1548
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1549
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1550
+
1551
+ Returns:
1552
+
1553
+ Example:
1554
+
1555
+ ```python
1556
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
1557
+
1558
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1559
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1560
+
1561
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
1562
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1563
+
1564
+ >>> # Generate
1565
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1566
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1567
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
1568
+ ```"""
1569
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1570
+ output_hidden_states = (
1571
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1572
+ )
1573
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1574
+
1575
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1576
+ outputs = self.model(
1577
+ input_ids=input_ids,
1578
+ attention_mask=attention_mask,
1579
+ position_ids=position_ids,
1580
+ past_key_values=past_key_values,
1581
+ inputs_embeds=inputs_embeds,
1582
+ image_features=image_features,
1583
+ use_cache=use_cache,
1584
+ output_attentions=output_attentions,
1585
+ output_hidden_states=output_hidden_states,
1586
+ return_dict=return_dict,
1587
+ cache_position=cache_position,
1588
+ **kwargs,
1589
+ )
1590
+
1591
+ hidden_states = outputs[0]
1592
+ logits = self.lm_head(hidden_states)
1593
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1594
+ logits = logits[:, slice_indices, : self.vocab_size]
1595
+
1596
+ loss = None
1597
+ if labels is not None:
1598
+ if len(kwargs) > 0 and set(kwargs.keys()) != set(["ignore_index"]):
1599
+ warnings.warn(
1600
+ f"The following kwargs may not be supported: {', '.join(kwargs.keys())}. ",
1601
+ stacklevel=2,
1602
+ )
1603
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
1604
+
1605
+ if not return_dict:
1606
+ output = (logits,) + outputs[1:]
1607
+ return (loss,) + output if loss is not None else output
1608
+
1609
+ return CausalLMOutputWithPast(
1610
+ loss=loss,
1611
+ logits=logits,
1612
+ past_key_values=outputs.past_key_values,
1613
+ hidden_states=outputs.hidden_states,
1614
+ attentions=outputs.attentions,
1615
+ )
1616
+
1617
+ def prepare_inputs_for_generation(
1618
+ self,
1619
+ input_ids: torch.Tensor,
1620
+ past_key_values: Optional[Plamo2Cache] = None,
1621
+ attention_mask: Optional[torch.Tensor] = None,
1622
+ inputs_embeds: Optional[torch.Tensor] = None,
1623
+ image_features: Optional[torch.Tensor] = None,
1624
+ **kwargs: Any,
1625
+ ) -> Dict[str, Any]:
1626
+ if past_key_values:
1627
+ input_ids = input_ids[:, -1:]
1628
+ if image_features is not None:
1629
+ image_features = image_features[:, -1:, :]
1630
+
1631
+ position_ids = kwargs.get("position_ids", None)
1632
+ if attention_mask is not None and position_ids is None:
1633
+ # create position_ids on the fly for batch generation
1634
+ position_ids = attention_mask.long().cumsum(-1) - 1
1635
+ position_ids.masked_fill_(attention_mask == 0, 1)
1636
+ if past_key_values:
1637
+ position_ids = position_ids[:, -1].unsqueeze(-1)
1638
+
1639
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1640
+ if inputs_embeds is not None and past_key_values is None:
1641
+ model_inputs: Dict[str, Any] = {"inputs_embeds": inputs_embeds}
1642
+ else:
1643
+ model_inputs = {"input_ids": input_ids}
1644
+
1645
+ model_inputs.update(
1646
+ {
1647
+ "position_ids": position_ids,
1648
+ "past_key_values": past_key_values,
1649
+ "use_cache": kwargs.get("use_cache"),
1650
+ "attention_mask": attention_mask,
1651
+ "image_features": image_features,
1652
+ }
1653
+ )
1654
+ return model_inputs
1655
+
1656
+ @staticmethod
1657
+ def _reorder_cache(past_key_values: Plamo2Cache, beam_idx: torch.Tensor) -> Plamo2Cache:
1658
+ past_key_values.reorder_cache(beam_idx)
1659
+ return past_key_values
1660
+
1661
+
1662
+ class MLPImageProjector(nn.Module):
1663
+ def __init__(self, config: Plamo2Config) -> None:
1664
+ super().__init__()
1665
+ self.config = config
1666
+
1667
+ assert config.image_feature_size is not None # for typing
1668
+
1669
+ # nn.LayerNorm is not supported by PFVM, so use RMSNorm + Bias instead to approximate this.
1670
+ self.norm0 = RMSNorm(config.image_feature_size, eps=config.rms_norm_eps)
1671
+ self.bias0 = Bias(config.image_feature_size)
1672
+
1673
+ # PFVM doesn't support Linear with bias, so add bias manually afterwards.
1674
+ self.linear1 = nn.Linear(config.image_feature_size, config.hidden_size, bias=False)
1675
+ self.bias1 = Bias(config.hidden_size)
1676
+ self.act1 = nn.GELU()
1677
+
1678
+ self.linear2 = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
1679
+ self.bias2 = Bias(config.hidden_size)
1680
+
1681
+ def forward(
1682
+ self,
1683
+ hidden_states: torch.Tensor,
1684
+ ) -> torch.Tensor:
1685
+ hidden_states = self.norm0(hidden_states)
1686
+ hidden_states = self.bias0(hidden_states)
1687
+
1688
+ hidden_states = self.linear1(hidden_states)
1689
+ hidden_states = self.bias1(hidden_states)
1690
+ hidden_states = self.act1(hidden_states)
1691
+
1692
+ hidden_states = self.linear2(hidden_states)
1693
+ hidden_states = self.bias2(hidden_states)
1694
+
1695
+ return hidden_states
1696
+
1697
+
1698
+ class Bias(nn.Module):
1699
+ def __init__(self, num_features: int) -> None:
1700
+ super().__init__()
1701
+ self._bias = nn.Parameter(torch.zeros((num_features,)))
1702
+
1703
+ def forward(
1704
+ self,
1705
+ x: torch.Tensor,
1706
+ ) -> torch.Tensor:
1707
+ return x + self._bias
special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|plamo:bos|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": "<|plamo:bos|>",
10
+ "pad_token": {
11
+ "content": "<|plamo:pad|>",
12
+ "lstrip": false,
13
+ "normalized": false,
14
+ "rstrip": false,
15
+ "single_word": false
16
+ },
17
+ "unk_token": {
18
+ "content": "<|plamo:unk|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
tokenization_plamo.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import math
3
+ import os
4
+ from shutil import copyfile
5
+ from typing import Any, Optional, Tuple
6
+
7
+ import numpy as np
8
+
9
+ # NOTE: numba does not support type hints for njit: https://github.com/python/mypy/issues/16149
10
+ from numba import njit # type: ignore[attr-defined]
11
+ from numba.core import types
12
+ from numba.typed import Dict, List
13
+ from transformers.tokenization_utils import PreTrainedTokenizer
14
+ from transformers.utils import logging
15
+
16
+ VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.jsonl"}
17
+ logger = logging.get_logger(__name__)
18
+
19
+ INVALID_SCORE = -20000000
20
+ UNKNOWN_SCORE = -10000000
21
+
22
+ TABLE_PIECE_LENGTH = 0
23
+ TABLE_TOKEN_ID = 1
24
+ TABLE_SCORE = 2
25
+ TABLE_PIECE_ID = 3
26
+
27
+ PATH_TOKEN_LENGTH = 0
28
+ PATH_TOKEN_ID = 1
29
+ PATH_NUM_TOKENS = 2
30
+
31
+
32
+ class AhoCorasick:
33
+ def __init__(self) -> None:
34
+ # List of tokens in the vocabulary.
35
+ self._tokens: list[str]
36
+
37
+ # A mapping from a byte code point to a token ID, used for byte fallback.
38
+ self._bytes: np.ndarray
39
+
40
+ # A mapping from a suffix's piece code to a suffix ID.
41
+ #
42
+ # Typically, the Aho-Corasick algorithm builds a Trie and adds suffix links between nodes
43
+ # of the Trie. In this implementation, a suffix ID corresponds to a node in the trie, and
44
+ # a piece code to an edge (in other words, a pair of a node and the next character).
45
+ #
46
+ # A piece code is a 64-bit integer:
47
+ # - The upper 32 bits store the Unicode code point of the first character.
48
+ # - The lower 32 bits store the suffix ID of the remaining suffix.
49
+ #
50
+ # A suffix ID is an integer indicating the starting position in the _table.
51
+ self._to_suffix_id: Dict[types.int64, types.int32]
52
+
53
+ # Flattened table representing the Trie structure for the Aho-Corasick algorithm.
54
+ # It stores information including scores for each piece (prefix) within each suffix.
55
+ # It is flattened for memory efficiency and performance. Suffixes are stored in
56
+ # lexicographical order of their reversed strings, which improves memory access locality
57
+ # when exploring new characters starting from the string's end. Pieces within a suffix are
58
+ # stored in the decreasing order of their lengths.
59
+ #
60
+ # Each piece (a prefix fo the suffix) contains four pieces of information:
61
+ # - TABLE_PIECE_LENGTH: Length of the piece.
62
+ # - TABLE_TOKEN_ID: Token ID (or -1 if the piece is not a valid token).
63
+ # - TABLE_SCORE: Score (or INVALID_SCORE if the piece is not a valid token).
64
+ # - TABLE_PIECE_ID: Piece ID of the suffix.
65
+ #
66
+ # Each suffix also includes a sentinel row with a length of 1, a score of UNKNOWN_SCORE,
67
+ # and a token ID of -1. Sentinel rows are identified by the score being UNKNOWN_SCORE.
68
+ self._table: np.ndarray
69
+
70
+ def build(self, vocab: list[Any]) -> None:
71
+ self._bytes = np.zeros(256, dtype=np.int32)
72
+ self._to_suffix_id = Dict.empty(key_type=types.int64, value_type=types.int32)
73
+
74
+ # Build suffix_to_score and token_to_token_id.
75
+ # The suffix_to_score dictionary maps a suffix to its score. It also includes all suffixes
76
+ # of the token for the Trie structure for the Aho-Corasick algorithm. If a suffix is not a
77
+ # valid token, its score is set to math.nan.
78
+ # The token_to_token_id dictionary maps a token to its token ID.
79
+ suffix_to_score: dict[str, float] = {}
80
+ token_to_token_id: dict[str, int] = {}
81
+ self._tokens = []
82
+ for token_id, row in enumerate(vocab):
83
+ assert isinstance(row[0], str), row
84
+ assert isinstance(row[1], (int, float)), row
85
+
86
+ token = str(row[0])
87
+ self._tokens.append(token)
88
+ token_to_token_id[token] = token_id
89
+
90
+ # Special handling for byte tokens.
91
+ if len(row) > 2 and row[2] == "BYTE":
92
+ assert len(token) == 6 and token.startswith("<0x") and token.endswith(">"), row[0]
93
+ self._bytes[int(row[0][3:5], 16)] = token_id
94
+ continue
95
+
96
+ suffix_to_score[token] = float(row[1])
97
+ # Ensure that all suffixes are included in suffix_to_score.
98
+ for i in range(1, len(token)):
99
+ suffix_to_score[token[i:]] = suffix_to_score.get(token[i:], math.nan)
100
+
101
+ # Ensure all byte tokens are set.
102
+ for i in range(256):
103
+ assert self._bytes[i] != 0, f"Byte token for <0x{i:02X}> is not set."
104
+
105
+ # List suffixes in lexicographical order of their reversed strings.
106
+ suffixes = list(suffix_to_score.keys())
107
+ suffixes.append("")
108
+ suffixes.sort(key=lambda x: x[::-1])
109
+
110
+ # Build suffix_to_id, which is a mapping from a suffix to a suffix ID, and _to_suffix_id,
111
+ # which is a mapping from a piece code to a suffix ID.
112
+ suffix_to_id: dict[str, int] = {}
113
+ num_pieces = 0
114
+ for s in suffixes:
115
+ suffix_to_id[s] = num_pieces
116
+ if s != "":
117
+ self._to_suffix_id[ord(s[0]) << 32 | suffix_to_id[s[1:]]] = np.int32(num_pieces)
118
+ num_pieces += 1 + sum(s[:i] in suffix_to_score for i in range(1, len(s) + 1))
119
+ assert suffix_to_id[""] == 0, suffix_to_id[""]
120
+
121
+ # Build _table, which is a flattened table representing the Trie structure for the Aho-Corasick.
122
+ self._table = np.zeros((num_pieces, 4), dtype=np.int32)
123
+ i = 0
124
+ for suffix in suffixes:
125
+ # Add all prefixes of the suffix to the table.
126
+ for piece_length in range(len(suffix), 0, -1):
127
+ piece = suffix[:piece_length]
128
+ score = suffix_to_score.get(piece, None)
129
+ if score is None:
130
+ continue
131
+ self._table[i, TABLE_PIECE_LENGTH] = piece_length
132
+ self._table[i, TABLE_TOKEN_ID] = token_to_token_id.get(piece, -1)
133
+ self._table[i, TABLE_SCORE] = round(score * 1e4) if math.isfinite(score) else INVALID_SCORE
134
+ self._table[i, TABLE_PIECE_ID] = suffix_to_id[piece]
135
+ i += 1
136
+
137
+ # Add a sentinel row.
138
+ self._table[i, TABLE_PIECE_LENGTH] = 1
139
+ self._table[i, TABLE_TOKEN_ID] = -1
140
+ self._table[i, TABLE_SCORE] = UNKNOWN_SCORE
141
+ i += 1
142
+ assert i == num_pieces, (i, num_pieces)
143
+
144
+ @staticmethod
145
+ @njit
146
+ def _encode(
147
+ to_suffix_id: Dict[types.int64, types.int32],
148
+ table: np.ndarray,
149
+ bytes: np.ndarray,
150
+ data: np.ndarray,
151
+ ) -> np.ndarray:
152
+ # Initialize scores array with a high value and set the score at the end to 0.
153
+ # This array keeps track of the minimum cost (best score) to encode from each position to the end.
154
+ scores = np.full((len(data) + 1,), 2**60, dtype=np.int64)
155
+ scores[-1] = 0
156
+
157
+ # Path array to store the best path information.
158
+ # The path array keeps track of token length, token ID, and number of tokens needed to encode.
159
+ path = np.zeros((len(data) + 1, 3), dtype=np.int32)
160
+
161
+ # Initialize suffix_id to 0, which represents the root of the Trie.
162
+ suffix_id = 0
163
+
164
+ # Process the input data from the end to the beginning.
165
+ for i in range(len(data) - 1, -1, -1):
166
+ c = data[i]
167
+
168
+ # Find the next suffix ID by iterating the suffix IDs of prefixes of the current suffix.
169
+ # NOTE: If no suffix ID is found, suffix_id will be set to 0.
170
+ for p in range(suffix_id, len(table)):
171
+ suffix_id = to_suffix_id.get(c << 32 | table[p, TABLE_PIECE_ID], np.int32(0))
172
+ # If a next suffix ID is found or a sentinel row is reached, break the loop.
173
+ if suffix_id > 0 or table[p, TABLE_SCORE] == UNKNOWN_SCORE:
174
+ break
175
+
176
+ # Update the best path to the current position. If multiple paths have the same score,
177
+ # this chooses the longest prefix as the best path (table is sorted in the decreasing
178
+ # order of piece length).
179
+ for p in range(suffix_id, len(table)):
180
+ score = table[p, TABLE_SCORE]
181
+ if score > INVALID_SCORE:
182
+ piece_length = table[p, TABLE_PIECE_LENGTH]
183
+ s = scores[i + piece_length] - score
184
+ if s < scores[i]:
185
+ scores[i] = s
186
+ path[i, PATH_TOKEN_LENGTH] = piece_length
187
+ path[i, PATH_TOKEN_ID] = table[p, TABLE_TOKEN_ID]
188
+ path[i, PATH_NUM_TOKENS] = path[i + piece_length, PATH_NUM_TOKENS] + 1
189
+ if score == UNKNOWN_SCORE:
190
+ # Add number of bytes to represent `c` in UTF-8 (minus 1; 1 is already
191
+ # added above).
192
+ path[i, PATH_NUM_TOKENS] += (c >= 0x80) + (c >= 0x800) + (c >= 0x10000)
193
+
194
+ # If it reaches a sentinel row, break the loop.
195
+ if score == UNKNOWN_SCORE:
196
+ break
197
+
198
+ # Decode the best path from the beginning to get the token IDs.
199
+ pos = 0
200
+ token_ids = np.zeros(path[0, PATH_NUM_TOKENS], dtype=np.int32)
201
+ token_pos = 0
202
+ while pos < len(data):
203
+ if path[pos, PATH_TOKEN_ID] >= 0:
204
+ token_ids[token_pos] = path[pos, PATH_TOKEN_ID]
205
+ token_pos += 1
206
+ else:
207
+ # Fall back to byte tokens.
208
+ c = data[pos]
209
+ s = 1 + (c >= 0x80) + (c >= 0x800) + (c >= 0x10000)
210
+ # Add byte tokens representing UTF-8 bytes.
211
+ for i in range(s):
212
+ b = c if s == 1 else (0xF00 >> s) & 0xFF if i == 0 else 0x80
213
+ token_ids[token_pos] = bytes[b | ((c >> (s - i - 1) * 6) & 0x3F)]
214
+ token_pos += 1
215
+
216
+ # Ensure that pos should increase by at least 1.
217
+ assert path[pos, PATH_TOKEN_LENGTH] > 0, (pos, path[pos])
218
+ pos += path[pos, PATH_TOKEN_LENGTH]
219
+
220
+ return token_ids
221
+
222
+ def encode(self, data: str) -> np.ndarray:
223
+ """Encodes a string into a sequence of token IDs."""
224
+ return np.asarray(
225
+ self._encode(
226
+ self._to_suffix_id,
227
+ self._table,
228
+ self._bytes,
229
+ # Convert a string into a numpy array of Unicode code points.
230
+ # NOTE: This skips UTF-32 BOM.
231
+ np.frombuffer(data.encode("utf-32"), dtype=np.int32)[1:],
232
+ )
233
+ )
234
+
235
+ def encode_as_tokens(self, data: str) -> list[str]:
236
+ """Encodes a string into a sequence of tokens."""
237
+ return [self._tokens[token_id] for token_id in self.encode(data)]
238
+
239
+
240
+ class Plamo2Tokenizer(PreTrainedTokenizer): # type: ignore
241
+ vocab_files_names = VOCAB_FILES_NAMES
242
+ model_input_names = ["input_ids", "attention_mask"]
243
+
244
+ _save_files = [
245
+ "special_tokens_map.json",
246
+ "tokenization_plamo.py",
247
+ "tokenizer.jsonl",
248
+ "tokenizer_config.json",
249
+ ]
250
+
251
+ def __init__(
252
+ self,
253
+ vocab_file: str,
254
+ unk_token: str = "<|plamo:unk|>",
255
+ bos_token: str = "<|plamo:bos|>",
256
+ eos_token: str = "<|plamo:eos|>",
257
+ pad_token: str = "<|plamo:pad|>",
258
+ cls_token: Optional[str] = None,
259
+ sep_token: Optional[str] = None,
260
+ mask_token: Optional[str] = None,
261
+ clean_up_tokenization_spaces: bool = False,
262
+ **kwargs: Any,
263
+ ) -> None:
264
+ """Tokenizer for PLaMo.
265
+
266
+ Args:
267
+ vocab_file (str): Vocabrary file path.
268
+ unk_token (str): Unknown token.
269
+ bos_token (str): Beginning of sentence token.
270
+ eos_token (str): End of sentence token.
271
+ pad_token (str): Padding token.
272
+ cls_token (str):
273
+ Classification token, to extract a summary of an input sequence leveraging self-attention along the
274
+ full depth of the model.
275
+ sep_token (str): Separation token, to separate context and query in an input sequence.
276
+ mask_token (str): Mask token, to use when training a model with masked-language modeling.
277
+ clean_up_tokenization_spaces (bool): Whether or not to clean up the tokenization spaces.
278
+ num_threads (int):
279
+ Number of threads. This value will be ignored if one of `PLAMO_TOKENIZER_NUM_THREADS` or
280
+ `RAYON_NUM_THREADS` is set as an environment variable.
281
+ """
282
+ if "add_bos_token" not in kwargs:
283
+ kwargs["add_bos_token"] = False
284
+ if "add_eos_token" not in kwargs:
285
+ kwargs["add_eos_token"] = False
286
+ self.data: list[Any] = [json.loads(line) for line in open(vocab_file, "r", encoding="utf-8")]
287
+ self.vocab: dict[str, int] = {v[0]: i for i, v in enumerate(self.data)}
288
+ self.aho_corasick = AhoCorasick()
289
+ self.aho_corasick.build(self.data)
290
+ self.vocab_file = vocab_file
291
+ self.add_bos_token = kwargs["add_bos_token"]
292
+ self.add_eos_token = kwargs["add_eos_token"]
293
+
294
+ super().__init__(
295
+ vocab_file=vocab_file,
296
+ unk_token=unk_token,
297
+ bos_token=bos_token,
298
+ eos_token=eos_token,
299
+ pad_token=pad_token,
300
+ cls_token=cls_token,
301
+ sep_token=sep_token,
302
+ mask_token=mask_token,
303
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
304
+ **kwargs,
305
+ )
306
+
307
+ # the functions below are copied from hf transformers LlamaTokenizer's implementation to fix the behaviour of the tokenizer
308
+ # https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/models/llama/tokenization_llama.py
309
+
310
+ def __getstate__(self) -> dict[str, Any]:
311
+ state = self.__dict__.copy()
312
+ state["aho_corasick"] = None
313
+ return state
314
+
315
+ def __setstate__(self, d: dict[str, Any]) -> None:
316
+ self.__dict__ = d
317
+ self.aho_corasick = AhoCorasick()
318
+ self.aho_corasick.build(self.data)
319
+
320
+ @property
321
+ def vocab_size(self) -> Any:
322
+ """Returns vocab size"""
323
+ return len(self.data)
324
+
325
+ def token_to_score(self, token: str) -> Optional[float]:
326
+ """Returns score of the token"""
327
+ token_id = self.vocab.get(token, None)
328
+ return None if token_id is None else self.data[token_id][1]
329
+
330
+ def get_vocab(self) -> dict[str, int]:
331
+ """Returns vocab as a dict"""
332
+ vocab = self.vocab.copy()
333
+ vocab.update(self.added_tokens_encoder)
334
+ return vocab
335
+
336
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
337
+ """Converts a sequence of tokens (string) in a single string."""
338
+ return b"".join(
339
+ [bytes([int(t[3:5], 16)]) if t.startswith("<0x") else t.encode("utf-8") for t in tokens]
340
+ ).decode("utf-8", errors="replace")
341
+
342
+ def _tokenize(self, text: str) -> Any:
343
+ """Returns a tokenized string."""
344
+ return self.aho_corasick.encode_as_tokens(text)
345
+
346
+ def _convert_token_to_id(self, token: str) -> Any:
347
+ """Converts a token (str) in an id using the vocab."""
348
+ return self.vocab.get(token, 0)
349
+
350
+ def _convert_id_to_token(self, index: int) -> Any:
351
+ """Converts an index (integer) in a token (str) using the vocab."""
352
+ return self.data[index][0]
353
+
354
+ def build_inputs_with_special_tokens(
355
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
356
+ ) -> List[int]:
357
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
358
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
359
+
360
+ output = bos_token_id + token_ids_0 + eos_token_id
361
+
362
+ if token_ids_1 is not None:
363
+ output = output + bos_token_id + token_ids_1 + eos_token_id
364
+
365
+ return output
366
+
367
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
368
+ """
369
+ Save the vocabulary and special tokens file to a directory.
370
+
371
+ Args:
372
+ save_directory (`str`):
373
+ The directory in which to save the vocabulary.
374
+
375
+ Returns:
376
+ `Tuple(str)`: Paths to the files saved.
377
+ """
378
+ if not os.path.isdir(save_directory):
379
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
380
+ return ("",)
381
+ out_vocab_file = os.path.join(
382
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
383
+ )
384
+
385
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
386
+ copyfile(self.vocab_file, out_vocab_file)
387
+ elif not os.path.isfile(self.vocab_file):
388
+ with open(out_vocab_file, "w") as f:
389
+ for token in self.data:
390
+ print(json.dumps(token, ensure_ascii=False), file=f)
391
+
392
+ return (out_vocab_file,)
tokenizer.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3af94303d6c676e875e13165abf4e4a9d42c663ef247a8985cdeedb223cc0681
3
+ size 10573745
tokenizer_config.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<|plamo:unk|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "1": {
14
+ "content": "<|plamo:bos|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "2": {
22
+ "content": "<|plamo:eos|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "3": {
30
+ "content": "<|plamo:pad|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ }
37
+ },
38
+ "auto_map": {
39
+ "AutoTokenizer": [
40
+ "tokenization_plamo.Plamo2Tokenizer",
41
+ null
42
+ ]
43
+ },
44
+ "bos_token": "<|plamo:bos|>",
45
+ "clean_up_tokenization_spaces": false,
46
+ "cls_token": null,
47
+ "eos_token": "<|plamo:bos|>",
48
+ "extra_special_tokens": {},
49
+ "local_file_only": true,
50
+ "mask_token": null,
51
+ "model_max_length": 1000000000000000019884624838656,
52
+ "pad_token": "<|plamo:pad|>",
53
+ "sep_token": null,
54
+ "tokenizer_class": "Plamo2Tokenizer",
55
+ "unk_token": "<|plamo:unk|>"
56
+ }