-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathomniparser.py
More file actions
110 lines (97 loc) · 3.22 KB
/
omniparser.py
File metadata and controls
110 lines (97 loc) · 3.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from utils import (
get_som_labeled_img,
check_ocr_box,
get_caption_model_processor,
get_dino_model,
get_yolo_model,
)
import torch
from ultralytics import YOLO
from PIL import Image
from typing import Dict, Tuple, List
import io
import base64
config = {
"som_model_path": "weights/icon_detect_v1_5/model_v1_5.pt",
"device": "cpu",
"caption_model_path": "florence2",
"draw_bbox_config": {
"text_scale": 0.8,
"text_thickness": 2,
"text_padding": 3,
"thickness": 3,
},
"BOX_TRESHOLD": 0.05,
}
class Omniparser(object):
def __init__(self, config: Dict):
self.config = config
self.som_model = get_yolo_model(model_path=config["som_model_path"])
# self.caption_model_processor = get_caption_model_processor(config['caption_model_path'], device=cofig['device'])
# self.caption_model_processor['model'].to(torch.float32)
def parse(self, image_path: str):
print("Parsing image:", image_path)
ocr_bbox_rslt, is_goal_filtered = check_ocr_box(
image_path,
display_img=False,
output_bb_format="xyxy",
goal_filtering=None,
easyocr_args={"paragraph": False, "text_threshold": 0.9},
)
text, ocr_bbox = ocr_bbox_rslt
draw_bbox_config = self.config["draw_bbox_config"]
BOX_TRESHOLD = self.config["BOX_TRESHOLD"]
dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
image_path,
self.som_model,
BOX_TRESHOLD=BOX_TRESHOLD,
output_coord_in_ratio=False,
ocr_bbox=ocr_bbox,
draw_bbox_config=draw_bbox_config,
caption_model_processor=None,
ocr_text=text,
use_local_semantics=False,
)
image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
# formating output
return_list = [
{
"from": "omniparser",
"shape": {
"x": coord[0],
"y": coord[1],
"width": coord[2],
"height": coord[3],
},
"text": parsed_content_list[i].split(": ")[1],
"type": "text",
}
for i, (k, coord) in enumerate(label_coordinates.items())
if i < len(parsed_content_list)
]
return_list.extend(
[
{
"from": "omniparser",
"shape": {
"x": coord[0],
"y": coord[1],
"width": coord[2],
"height": coord[3],
},
"text": "None",
"type": "icon",
}
for i, (k, coord) in enumerate(label_coordinates.items())
if i >= len(parsed_content_list)
]
)
return [image, return_list]
parser = Omniparser(config)
image_path = "imgs/ios.png"
# time the parser
import time
s = time.time()
image, parsed_content_list = parser.parse(image_path)
device = config["device"]
print(f"Time taken for Omniparser on {device}:", time.time() - s)