Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 166 additions & 0 deletions tests/unit/cli/test_hashing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import factory
import pretend

from warehouse import db
from warehouse.accounts.models import User
from warehouse.cli import hashing

from ...common.db.accounts import UserEventFactory, UserFactory
from ...common.db.ip_addresses import IpAddress


class TestBackfillIpAddresses:
def test_no_records_to_backfill(self, cli, db_request, monkeypatch):
engine = pretend.stub()
registry_dict = {}
config = pretend.stub(
registry=pretend.stub(
__getitem__=registry_dict.__getitem__,
__setitem__=registry_dict.__setitem__,
settings={"warehouse.ip_salt": "NaCl"},
)
)
config.registry["sqlalchemy.engine"] = engine
Comment on lines +27 to +35
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thought: ‏I struggled with this syntax for pretend.stub() to make an object that is both dict-accessible as well as method-friendly, since we call both config.registry["somekey"] as well as config.registry.settings["somekey"].
If there's a better way to represent this nesting with pretend.stub(), happy to change to that!

session_cls = pretend.call_recorder(lambda bind: db_request.db)
monkeypatch.setattr(db, "Session", session_cls)

assert db_request.db.query(User.Event).count() == 0

result = cli.invoke(hashing.backfill_ipaddrs, obj=config)

assert result.exit_code == 0
assert result.output.strip() == "No rows to backfill. Done!"

def test_backfill_with_no_ipaddr_obj(self, cli, db_session, monkeypatch):
engine = pretend.stub()
registry_dict = {}
config = pretend.stub(
registry=pretend.stub(
__getitem__=registry_dict.__getitem__,
__setitem__=registry_dict.__setitem__,
settings={"warehouse.ip_salt": "NaCl"},
)
)
config.registry["sqlalchemy.engine"] = engine
session_cls = pretend.call_recorder(lambda bind: db_session)
monkeypatch.setattr(db, "Session", session_cls)

user = UserFactory.create()
UserEventFactory.create_batch(
3,
source=user,
tag="dummy:tag",
ip_address_string=factory.Faker("ipv4"),
)
assert db_session.query(User.Event).count() == 3
assert db_session.query(IpAddress).count() == 0

result = cli.invoke(hashing.backfill_ipaddrs, obj=config)

assert result.exit_code == 0
assert db_session.query(IpAddress).count() == 3

def tests_backfills_records(self, cli, db_request, remote_addr, monkeypatch):
engine = pretend.stub()
registry_dict = {}
config = pretend.stub(
registry=pretend.stub(
__getitem__=registry_dict.__getitem__,
__setitem__=registry_dict.__setitem__,
settings={"warehouse.ip_salt": "NaCl"},
)
)
config.registry["sqlalchemy.engine"] = engine
session_cls = pretend.call_recorder(lambda bind: db_request.db)
monkeypatch.setattr(db, "Session", session_cls)

user = UserFactory.create()
UserEventFactory.create_batch(
3,
source=user,
tag="dummy:tag",
ip_address_string=remote_addr,
)
assert db_request.db.query(User.Event).count() == 3

args = [
"--batch-size",
"2",
]

result = cli.invoke(hashing.backfill_ipaddrs, args, obj=config)

assert result.exit_code == 0
assert result.output.strip() == "Backfilling 2 rows...\nBackfilled 2 rows"
# check that two of the ip addresses have been hashed
assert (
db_request.db.query(User.Event)
.where(User.Event.ip_address_id.is_not(None))
.count()
== 2
)
# and that there's only a single unassociated ip address left
assert (
db_request.db.query(User.Event)
.where(User.Event.ip_address_id.is_(None))
.one()
)

def test_continue_until_done(self, cli, db_request, remote_addr, monkeypatch):
engine = pretend.stub()
registry_dict = {}
config = pretend.stub(
registry=pretend.stub(
__getitem__=registry_dict.__getitem__,
__setitem__=registry_dict.__setitem__,
settings={"warehouse.ip_salt": "NaCl"},
)
)
config.registry["sqlalchemy.engine"] = engine
session_cls = pretend.call_recorder(lambda bind: db_request.db)
monkeypatch.setattr(db, "Session", session_cls)

user = UserFactory.create()
UserEventFactory.create_batch(
3,
source=user,
tag="dummy:tag",
ip_address_string=remote_addr,
)

args = [
"--batch-size",
"1",
"--sleep-time",
"0",
"--continue-until-done",
]

result = cli.invoke(hashing.backfill_ipaddrs, args, obj=config)

assert result.exit_code == 0
# check that all the ip addresses have been associated
assert (
db_request.db.query(User.Event)
.where(User.Event.ip_address_id.is_not(None))
.count()
== 3
)
assert (
db_request.db.query(User.Event)
.where(User.Event.ip_address_id.is_(None))
.count()
== 0
)
146 changes: 146 additions & 0 deletions warehouse/cli/hashing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import time

import click

from sqlalchemy import select
from sqlalchemy.exc import NoResultFound

from warehouse.cli import warehouse


@warehouse.group()
def hashing():
"""
Run Hashing operations for Warehouse data
"""


@hashing.command()
@click.option(
"-b",
"--batch-size",
default=10_000,
show_default=True,
help="Number of rows to associate at a time",
)
@click.option(
"-st",
"--sleep-time",
default=1,
show_default=True,
help="Number of seconds to sleep between batches",
)
@click.option(
"--continue-until-done",
is_flag=True,
default=False,
help="Continue until all rows are complete",
)
@click.pass_obj
def backfill_ipaddrs(
config,
batch_size: int,
sleep_time: int,
continue_until_done: bool,
):
"""
Backfill the `ip_addresses.ip_address` column for Events
"""
# Imported here because we don't want to trigger an import from anything
# but warehouse.cli at the module scope.
from warehouse.db import Session

# This lives in the outer function so we only create a single session per
# invocation of the CLI command.
session = Session(bind=config.registry["sqlalchemy.engine"])

salt = config.registry.settings["warehouse.ip_salt"]

_backfill_ips(session, salt, batch_size, sleep_time, continue_until_done)


def _backfill_ips(
session,
salt: str,
batch_size: int,
sleep_time: int,
continue_until_done: bool,
) -> None:
"""
Create missing IPAddress objects for events that don't have them.

Broken out from the CLI command so that it can be called recursively.

TODO: Currently operates on only User events, but should be expanded to
include Project events and others.
"""
from warehouse.accounts.models import User
from warehouse.ip_addresses.models import IpAddress

# Get rows a batch at a time, only if the row doesn't have an `ip_address_id
no_ip_obj_rows = session.scalars(
select(User.Event)
.where(User.Event.ip_address_id.is_(None)) # type: ignore[attr-defined]
.order_by(User.Event.time) # type: ignore[attr-defined]
.limit(batch_size)
).all()

if not no_ip_obj_rows:
click.echo("No rows to backfill. Done!")
return

how_many = len(no_ip_obj_rows)

click.echo(f"Backfilling {how_many} rows...")
for row in no_ip_obj_rows:
# See if there's already an IPAddress object for this IP.
# If not, create one.
try:
ip_addr = (
session.query(IpAddress)
.filter(IpAddress.ip_address == row.ip_address_string)
.one()
)
except NoResultFound:
ip_addr = IpAddress( # type: ignore[call-arg]
ip_address=row.ip_address_string,
hashed_ip_address=hashlib.sha256(
(row.ip_address_string + salt).encode("utf8")
).hexdigest(),
)
# Associate the IPAddress object with the Event
row.ip_address_obj = ip_addr
session.add(ip_addr)

# Update the rows with any new IPAddress objects
session.add_all(no_ip_obj_rows)
session.commit()

# If there are more rows to backfill, recurse until done
if continue_until_done and how_many == batch_size:
click.echo(
f"Backfilled {batch_size} rows. Sleeping for {sleep_time} second(s)..."
)
time.sleep(sleep_time)
_backfill_ips(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have more than 10,000,000 entries that need backfilled this will cause an error, but we've already committed the session so we won't lose progress, so can just run it again.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good note - where does the 10m entries limit come from?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am also curious of this.

session,
salt,
batch_size,
sleep_time,
continue_until_done,
)
else:
click.echo(f"Backfilled {how_many} rows")
return