Skip to content

Commit ed8d6a0

Browse files
authored
make IFileStorage instances individually namable (#3416)
1 parent 9bb930d commit ed8d6a0

File tree

5 files changed

+54
-16
lines changed

5 files changed

+54
-16
lines changed

tests/unit/packaging/test_init.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import pretend
1414
import pytest
1515

16+
from functools import partial
17+
1618
from celery.schedules import crontab
1719

1820
from warehouse import packaging
@@ -24,7 +26,9 @@
2426

2527
@pytest.mark.parametrize("with_trending", [True, False])
2628
def test_includme(monkeypatch, with_trending):
27-
storage_class = pretend.stub(create_service=pretend.stub())
29+
storage_class = pretend.stub(
30+
create_service=pretend.call_recorder(lambda *a, **kw: pretend.stub())
31+
)
2832

2933
def key_factory(keystring, iterate_on=None):
3034
return pretend.call(keystring, iterate_on=iterate_on)
@@ -50,10 +54,18 @@ def key_factory(keystring, iterate_on=None):
5054

5155
packaging.includeme(config)
5256

53-
assert config.register_service_factory.calls == [
54-
pretend.call(storage_class.create_service, IFileStorage, name='files'),
55-
pretend.call(storage_class.create_service, IFileStorage, name='docs'),
56-
]
57+
assert repr(config.register_service_factory.calls[0]) == repr(
58+
pretend.call(
59+
partial(storage_class.create_service, name='files'),
60+
IFileStorage, name='files'
61+
)
62+
)
63+
assert repr(config.register_service_factory.calls[1]) == repr(
64+
pretend.call(
65+
partial(storage_class.create_service, name='docs'),
66+
IFileStorage, name='docs'
67+
)
68+
)
5769
assert config.register_origin_cache_keys.calls == [
5870
pretend.call(
5971
File,

tests/unit/packaging/test_services.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,18 @@ def test_create_service(self):
3939
settings={"files.path": "/the/one/two/"},
4040
),
4141
)
42-
storage = LocalFileStorage.create_service(None, request)
42+
storage = LocalFileStorage.create_service(None, request, name='files')
4343
assert storage.base == "/the/one/two/"
4444

45+
def test_create_service_no_name(self):
46+
request = pretend.stub(
47+
registry=pretend.stub(
48+
settings={"files.path": "/the/one/two/"},
49+
),
50+
)
51+
with pytest.raises(ValueError):
52+
LocalFileStorage.create_service(None, request)
53+
4554
def test_gets_file(self, tmpdir):
4655
with open(str(tmpdir.join("file.txt")), "wb") as fp:
4756
fp.write(b"my test file contents")
@@ -152,11 +161,20 @@ def test_create_service(self):
152161
find_service=pretend.call_recorder(lambda name: session),
153162
registry=pretend.stub(settings={"files.bucket": "froblob"}),
154163
)
155-
storage = S3FileStorage.create_service(None, request)
164+
storage = S3FileStorage.create_service(None, request, name='files')
156165

157166
assert request.find_service.calls == [pretend.call(name="aws.session")]
158167
assert storage.bucket.name == "froblob"
159168

169+
def test_create_service_without_name(self):
170+
session = boto3.session.Session()
171+
request = pretend.stub(
172+
find_service=pretend.call_recorder(lambda name: session),
173+
registry=pretend.stub(settings={"files.bucket": "froblob"}),
174+
)
175+
with pytest.raises(ValueError):
176+
S3FileStorage.create_service(None, request)
177+
160178
def test_gets_file(self):
161179
s3key = pretend.stub(get=lambda: {"Body": io.BytesIO(b"my contents")})
162180
bucket = pretend.stub(Object=pretend.call_recorder(lambda path: s3key))

warehouse/packaging/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
# See the License for the specific language governing permissions and
1111
# limitations under the License.
1212

13+
from functools import partial
14+
1315
from celery.schedules import crontab
1416
from sqlalchemy.orm.base import NO_VALUE
1517

@@ -40,14 +42,16 @@ def includeme(config):
4042
config.registry.settings["files.backend"],
4143
)
4244
config.register_service_factory(
43-
files_storage_class.create_service, IFileStorage, name='files'
45+
partial(files_storage_class.create_service, name='files'),
46+
IFileStorage, name='files'
4447
)
4548

4649
docs_storage_class = config.maybe_dotted(
4750
config.registry.settings["docs.backend"],
4851
)
4952
config.register_service_factory(
50-
docs_storage_class.create_service, IFileStorage, name='docs'
53+
partial(docs_storage_class.create_service, name='docs'),
54+
IFileStorage, name='docs'
5155
)
5256

5357
# Register our origin cache keys

warehouse/packaging/interfaces.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515

1616
class IFileStorage(Interface):
1717

18-
def create_service(context, request):
18+
def create_service(context, request, name=None):
1919
"""
2020
Create the service, given the context and request for which it is being
21-
created for.
21+
created for, passing a name for settings.
2222
"""
2323

2424
def get(path):

warehouse/packaging/services.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,10 @@ def __init__(self, base):
3939
self.base = base
4040

4141
@classmethod
42-
def create_service(cls, context, request):
43-
return cls(request.registry.settings["files.path"])
42+
def create_service(cls, context, request, name=None):
43+
if name is None:
44+
raise ValueError('name is required')
45+
return cls(request.registry.settings[f"{name}.path"])
4446

4547
def get(self, path):
4648
return open(os.path.join(self.base, path), "rb")
@@ -69,12 +71,14 @@ def __init__(self, s3_client, bucket, *, prefix=None):
6971
self.prefix = prefix
7072

7173
@classmethod
72-
def create_service(cls, context, request):
74+
def create_service(cls, context, request, name=None):
75+
if name is None:
76+
raise ValueError('name is required')
7377
session = request.find_service(name="aws.session")
7478
s3_client = session.client("s3")
7579
s3 = session.resource("s3")
76-
bucket = s3.Bucket(request.registry.settings["files.bucket"])
77-
prefix = request.registry.settings.get("files.prefix")
80+
bucket = s3.Bucket(request.registry.settings[f"{name}.bucket"])
81+
prefix = request.registry.settings.get(f"{name}.prefix")
7882
return cls(s3_client, bucket, prefix=prefix)
7983

8084
def _get_path(self, path):

0 commit comments

Comments
 (0)