PierreEpron commited on
Commit
329bf91
·
verified ·
1 Parent(s): dae84f1

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -0
app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+
4
+ from datetime import datetime
5
+ import demoji
6
+ from huggingface_hub import CommitScheduler
7
+ from pathlib import Path
8
+ import re
9
+ from transformers import pipeline
10
+ from uuid import uuid4
11
+
12
+ #based on https://huggingface.co/spaces/Wauplin/space_to_dataset_saver/blob/main/app_json.py
13
+ #data is saved at https://huggingface.co/datasets/MR17u/tweeteval-irony-mcc/tree/main
14
+
15
+ # JSON_DATASET_DIR = Path("json_dataset")
16
+ # JSON_DATASET_DIR.mkdir(parents=True, exist_ok=True)
17
+
18
+ # JSON_DATASET_PATH = JSON_DATASET_DIR / f"data-{uuid4()}.json"
19
+
20
+
21
+ prompt = '''### Instruction:
22
+ Classify if the following tweet is ironic or not
23
+ ### Input:
24
+ {text}
25
+ ### Response:
26
+ '''
27
+
28
+
29
+ # scheduler = CommitScheduler(
30
+ # repo_id="tweeteval-irony-mcc",
31
+ # repo_type="dataset",
32
+ # folder_path=JSON_DATASET_DIR,
33
+ # path_in_repo="data",
34
+ # )
35
+
36
+ classifier = pipeline("text-generation", model="meta-llama/Llama-2-7b-hf")
37
+ classifier.load_lora_weights("PierreEpron/llama7b-irony", weight_name="adapter_model.safetensors")
38
+
39
+ def clean_brackets(text):
40
+ return text.replace('{', '(').replace('}', ')')
41
+
42
+ def clean_emojis(text, type:str = ''):
43
+ if type=='rem':
44
+ return demoji.replace(text, '')
45
+ elif type!='keep':
46
+ return demoji.replace_with_desc(text, type)
47
+ else:
48
+ return text
49
+
50
+ def clean_hashtags(text, hashtags=['#irony', '#sarcasm','#not']):
51
+ for hashtag in hashtags:
52
+ text = re.sub(hashtag, '', text, flags=re.I)
53
+ return re.sub(r' +', r' ', text)
54
+
55
+ def clean_text(text):
56
+ return re.sub(' {2,}', ' ',clean_emojis(clean_hashtags(clean_brackets(text)))).strip()
57
+
58
+ # def save_json(entry: str, result) -> None:
59
+ # with scheduler.lock:
60
+ # with JSON_DATASET_PATH.open("a") as f:
61
+ # result = json.loads(result.replace("'",'"'))[0]
62
+ # json.dump({"entry": entry, "label": result['label'], "score": result['score'], "datetime": datetime.now().isoformat()}, f)
63
+ # f.write("\n")
64
+
65
+ def classif(text: str):
66
+ return classifier(prompt.format(text=clean_text(text)))
67
+
68
+ with gr.Blocks() as demo:
69
+ with gr.Row():
70
+ entry = gr.Textbox(label="Input")
71
+ result = gr.Textbox(label="Classification")
72
+ input_btn = gr.Button("Submit")
73
+ input_btn.click(fn=classif, inputs=entry, outputs=result).success(
74
+ fn=print, #save_json,
75
+ inputs=[entry, result],
76
+ outputs=None
77
+ )
78
+
79
+ demo.launch()