Skip to content

Add wipe-out functionality #7202

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions tests/common/db/malware.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,13 @@
MalwareCheckObjectType,
MalwareCheckState,
MalwareCheckType,
MalwareVerdict,
VerdictClassification,
VerdictConfidence,
)

from .base import WarehouseFactory
from .packaging import FileFactory


class MalwareCheckFactory(WarehouseFactory):
Expand All @@ -33,9 +37,20 @@ class Meta:
version = 1
short_description = factory.fuzzy.FuzzyText(length=80)
long_description = factory.fuzzy.FuzzyText(length=300)
check_type = factory.fuzzy.FuzzyChoice([e for e in MalwareCheckType])
hooked_object = factory.fuzzy.FuzzyChoice([e for e in MalwareCheckObjectType])
state = factory.fuzzy.FuzzyChoice([e for e in MalwareCheckState])
check_type = factory.fuzzy.FuzzyChoice(list(MalwareCheckType))
hooked_object = factory.fuzzy.FuzzyChoice(list(MalwareCheckObjectType))
state = factory.fuzzy.FuzzyChoice(list(MalwareCheckState))
created = factory.fuzzy.FuzzyNaiveDateTime(
datetime.datetime.utcnow() - datetime.timedelta(days=7)
)


class MalwareVerdictFactory(WarehouseFactory):
class Meta:
model = MalwareVerdict

check = factory.SubFactory(MalwareCheckFactory)
release_file = factory.SubFactory(FileFactory)
classification = factory.fuzzy.FuzzyChoice(list(VerdictClassification))
confidence = factory.fuzzy.FuzzyChoice(list(VerdictConfidence))
message = factory.fuzzy.FuzzyText(length=80)
22 changes: 18 additions & 4 deletions tests/unit/admin/views/test_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,27 +67,41 @@ def test_get_check_not_found(self, db_request):


class TestChangeCheckState:
def test_change_to_enabled(self, db_request):
@pytest.mark.parametrize(
("final_state"), [MalwareCheckState.disabled, MalwareCheckState.wiped_out]
)
def test_change_to_valid_state(self, db_request, final_state):
check = MalwareCheckFactory.create(
name="MyCheck", state=MalwareCheckState.disabled
)

db_request.POST = {"id": check.id, "check_state": "enabled"}
db_request.POST = {"id": check.id, "check_state": final_state.value}
db_request.matchdict["check_name"] = check.name

db_request.session = pretend.stub(
flash=pretend.call_recorder(lambda *a, **kw: None)
)
wipe_out_recorder = pretend.stub(
delay=pretend.call_recorder(lambda *a, **kw: None)
)
db_request.task = pretend.call_recorder(lambda *a, **kw: wipe_out_recorder)

db_request.route_path = pretend.call_recorder(
lambda *a, **kw: "/admin/checks/MyCheck/change_state"
)

views.change_check_state(db_request)

assert db_request.session.flash.calls == [
pretend.call("Changed 'MyCheck' check to 'enabled'!", queue="success")
pretend.call(
"Changed 'MyCheck' check to '%s'!" % final_state.value, queue="success"
)
]
assert check.state == MalwareCheckState.enabled

assert check.state == final_state

if final_state == MalwareCheckState.wiped_out:
assert wipe_out_recorder.delay.calls == [pretend.call("MyCheck")]

def test_change_to_invalid_state(self, db_request):
check = MalwareCheckFactory.create(name="MyCheck")
Expand Down
56 changes: 54 additions & 2 deletions tests/unit/malware/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
import warehouse.malware.checks as checks

from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareVerdict
from warehouse.malware.tasks import run_check, sync_checks
from warehouse.malware.tasks import remove_verdicts, run_check, sync_checks

from ...common.db.malware import MalwareCheckFactory
from ...common.db.malware import MalwareCheckFactory, MalwareVerdictFactory
from ...common.db.packaging import FileFactory, ProjectFactory, ReleaseFactory


Expand Down Expand Up @@ -255,3 +255,55 @@ def test_only_wiped_out(self, db_session):
from codebase."
),
]


class TestRemoveVerdicts:
def test_no_verdicts(self, db_session):
check = MalwareCheckFactory.create()

request = pretend.stub(
db=db_session,
log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None),),
)
task = pretend.stub()
remove_verdicts(task, request, check.name)

assert request.log.info.calls == [
pretend.call(
"Removing 0 malware verdicts associated with %s version 1." % check.name
),
]

@pytest.mark.parametrize(("check_with_verdicts"), [True, False])
def test_many_verdicts(self, db_session, check_with_verdicts):
check0 = MalwareCheckFactory.create()
check1 = MalwareCheckFactory.create()
project = ProjectFactory.create(name="foo")
release = ReleaseFactory.create(project=project)
file0 = FileFactory.create(release=release, filename="foo.bar")
num_verdicts = 10

for i in range(num_verdicts):
MalwareVerdictFactory.create(check=check1, release_file=file0)

request = pretend.stub(
db=db_session,
log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None),),
)

task = pretend.stub()

if check_with_verdicts:
wiped_out_check = check1
else:
wiped_out_check = check0
num_verdicts = 0

remove_verdicts(task, request, wiped_out_check.name)

assert request.log.info.calls == [
pretend.call(
"Removing %d malware verdicts associated with %s version 1."
% (num_verdicts, wiped_out_check.name)
),
]
3 changes: 3 additions & 0 deletions warehouse/admin/views/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from sqlalchemy.orm.exc import NoResultFound

from warehouse.malware.models import MalwareCheck, MalwareCheckState
from warehouse.malware.tasks import remove_verdicts


@view_config(
Expand Down Expand Up @@ -80,6 +81,8 @@ def change_check_state(request):
except (AttributeError, KeyError):
request.session.flash("Invalid check state provided.", queue="error")
else:
if check.state == MalwareCheckState.wiped_out:
request.task(remove_verdicts).delay(check.name)
request.session.flash(
f"Changed {check.name!r} check to {check.state.value!r}!", queue="success"
)
Expand Down
22 changes: 21 additions & 1 deletion warehouse/malware/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import warehouse.malware.checks as checks

from warehouse.malware.models import MalwareCheck, MalwareCheckState
from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareVerdict
from warehouse.malware.utils import get_check_fields
from warehouse.tasks import task

Expand Down Expand Up @@ -86,3 +86,23 @@ def sync_checks(task, request):
request.log.info("Adding new %s to the database." % check_name)
fields = get_check_fields(check)
request.db.add(MalwareCheck(**fields))


@task(bind=True, ignore_result=True, acks_late=True)
def remove_verdicts(task, request, check_name):
check_ids = (
request.db.query(MalwareCheck.id, MalwareCheck.version)
.filter(MalwareCheck.name == check_name)
.all()
)

for check_id, check_version in check_ids:
query = request.db.query(MalwareVerdict).filter(
MalwareVerdict.check_id == check_id
)
num_verdicts = query.count()
request.log.info(
"Removing %d malware verdicts associated with %s version %d."
% (num_verdicts, check_name, check_version)
)
query.delete(synchronize_session=False)