Skip to content

Commit ec32ee4

Browse files
NivekTfacebook-github-bot
authored andcommitted
Adding timeout option for GDriveReader (#153)
Summary: Pull Request resolved: #153 Fixes #132 Fixes #137 Test Plan: Imported from OSS Reviewed By: ejguan Differential Revision: D33370399 Pulled By: NivekT fbshipit-source-id: d20d6b3b17103bbeee79c788732c76bfb4f29f76
1 parent 7899c86 commit ec32ee4

File tree

2 files changed

+44
-40
lines changed

2 files changed

+44
-40
lines changed

torchdata/datapipes/iter/load/online.py

Lines changed: 41 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Facebook, Inc. and its affiliates.
22
import re
3-
from typing import Iterator, Tuple
3+
from typing import Iterator, Optional, Tuple
44
from urllib.parse import urlparse
55

66
import requests
@@ -10,7 +10,7 @@
1010
from torchdata.datapipes.utils import StreamWrapper
1111

1212

13-
def _get_response_from_http(url: str, *, timeout: float) -> Tuple[str, StreamWrapper]:
13+
def _get_response_from_http(url: str, *, timeout: Optional[float]) -> Tuple[str, StreamWrapper]:
1414
try:
1515
with requests.Session() as session:
1616
if timeout is None:
@@ -29,15 +29,14 @@ def _get_response_from_http(url: str, *, timeout: float) -> Tuple[str, StreamWra
2929
class HTTPReaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
3030
r""":class:`HTTPReaderIterDataPipe`
3131
32-
Iterable DataPipe that takes file URLs (http URLs pointing to files), and
33-
yields tuples of file URL and IO stream
32+
Iterable DataPipe that takes file URLs (http URLs pointing to files), and yields tuples of file URL and IO stream.
3433
3534
Args:
3635
source_datapipe: a DataPipe that contains URLs
37-
timeout : timeout in seconds for http request
36+
timeout: timeout in seconds for http request
3837
"""
3938

40-
def __init__(self, source_datapipe: IterDataPipe[str], timeout=None) -> None:
39+
def __init__(self, source_datapipe: IterDataPipe[str], timeout: Optional[float] = None) -> None:
4140
self.source_datapipe: IterDataPipe[str] = source_datapipe
4241
self.timeout = timeout
4342

@@ -49,47 +48,54 @@ def __len__(self) -> int:
4948
return len(self.source_datapipe)
5049

5150

52-
def _get_response_from_google_drive(url: str) -> Tuple[str, StreamWrapper]:
51+
def _get_response_from_google_drive(url: str, *, timeout: Optional[float]) -> Tuple[str, StreamWrapper]:
5352
confirm_token = None
54-
session = requests.Session()
55-
response = session.get(url, stream=True)
56-
for k, v in response.cookies.items():
57-
if k.startswith("download_warning"):
58-
confirm_token = v
59-
if confirm_token is None:
60-
if "Quota exceeded" in str(response.content):
61-
raise RuntimeError(f"Google drive link {url} is currently unavailable, because the quota was exceeded.")
62-
63-
if confirm_token:
64-
url = url + "&confirm=" + confirm_token
65-
66-
response = session.get(url, stream=True)
67-
68-
if "content-disposition" not in response.headers:
69-
raise RuntimeError("Internal error: headers don't contain content-disposition.")
70-
71-
filename = re.findall('filename="(.+)"', response.headers["content-disposition"])
72-
if filename is None:
73-
raise RuntimeError("Filename could not be autodetected")
53+
with requests.Session() as session:
54+
if timeout is None:
55+
response = session.get(url, stream=True)
56+
else:
57+
response = session.get(url, timeout=timeout, stream=True)
58+
for k, v in response.cookies.items():
59+
if k.startswith("download_warning"):
60+
confirm_token = v
61+
if confirm_token is None:
62+
if "Quota exceeded" in str(response.content):
63+
raise RuntimeError(f"Google drive link {url} is currently unavailable, because the quota was exceeded.")
64+
65+
if confirm_token:
66+
url = url + "&confirm=" + confirm_token
67+
68+
if timeout is None:
69+
response = session.get(url, stream=True)
70+
else:
71+
response = session.get(url, timeout=timeout, stream=True)
72+
73+
if "content-disposition" not in response.headers:
74+
raise RuntimeError("Internal error: headers don't contain content-disposition.")
75+
76+
filename = re.findall('filename="(.+)"', response.headers["content-disposition"])
77+
if filename is None:
78+
raise RuntimeError("Filename could not be autodetected")
7479
return filename[0], StreamWrapper(response.raw)
7580

7681

7782
class GDriveReaderDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
7883
r"""
79-
Iterable DataPipe that takes URLs point at GDrive files, and
80-
yields tuples of file name and IO stream
84+
Iterable DataPipe that takes URLs point at GDrive files, and yields tuples of file name and IO stream.
8185
8286
Args:
8387
source_datapipe: a DataPipe that contains URLs to GDrive files
88+
timeout: timeout in seconds for http request
8489
"""
8590
source_datapipe: IterDataPipe[str]
8691

87-
def __init__(self, source_datapipe: IterDataPipe[str]) -> None:
92+
def __init__(self, source_datapipe: IterDataPipe[str], *, timeout: Optional[float] = None) -> None:
8893
self.source_datapipe = source_datapipe
94+
self.timeout = timeout
8995

9096
def __iter__(self) -> Iterator[Tuple[str, StreamWrapper]]:
9197
for url in self.source_datapipe:
92-
yield _get_response_from_google_drive(url)
98+
yield _get_response_from_google_drive(url, timeout=self.timeout)
9399

94100
def __len__(self) -> int:
95101
return len(self.source_datapipe)
@@ -98,15 +104,15 @@ def __len__(self) -> int:
98104
class OnlineReaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
99105
r""":class:
100106
Iterable DataPipe that takes file URLs (can be http URLs pointing to files or URLs to GDrive files), and
101-
yields tuples of file URL and IO stream
107+
yields tuples of file URL and IO stream.
102108
103109
Args:
104110
source_datapipe: a DataPipe that contains URLs
105-
timeout : timeout in seconds for http request
111+
timeout: timeout in seconds for http request
106112
"""
107113
source_datapipe: IterDataPipe[str]
108114

109-
def __init__(self, source_datapipe: IterDataPipe[str], *, timeout=None) -> None:
115+
def __init__(self, source_datapipe: IterDataPipe[str], *, timeout: Optional[float] = None) -> None:
110116
self.source_datapipe = source_datapipe
111117
self.timeout = timeout
112118

@@ -115,8 +121,7 @@ def __iter__(self) -> Iterator[Tuple[str, StreamWrapper]]:
115121
parts = urlparse(url)
116122

117123
if re.match(r"(drive|docs)[.]google[.]com", parts.netloc):
118-
# TODO(137): can this also have a timeout?
119-
yield _get_response_from_google_drive(url)
124+
yield _get_response_from_google_drive(url, timeout=self.timeout)
120125
else:
121126
yield _get_response_from_http(url, timeout=self.timeout)
122127

torchdata/datapipes/iter/util/hashchecker.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __iter__(self) -> Iterator[Tuple[str, StreamWrapper]]:
5858
hash_func.update(data)
5959
# File Stream
6060
else:
61-
# Not all of streams have `read(bytes)` method.
61+
# Not all streams have `read(bytes)` method.
6262
# `__iter__` method is chosen because it is a common interface for IOBase.
6363
for d in data:
6464
hash_func.update(d)
@@ -72,9 +72,8 @@ def __iter__(self) -> Iterator[Tuple[str, StreamWrapper]]:
7272

7373
if hash_func.hexdigest() != self.hash_dict[file_name]:
7474
raise RuntimeError(
75-
"The hash {} of {} does not match. Delete the file manually and retry.".format(
76-
hash_func.hexdigest(), file_name
77-
)
75+
f"The computed hash {hash_func.hexdigest()} of {file_name} does not match the expected"
76+
f"hash {self.hash_dict[file_name]}. Delete the file manually and retry."
7877
)
7978

8079
if isinstance(data, (str, bytes, bytearray)):

0 commit comments

Comments
 (0)