|
1 | | -import os |
2 | 1 | import hashlib |
3 | 2 | from tqdm import tqdm |
4 | 3 | import numpy as np |
5 | | -import uuid |
6 | 4 | import cv2 |
7 | | -import socket |
8 | | -from urllib.request import urlretrieve |
9 | 5 | import torch |
10 | | - |
| 6 | +import base64 |
| 7 | +import urllib.request |
11 | 8 |
|
12 | 9 | class Img_Tools(object): |
13 | 10 | """ |
14 | | - Img操作 |
| 11 | + Img工具类 |
15 | 12 | """ |
16 | 13 |
|
17 | 14 | def __init__(self): |
18 | 15 | pass |
19 | | - |
| 16 | + |
20 | 17 | @staticmethod |
21 | | - def down_urls(url_list, save_path, time=5): |
22 | | - """ |
23 | | - 根据url下载图片,随机uuid命名。 |
| 18 | + def read_web_img(image_url=None,image_file=None,image_base64=None,url_time_out=10): |
| 19 | + ''' |
| 20 | + 参数三选一,当传入多个参数,仅返回最高优先级的图像 |
| 21 | +
|
| 22 | + 优先级: 文件 > base64 > url |
| 23 | + |
| 24 | + url_time_out : URL下载耗时限制,默认10秒 |
| 25 | + ''' |
| 26 | + if image_file: |
| 27 | + try: |
| 28 | + img = cv2.imdecode(np.frombuffer(image_file, np.uint8), cv2.IMREAD_COLOR) |
| 29 | + if img.any(): |
| 30 | + return img |
| 31 | + else: |
| 32 | + return 'IMAGE_ERROR_UNSUPPORTED_FORMAT' |
| 33 | + except: |
| 34 | + return 'IMAGE_ERROR_UNSUPPORTED_FORMAT' |
24 | 35 |
|
25 | | - url_list(list): URL列表 eg:['http://a.jpg', http://aaacc', ...] |
26 | | - save_path(str): 图像保存路径 |
27 | | - time(int):耗时限制,单位s |
28 | | - """ |
29 | | - socket.setdefaulttimeout(time) # 超时限制 |
30 | | - assert len(url_list) > 0 and os.path.isdir(save_path) # 验证文件夹是否存在 |
31 | | - print("start download imgs...") |
32 | | - for url in tqdm(url_list): |
| 36 | + elif image_base64: |
33 | 37 | try: |
34 | | - urlretrieve(url, save_path + str(uuid.uuid1()) + ".jpg") |
35 | | - except socket.timeout: |
36 | | - print("error url: ", url, "\n") |
| 38 | + img = base64.b64decode(image_base64) |
| 39 | + img_array = np.frombuffer(img, np.uint8) |
| 40 | + img = cv2.imdecode(img_array, cv2.IMREAD_COLOR) |
| 41 | + if img.any(): |
| 42 | + return img |
| 43 | + else: |
| 44 | + return 'IMAGE_ERROR_UNSUPPORTED_FORMAT' |
| 45 | + except: |
| 46 | + return 'IMAGE_ERROR_UNSUPPORTED_FORMAT' |
| 47 | + elif image_url: |
| 48 | + try: |
| 49 | + resp = urllib.request.urlopen(image_url,time_out=url_time_out) |
| 50 | + except: |
| 51 | + return 'URL_DOWNLOAD_TIMEOUT' |
| 52 | + try: |
| 53 | + image = np.asarray(bytearray(resp.read()), dtype="uint8") |
| 54 | + img = cv2.imdecode(image, cv2.IMREAD_COLOR) |
| 55 | + if img.any(): |
| 56 | + return img |
| 57 | + else: |
| 58 | + return 'IMAGE_ERROR_UNSUPPORTED_FORMAT' |
| 59 | + except: |
| 60 | + return 'INVALID_IMAGE_URL' |
| 61 | + else: |
| 62 | + return 'MISSING_ARGUMENTS' |
37 | 63 |
|
38 | 64 | @staticmethod |
39 | 65 | def plot_bbox(img, bbox, name, prob): |
|
0 commit comments