Skip to content

Add malware check syncing mechanism #7190

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 7, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions bin/release
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@ set -eo pipefail

# Migrate our database to the latest revision.
python -m warehouse db upgrade head

# Insert/upgrade malware checks.
python -m warehouse malware sync-checks
36 changes: 36 additions & 0 deletions tests/unit/cli/test_malware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pretend

from warehouse.cli.malware import sync_checks
from warehouse.malware.tasks import sync_checks as _sync_checks


class TestCLIMalware:
def test_sync_checks(self, cli):
request = pretend.stub()
task = pretend.stub(
get_request=pretend.call_recorder(lambda *a, **kw: request),
run=pretend.call_recorder(lambda *a, **kw: None),
)
config = pretend.stub(task=pretend.call_recorder(lambda *a, **kw: task))

result = cli.invoke(sync_checks, obj=config)

assert result.exit_code == 0
assert config.task.calls == [
pretend.call(_sync_checks),
pretend.call(_sync_checks),
]
assert task.get_request.calls == [pretend.call()]
assert task.run.calls == [pretend.call(request)]
258 changes: 215 additions & 43 deletions tests/unit/malware/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,68 +18,240 @@

import warehouse.malware.checks as checks

from warehouse.malware.models import MalwareVerdict
from warehouse.malware.tasks import run_check
from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareVerdict
from warehouse.malware.tasks import run_check, sync_checks

from ...common.db.malware import MalwareCheckFactory
from ...common.db.packaging import FileFactory, ProjectFactory, ReleaseFactory


def test_run_check(monkeypatch, db_request):
project = ProjectFactory.create(name="foo")
release = ReleaseFactory.create(project=project)
file0 = FileFactory.create(release=release, filename="foo.bar")
MalwareCheckFactory.create(name="ExampleCheck", state="enabled")
class TestRunCheck:
def test_success(self, monkeypatch, db_request):
project = ProjectFactory.create(name="foo")
release = ReleaseFactory.create(project=project)
file0 = FileFactory.create(release=release, filename="foo.bar")
MalwareCheckFactory.create(name="ExampleCheck", state=MalwareCheckState.enabled)

task = pretend.stub()
run_check(task, db_request, "ExampleCheck", file0.id)
assert db_request.db.query(MalwareVerdict).one()
task = pretend.stub()
run_check(task, db_request, "ExampleCheck", file0.id)
assert db_request.db.query(MalwareVerdict).one()

def test_missing_check_id(self, monkeypatch, db_session):
exc = NoResultFound("No row was found for one()")

def test_run_check_missing_check_id(monkeypatch, db_session):
exc = NoResultFound("No row was found for one()")
class FakeMalwareCheck:
def __init__(self, db):
raise exc

class FakeMalwareCheck:
def __init__(self, db):
raise exc
class Task:
@staticmethod
@pretend.call_recorder
def retry(exc):
raise celery.exceptions.Retry

class Task:
@staticmethod
@pretend.call_recorder
def retry(exc):
raise celery.exceptions.Retry
task = Task()

task = Task()
checks.FakeMalwareCheck = FakeMalwareCheck

checks.FakeMalwareCheck = FakeMalwareCheck
request = pretend.stub(
db=db_session,
log=pretend.stub(
error=pretend.call_recorder(lambda *args, **kwargs: None),
),
)

with pytest.raises(celery.exceptions.Retry):
run_check(
task,
request,
"FakeMalwareCheck",
"d03d75d1-2511-4a8b-9759-62294a6fe3a7",
)

assert request.log.error.calls == [
pretend.call(
"Error executing check %s: %s",
"FakeMalwareCheck",
"No row was found for one()",
)
]

assert task.retry.calls == [pretend.call(exc=exc)]

del checks.FakeMalwareCheck

def test_missing_check(self, db_request):
task = pretend.stub()
with pytest.raises(AttributeError):
run_check(
task,
db_request,
"DoesNotExistCheck",
"d03d75d1-2511-4a8b-9759-62294a6fe3a7",
)


class TestSyncChecks:
def test_no_updates(self, db_session):
MalwareCheckFactory.create(
name="ExampleCheck", state=MalwareCheckState.disabled
)

task = pretend.stub()

request = pretend.stub(
db=db_session,
log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None),),
)

sync_checks(task, request)

assert request.log.info.calls == [
pretend.call("1 malware checks found in codebase."),
pretend.call("ExampleCheck is unmodified."),
]

request = pretend.stub(
db=db_session,
log=pretend.stub(error=pretend.call_recorder(lambda *args, **kwargs: None),),
@pytest.mark.parametrize(
("final_state"), [MalwareCheckState.enabled, MalwareCheckState.disabled]
)
def test_upgrade_check(self, monkeypatch, db_session, final_state):
MalwareCheckFactory.create(name="ExampleCheck", state=final_state)

class ExampleCheck:
version = 2
short_description = "This is a short description."
long_description = "This is a longer description."
check_type = "scheduled"

monkeypatch.setattr(checks, "ExampleCheck", ExampleCheck)

task = pretend.stub()
request = pretend.stub(
db=db_session,
log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None),),
)

sync_checks(task, request)

assert request.log.info.calls == [
pretend.call("1 malware checks found in codebase."),
pretend.call("Updating existing ExampleCheck."),
]
db_checks = (
db_session.query(MalwareCheck)
.filter(MalwareCheck.name == "ExampleCheck")
.all()
)

assert len(db_checks) == 2

with pytest.raises(celery.exceptions.Retry):
run_check(
task, request, "FakeMalwareCheck", "d03d75d1-2511-4a8b-9759-62294a6fe3a7"
if final_state == MalwareCheckState.disabled:
assert (
db_checks[0].state == db_checks[1].state == MalwareCheckState.disabled
)

else:
for c in db_checks:
if c.state == final_state:
assert c.version == 2
else:
assert c.version == 1

def test_one_new_check(self, db_session):
task = pretend.stub()

class FakeMalwareCheck:
version = 1
short_description = "This is a short description."
long_description = "This is a longer description."
check_type = "scheduled"

checks.FakeMalwareCheck = FakeMalwareCheck

request = pretend.stub(
db=db_session,
log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None),),
)

MalwareCheckFactory.create(
name="ExampleCheck", state=MalwareCheckState.evaluation
)

sync_checks(task, request)

assert request.log.info.calls == [
pretend.call("2 malware checks found in codebase."),
pretend.call("ExampleCheck is unmodified."),
pretend.call("Adding new FakeMalwareCheck to the database."),
]
assert db_session.query(MalwareCheck).count() == 2

new_check = (
db_session.query(MalwareCheck)
.filter(MalwareCheck.name == "FakeMalwareCheck")
.one()
)

assert request.log.error.calls == [
pretend.call(
"Error executing check %s: %s",
"FakeMalwareCheck",
"No row was found for one()",
assert new_check.state == MalwareCheckState.disabled

del checks.FakeMalwareCheck

def test_too_many_db_checks(self, db_session):
task = pretend.stub()

MalwareCheckFactory.create(name="ExampleCheck", state=MalwareCheckState.enabled)
MalwareCheckFactory.create(
name="AnotherCheck", state=MalwareCheckState.disabled
)
MalwareCheckFactory.create(
name="AnotherCheck", state=MalwareCheckState.evaluation, version=2
)

request = pretend.stub(
db=db_session,
log=pretend.stub(
info=pretend.call_recorder(lambda *args, **kwargs: None),
error=pretend.call_recorder(lambda *args, **kwargs: None),
),
)
]

assert task.retry.calls == [pretend.call(exc=exc)]
with pytest.raises(Exception):
sync_checks(task, request)

assert request.log.info.calls == [
pretend.call("1 malware checks found in codebase."),
]

def test_run_check_missing_check(db_request):
task = pretend.stub()
with pytest.raises(AttributeError):
run_check(
task,
db_request,
"DoesNotExistCheck",
"d03d75d1-2511-4a8b-9759-62294a6fe3a7",
assert request.log.error.calls == [
pretend.call(
"""Found 2 active checks in the db, but only 1 checks in \
code. Please manually move superfluous checks to the wiped_out state \
in the check admin."""
),
]

def test_only_wiped_out(self, db_session):
task = pretend.stub()
MalwareCheckFactory.create(
name="ExampleCheck", state=MalwareCheckState.wiped_out
)
request = pretend.stub(
db=db_session,
log=pretend.stub(
info=pretend.call_recorder(lambda *args, **kwargs: None),
error=pretend.call_recorder(lambda *args, **kwargs: None),
),
)

sync_checks(task, request)

assert request.log.info.calls == [
pretend.call("1 malware checks found in codebase."),
]

assert request.log.error.calls == [
pretend.call(
"ExampleCheck is wiped_out and cannot be synced. \
Please remove check from codebase."
),
]
Loading