Skip to content

Commit 7ea7ca8

Browse files
xmunozewdurbin
andcommitted
Add initial hook-based check execution mechanism (#7160)
* Add initial hook-based check execution mechanism * scratch/poc * Add initial hook-based check execution mechanism * Use sqlalchemy event hooks for malware checks * Fix unit tests * Add enum for MalwareCheckObjectType * Add unit tests for init. * Add tests for tasks, services, and utils. Also, some small bugfixes in MalwareCheckFactory and the get_enabled_checks method. * Fix spurious task test. * Add missing drop enum to downgrade function. * Added TODO to dev/environment * Be more explicit in check lookup Co-authored-by: Ernest W. Durbin III <[email protected]>
1 parent 0b29743 commit 7ea7ca8

20 files changed

+700
-10
lines changed

dev/environment

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ MAIL_BACKEND=warehouse.email.services.SMTPEmailSender host=smtp port=2525 ssl=fa
2929

3030
BREACHED_PASSWORDS=warehouse.accounts.NullPasswordBreachedService
3131

32+
#TODO: change this to PrinterMalwareCheckService before deploy
33+
MALWARE_CHECK_BACKEND=warehouse.malware.services.DatabaseMalwareCheckService
34+
3235
METRICS_BACKEND=warehouse.metrics.DataDogMetrics host=notdatadog
3336

3437
STATUSPAGE_URL=https://2p66nmmycsj3.statuspage.io

tests/common/db/malware.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,12 @@
1515
import factory
1616
import factory.fuzzy
1717

18-
from warehouse.malware.models import MalwareCheck, MalwareCheckState, MalwareCheckType
18+
from warehouse.malware.models import (
19+
MalwareCheck,
20+
MalwareCheckObjectType,
21+
MalwareCheckState,
22+
MalwareCheckType,
23+
)
1924

2025
from .base import WarehouseFactory
2126

@@ -29,11 +34,7 @@ class Meta:
2934
short_description = factory.fuzzy.FuzzyText(length=80)
3035
long_description = factory.fuzzy.FuzzyText(length=300)
3136
check_type = factory.fuzzy.FuzzyChoice([e for e in MalwareCheckType])
32-
hook_name = (
33-
"project:release:file:upload"
34-
if check_type == MalwareCheckType.event_hook
35-
else None
36-
)
37+
hooked_object = factory.fuzzy.FuzzyChoice([e for e in MalwareCheckObjectType])
3738
state = factory.fuzzy.FuzzyChoice([e for e in MalwareCheckState])
3839
created = factory.fuzzy.FuzzyNaiveDateTime(
3940
datetime.datetime.utcnow() - datetime.timedelta(days=7)

tests/conftest.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,9 @@ def app_config(database):
174174
"files.backend": "warehouse.packaging.services.LocalFileStorage",
175175
"docs.backend": "warehouse.packaging.services.LocalFileStorage",
176176
"mail.backend": "warehouse.email.services.SMTPEmailSender",
177+
"malware_check.backend": (
178+
"warehouse.malware.services.PrinterMalwareCheckService"
179+
),
177180
"files.url": "http://localhost:7000/",
178181
"sessions.secret": "123456",
179182
"sessions.url": "redis://localhost:0/",

tests/unit/malware/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License");
2+
# you may not use this file except in compliance with the License.
3+
# You may obtain a copy of the License at
4+
#
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.

tests/unit/malware/test_init.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License");
2+
# you may not use this file except in compliance with the License.
3+
# You may obtain a copy of the License at
4+
#
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
13+
from collections import defaultdict
14+
15+
import pretend
16+
17+
from warehouse import malware
18+
from warehouse.malware import utils
19+
from warehouse.malware.interfaces import IMalwareCheckService
20+
21+
from ...common.db.accounts import UserFactory
22+
from ...common.db.packaging import FileFactory, ProjectFactory, ReleaseFactory
23+
24+
25+
def test_determine_malware_checks_no_checks(monkeypatch, db_request):
26+
def get_enabled_checks(session):
27+
return defaultdict(list)
28+
29+
monkeypatch.setattr(utils, "get_enabled_checks", get_enabled_checks)
30+
31+
project = ProjectFactory.create(name="foo")
32+
release = ReleaseFactory.create(project=project)
33+
file0 = FileFactory.create(release=release, filename="foo.bar")
34+
35+
session = pretend.stub(info={}, new={file0, release, project}, dirty={}, deleted={})
36+
37+
malware.determine_malware_checks(pretend.stub(), session, pretend.stub())
38+
assert session.info["warehouse.malware.checks"] == set()
39+
40+
41+
def test_determine_malware_checks_nothing_new(monkeypatch, db_request):
42+
def get_enabled_checks(session):
43+
result = defaultdict(list)
44+
result["File"] = ["Check1", "Check2"]
45+
result["Release"] = ["Check3"]
46+
return result
47+
48+
monkeypatch.setattr(utils, "get_enabled_checks", get_enabled_checks)
49+
50+
project = ProjectFactory.create(name="foo")
51+
release = ReleaseFactory.create(project=project)
52+
file0 = FileFactory.create(release=release, filename="foo.bar")
53+
54+
session = pretend.stub(info={}, new={}, dirty={file0, release}, deleted={})
55+
56+
malware.determine_malware_checks(pretend.stub(), session, pretend.stub())
57+
assert session.info.get("warehouse.malware.checks") is None
58+
59+
60+
def test_determine_malware_checks_unsupported_object(monkeypatch, db_request):
61+
def get_enabled_checks(session):
62+
result = defaultdict(list)
63+
result["File"] = ["Check1", "Check2"]
64+
result["Release"] = ["Check3"]
65+
return result
66+
67+
monkeypatch.setattr(utils, "get_enabled_checks", get_enabled_checks)
68+
69+
user = UserFactory.create()
70+
71+
session = pretend.stub(info={}, new={user}, dirty={}, deleted={})
72+
73+
malware.determine_malware_checks(pretend.stub(), session, pretend.stub())
74+
assert session.info.get("warehouse.malware.checks") is None
75+
76+
77+
def test_determine_malware_checks_file_only(monkeypatch, db_request):
78+
def get_enabled_checks(session):
79+
result = defaultdict(list)
80+
result["File"] = ["Check1", "Check2"]
81+
result["Release"] = ["Check3"]
82+
return result
83+
84+
monkeypatch.setattr(utils, "get_enabled_checks", get_enabled_checks)
85+
86+
project = ProjectFactory.create(name="foo")
87+
release = ReleaseFactory.create(project=project)
88+
file0 = FileFactory.create(release=release, filename="foo.bar")
89+
90+
session = pretend.stub(info={}, new={file0}, dirty={}, deleted={})
91+
92+
checks = set(["Check%d:%s" % (x, file0.id) for x in range(1, 3)])
93+
malware.determine_malware_checks(pretend.stub(), session, pretend.stub())
94+
assert session.info["warehouse.malware.checks"] == checks
95+
96+
97+
def test_determine_malware_checks_file_and_release(monkeypatch, db_request):
98+
def get_enabled_checks(session):
99+
result = defaultdict(list)
100+
result["File"] = ["Check1", "Check2"]
101+
result["Release"] = ["Check3"]
102+
return result
103+
104+
monkeypatch.setattr(utils, "get_enabled_checks", get_enabled_checks)
105+
106+
project = ProjectFactory.create(name="foo")
107+
release = ReleaseFactory.create(project=project)
108+
file0 = FileFactory.create(release=release, filename="foo.bar")
109+
file1 = FileFactory.create(release=release, filename="foo.baz")
110+
111+
session = pretend.stub(
112+
info={}, new={project, release, file0, file1}, dirty={}, deleted={}
113+
)
114+
115+
checks = set(["Check%d:%s" % (x, file0.id) for x in range(1, 3)])
116+
checks.update(["Check%d:%s" % (x, file1.id) for x in range(1, 3)])
117+
checks.add("Check3:%s" % release.id)
118+
119+
malware.determine_malware_checks(pretend.stub(), session, pretend.stub())
120+
121+
assert session.info["warehouse.malware.checks"] == checks
122+
123+
124+
def test_enqueue_malware_checks(app_config):
125+
malware_check = pretend.stub(
126+
run_checks=pretend.call_recorder(lambda malware_checks: None)
127+
)
128+
factory = pretend.call_recorder(lambda ctx, config: malware_check)
129+
app_config.register_service_factory(factory, IMalwareCheckService)
130+
app_config.commit()
131+
session = pretend.stub(
132+
info={
133+
"warehouse.malware.checks": {"Check1:ba70267f-fabf-496f-9ac2-d237a983b187"}
134+
}
135+
)
136+
137+
malware.queue_malware_checks(app_config, session)
138+
139+
assert factory.calls == [pretend.call(None, app_config)]
140+
assert malware_check.run_checks.calls == [
141+
pretend.call({"Check1:ba70267f-fabf-496f-9ac2-d237a983b187"})
142+
]
143+
assert "warehouse.malware.checks" not in session.info
144+
145+
146+
def test_enqueue_malware_checks_no_checks(app_config):
147+
session = pretend.stub(info={})
148+
malware.queue_malware_checks(app_config, session)
149+
assert "warehouse.malware.checks" not in session.info
150+
151+
152+
def test_includeme():
153+
malware_check_class = pretend.stub(
154+
create_service=pretend.call_recorder(lambda *a, **kw: pretend.stub())
155+
)
156+
157+
config = pretend.stub(
158+
maybe_dotted=lambda dotted: malware_check_class,
159+
register_service_factory=pretend.call_recorder(
160+
lambda factory, iface, name=None: None
161+
),
162+
registry=pretend.stub(
163+
settings={"malware_check.backend": "TestMalwareCheckService"}
164+
),
165+
)
166+
167+
malware.includeme(config)
168+
169+
assert config.register_service_factory.calls == [
170+
pretend.call(malware_check_class.create_service, IMalwareCheckService),
171+
]

tests/unit/malware/test_services.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License");
2+
# you may not use this file except in compliance with the License.
3+
# You may obtain a copy of the License at
4+
#
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
13+
import pretend
14+
15+
from zope.interface.verify import verifyClass
16+
17+
from warehouse.malware.interfaces import IMalwareCheckService
18+
from warehouse.malware.services import (
19+
DatabaseMalwareCheckService,
20+
PrinterMalwareCheckService,
21+
)
22+
from warehouse.malware.tasks import run_check
23+
24+
25+
class TestPrinterMalwareCheckService:
26+
def test_verify_service(self):
27+
assert verifyClass(IMalwareCheckService, PrinterMalwareCheckService)
28+
29+
def test_create_service(self):
30+
request = pretend.stub()
31+
service = PrinterMalwareCheckService.create_service(None, request)
32+
assert service.executor == print
33+
34+
def test_run_checks(self, capfd):
35+
request = pretend.stub()
36+
service = PrinterMalwareCheckService.create_service(None, request)
37+
checks = ["one", "two", "three"]
38+
service.run_checks(checks)
39+
out, err = capfd.readouterr()
40+
assert out == "one\ntwo\nthree\n"
41+
42+
43+
class TestDatabaseMalwareService:
44+
def test_verify_service(self):
45+
assert verifyClass(IMalwareCheckService, DatabaseMalwareCheckService)
46+
47+
def test_create_service(self, db_request):
48+
_delay = pretend.call_recorder(lambda *args: None)
49+
db_request.task = lambda x: pretend.stub(delay=_delay)
50+
service = DatabaseMalwareCheckService.create_service(None, db_request)
51+
assert service.executor == db_request.task(run_check).delay
52+
53+
def test_run_checks(self, db_request):
54+
_delay = pretend.call_recorder(lambda *args: None)
55+
db_request.task = lambda x: pretend.stub(delay=_delay)
56+
service = DatabaseMalwareCheckService.create_service(None, db_request)
57+
checks = ["MyTestCheck:ba70267f-fabf-496f-9ac2-d237a983b187"]
58+
service.run_checks(checks)
59+
assert _delay.calls == [
60+
pretend.call("MyTestCheck", "ba70267f-fabf-496f-9ac2-d237a983b187")
61+
]

tests/unit/malware/test_tasks.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License");
2+
# you may not use this file except in compliance with the License.
3+
# You may obtain a copy of the License at
4+
#
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
13+
import celery
14+
import pretend
15+
import pytest
16+
17+
from sqlalchemy.orm.exc import NoResultFound
18+
19+
import warehouse.malware.checks as checks
20+
21+
from warehouse.malware.models import MalwareVerdict
22+
from warehouse.malware.tasks import run_check
23+
24+
from ...common.db.malware import MalwareCheckFactory
25+
from ...common.db.packaging import FileFactory, ProjectFactory, ReleaseFactory
26+
27+
28+
def test_run_check(monkeypatch, db_request):
29+
project = ProjectFactory.create(name="foo")
30+
release = ReleaseFactory.create(project=project)
31+
file0 = FileFactory.create(release=release, filename="foo.bar")
32+
MalwareCheckFactory.create(name="ExampleCheck", state="enabled")
33+
34+
task = pretend.stub()
35+
run_check(task, db_request, "ExampleCheck", file0.id)
36+
assert db_request.db.query(MalwareVerdict).one()
37+
38+
39+
def test_run_check_missing_check_id(monkeypatch, db_session):
40+
exc = NoResultFound("No row was found for one()")
41+
42+
class FakeMalwareCheck:
43+
def __init__(self, db):
44+
raise exc
45+
46+
class Task:
47+
@staticmethod
48+
@pretend.call_recorder
49+
def retry(exc):
50+
raise celery.exceptions.Retry
51+
52+
task = Task()
53+
54+
checks.FakeMalwareCheck = FakeMalwareCheck
55+
56+
request = pretend.stub(
57+
db=db_session,
58+
log=pretend.stub(error=pretend.call_recorder(lambda *args, **kwargs: None),),
59+
)
60+
61+
with pytest.raises(celery.exceptions.Retry):
62+
run_check(
63+
task, request, "FakeMalwareCheck", "d03d75d1-2511-4a8b-9759-62294a6fe3a7"
64+
)
65+
66+
assert request.log.error.calls == [
67+
pretend.call(
68+
"Error executing check %s: %s",
69+
"FakeMalwareCheck",
70+
"No row was found for one()",
71+
)
72+
]
73+
74+
assert task.retry.calls == [pretend.call(exc=exc)]
75+
76+
77+
def test_run_check_missing_check(db_request):
78+
task = pretend.stub()
79+
with pytest.raises(AttributeError):
80+
run_check(
81+
task,
82+
db_request,
83+
"DoesNotExistCheck",
84+
"d03d75d1-2511-4a8b-9759-62294a6fe3a7",
85+
)

0 commit comments

Comments
 (0)