File size: 4,919 Bytes
3d06b91 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import os
import re
from typing import cast, Any
from datasets import load_dataset, Dataset as HfDataset
from unfat.extract import Extractor
from unfat.client import OpenAiCompatClient
from unfat.datasets import Dataset, Prompts, hub_prompts, HubSplit
from unfat.together import llama_3_1_70b_together
from unfat.lora import LoraSettings
def gen_prompts(
ds_name: str,
text_field: str,
start_regex: re.Pattern | None = None,
end_regex: re.Pattern | None = None,
):
ds = cast(HfDataset, load_dataset(ds_name, split="train"))
def items():
for row in ds:
casted = cast(dict[Any, Any], row)
text = casted[text_field]
if start_regex and end_regex:
yield end_regex.sub("", start_regex.sub("", text))
elif start_regex:
yield start_regex.sub("", text)
elif end_regex:
yield end_regex.sub("", text)
else:
yield text
return Prompts(
output_path=f"hub/{ds_name}.jsonl",
count=lambda: len(ds),
items=items,
)
def extract_prompts_from_convos(
ds_name: str,
messages_field: str,
role_field: str,
content_field: str,
user_role: str,
):
ds = cast(HfDataset, load_dataset(ds_name, split="train"))
def items():
for row in ds:
casted = cast(dict[Any, Any], row)
for message in casted[messages_field]:
if message[role_field] == user_role:
yield message[content_field]
break
return Prompts(
output_path=f"hub/{ds_name}.jsonl",
count=lambda: len(ds),
items=items,
)
def main():
output_dir = "output"
rp_english = extract_prompts_from_convos(
ds_name="OdiaGenAI/roleplay_english",
messages_field="conversations",
role_field="from",
content_field="value",
user_role="user",
)
bluemoon = extract_prompts_from_convos(
ds_name="xDAN2099/RolePlay-Mixed-Bluemoon-Limarp",
messages_field="conversations",
role_field="from",
content_field="value",
user_role="human",
)
roleplay_prompts = gen_prompts(
ds_name="AlekseyKorshuk/roleplay-io",
text_field="input_text",
start_regex=re.compile(r'^User: '),
end_regex=re.compile(r'Bot:\s*$'),
)
roleplay_instr_prompts = gen_prompts(
ds_name="iamketan25/roleplay-instructions-dataset",
text_field="prompt",
start_regex=re.compile(r'^Human: '),
end_regex=re.compile(r'Assistant:\s*$'),
)
extractor = Extractor(
max_concurrent=50,
output_dir=output_dir,
client=OpenAiCompatClient(
base_url="https://glhf.chat/api/openai/v1",
api_key=os.environ["GLHF_API_KEY"],
model="hf:TheDrummer/Behemoth-123B-v1.2",
retries=20,
),
dataset=Dataset(
train=[
hub_prompts(
name="mlabonne/harmful_behaviors",
text_field="text",
split="train",
),
roleplay_instr_prompts,
roleplay_prompts,
rp_english,
bluemoon,
hub_prompts(
name="TheDrummer/AmoralQA-v2",
text_field="prompt",
split="train",
),
hub_prompts(
name="vicgalle/OpenHermesPreferences-roleplay",
text_field="prompt",
split="train",
),
hub_prompts(
name="mrcuddle/DPO_Pairs_Roleplay-Alpaca",
text_field="prompt",
split="train",
),
hub_prompts(
name="ResplendentAI/theory_of_mind_fixed_output",
text_field="instruction",
split="train",
),
hub_prompts(
name="mlabonne/harmless_alpaca",
text_field="text",
split=HubSplit(name="train", max_rows=1000),
),
],
),
)
extractor.run()
dataset = extractor.output_dataset()
together_config = llama_3_1_70b_together(
output_dir=output_dir,
dataset=dataset,
api_key=os.environ["TOGETHER_API_KEY"],
settings=LoraSettings(
rank=32,
alpha=16,
dropout=0.01,
num_epochs=2,
learning_rate=4e-4,
evals_per_epoch=0,
wandb_project="behemoth-distill",
wandb_api_key=os.environ["WANDB_API_KEY"],
)
)
files = together_config.upload_files()
together_config.finetune(files)
if __name__ == "__main__":
main()
|