jenbenarye commited on
Commit
2d2463b
·
unverified ·
2 Parent(s): 62aa801 d151abe

Merge pull request #1 from jenbenarye/train

Browse files
ml/dataset_training.ipynb DELETED
@@ -1,398 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 43,
6
- "metadata": {},
7
- "outputs": [],
8
- "source": [
9
- "#dependencies:\n",
10
- "import pandas as pd\n",
11
- "\n",
12
- "import torch\n",
13
- "from transformers import GPT2Tokenizer\n",
14
- "\n",
15
- "from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer"
16
- ]
17
- },
18
- {
19
- "cell_type": "code",
20
- "execution_count": 44,
21
- "metadata": {},
22
- "outputs": [
23
- {
24
- "data": {
25
- "application/vnd.jupyter.widget-view+json": {
26
- "model_id": "b8a22b8d60c0417eafbf554832398287",
27
- "version_major": 2,
28
- "version_minor": 0
29
- },
30
- "text/plain": [
31
- "Resolving data files: 0%| | 0/18 [00:00<?, ?it/s]"
32
- ]
33
- },
34
- "metadata": {},
35
- "output_type": "display_data"
36
- },
37
- {
38
- "data": {
39
- "application/vnd.jupyter.widget-view+json": {
40
- "model_id": "b83d2624c2b14986a8297821460225ab",
41
- "version_major": 2,
42
- "version_minor": 0
43
- },
44
- "text/plain": [
45
- "Resolving data files: 0%| | 0/18 [00:00<?, ?it/s]"
46
- ]
47
- },
48
- "metadata": {},
49
- "output_type": "display_data"
50
- },
51
- {
52
- "data": {
53
- "application/vnd.jupyter.widget-view+json": {
54
- "model_id": "b4304c0f48cb472589b5e80d3a42cba2",
55
- "version_major": 2,
56
- "version_minor": 0
57
- },
58
- "text/plain": [
59
- "Resolving data files: 0%| | 0/18 [00:00<?, ?it/s]"
60
- ]
61
- },
62
- "metadata": {},
63
- "output_type": "display_data"
64
- }
65
- ],
66
- "source": [
67
- "#loading datasets:\n",
68
- "from datasets import load_dataset\n",
69
- "\n",
70
- "ds = load_dataset(\"stanfordnlp/SHP\", split='train')"
71
- ]
72
- },
73
- {
74
- "cell_type": "code",
75
- "execution_count": 45,
76
- "metadata": {},
77
- "outputs": [
78
- {
79
- "name": "stdout",
80
- "output_type": "stream",
81
- "text": [
82
- "Index(['post_id', 'domain', 'upvote_ratio', 'history', 'c_root_id_A',\n",
83
- " 'c_root_id_B', 'created_at_utc_A', 'created_at_utc_B', 'score_A',\n",
84
- " 'score_B', 'human_ref_A', 'human_ref_B', 'labels', 'seconds_difference',\n",
85
- " 'score_ratio'],\n",
86
- " dtype='object')\n"
87
- ]
88
- }
89
- ],
90
- "source": [
91
- "df = ds.to_pandas()\n",
92
- "print(df.columns)\n"
93
- ]
94
- },
95
- {
96
- "cell_type": "code",
97
- "execution_count": 46,
98
- "metadata": {},
99
- "outputs": [
100
- {
101
- "data": {
102
- "text/html": [
103
- "<div>\n",
104
- "<style scoped>\n",
105
- " .dataframe tbody tr th:only-of-type {\n",
106
- " vertical-align: middle;\n",
107
- " }\n",
108
- "\n",
109
- " .dataframe tbody tr th {\n",
110
- " vertical-align: top;\n",
111
- " }\n",
112
- "\n",
113
- " .dataframe thead th {\n",
114
- " text-align: right;\n",
115
- " }\n",
116
- "</style>\n",
117
- "<table border=\"1\" class=\"dataframe\">\n",
118
- " <thead>\n",
119
- " <tr style=\"text-align: right;\">\n",
120
- " <th></th>\n",
121
- " <th>upvote_ratio</th>\n",
122
- " <th>history</th>\n",
123
- " <th>score_A</th>\n",
124
- " <th>score_B</th>\n",
125
- " <th>human_ref_A</th>\n",
126
- " <th>human_ref_B</th>\n",
127
- " <th>labels</th>\n",
128
- " <th>score_ratio</th>\n",
129
- " </tr>\n",
130
- " </thead>\n",
131
- " <tbody>\n",
132
- " <tr>\n",
133
- " <th>0</th>\n",
134
- " <td>0.99</td>\n",
135
- " <td>In an interview right before receiving the 201...</td>\n",
136
- " <td>52</td>\n",
137
- " <td>54</td>\n",
138
- " <td>Currently wrapping up my PhD. There is a stark...</td>\n",
139
- " <td>It’s ironic to me that research has shown that...</td>\n",
140
- " <td>0</td>\n",
141
- " <td>1.038462</td>\n",
142
- " </tr>\n",
143
- " <tr>\n",
144
- " <th>1</th>\n",
145
- " <td>0.95</td>\n",
146
- " <td>If any professor is reading this: please do no...</td>\n",
147
- " <td>5</td>\n",
148
- " <td>17</td>\n",
149
- " <td>And when your teacher doesn't listen or pay at...</td>\n",
150
- " <td>I'm pretty strict on time, to the point where ...</td>\n",
151
- " <td>0</td>\n",
152
- " <td>3.400000</td>\n",
153
- " </tr>\n",
154
- " <tr>\n",
155
- " <th>2</th>\n",
156
- " <td>0.95</td>\n",
157
- " <td>If any professor is reading this: please do no...</td>\n",
158
- " <td>5</td>\n",
159
- " <td>7</td>\n",
160
- " <td>Profs can be oblivious? What’s new!</td>\n",
161
- " <td>This sounds like a problem with a specific pro...</td>\n",
162
- " <td>0</td>\n",
163
- " <td>1.400000</td>\n",
164
- " </tr>\n",
165
- " <tr>\n",
166
- " <th>3</th>\n",
167
- " <td>0.95</td>\n",
168
- " <td>If any professor is reading this: please do no...</td>\n",
169
- " <td>7</td>\n",
170
- " <td>5</td>\n",
171
- " <td>This sounds like a problem with a specific pro...</td>\n",
172
- " <td>And when your teacher doesn't listen or pay at...</td>\n",
173
- " <td>1</td>\n",
174
- " <td>1.400000</td>\n",
175
- " </tr>\n",
176
- " <tr>\n",
177
- " <th>4</th>\n",
178
- " <td>0.95</td>\n",
179
- " <td>If any professor is reading this: please do no...</td>\n",
180
- " <td>6</td>\n",
181
- " <td>7</td>\n",
182
- " <td>This would be totally unacceptable in my class...</td>\n",
183
- " <td>This sounds like a problem with a specific pro...</td>\n",
184
- " <td>0</td>\n",
185
- " <td>1.166667</td>\n",
186
- " </tr>\n",
187
- " <tr>\n",
188
- " <th>...</th>\n",
189
- " <td>...</td>\n",
190
- " <td>...</td>\n",
191
- " <td>...</td>\n",
192
- " <td>...</td>\n",
193
- " <td>...</td>\n",
194
- " <td>...</td>\n",
195
- " <td>...</td>\n",
196
- " <td>...</td>\n",
197
- " </tr>\n",
198
- " <tr>\n",
199
- " <th>348713</th>\n",
200
- " <td>0.94</td>\n",
201
- " <td>Can I get in trouble for giving my neighbor hi...</td>\n",
202
- " <td>7</td>\n",
203
- " <td>25</td>\n",
204
- " <td>Just put up a fence. Legally he isn't responsi...</td>\n",
205
- " <td>Whatever you do, don't cut his trees down.</td>\n",
206
- " <td>0</td>\n",
207
- " <td>3.571429</td>\n",
208
- " </tr>\n",
209
- " <tr>\n",
210
- " <th>348714</th>\n",
211
- " <td>0.94</td>\n",
212
- " <td>Can I get in trouble for giving my neighbor hi...</td>\n",
213
- " <td>2</td>\n",
214
- " <td>25</td>\n",
215
- " <td>If OP pays someone to clean his yard, and then...</td>\n",
216
- " <td>Whatever you do, don't cut his trees down.</td>\n",
217
- " <td>0</td>\n",
218
- " <td>12.500000</td>\n",
219
- " </tr>\n",
220
- " <tr>\n",
221
- " <th>348715</th>\n",
222
- " <td>0.94</td>\n",
223
- " <td>Can I get in trouble for giving my neighbor hi...</td>\n",
224
- " <td>9</td>\n",
225
- " <td>7</td>\n",
226
- " <td>My observation is that both of you are idiots...</td>\n",
227
- " <td>Are you Rand Paul's neighbor? https://www.gq....</td>\n",
228
- " <td>1</td>\n",
229
- " <td>1.285714</td>\n",
230
- " </tr>\n",
231
- " <tr>\n",
232
- " <th>348716</th>\n",
233
- " <td>0.94</td>\n",
234
- " <td>Can I get in trouble for giving my neighbor hi...</td>\n",
235
- " <td>9</td>\n",
236
- " <td>7</td>\n",
237
- " <td>My observation is that both of you are idiots...</td>\n",
238
- " <td>Just put up a fence. Legally he isn't responsi...</td>\n",
239
- " <td>1</td>\n",
240
- " <td>1.285714</td>\n",
241
- " </tr>\n",
242
- " <tr>\n",
243
- " <th>348717</th>\n",
244
- " <td>0.94</td>\n",
245
- " <td>Can I get in trouble for giving my neighbor hi...</td>\n",
246
- " <td>7</td>\n",
247
- " <td>2</td>\n",
248
- " <td>Capture his acts on camera. Collect and bag l...</td>\n",
249
- " <td>If OP pays someone to clean his yard, and then...</td>\n",
250
- " <td>1</td>\n",
251
- " <td>3.500000</td>\n",
252
- " </tr>\n",
253
- " </tbody>\n",
254
- "</table>\n",
255
- "<p>348718 rows × 8 columns</p>\n",
256
- "</div>"
257
- ],
258
- "text/plain": [
259
- " upvote_ratio history \\\n",
260
- "0 0.99 In an interview right before receiving the 201... \n",
261
- "1 0.95 If any professor is reading this: please do no... \n",
262
- "2 0.95 If any professor is reading this: please do no... \n",
263
- "3 0.95 If any professor is reading this: please do no... \n",
264
- "4 0.95 If any professor is reading this: please do no... \n",
265
- "... ... ... \n",
266
- "348713 0.94 Can I get in trouble for giving my neighbor hi... \n",
267
- "348714 0.94 Can I get in trouble for giving my neighbor hi... \n",
268
- "348715 0.94 Can I get in trouble for giving my neighbor hi... \n",
269
- "348716 0.94 Can I get in trouble for giving my neighbor hi... \n",
270
- "348717 0.94 Can I get in trouble for giving my neighbor hi... \n",
271
- "\n",
272
- " score_A score_B human_ref_A \\\n",
273
- "0 52 54 Currently wrapping up my PhD. There is a stark... \n",
274
- "1 5 17 And when your teacher doesn't listen or pay at... \n",
275
- "2 5 7 Profs can be oblivious? What’s new! \n",
276
- "3 7 5 This sounds like a problem with a specific pro... \n",
277
- "4 6 7 This would be totally unacceptable in my class... \n",
278
- "... ... ... ... \n",
279
- "348713 7 25 Just put up a fence. Legally he isn't responsi... \n",
280
- "348714 2 25 If OP pays someone to clean his yard, and then... \n",
281
- "348715 9 7 My observation is that both of you are idiots... \n",
282
- "348716 9 7 My observation is that both of you are idiots... \n",
283
- "348717 7 2 Capture his acts on camera. Collect and bag l... \n",
284
- "\n",
285
- " human_ref_B labels score_ratio \n",
286
- "0 It’s ironic to me that research has shown that... 0 1.038462 \n",
287
- "1 I'm pretty strict on time, to the point where ... 0 3.400000 \n",
288
- "2 This sounds like a problem with a specific pro... 0 1.400000 \n",
289
- "3 And when your teacher doesn't listen or pay at... 1 1.400000 \n",
290
- "4 This sounds like a problem with a specific pro... 0 1.166667 \n",
291
- "... ... ... ... \n",
292
- "348713 Whatever you do, don't cut his trees down. 0 3.571429 \n",
293
- "348714 Whatever you do, don't cut his trees down. 0 12.500000 \n",
294
- "348715 Are you Rand Paul's neighbor? https://www.gq.... 1 1.285714 \n",
295
- "348716 Just put up a fence. Legally he isn't responsi... 1 1.285714 \n",
296
- "348717 If OP pays someone to clean his yard, and then... 1 3.500000 \n",
297
- "\n",
298
- "[348718 rows x 8 columns]"
299
- ]
300
- },
301
- "execution_count": 46,
302
- "metadata": {},
303
- "output_type": "execute_result"
304
- }
305
- ],
306
- "source": [
307
- "# df['response_length'] = df['history'].apply(len)\n",
308
- "# df['label'] = df['response_length'].apply(lambda x: 'long' if x > 100 else 'short')\n",
309
- "df.drop(columns=['post_id', 'domain', 'c_root_id_A', 'c_root_id_B', 'created_at_utc_A', 'created_at_utc_B', 'seconds_difference'])"
310
- ]
311
- },
312
- {
313
- "cell_type": "code",
314
- "execution_count": 47,
315
- "metadata": {},
316
- "outputs": [
317
- {
318
- "name": "stderr",
319
- "output_type": "stream",
320
- "text": [
321
- "/Users/riddhib/.pyenv/versions/3.10.13/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:1617: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be deprecated in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n",
322
- " warnings.warn(\n"
323
- ]
324
- }
325
- ],
326
- "source": [
327
- "model = AutoModelForCausalLMWithValueHead.from_pretrained(\"gpt2\")\n",
328
- "ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(\"gpt2\")\n",
329
- "tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n",
330
- "tokenizer.pad_token = tokenizer.eos_token"
331
- ]
332
- },
333
- {
334
- "cell_type": "code",
335
- "execution_count": 48,
336
- "metadata": {},
337
- "outputs": [],
338
- "source": [
339
- "from trl_rlhf_data import runner, ScriptArguments\n",
340
- "import re\n",
341
- "from dataclasses import dataclass\n",
342
- "from typing import Dict, List, Optional\n",
343
- "\n",
344
- "from datasets import load_dataset\n",
345
- "from transformers import HfArgumentParser"
346
- ]
347
- },
348
- {
349
- "cell_type": "code",
350
- "execution_count": 49,
351
- "metadata": {},
352
- "outputs": [
353
- {
354
- "ename": "TypeError",
355
- "evalue": "runner() takes 0 positional arguments but 1 was given",
356
- "output_type": "error",
357
- "traceback": [
358
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
359
- "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
360
- "Cell \u001b[0;32mIn[49], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m dataset \u001b[38;5;241m=\u001b[39m \u001b[43mrunner\u001b[49m\u001b[43m(\u001b[49m\u001b[43mScriptArguments\u001b[49m\u001b[43m)\u001b[49m\n",
361
- "\u001b[0;31mTypeError\u001b[0m: runner() takes 0 positional arguments but 1 was given"
362
- ]
363
- }
364
- ],
365
- "source": [
366
- "dataset = runner(ScriptArguments)"
367
- ]
368
- },
369
- {
370
- "cell_type": "code",
371
- "execution_count": null,
372
- "metadata": {},
373
- "outputs": [],
374
- "source": []
375
- }
376
- ],
377
- "metadata": {
378
- "kernelspec": {
379
- "display_name": "Python 3",
380
- "language": "python",
381
- "name": "python3"
382
- },
383
- "language_info": {
384
- "codemirror_mode": {
385
- "name": "ipython",
386
- "version": 3
387
- },
388
- "file_extension": ".py",
389
- "mimetype": "text/x-python",
390
- "name": "python",
391
- "nbconvert_exporter": "python",
392
- "pygments_lexer": "ipython3",
393
- "version": "3.10.13"
394
- }
395
- },
396
- "nbformat": 4,
397
- "nbformat_minor": 2
398
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ml/kto_dataset_processor.py CHANGED
@@ -1,65 +1,210 @@
1
- from datasets import load_dataset, Dataset
2
  import pandas as pd
3
- from pdb import set_trace as st
 
 
 
 
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- def process_dataset_ultrafeedback():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  """
8
- Processes the 'train_prefs' and 'test_prefs' splits of the 'HuggingFaceH4/ultrafeedback_binarized' dataset
9
- into a unified format for preference modeling.
 
 
 
 
 
10
 
11
  Returns:
12
- dict: A dictionary containing the unified 'train' and 'test' splits of the dataset in the KTO format.
13
- Each split is a Hugging Face Dataset object.
 
 
14
  """
15
- # Load the relevant splits of the dataset
16
- dataset_name = "HuggingFaceH4/ultrafeedback_binarized"
17
- train_prefs = load_dataset(dataset_name, split="train_prefs")
18
- test_prefs = load_dataset(dataset_name, split="test_prefs")
19
-
20
- # Function to transform a single example into the desired schema
21
- def transform_data(example):
22
- data_points = []
23
- # Chosen completion
24
- chosen_completion = example["chosen"][1]["content"]
25
- if chosen_completion.strip(): # Check for non-empty completions
26
- data_points.append({
27
- "prompt": example["prompt"],
28
- "completion": chosen_completion.strip(),
29
- "label": True
30
- })
31
- # Rejected completion
32
- rejected_completion = example["rejected"][1]["content"]
33
- if rejected_completion.strip(): # Check for non-empty completions
34
- data_points.append({
35
- "prompt": example["prompt"],
36
- "completion": rejected_completion.strip(),
37
- "label": False
38
- })
39
- return data_points
40
-
41
- # Process train and test splits
42
- train_data = []
43
- test_data = []
44
-
45
- for example in train_prefs:
46
- train_data.extend(transform_data(example))
47
-
48
- for example in test_prefs:
49
- test_data.extend(transform_data(example))
50
-
51
- # Convert unified data to DataFrames
52
- train_df = pd.DataFrame(train_data)
53
- test_df = pd.DataFrame(test_data)
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  # Convert to Hugging Face Dataset
57
- unified_train = Dataset.from_pandas(train_df)
58
- unified_test = Dataset.from_pandas(test_df)
59
 
60
- return {"train": unified_train, "test": unified_test}
 
 
61
 
 
62
 
63
  if __name__ == "__main__":
64
- kto_dataset = process_dataset_ultrafeedback()
65
- st()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import Dataset, load_dataset
2
  import pandas as pd
3
+ from sklearn.model_selection import train_test_split
4
+ import json
5
+ from ipdb import set_trace as st
6
+ from transformers import AutoTokenizer
7
+ from enum import Enum
8
 
9
+ class SupportedLanguages(str, Enum):
10
+ """Enumeration of supported languages"""
11
+ ENGLISH = "English"
12
+ DUTCH = "Dutch"
13
+ ITALIAN = "Italian"
14
+ SPANISH = "Spanish"
15
+ FRENCH = "French"
16
+ GERMAN = "German"
17
+ PORTUGUESE = "Portuguese"
18
+ RUSSIAN = "Russian"
19
+ CHINESE = "Chinese"
20
+ JAPANESE = "Japanese"
21
+ KOREAN = "Korean"
22
 
23
+ def transform_conversation(
24
+ entry: dict,
25
+ model_name: str,
26
+ max_history_turns: int = 10,
27
+ max_history_tokens: int = 4000
28
+ ) -> list:
29
+ """Transform conversation into KTO format with history"""
30
+ data_points = []
31
+ conversation = entry["conversation"]
32
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
33
+
34
+ for i, message in enumerate(conversation):
35
+ # Only create data points for assistant messages that have ratings
36
+ if message["role"] != "assistant" or message["rating"] not in [1, -1]:
37
+ continue
38
+
39
+ # Get previous messages up to limits
40
+ formatted_history = []
41
+ formatted_prompt = ""
42
+ tokens = 0
43
+ pairs = 0 # Count complete user/assistant pairs
44
+
45
+ # Start from the current message and work backwards
46
+ current_idx = i - 1
47
+ while current_idx >= 0 and pairs < max_history_turns:
48
+ # We need both user and assistant messages to form a pair
49
+ if current_idx > 0 and conversation[current_idx]["role"] == "user" and conversation[current_idx-1]["role"] == "assistant":
50
+ # Add the pair to history
51
+ formatted_history.insert(0, conversation[current_idx-1]) # assistant
52
+ formatted_history.insert(1, conversation[current_idx]) # user
53
+
54
+ # Check token limit
55
+ try:
56
+ current_formatted = tokenizer.apply_chat_template(formatted_history, tokenize=False)
57
+ current_tokens = len(tokenizer.encode(current_formatted))
58
+
59
+ if current_tokens > max_history_tokens:
60
+ formatted_history = formatted_history[2:] # Remove the oldest pair
61
+ break
62
+
63
+ formatted_prompt = current_formatted
64
+ tokens = current_tokens
65
+ pairs += 1
66
+ current_idx -= 2
67
+ except Exception:
68
+ # If template application fails, remove the last added pair
69
+ formatted_history = formatted_history[2:]
70
+ break
71
+ else:
72
+ current_idx -= 1
73
+
74
+ # Add the final user message that prompted the rated response
75
+ if i > 0 and conversation[i-1]["role"] == "user":
76
+ last_history = formatted_history + [conversation[i-1]]
77
+ try:
78
+ formatted_prompt = tokenizer.apply_chat_template(last_history, tokenize=False)
79
+ except Exception:
80
+ # If template application fails, use the previous valid prompt
81
+ pass
82
+
83
+ data_points.append({
84
+ "prompt": formatted_prompt.strip(),
85
+ "completion": message["content"].strip(),
86
+ "label": message["rating"] == 1,
87
+ "timestamp": entry["timestamp"],
88
+ "session_id": entry["session_id"],
89
+ "conversation_id": entry["conversation_id"],
90
+ "language": entry["language"]
91
+ })
92
+
93
+ return data_points
94
+
95
+ def process_feel_dataset(
96
+ language: str,
97
+ model_name: str = "CohereForAI/aya-expanse-8b",
98
+ max_history_turns: int = 10,
99
+ max_history_tokens: int = 4000
100
+ ):
101
  """
102
+ Processes the feel dataset into a format suitable for KTO training using TRL.
103
+
104
+ Args:
105
+ language: Language to filter the dataset for (must be one of SupportedLanguages)
106
+ model_name: Name of the model to format for
107
+ max_history_turns: Maximum number of previous turns to include in history
108
+ max_history_tokens: Maximum number of tokens allowed in history
109
 
110
  Returns:
111
+ dict: A dictionary containing the 'train' and 'test' splits of the dataset in KTO format
112
+
113
+ Raises:
114
+ ValueError: If language is not provided or not in SupportedLanguages
115
  """
116
+ # Validate language
117
+ if not language:
118
+ raise ValueError("Language parameter is required")
119
+
120
+ try:
121
+ # Validate that it's a supported language
122
+ SupportedLanguages(language)
123
+ except ValueError:
124
+ supported_langs = "\n- ".join([lang.value for lang in SupportedLanguages])
125
+ raise ValueError(
126
+ f"Invalid language: '{language}'\n"
127
+ f"Supported languages are:\n- {supported_langs}"
128
+ )
129
+
130
+ # Load feel dataset from HuggingFace
131
+ feel_dataset = load_dataset("feel-fl/feel-feedback")["train"]
132
+
133
+ # Filter dataset by language
134
+ feel_dataset = feel_dataset.filter(lambda x: x["language"] == language)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
+ if len(feel_dataset) == 0:
137
+ raise ValueError(f"No data found for language: {language}")
138
+
139
+ kto_data = []
140
+
141
+ # Process all conversations in the filtered dataset
142
+ for entry in feel_dataset:
143
+ kto_data.extend(transform_conversation(
144
+ entry,
145
+ model_name,
146
+ max_history_turns,
147
+ max_history_tokens
148
+ ))
149
+
150
+ if len(kto_data) == 0:
151
+ raise ValueError(f"No valid training examples found for language: {language}")
152
+
153
+ # Convert to DataFrame
154
+ kto_df = pd.DataFrame(kto_data)
155
+
156
+ # Split into train and test sets (70% train, 30% test)
157
+ train_df, test_df = train_test_split(kto_df, test_size=0.3, random_state=42)
158
+
159
+ # Reset index to remove '__index_level_0__'
160
+ train_df = train_df.reset_index(drop=True)
161
+ test_df = test_df.reset_index(drop=True)
162
 
163
  # Convert to Hugging Face Dataset
164
+ train_dataset = Dataset.from_pandas(train_df)
165
+ test_dataset = Dataset.from_pandas(test_df)
166
 
167
+ print(f"Processed {len(kto_data)} examples for language: {language}")
168
+ print(f"Train set size: {len(train_dataset)}")
169
+ print(f"Test set size: {len(test_dataset)}")
170
 
171
+ return {"train": train_dataset, "test": test_dataset}
172
 
173
  if __name__ == "__main__":
174
+ # Process the dataset
175
+ datasets = process_feel_dataset("English")
176
+
177
+ # Print distribution of positive/negative labels
178
+ train_labels = datasets['train']['label']
179
+ test_labels = datasets['test']['label']
180
+
181
+ print("\nLabel Distribution:")
182
+ print("Train set:")
183
+ print(f"Positive feedback: {sum(train_labels)}")
184
+ print(f"Negative feedback: {len(train_labels) - sum(train_labels)}")
185
+ print(f"Positive ratio: {sum(train_labels)/len(train_labels):.2%}")
186
+
187
+ print("\nTest set:")
188
+ print(f"Positive feedback: {sum(test_labels)}")
189
+ print(f"Negative feedback: {len(test_labels) - sum(test_labels)}")
190
+ print(f"Positive ratio: {sum(test_labels)/len(test_labels):.2%}")
191
+
192
+ # Load original FEEL dataset
193
+ feel_dataset = load_dataset("feel-fl/feel-feedback", split="train")
194
+
195
+ # Print one original conversation
196
+ print("\nOriginal conversation from FEEL dataset:")
197
+ print(json.dumps(feel_dataset[0], indent=2))
198
+
199
+ # Print sample entries from processed dataset
200
+ print("\nSample entries from processed KTO dataset:")
201
+ print("\n" + "="*80 + "\nTRAIN SET SAMPLES\n" + "="*80)
202
+
203
+ # Export datasets to CSV
204
+ train_df = datasets['train'].to_pandas()
205
+ test_df = datasets['test'].to_pandas()
206
+
207
+ train_df.to_csv('kto_train_dataset.csv', index=False)
208
+ test_df.to_csv('kto_test_dataset.csv', index=False)
209
+
210
+ print("\nDatasets exported to 'kto_train_dataset.csv' and 'kto_test_dataset.csv'")
ml/kto_lora.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from dataclasses import dataclass
4
+ from accelerate import PartialState
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
6
+ from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, maybe_unpair_preference_dataset, setup_chat_format
7
+ from kto_dataset_processor import process_feel_dataset
8
+ from datetime import datetime
9
+ import wandb
10
+
11
+ # PEFT library: attach and load adapters
12
+ from peft import get_peft_model, PeftModel
13
+
14
+ ####################################
15
+ # CONFIGURATION
16
+ ####################################
17
+
18
+ @dataclass
19
+ class ScriptArguments:
20
+ """
21
+ Configuration for the script.
22
+ """
23
+ process_dataset_func: callable = process_feel_dataset # Function to process dataset
24
+ checkpoint_path: str = None # Checkpoint path if needed
25
+ push_to_hub: bool = False # Whether to push the adapter to the HF Hub after training
26
+ language: str = "en" # Language identifier (e.g., "en", "fr", etc.)
27
+
28
+ @dataclass
29
+ class ModelArguments(ModelConfig):
30
+ """
31
+ Configuration for the model.
32
+ """
33
+ model_name: str = "CohereForAI/aya-expanse-8b"
34
+ use_peft: bool = True
35
+ lora_target_modules: str = "all-linear"
36
+ lora_r: int = 16
37
+ lora_alpha: int = 16
38
+ trust_remote_code: bool = True
39
+
40
+ @dataclass
41
+ class TrainingArguments(KTOConfig):
42
+ """
43
+ Configuration for the KTO trainer.
44
+ """
45
+ output_dir: str = f"kto_{ModelArguments.model_name}_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
46
+ num_train_epochs: int = 1
47
+ per_device_train_batch_size: int = 4
48
+ learning_rate: float = 5e-7
49
+ lr_scheduler_type: str = "cosine"
50
+ gradient_accumulation_steps: int = 1
51
+ logging_steps: int = 10
52
+ eval_steps: int = 500
53
+ warmup_ratio: float = 0.1
54
+ bf16: bool = True
55
+ logging_first_step: bool = True
56
+
57
+ # Initialize configurations
58
+ script_args = ScriptArguments()
59
+ training_args = TrainingArguments()
60
+ model_args = ModelArguments()
61
+
62
+ ####################################
63
+ # HELPER FUNCTIONS
64
+ ####################################
65
+
66
+ def load_model_and_tokenizer(model_args):
67
+ """
68
+ Load the base model and tokenizer from the Hugging Face Hub.
69
+ """
70
+ model = AutoModelForCausalLM.from_pretrained(
71
+ model_args.model_name,
72
+ trust_remote_code=model_args.trust_remote_code,
73
+ torch_dtype=torch.float16,
74
+ device_map="auto"
75
+ )
76
+ tokenizer = AutoTokenizer.from_pretrained(
77
+ model_args.model_name,
78
+ trust_remote_code=model_args.trust_remote_code
79
+ )
80
+
81
+ # Set pad token if it is missing
82
+ if tokenizer.pad_token is None:
83
+ tokenizer.pad_token = tokenizer.eos_token
84
+
85
+ # Setup chat format if not available on the tokenizer
86
+ if not getattr(tokenizer, "chat_template", None):
87
+ model, tokenizer = setup_chat_format(model, tokenizer)
88
+
89
+ return model, tokenizer
90
+
91
+ ####################################
92
+ # MAIN LOGIC
93
+ ####################################
94
+
95
+ def main():
96
+ # Initialize wandb for logging
97
+ wandb.init(project="kto")
98
+
99
+ print("Loading base model and tokenizer...")
100
+ model, tokenizer = load_model_and_tokenizer(model_args)
101
+ ref_model, _ = load_model_and_tokenizer(model_args)
102
+ print("Models and tokenizer loaded.")
103
+
104
+ # -----------------------------
105
+ # Adapter Loading or Initialization
106
+ # -----------------------------
107
+ # Configure the PEFT / LoRA adapter settings
108
+ peft_config = get_peft_config(model_args)
109
+ adapter_dir = os.path.join("adapters", script_args.language)
110
+
111
+ if os.path.isdir(adapter_dir):
112
+ # If an adapter for this language already exists, load it into the base model.
113
+ model = PeftModel.from_pretrained(model, adapter_dir)
114
+ print(f"Loaded existing adapter for language '{script_args.language}' from {adapter_dir}.")
115
+ else:
116
+ # Otherwise, initialize a new LoRA adapter.
117
+ model = get_peft_model(model, peft_config)
118
+ print(f"No adapter found for language '{script_args.language}'. Initialized new adapter.")
119
+
120
+ # -----------------------------
121
+ # Data Preparation and Training
122
+ # -----------------------------
123
+ print("Processing dataset...")
124
+ dataset = script_args.process_dataset_func()
125
+ print("Dataset processed.")
126
+
127
+ print("Initializing trainer...")
128
+ trainer = KTOTrainer(
129
+ model=model,
130
+ ref_model=ref_model,
131
+ args=training_args,
132
+ train_dataset=dataset["train"],
133
+ eval_dataset=dataset["test"],
134
+ processing_class=tokenizer,
135
+ peft_config=peft_config,
136
+ )
137
+
138
+ # Training
139
+ print("Starting training...")
140
+ trainer.train()
141
+ print("Training completed.")
142
+
143
+ # Evaluation
144
+ print("Evaluating model...")
145
+ metrics = trainer.evaluate()
146
+ print(f"Metrics: {metrics}")
147
+ trainer.log_metrics("eval", metrics)
148
+ trainer.save_metrics("eval", metrics)
149
+
150
+ # Log metrics to wandb
151
+ wandb.log({
152
+ "epoch": metrics.get("epoch"),
153
+ "grad_norm": metrics.get("grad_norm"),
154
+ "kl": metrics.get("kl"),
155
+ "learning_rate": metrics.get("learning_rate"),
156
+ "logits/chosen": metrics.get("logits/chosen"),
157
+ "logits/rejected": metrics.get("logits/rejected"),
158
+ "logps/chosen": metrics.get("logps/chosen"),
159
+ "logps/rejected": metrics.get("logps/rejected"),
160
+ "loss": metrics.get("loss"),
161
+ "rewards/chosen": metrics.get("rewards/chosen"),
162
+ "rewards/margins": metrics.get("rewards/margins"),
163
+ "rewards/rejected": metrics.get("rewards/rejected"),
164
+ "step": metrics.get("step")
165
+ })
166
+
167
+ # -----------------------------
168
+ # Adapter Saving
169
+ # -----------------------------
170
+ print("Saving adapter...")
171
+ os.makedirs(adapter_dir, exist_ok=True)
172
+ model.save_pretrained(adapter_dir)
173
+ print(f"Adapter for language '{script_args.language}' saved to: {adapter_dir}")
174
+
175
+ if script_args.push_to_hub:
176
+ print("Pushing adapter to Hugging Face Hub...")
177
+ model.push_to_hub(repo_id=f"your_hf_org/{script_args.language}-adapter")
178
+
179
+ print("Process completed.")
180
+
181
+ # Finish wandb run
182
+ wandb.finish()
183
+
184
+ if __name__ == "__main__":
185
+ main()