Skip to content

Commit 7947a92

Browse files
xmunozewdurbin
authored andcommitted
Add malware check syncing mechanism (#7190)
* Add malware check syncing mechanism * Code review changes.
1 parent e67a761 commit 7947a92

File tree

9 files changed

+464
-74
lines changed

9 files changed

+464
-74
lines changed

bin/release

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,6 @@ set -eo pipefail
55

66
# Migrate our database to the latest revision.
77
python -m warehouse db upgrade head
8+
9+
# Insert/upgrade malware checks.
10+
python -m warehouse malware sync-checks

tests/unit/cli/test_malware.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License");
2+
# you may not use this file except in compliance with the License.
3+
# You may obtain a copy of the License at
4+
#
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
13+
import pretend
14+
15+
from warehouse.cli.malware import sync_checks
16+
from warehouse.malware.tasks import sync_checks as _sync_checks
17+
18+
19+
class TestCLIMalware:
20+
def test_sync_checks(self, cli):
21+
request = pretend.stub()
22+
task = pretend.stub(
23+
get_request=pretend.call_recorder(lambda *a, **kw: request),
24+
run=pretend.call_recorder(lambda *a, **kw: None),
25+
)
26+
config = pretend.stub(task=pretend.call_recorder(lambda *a, **kw: task))
27+
28+
result = cli.invoke(sync_checks, obj=config)
29+
30+
assert result.exit_code == 0
31+
assert config.task.calls == [
32+
pretend.call(_sync_checks),
33+
pretend.call(_sync_checks),
34+
]
35+
assert task.get_request.calls == [pretend.call()]
36+
assert task.run.calls == [pretend.call(request)]

tests/unit/malware/test_checks.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License");
2+
# you may not use this file except in compliance with the License.
3+
# You may obtain a copy of the License at
4+
#
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
13+
import inspect
14+
15+
import warehouse.malware.checks as checks
16+
17+
from warehouse.malware.checks.base import MalwareCheckBase
18+
from warehouse.malware.utils import get_check_fields
19+
20+
21+
def test_checks_subclass_base():
22+
checks_from_module = inspect.getmembers(checks, inspect.isclass)
23+
24+
subclasses_of_malware_base = {
25+
cls.__name__: cls for cls in MalwareCheckBase.__subclasses__()
26+
}
27+
28+
assert len(checks_from_module) == len(subclasses_of_malware_base)
29+
30+
for check_name, check in checks_from_module:
31+
assert subclasses_of_malware_base[check_name] == check
32+
33+
34+
def test_checks_fields():
35+
checks_from_module = inspect.getmembers(checks, inspect.isclass)
36+
37+
for check_name, check in checks_from_module:
38+
elems = inspect.getmembers(check, lambda a: not (inspect.isroutine(a)))
39+
inspection_fields = {"name": check_name}
40+
for elem_name, value in elems:
41+
if not elem_name.startswith("__"):
42+
inspection_fields[elem_name] = value
43+
fields = get_check_fields(check)
44+
45+
assert inspection_fields == fields

tests/unit/malware/test_tasks.py

Lines changed: 215 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -18,68 +18,240 @@
1818

1919
import warehouse.malware.checks as checks
2020

21-
from warehouse.malware.models import MalwareVerdict
22-
from warehouse.malware.tasks import run_check
21+
from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareVerdict
22+
from warehouse.malware.tasks import run_check, sync_checks
2323

2424
from ...common.db.malware import MalwareCheckFactory
2525
from ...common.db.packaging import FileFactory, ProjectFactory, ReleaseFactory
2626

2727

28-
def test_run_check(monkeypatch, db_request):
29-
project = ProjectFactory.create(name="foo")
30-
release = ReleaseFactory.create(project=project)
31-
file0 = FileFactory.create(release=release, filename="foo.bar")
32-
MalwareCheckFactory.create(name="ExampleCheck", state="enabled")
28+
class TestRunCheck:
29+
def test_success(self, monkeypatch, db_request):
30+
project = ProjectFactory.create(name="foo")
31+
release = ReleaseFactory.create(project=project)
32+
file0 = FileFactory.create(release=release, filename="foo.bar")
33+
MalwareCheckFactory.create(name="ExampleCheck", state=MalwareCheckState.enabled)
3334

34-
task = pretend.stub()
35-
run_check(task, db_request, "ExampleCheck", file0.id)
36-
assert db_request.db.query(MalwareVerdict).one()
35+
task = pretend.stub()
36+
run_check(task, db_request, "ExampleCheck", file0.id)
37+
assert db_request.db.query(MalwareVerdict).one()
3738

39+
def test_missing_check_id(self, monkeypatch, db_session):
40+
exc = NoResultFound("No row was found for one()")
3841

39-
def test_run_check_missing_check_id(monkeypatch, db_session):
40-
exc = NoResultFound("No row was found for one()")
42+
class FakeMalwareCheck:
43+
def __init__(self, db):
44+
raise exc
4145

42-
class FakeMalwareCheck:
43-
def __init__(self, db):
44-
raise exc
46+
checks.FakeMalwareCheck = FakeMalwareCheck
4547

46-
class Task:
47-
@staticmethod
48-
@pretend.call_recorder
49-
def retry(exc):
50-
raise celery.exceptions.Retry
48+
class Task:
49+
@staticmethod
50+
@pretend.call_recorder
51+
def retry(exc):
52+
raise celery.exceptions.Retry
5153

52-
task = Task()
54+
task = Task()
5355

54-
checks.FakeMalwareCheck = FakeMalwareCheck
56+
request = pretend.stub(
57+
db=db_session,
58+
log=pretend.stub(
59+
error=pretend.call_recorder(lambda *args, **kwargs: None),
60+
),
61+
)
62+
63+
with pytest.raises(celery.exceptions.Retry):
64+
run_check(
65+
task,
66+
request,
67+
"FakeMalwareCheck",
68+
"d03d75d1-2511-4a8b-9759-62294a6fe3a7",
69+
)
70+
71+
assert request.log.error.calls == [
72+
pretend.call(
73+
"Error executing check %s: %s",
74+
"FakeMalwareCheck",
75+
"No row was found for one()",
76+
)
77+
]
78+
79+
assert task.retry.calls == [pretend.call(exc=exc)]
80+
81+
del checks.FakeMalwareCheck
82+
83+
def test_missing_check(self, db_request):
84+
task = pretend.stub()
85+
with pytest.raises(AttributeError):
86+
run_check(
87+
task,
88+
db_request,
89+
"DoesNotExistCheck",
90+
"d03d75d1-2511-4a8b-9759-62294a6fe3a7",
91+
)
92+
93+
94+
class TestSyncChecks:
95+
def test_no_updates(self, db_session):
96+
MalwareCheckFactory.create(
97+
name="ExampleCheck", state=MalwareCheckState.disabled
98+
)
99+
100+
task = pretend.stub()
101+
102+
request = pretend.stub(
103+
db=db_session,
104+
log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None),),
105+
)
106+
107+
sync_checks(task, request)
108+
109+
assert request.log.info.calls == [
110+
pretend.call("1 malware checks found in codebase."),
111+
pretend.call("ExampleCheck is unmodified."),
112+
]
55113

56-
request = pretend.stub(
57-
db=db_session,
58-
log=pretend.stub(error=pretend.call_recorder(lambda *args, **kwargs: None),),
114+
@pytest.mark.parametrize(
115+
("final_state"), [MalwareCheckState.enabled, MalwareCheckState.disabled]
59116
)
117+
def test_upgrade_check(self, monkeypatch, db_session, final_state):
118+
MalwareCheckFactory.create(name="ExampleCheck", state=final_state)
119+
120+
class ExampleCheck:
121+
version = 2
122+
short_description = "This is a short description."
123+
long_description = "This is a longer description."
124+
check_type = "scheduled"
125+
126+
monkeypatch.setattr(checks, "ExampleCheck", ExampleCheck)
127+
128+
task = pretend.stub()
129+
request = pretend.stub(
130+
db=db_session,
131+
log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None),),
132+
)
133+
134+
sync_checks(task, request)
135+
136+
assert request.log.info.calls == [
137+
pretend.call("1 malware checks found in codebase."),
138+
pretend.call("Updating existing ExampleCheck."),
139+
]
140+
db_checks = (
141+
db_session.query(MalwareCheck)
142+
.filter(MalwareCheck.name == "ExampleCheck")
143+
.all()
144+
)
145+
146+
assert len(db_checks) == 2
60147

61-
with pytest.raises(celery.exceptions.Retry):
62-
run_check(
63-
task, request, "FakeMalwareCheck", "d03d75d1-2511-4a8b-9759-62294a6fe3a7"
148+
if final_state == MalwareCheckState.disabled:
149+
assert (
150+
db_checks[0].state == db_checks[1].state == MalwareCheckState.disabled
151+
)
152+
153+
else:
154+
for c in db_checks:
155+
if c.state == final_state:
156+
assert c.version == 2
157+
else:
158+
assert c.version == 1
159+
160+
def test_one_new_check(self, db_session):
161+
task = pretend.stub()
162+
163+
class FakeMalwareCheck:
164+
version = 1
165+
short_description = "This is a short description."
166+
long_description = "This is a longer description."
167+
check_type = "scheduled"
168+
169+
checks.FakeMalwareCheck = FakeMalwareCheck
170+
171+
request = pretend.stub(
172+
db=db_session,
173+
log=pretend.stub(info=pretend.call_recorder(lambda *args, **kwargs: None),),
174+
)
175+
176+
MalwareCheckFactory.create(
177+
name="ExampleCheck", state=MalwareCheckState.evaluation
178+
)
179+
180+
sync_checks(task, request)
181+
182+
assert request.log.info.calls == [
183+
pretend.call("2 malware checks found in codebase."),
184+
pretend.call("ExampleCheck is unmodified."),
185+
pretend.call("Adding new FakeMalwareCheck to the database."),
186+
]
187+
assert db_session.query(MalwareCheck).count() == 2
188+
189+
new_check = (
190+
db_session.query(MalwareCheck)
191+
.filter(MalwareCheck.name == "FakeMalwareCheck")
192+
.one()
64193
)
65194

66-
assert request.log.error.calls == [
67-
pretend.call(
68-
"Error executing check %s: %s",
69-
"FakeMalwareCheck",
70-
"No row was found for one()",
195+
assert new_check.state == MalwareCheckState.disabled
196+
197+
del checks.FakeMalwareCheck
198+
199+
def test_too_many_db_checks(self, db_session):
200+
task = pretend.stub()
201+
202+
MalwareCheckFactory.create(name="ExampleCheck", state=MalwareCheckState.enabled)
203+
MalwareCheckFactory.create(
204+
name="AnotherCheck", state=MalwareCheckState.disabled
205+
)
206+
MalwareCheckFactory.create(
207+
name="AnotherCheck", state=MalwareCheckState.evaluation, version=2
208+
)
209+
210+
request = pretend.stub(
211+
db=db_session,
212+
log=pretend.stub(
213+
info=pretend.call_recorder(lambda *args, **kwargs: None),
214+
error=pretend.call_recorder(lambda *args, **kwargs: None),
215+
),
71216
)
72-
]
73217

74-
assert task.retry.calls == [pretend.call(exc=exc)]
218+
with pytest.raises(Exception):
219+
sync_checks(task, request)
75220

221+
assert request.log.info.calls == [
222+
pretend.call("1 malware checks found in codebase."),
223+
]
76224

77-
def test_run_check_missing_check(db_request):
78-
task = pretend.stub()
79-
with pytest.raises(AttributeError):
80-
run_check(
81-
task,
82-
db_request,
83-
"DoesNotExistCheck",
84-
"d03d75d1-2511-4a8b-9759-62294a6fe3a7",
225+
assert request.log.error.calls == [
226+
pretend.call(
227+
"Found 2 active checks in the db, but only 1 checks in code. Please \
228+
manually move superfluous checks to the wiped_out state in the check admin: \
229+
AnotherCheck"
230+
),
231+
]
232+
233+
def test_only_wiped_out(self, db_session):
234+
task = pretend.stub()
235+
MalwareCheckFactory.create(
236+
name="ExampleCheck", state=MalwareCheckState.wiped_out
237+
)
238+
request = pretend.stub(
239+
db=db_session,
240+
log=pretend.stub(
241+
info=pretend.call_recorder(lambda *args, **kwargs: None),
242+
error=pretend.call_recorder(lambda *args, **kwargs: None),
243+
),
85244
)
245+
246+
sync_checks(task, request)
247+
248+
assert request.log.info.calls == [
249+
pretend.call("1 malware checks found in codebase."),
250+
]
251+
252+
assert request.log.error.calls == [
253+
pretend.call(
254+
"ExampleCheck is wiped_out and cannot be synced. Please remove check \
255+
from codebase."
256+
),
257+
]

0 commit comments

Comments
 (0)