1
1
# Copyright (c) Facebook, Inc. and its affiliates.
2
2
import re
3
- from typing import Iterator , Tuple
3
+ from typing import Iterator , Optional , Tuple
4
4
from urllib .parse import urlparse
5
5
6
6
import requests
10
10
from torchdata .datapipes .utils import StreamWrapper
11
11
12
12
13
- def _get_response_from_http (url : str , * , timeout : float ) -> Tuple [str , StreamWrapper ]:
13
+ def _get_response_from_http (url : str , * , timeout : Optional [ float ] ) -> Tuple [str , StreamWrapper ]:
14
14
try :
15
15
with requests .Session () as session :
16
16
if timeout is None :
@@ -29,15 +29,14 @@ def _get_response_from_http(url: str, *, timeout: float) -> Tuple[str, StreamWra
29
29
class HTTPReaderIterDataPipe (IterDataPipe [Tuple [str , StreamWrapper ]]):
30
30
r""":class:`HTTPReaderIterDataPipe`
31
31
32
- Iterable DataPipe that takes file URLs (http URLs pointing to files), and
33
- yields tuples of file URL and IO stream
32
+ Iterable DataPipe that takes file URLs (http URLs pointing to files), and yields tuples of file URL and IO stream.
34
33
35
34
Args:
36
35
source_datapipe: a DataPipe that contains URLs
37
- timeout : timeout in seconds for http request
36
+ timeout: timeout in seconds for http request
38
37
"""
39
38
40
- def __init__ (self , source_datapipe : IterDataPipe [str ], timeout = None ) -> None :
39
+ def __init__ (self , source_datapipe : IterDataPipe [str ], timeout : Optional [ float ] = None ) -> None :
41
40
self .source_datapipe : IterDataPipe [str ] = source_datapipe
42
41
self .timeout = timeout
43
42
@@ -49,47 +48,54 @@ def __len__(self) -> int:
49
48
return len (self .source_datapipe )
50
49
51
50
52
- def _get_response_from_google_drive (url : str ) -> Tuple [str , StreamWrapper ]:
51
+ def _get_response_from_google_drive (url : str , * , timeout : Optional [ float ] ) -> Tuple [str , StreamWrapper ]:
53
52
confirm_token = None
54
- session = requests .Session ()
55
- response = session .get (url , stream = True )
56
- for k , v in response .cookies .items ():
57
- if k .startswith ("download_warning" ):
58
- confirm_token = v
59
- if confirm_token is None :
60
- if "Quota exceeded" in str (response .content ):
61
- raise RuntimeError (f"Google drive link { url } is currently unavailable, because the quota was exceeded." )
62
-
63
- if confirm_token :
64
- url = url + "&confirm=" + confirm_token
65
-
66
- response = session .get (url , stream = True )
67
-
68
- if "content-disposition" not in response .headers :
69
- raise RuntimeError ("Internal error: headers don't contain content-disposition." )
70
-
71
- filename = re .findall ('filename="(.+)"' , response .headers ["content-disposition" ])
72
- if filename is None :
73
- raise RuntimeError ("Filename could not be autodetected" )
53
+ with requests .Session () as session :
54
+ if timeout is None :
55
+ response = session .get (url , stream = True )
56
+ else :
57
+ response = session .get (url , timeout = timeout , stream = True )
58
+ for k , v in response .cookies .items ():
59
+ if k .startswith ("download_warning" ):
60
+ confirm_token = v
61
+ if confirm_token is None :
62
+ if "Quota exceeded" in str (response .content ):
63
+ raise RuntimeError (f"Google drive link { url } is currently unavailable, because the quota was exceeded." )
64
+
65
+ if confirm_token :
66
+ url = url + "&confirm=" + confirm_token
67
+
68
+ if timeout is None :
69
+ response = session .get (url , stream = True )
70
+ else :
71
+ response = session .get (url , timeout = timeout , stream = True )
72
+
73
+ if "content-disposition" not in response .headers :
74
+ raise RuntimeError ("Internal error: headers don't contain content-disposition." )
75
+
76
+ filename = re .findall ('filename="(.+)"' , response .headers ["content-disposition" ])
77
+ if filename is None :
78
+ raise RuntimeError ("Filename could not be autodetected" )
74
79
return filename [0 ], StreamWrapper (response .raw )
75
80
76
81
77
82
class GDriveReaderDataPipe (IterDataPipe [Tuple [str , StreamWrapper ]]):
78
83
r"""
79
- Iterable DataPipe that takes URLs point at GDrive files, and
80
- yields tuples of file name and IO stream
84
+ Iterable DataPipe that takes URLs point at GDrive files, and yields tuples of file name and IO stream.
81
85
82
86
Args:
83
87
source_datapipe: a DataPipe that contains URLs to GDrive files
88
+ timeout: timeout in seconds for http request
84
89
"""
85
90
source_datapipe : IterDataPipe [str ]
86
91
87
- def __init__ (self , source_datapipe : IterDataPipe [str ]) -> None :
92
+ def __init__ (self , source_datapipe : IterDataPipe [str ], * , timeout : Optional [ float ] = None ) -> None :
88
93
self .source_datapipe = source_datapipe
94
+ self .timeout = timeout
89
95
90
96
def __iter__ (self ) -> Iterator [Tuple [str , StreamWrapper ]]:
91
97
for url in self .source_datapipe :
92
- yield _get_response_from_google_drive (url )
98
+ yield _get_response_from_google_drive (url , timeout = self . timeout )
93
99
94
100
def __len__ (self ) -> int :
95
101
return len (self .source_datapipe )
@@ -98,15 +104,15 @@ def __len__(self) -> int:
98
104
class OnlineReaderIterDataPipe (IterDataPipe [Tuple [str , StreamWrapper ]]):
99
105
r""":class:
100
106
Iterable DataPipe that takes file URLs (can be http URLs pointing to files or URLs to GDrive files), and
101
- yields tuples of file URL and IO stream
107
+ yields tuples of file URL and IO stream.
102
108
103
109
Args:
104
110
source_datapipe: a DataPipe that contains URLs
105
- timeout : timeout in seconds for http request
111
+ timeout: timeout in seconds for http request
106
112
"""
107
113
source_datapipe : IterDataPipe [str ]
108
114
109
- def __init__ (self , source_datapipe : IterDataPipe [str ], * , timeout = None ) -> None :
115
+ def __init__ (self , source_datapipe : IterDataPipe [str ], * , timeout : Optional [ float ] = None ) -> None :
110
116
self .source_datapipe = source_datapipe
111
117
self .timeout = timeout
112
118
@@ -115,8 +121,7 @@ def __iter__(self) -> Iterator[Tuple[str, StreamWrapper]]:
115
121
parts = urlparse (url )
116
122
117
123
if re .match (r"(drive|docs)[.]google[.]com" , parts .netloc ):
118
- # TODO(137): can this also have a timeout?
119
- yield _get_response_from_google_drive (url )
124
+ yield _get_response_from_google_drive (url , timeout = self .timeout )
120
125
else :
121
126
yield _get_response_from_http (url , timeout = self .timeout )
122
127
0 commit comments