captchaboy commited on
Commit
0fbbe99
·
1 Parent(s): 87bc419

Update demo.py

Browse files
Files changed (1) hide show
  1. demo.py +54 -59
demo.py CHANGED
@@ -1,15 +1,29 @@
 
1
  import argparse
2
  import logging
3
  import os
4
  import glob
5
  import tqdm
6
- import torch
7
  import PIL
8
  import cv2
9
  import numpy as np
10
  import torch.nn.functional as F
11
  from torchvision import transforms
12
  from utils import Config, Logger, CharsetMapper
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  def get_model(config):
15
  import importlib
@@ -21,12 +35,22 @@ def get_model(config):
21
  model = model.eval()
22
  return model
23
 
24
- def preprocess(img, width, height):
25
- img = cv2.resize(np.array(img), (width, height))
26
- img = transforms.ToTensor()(img).unsqueeze(0)
27
- mean = torch.tensor([0.485, 0.456, 0.406])
28
- std = torch.tensor([0.229, 0.224, 0.225])
29
- return (img-mean[...,None,None]) / std[...,None,None]
 
 
 
 
 
 
 
 
 
 
30
 
31
  def postprocess(output, charset, model_eval):
32
  def _get_output(last_output, model_eval):
@@ -54,56 +78,27 @@ def postprocess(output, charset, model_eval):
54
 
55
  return pt_text, pt_scores, pt_lengths_
56
 
57
- def load(model, file, device=None, strict=True):
58
- if device is None: device = 'cpu'
59
- elif isinstance(device, int): device = torch.device('cuda', device)
60
- assert os.path.isfile(file)
61
- state = torch.load(file, map_location=device)
62
- if set(state.keys()) == {'model', 'opt'}:
63
- state = state['model']
64
- model.load_state_dict(state, strict=strict)
65
- return model
 
 
 
 
 
66
 
67
- def main():
68
- parser = argparse.ArgumentParser()
69
- parser.add_argument('--config', type=str, default='configs/train_abinet.yaml',
70
- help='path to config file')
71
- parser.add_argument('--input', type=str, default='figs/test')
72
- parser.add_argument('--cuda', type=int, default=-1)
73
- parser.add_argument('--checkpoint', type=str, default='workdir/train-abinet/best-train-abinet.pth')
74
- parser.add_argument('--model_eval', type=str, default='alignment',
75
- choices=['alignment', 'vision', 'language'])
76
- args = parser.parse_args()
77
- config = Config(args.config)
78
- if args.checkpoint is not None: config.model_checkpoint = args.checkpoint
79
- if args.model_eval is not None: config.model_eval = args.model_eval
80
- config.global_phase = 'test'
81
- config.model_vision_checkpoint, config.model_language_checkpoint = None, None
82
- device = 'cpu' if args.cuda < 0 else f'cuda:{args.cuda}'
83
-
84
- Logger.init(config.global_workdir, config.global_name, config.global_phase)
85
- Logger.enable_file()
86
- logging.info(config)
87
-
88
- logging.info('Construct model.')
89
- model = get_model(config).to(device)
90
- model = load(model, config.model_checkpoint, device=device)
91
- charset = CharsetMapper(filename=config.dataset_charset_path,
92
- max_length=config.dataset_max_length + 1)
93
-
94
- if os.path.isdir(args.input):
95
- paths = [os.path.join(args.input, fname) for fname in os.listdir(args.input)]
96
- else:
97
- paths = glob.glob(os.path.expanduser(args.input))
98
- assert paths, "The input path(s) was not found"
99
- paths = sorted(paths)
100
- for path in tqdm.tqdm(paths):
101
- img = PIL.Image.open(path).convert('RGB')
102
- img = preprocess(img, config.dataset_image_width, config.dataset_image_height)
103
- img = img.to(device)
104
- res = model(img)
105
- pt_text, _, __ = postprocess(res, charset, config.model_eval)
106
- logging.info(f'{path}: {pt_text[0]}')
107
-
108
- if __name__ == '__main__':
109
- main()
 
1
+ # from transformers import AutoModel
2
  import argparse
3
  import logging
4
  import os
5
  import glob
6
  import tqdm
7
+ import torch, re
8
  import PIL
9
  import cv2
10
  import numpy as np
11
  import torch.nn.functional as F
12
  from torchvision import transforms
13
  from utils import Config, Logger, CharsetMapper
14
+ import gradio as gr
15
+ #dfgdfg
16
+ import gdown
17
+ gdown.download(id='16PF_b4dURVkBt4OT7E-a-vq-SRxi0uDl', output='lol.pth')
18
+ gdown.download(id='19rGjfo73P25O_keQv30snfe3IHrK0uV2', output='config.yaml')
19
+
20
+ # gdown.download(id='1qyNV80qmYHx_r4KsG3_8PXQ6ff1a1dov', output='modules.zip')
21
+
22
+ # gdown.download(id='1UMZ7i8SpfuNw0N2JvVY8euaNx9gu3x6N', output='configs.zip')
23
+
24
+ # gdown.download(id='1yHD7_4DD_keUwGs2nenAYDaQ2CNEA5IU', output='data.zip')
25
+ # os.system('unzip data.zip && unzip configs.zip && unzip modules.zip')
26
+
27
 
28
  def get_model(config):
29
  import importlib
 
35
  model = model.eval()
36
  return model
37
 
38
+
39
+ def load(model, file, device=None, strict=True):
40
+ if device is None: device = 'cpu'
41
+ elif isinstance(device, int): device = torch.device('cuda', device)
42
+ assert os.path.isfile(file)
43
+ state = torch.load(file, map_location=device)
44
+ if set(state.keys()) == {'model', 'opt'}:
45
+ state = state['model']
46
+ model.load_state_dict(state, strict=strict)
47
+ return model
48
+
49
+ config = Config('config.yaml')
50
+ config.model_vision_checkpoint = None
51
+ model = get_model(config)
52
+ model = load(model, 'lol.pth')
53
+
54
 
55
  def postprocess(output, charset, model_eval):
56
  def _get_output(last_output, model_eval):
 
78
 
79
  return pt_text, pt_scores, pt_lengths_
80
 
81
+ def preprocess(img, width, height):
82
+ img = cv2.resize(np.array(img), (width, height))
83
+ img = transforms.ToTensor()(img).unsqueeze(0)
84
+ mean = torch.tensor([0.485, 0.456, 0.406])
85
+ std = torch.tensor([0.229, 0.224, 0.225])
86
+ return (img-mean[...,None,None]) / std[...,None,None]
87
+
88
+ def process_image(image):
89
+ charset = CharsetMapper(filename=config.dataset_charset_path, max_length=config.dataset_max_length + 1)
90
+
91
+ img = image.convert('RGB')
92
+ img = preprocess(img, config.dataset_image_width, config.dataset_image_height)
93
+ res = model(img)
94
+ return postprocess(res, charset, 'alignment')[0][0]
95
 
96
+ iface = gr.Interface(fn=process_image,
97
+ inputs=gr.inputs.Image(type="pil"),
98
+ outputs=gr.outputs.Textbox(),
99
+ title="8kun kek",
100
+ description="Making Jim Watkins sheete because he is a techlet pedo",
101
+ # article=article,
102
+ # examples=glob.glob('figs/test/*.png')
103
+ )
104
+ iface.launch(debug=True)