Skip to content

Commit 9a23004

Browse files
niklubnik
andauthored
fix: DIA-2214: Implement ro utility to access base64 from url (#460)
Co-authored-by: nik <[email protected]>
1 parent 856f396 commit 9a23004

File tree

1 file changed

+127
-2
lines changed
  • src/label_studio_sdk/_extensions/label_studio_tools/core/utils

1 file changed

+127
-2
lines changed

src/label_studio_sdk/_extensions/label_studio_tools/core/utils/io.py

Lines changed: 127 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging
44
import os
55
import shutil
6+
import base64
67
from contextlib import contextmanager
78
from tempfile import mkdtemp
89
from urllib.parse import urlparse
@@ -201,7 +202,7 @@ def download_and_cache(
201202

202203
# local storage: /data/local-files?d=dir/1.jpg => 1.jpg
203204
if is_local_storage_file:
204-
url_filename = os.path.basename(url.split('?d=')[1])
205+
url_filename = os.path.basename(url.split("?d=")[1])
205206
# cloud storage: s3://bucket/1.jpg => 1.jpg
206207
elif is_cloud_storage_file:
207208
url_filename = os.path.basename(url)
@@ -213,7 +214,11 @@ def download_and_cache(
213214
filepath = os.path.join(cache_dir, url_hash + "__" + url_filename)
214215

215216
if not os.path.exists(filepath):
216-
logger.info("Download {url} to {filepath}. download_resources: {download_resources}".format(url=url, filepath=filepath, download_resources=download_resources))
217+
logger.info(
218+
"Download {url} to {filepath}. download_resources: {download_resources}".format(
219+
url=url, filepath=filepath, download_resources=download_resources
220+
)
221+
)
217222
if download_resources:
218223
headers = {
219224
# avoid requests.exceptions.HTTPError: 403 Client Error: Forbidden. Please comply with the User-Agent policy:
@@ -256,3 +261,123 @@ def get_all_files_from_dir(d):
256261
if os.path.isfile(filepath):
257262
out.append(filepath)
258263
return out
264+
265+
266+
def get_base64_content(
267+
url,
268+
hostname=None,
269+
access_token=None,
270+
task_id=None,
271+
):
272+
"""This helper function is used to download a file and return its base64 representation without saving to filesystem.
273+
274+
:param url: File URL to download, it can be a uploaded file, local storage, cloud storage file or just http(s) url
275+
:param hostname: Label Studio Hostname, it will be used for uploaded files, local storage files and cloud storage files
276+
if not provided, it will be taken from LABEL_STUDIO_URL env variable
277+
:param access_token: Label Studio access token, it will be used for uploaded files, local storage files and cloud storage files
278+
if not provided, it will be taken from LABEL_STUDIO_API_KEY env variable
279+
:param task_id: Label Studio Task ID, required for cloud storage files
280+
because the URL will be rebuilt to `{hostname}/tasks/{task_id}/presign/?fileuri={url}`
281+
282+
:return: base64 encoded file content
283+
"""
284+
# get environment variables
285+
hostname = (
286+
hostname
287+
or os.getenv("LABEL_STUDIO_URL", "")
288+
or os.getenv("LABEL_STUDIO_HOST", "")
289+
)
290+
access_token = (
291+
access_token
292+
or os.getenv("LABEL_STUDIO_API_KEY", "")
293+
or os.getenv("LABEL_STUDIO_ACCESS_TOKEN", "")
294+
)
295+
if "localhost" in hostname:
296+
logger.warning(
297+
f"Using `localhost` ({hostname}) in LABEL_STUDIO_URL, "
298+
f"`localhost` is not accessible inside of docker containers. "
299+
f"You can check your IP with utilities like `ifconfig` and set it as LABEL_STUDIO_URL."
300+
)
301+
if hostname and not (
302+
hostname.startswith("http://") or hostname.startswith("https://")
303+
):
304+
raise ValueError(
305+
f"Invalid hostname in LABEL_STUDIO_URL: {hostname}. "
306+
"Please provide full URL starting with protocol (http:// or https://)."
307+
)
308+
309+
# fix file upload url
310+
if url.startswith("upload") or url.startswith("/upload"):
311+
url = "/data" + ("" if url.startswith("/") else "/") + url
312+
313+
is_uploaded_file = url.startswith("/data/upload")
314+
is_local_storage_file = url.startswith("/data/") and "?d=" in url
315+
is_cloud_storage_file = (
316+
url.startswith("s3:") or url.startswith("gs:") or url.startswith("azure-blob:")
317+
)
318+
319+
# Local storage file: try to load locally
320+
if is_local_storage_file:
321+
filepath = url.split("?d=")[1]
322+
filepath = safe_build_path(LOCAL_FILES_DOCUMENT_ROOT, filepath)
323+
if os.path.exists(filepath):
324+
logger.debug(
325+
f"Local Storage file path exists locally, read content directly: {filepath}"
326+
)
327+
with open(filepath, "rb") as f:
328+
return base64.b64encode(f.read()).decode("utf-8")
329+
330+
# Upload or Local Storage file
331+
if is_uploaded_file or is_local_storage_file or is_cloud_storage_file:
332+
# hostname check
333+
if not hostname:
334+
raise FileNotFoundError(
335+
f"Can't resolve url, hostname not provided: {url}. "
336+
"You can set LABEL_STUDIO_URL environment variable to use it as a hostname."
337+
)
338+
# uploaded and local storage file
339+
elif is_uploaded_file or is_local_storage_file:
340+
url = concat_urls(hostname, url)
341+
logger.info("Resolving url using hostname [" + hostname + "]: " + url)
342+
# s3, gs, azure-blob file
343+
elif is_cloud_storage_file:
344+
if task_id is None:
345+
raise Exception(
346+
"Label Studio Task ID is required for cloud storage files"
347+
)
348+
url = concat_urls(hostname, f"/tasks/{task_id}/presign/?fileuri={url}")
349+
logger.info(
350+
"Cloud storage file: Resolving url using hostname ["
351+
+ hostname
352+
+ "]: "
353+
+ url
354+
)
355+
356+
# check access token
357+
if not access_token:
358+
raise FileNotFoundError(
359+
"To access uploaded and local storage files you have to "
360+
"set LABEL_STUDIO_API_KEY environment variable."
361+
)
362+
363+
# Download the content but don't save to filesystem
364+
headers = {
365+
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_4) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/83.0.4103.97 Safari/537.36"
366+
}
367+
368+
# check if url matches hostname - then uses access token to this Label Studio instance
369+
parsed_url = urlparse(url)
370+
if access_token and hostname and parsed_url.netloc == urlparse(hostname).netloc:
371+
headers["Authorization"] = "Token " + access_token
372+
logger.debug("Authorization token is used for get_base64_content")
373+
374+
try:
375+
r = requests.get(url, headers=headers, verify=VERIFY_SSL)
376+
r.raise_for_status()
377+
return base64.b64encode(r.content).decode("utf-8")
378+
except requests.exceptions.SSLError as e:
379+
logger.error(
380+
f"SSL error during requests.get('{url}'): {e}\n"
381+
f"Try to set VERIFY_SSL=False in environment variables to bypass SSL verification."
382+
)
383+
raise e

0 commit comments

Comments
 (0)