kishizaki-sci commited on
Commit
e167b9d
·
verified ·
1 Parent(s): 3b87c55

Upload llama4_inference.ipynb

Browse files
Files changed (1) hide show
  1. llama4_inference.ipynb +287 -0
llama4_inference.ipynb ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "3008bf1f-df10-484b-9662-96cda681b93b",
6
+ "metadata": {},
7
+ "source": [
8
+ "```bash\n",
9
+ "git clone https://github.com/kIshizaki-sci/AutoAWQ.git\n",
10
+ "pip install -U transformers\n",
11
+ "pip install -e ./AutoAWQ\n",
12
+ "```"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": 1,
18
+ "id": "a388d190-a611-46b0-aa8c-dcf5c97b0c1b",
19
+ "metadata": {},
20
+ "outputs": [],
21
+ "source": [
22
+ "from awq import AutoAWQForCausalLM\n",
23
+ "import torch\n",
24
+ "import transformers\n",
25
+ "from transformers import AutoProcessor"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": 2,
31
+ "id": "de7f888d-fe64-47f1-a106-75a7b911354f",
32
+ "metadata": {},
33
+ "outputs": [
34
+ {
35
+ "name": "stdout",
36
+ "output_type": "stream",
37
+ "text": [
38
+ "torch version : 2.4.1+cu124\n",
39
+ "transformers version : 4.51.3\n"
40
+ ]
41
+ }
42
+ ],
43
+ "source": [
44
+ "print('torch version : ', torch.__version__)\n",
45
+ "print('transformers version : ', transformers.__version__)"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "code",
50
+ "execution_count": 3,
51
+ "id": "61623ebd-d8a4-45b6-aff3-3c16a78ecfa5",
52
+ "metadata": {},
53
+ "outputs": [],
54
+ "source": [
55
+ "quant_path = '/workspace/hf_cache/Llama-4-Scout-17B-16E-Instruct-AWQ'"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": 4,
61
+ "id": "c13b72e4-f6cd-4642-a110-040844127541",
62
+ "metadata": {},
63
+ "outputs": [
64
+ {
65
+ "name": "stderr",
66
+ "output_type": "stream",
67
+ "text": [
68
+ "/workspace/llama4-awq/AutoAWQ/awq/models/llama4.py:313: UserWarning: Multimodal input has not been implemented in Llama4AWQForConditionalGeneration yet.\n",
69
+ " warnings.warn(\"Multimodal input has not been implemented in Llama4AWQForConditionalGeneration yet.\", UserWarning)\n",
70
+ "You have set `use_cache` to `False`, but cache_implementation is set to hybrid. cache_implementation will have no effect.\n"
71
+ ]
72
+ },
73
+ {
74
+ "data": {
75
+ "application/vnd.jupyter.widget-view+json": {
76
+ "model_id": "15b41a3c3e154516b93b4f2b90e976fb",
77
+ "version_major": 2,
78
+ "version_minor": 0
79
+ },
80
+ "text/plain": [
81
+ "Replacing MoE Block...: 0%| | 0/48 [00:00<?, ?it/s]"
82
+ ]
83
+ },
84
+ "metadata": {},
85
+ "output_type": "display_data"
86
+ },
87
+ {
88
+ "data": {
89
+ "application/vnd.jupyter.widget-view+json": {
90
+ "model_id": "38fe24c9ae5549a7a7c674292b0e4f95",
91
+ "version_major": 2,
92
+ "version_minor": 0
93
+ },
94
+ "text/plain": [
95
+ "Replacing layers...: 0%| | 0/48 [00:00<?, ?it/s]"
96
+ ]
97
+ },
98
+ "metadata": {},
99
+ "output_type": "display_data"
100
+ },
101
+ {
102
+ "name": "stderr",
103
+ "output_type": "stream",
104
+ "text": [
105
+ "/workspace/llama4-awq/AutoAWQ/awq/models/base.py:540: UserWarning: Skipping fusing modules because AWQ extension is not installed.No module named 'awq_ext'\n",
106
+ " warnings.warn(\"Skipping fusing modules because AWQ extension is not installed.\" + msg)\n"
107
+ ]
108
+ },
109
+ {
110
+ "name": "stdout",
111
+ "output_type": "stream",
112
+ "text": [
113
+ "CPU times: user 1h 36min 24s, sys: 24min 59s, total: 2h 1min 23s\n",
114
+ "Wall time: 30min 59s\n"
115
+ ]
116
+ }
117
+ ],
118
+ "source": [
119
+ "%%time\n",
120
+ "model = AutoAWQForCausalLM.from_quantized(quant_path, torch_dtype=torch.float16, use_cache=True, device_map='auto')\n",
121
+ "processor = AutoProcessor.from_pretrained(quant_path)"
122
+ ]
123
+ },
124
+ {
125
+ "cell_type": "code",
126
+ "execution_count": 5,
127
+ "id": "665ae24c-cb73-489d-bb03-35d760460070",
128
+ "metadata": {},
129
+ "outputs": [
130
+ {
131
+ "name": "stdout",
132
+ "output_type": "stream",
133
+ "text": [
134
+ "Sat Apr 19 17:03:15 2025 \n",
135
+ "+-----------------------------------------------------------------------------------------+\n",
136
+ "| NVIDIA-SMI 565.57.01 Driver Version: 565.57.01 CUDA Version: 12.7 |\n",
137
+ "|-----------------------------------------+------------------------+----------------------+\n",
138
+ "| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
139
+ "| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n",
140
+ "| | | MIG M. |\n",
141
+ "|=========================================+========================+======================|\n",
142
+ "| 0 NVIDIA A100-SXM4-80GB On | 00000000:B7:00.0 Off | 0 |\n",
143
+ "| N/A 25C P0 75W / 400W | 60415MiB / 81920MiB | 0% Default |\n",
144
+ "| | | Disabled |\n",
145
+ "+-----------------------------------------+------------------------+----------------------+\n",
146
+ " \n",
147
+ "+-----------------------------------------------------------------------------------------+\n",
148
+ "| Processes: |\n",
149
+ "| GPU GI CI PID Type Process name GPU Memory |\n",
150
+ "| ID ID Usage |\n",
151
+ "|=========================================================================================|\n",
152
+ "+-----------------------------------------------------------------------------------------+\n"
153
+ ]
154
+ }
155
+ ],
156
+ "source": [
157
+ "!nvidia-smi"
158
+ ]
159
+ },
160
+ {
161
+ "cell_type": "code",
162
+ "execution_count": 6,
163
+ "id": "39c8154a-08c4-485e-b136-955d1b4fbec9",
164
+ "metadata": {},
165
+ "outputs": [],
166
+ "source": [
167
+ "messages = [\n",
168
+ " {\n",
169
+ " \"role\": \"user\",\n",
170
+ " \"content\": [\n",
171
+ " {\"type\": \"text\", \"text\": \"What does means the torsion free in the general relativit?\"},\n",
172
+ " ]\n",
173
+ " },\n",
174
+ "]"
175
+ ]
176
+ },
177
+ {
178
+ "cell_type": "code",
179
+ "execution_count": 7,
180
+ "id": "44c4ae97-96cf-47f7-a8e4-16d2c34bdc1b",
181
+ "metadata": {},
182
+ "outputs": [],
183
+ "source": [
184
+ "inputs = processor.apply_chat_template(\n",
185
+ " messages,\n",
186
+ " add_generation_prompt=True,\n",
187
+ " tokenize=True,\n",
188
+ " return_dict=True,\n",
189
+ " return_tensors=\"pt\",\n",
190
+ ").to(model.model.device)"
191
+ ]
192
+ },
193
+ {
194
+ "cell_type": "code",
195
+ "execution_count": 8,
196
+ "id": "c6dddd7b-5382-45e8-8dd1-6ca719200f64",
197
+ "metadata": {},
198
+ "outputs": [
199
+ {
200
+ "name": "stdout",
201
+ "output_type": "stream",
202
+ "text": [
203
+ "CPU times: user 5min 36s, sys: 5min 9s, total: 10min 45s\n",
204
+ "Wall time: 10min 45s\n"
205
+ ]
206
+ }
207
+ ],
208
+ "source": [
209
+ "%%time\n",
210
+ "outputs = model.generate(\n",
211
+ " **inputs,\n",
212
+ " max_new_tokens=2048,\n",
213
+ ")"
214
+ ]
215
+ },
216
+ {
217
+ "cell_type": "code",
218
+ "execution_count": 9,
219
+ "id": "ab1d4fe0-e315-40ee-b3e0-10db1e6b2023",
220
+ "metadata": {},
221
+ "outputs": [
222
+ {
223
+ "name": "stdout",
224
+ "output_type": "stream",
225
+ "text": [
226
+ "A question from the realm of differential geometry and general relativity!\n",
227
+ "\n",
228
+ "In general relativity, \"torsion-free\" refers to a property of a connection on a manifold, specifically in the context of Riemannian geometry.\n",
229
+ "\n",
230
+ "**Torsion** is a measure of how much a connection \"twists\" or \"turns\" a vector as it is parallel-transported around a closed loop. In other words, it's a measure of how much the connection deviates from being \"flat\" or \"Euclidean\".\n",
231
+ "\n",
232
+ "A **torsion-free connection**, also known as a **symmetric connection**, is a connection that has zero torsion. This means that when you parallel-transport a vector around a closed loop, it returns to its original orientation, without any twisting or turning.\n",
233
+ "\n",
234
+ "In mathematical terms, a torsion-free connection satisfies the following condition:\n",
235
+ "\n",
236
+ "$$\\Gamma^i_{jk} = \\Gamma^i_{kj}$$\n",
237
+ "\n",
238
+ "where $\\Gamma^i_{jk}$ are the Christoffel symbols of the second kind, which define the connection.\n",
239
+ "\n",
240
+ "In general relativity, the Levi-Civita connection is a fundamental concept, and it is assumed to be torsion-free. This connection is used to define the covariant derivative of tensors, which is essential for describing the curvature of spacetime.\n",
241
+ "\n",
242
+ "The assumption of a torsion-free connection has important implications:\n",
243
+ "\n",
244
+ "1. **Geodesic equation**: The geodesic equation, which describes the shortest path in curved spacetime, is derived from the Levi-Civita connection. A torsion-free connection ensures that geodesics are symmetric, meaning that they have no \"twist\" or \"turn\".\n",
245
+ "2. **Riemannian geometry**: The Levi-Civita connection is a fundamental ingredient in Riemannian geometry, which is the mathematical framework for describing curved spacetime in general relativity.\n",
246
+ "3. **Einstein's field equations**: The Einstein field equations, which relate the curvature of spacetime to the distribution of mass and energy, rely on the Levi-Civita connection.\n",
247
+ "\n",
248
+ "In summary, a torsion-free connection in general relativity means that the connection used to describe the curvature of spacetime has zero torsion, which is a fundamental assumption in Riemannian geometry and leads to the Levi-Civita connection. This assumption is crucial for the mathematical formulation of general relativity, including the geodesic equation and Einstein's field equations.<|eot|>\n"
249
+ ]
250
+ }
251
+ ],
252
+ "source": [
253
+ "response = processor.batch_decode(outputs[:, inputs[\"input_ids\"].shape[-1]:])[0]\n",
254
+ "print(response)"
255
+ ]
256
+ },
257
+ {
258
+ "cell_type": "code",
259
+ "execution_count": null,
260
+ "id": "5f3f1be7-6960-474e-b19d-afd2efa56174",
261
+ "metadata": {},
262
+ "outputs": [],
263
+ "source": []
264
+ }
265
+ ],
266
+ "metadata": {
267
+ "kernelspec": {
268
+ "display_name": "Python 3 (ipykernel)",
269
+ "language": "python",
270
+ "name": "python3"
271
+ },
272
+ "language_info": {
273
+ "codemirror_mode": {
274
+ "name": "ipython",
275
+ "version": 3
276
+ },
277
+ "file_extension": ".py",
278
+ "mimetype": "text/x-python",
279
+ "name": "python",
280
+ "nbconvert_exporter": "python",
281
+ "pygments_lexer": "ipython3",
282
+ "version": "3.11.11"
283
+ }
284
+ },
285
+ "nbformat": 4,
286
+ "nbformat_minor": 5
287
+ }