Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 40 additions & 4 deletions usaspending_api/common/retrieve_file_from_uri.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import boto3
import io
import requests
import tempfile
import urllib
Expand All @@ -16,6 +17,41 @@
)


class SpooledTempFileIOBase(tempfile.SpooledTemporaryFile, io.IOBase):
"""Improving the current implementation of standard library's
SpooledTemporaryFile class so that it mimics IOBase abstract. This is a
documented issue (https://bugs.python.org/issue26175) and has an open PR
(https://github.com/python/cpython/pull/3249) to fix the issue.

Inheriting the two classes and adding a few functions gets this class
_close_ to what is needs to be. If future issues appear it might be better
to bite the bullet and only use tempfile.NamedTemporaryFile()
"""

def readable(self):
return self._file.readable()

def readinto(self, b):
return self._file.readinto(b)

def writable(self):
return self._file.writable()

def seekable(self):
return self._file.seekable()

def seek(self, *args):
return self._file.seek(*args)

def truncate(self, size=None):
if size is None:
return self._file.truncate()

if size > self._max_size:
self.rollover()
return self._file.truncate(size)


class RetrieveFileFromUri:
def __init__(self, ruri):
self.ruri = ruri # Relative Uniform Resource Locator
Expand Down Expand Up @@ -43,7 +79,7 @@ def get_file_object(self, text=False):
def copy(self, dest_file_path):
"""
create a copy of the file and place at "dest_file_path" which
currently must be a filesystem path (not s3 or http).
currently must be a file system path (not s3 or http).
"""
if self.parsed_url_obj.scheme == "s3":
file_path = self.parsed_url_obj.path[1:] # remove leading '/' character
Expand All @@ -69,20 +105,20 @@ def _handle_s3(self, text):
boto3_s3 = boto3.resource("s3", region_name=settings.USASPENDING_AWS_REGION)
s3_bucket = boto3_s3.Bucket(self.parsed_url_obj.netloc)

f = tempfile.SpooledTemporaryFile() # Must be in binary mode (default)
f = SpooledTempFileIOBase() # Must be in binary mode (default)
s3_bucket.download_fileobj(file_path, f)

if text:
byte_str = f._file.getvalue()
f = tempfile.SpooledTemporaryFile(mode="r")
f = SpooledTempFileIOBase(mode="r")
f.write(byte_str.decode())

f.seek(0) # go to beginning of file for reading
return f

def _handle_http(self, text):
r = requests.get(self.ruri, allow_redirects=True)
f = tempfile.SpooledTemporaryFile(mode="w" if text else "w+b")
f = SpooledTempFileIOBase(mode="w" if text else "w+b")
f.write(r.text if text else r.content)
f.seek(0) # go to beginning of file for reading
return f
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# Example.com exists for this very purpose. I don't know about its uptime or
# anything. If downtime proves to be an issue, we can switch to something else.
URL = "http://example.com/"
SMALL_FILE_URL = "https://github.com/fedspendingtransparency/usaspending-api/blob/dev/README.md"


def test_retrieve_from_file():
Expand Down Expand Up @@ -63,3 +64,25 @@ def test_http_copy(temp_file_path):
c = f.read()
assert type(c) is str
assert len(c) > 0


def test_iobase_api(temp_file_path):
"""Testing IOBase API https://docs.python.org/3/library/io.html#io.IOBase
using a well-written standard library function, like open() will pass these
with flying colors. tempfile.SpooledTemporaryFile() is missing some expected
API which can cause issues. If you see an attribute error like below then
it might be good to leverage the custom class

```AttributeError: 'SpooledTemporaryFile' object has no attribute 'readable'```
"""

sources = (__file__, SMALL_FILE_URL)

def test_methods(f):
assert f.tell() == 0
assert f.readable() is True
assert f.seekable() is True

for source in sources:
with RetrieveFileFromUri(source).get_file_object() as f:
test_methods(f)