-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add missing type hints for tests.unittest. #13397
Changes from 8 commits
1674b7f
6c0344a
e36b12d
e4012f0
4853dd5
76270c3
5a21ef7
700dfcf
aca8cf6
603d70a
74d6acb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| Adding missing type hints to tests. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,6 +28,7 @@ | |
| Generic, | ||
| Iterable, | ||
| List, | ||
| NoReturn, | ||
| Optional, | ||
| Tuple, | ||
| Type, | ||
|
|
@@ -39,7 +40,7 @@ | |
| import canonicaljson | ||
| import signedjson.key | ||
| import unpaddedbase64 | ||
| from typing_extensions import Protocol | ||
| from typing_extensions import ParamSpec, Protocol | ||
|
|
||
| from twisted.internet.defer import Deferred, ensureDeferred | ||
| from twisted.python.failure import Failure | ||
|
|
@@ -67,7 +68,7 @@ | |
| from synapse.rest import RegisterServletsFunc | ||
| from synapse.server import HomeServer | ||
| from synapse.storage.keys import FetchKeyResult | ||
| from synapse.types import JsonDict, UserID, create_requester | ||
| from synapse.types import JsonDict, Requester, UserID, create_requester | ||
| from synapse.util import Clock | ||
| from synapse.util.httpresourcetree import create_resource_tree | ||
|
|
||
|
|
@@ -88,6 +89,9 @@ | |
| TV = TypeVar("TV") | ||
| _ExcType = TypeVar("_ExcType", bound=BaseException, covariant=True) | ||
|
|
||
| P = ParamSpec("P") | ||
| R = TypeVar("R") | ||
|
|
||
|
|
||
| class _TypedFailure(Generic[_ExcType], Protocol): | ||
| """Extension to twisted.Failure, where the 'value' has a certain type.""" | ||
|
|
@@ -97,7 +101,7 @@ def value(self) -> _ExcType: | |
| ... | ||
|
|
||
|
|
||
| def around(target): | ||
| def around(target: TV) -> Callable[[Callable[P, R]], None]: | ||
DMRobertson marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """A CLOS-style 'around' modifier, which wraps the original method of the | ||
| given instance with another piece of code. | ||
|
|
||
|
|
@@ -106,11 +110,11 @@ def method_name(orig, *args, **kwargs): | |
| return orig(*args, **kwargs) | ||
| """ | ||
|
|
||
| def _around(code): | ||
| def _around(code: Callable[P, R]) -> None: | ||
| name = code.__name__ | ||
| orig = getattr(target, name) | ||
|
|
||
clokep marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| def new(*args, **kwargs): | ||
| def new(*args: P.args, **kwargs: P.kwargs) -> R: | ||
| return code(orig, *args, **kwargs) | ||
|
|
||
| setattr(target, name, new) | ||
|
||
|
|
@@ -131,7 +135,7 @@ def __init__(self, methodName: str): | |
| level = getattr(method, "loglevel", getattr(self, "loglevel", None)) | ||
|
|
||
| @around(self) | ||
| def setUp(orig): | ||
| def setUp(orig: Callable[[], R]) -> R: | ||
| # if we're not starting in the sentinel logcontext, then to be honest | ||
| # all future bets are off. | ||
| if current_context(): | ||
|
|
@@ -144,7 +148,7 @@ def setUp(orig): | |
| if level is not None and old_level != level: | ||
|
|
||
| @around(self) | ||
| def tearDown(orig): | ||
| def tearDown(orig: Callable[[], R]) -> R: | ||
| ret = orig() | ||
| logging.getLogger().setLevel(old_level) | ||
| return ret | ||
|
|
@@ -158,7 +162,7 @@ def tearDown(orig): | |
| return orig() | ||
|
|
||
| @around(self) | ||
| def tearDown(orig): | ||
| def tearDown(orig: Callable[[], R]) -> R: | ||
| ret = orig() | ||
| # force a GC to workaround problems with deferreds leaking logcontexts when | ||
| # they are GCed (see the logcontext docs) | ||
|
|
@@ -167,7 +171,7 @@ def tearDown(orig): | |
|
|
||
| return ret | ||
|
|
||
| def assertObjectHasAttributes(self, attrs, obj): | ||
| def assertObjectHasAttributes(self, attrs: Dict[str, object], obj: object) -> None: | ||
| """Asserts that the given object has each of the attributes given, and | ||
| that the value of each matches according to assertEqual.""" | ||
| for key in attrs.keys(): | ||
|
|
@@ -178,44 +182,44 @@ def assertObjectHasAttributes(self, attrs, obj): | |
| except AssertionError as e: | ||
| raise (type(e))(f"Assert error for '.{key}':") from e | ||
|
|
||
| def assert_dict(self, required, actual): | ||
| def assert_dict(self, required: dict, actual: dict) -> None: | ||
| """Does a partial assert of a dict. | ||
|
|
||
| Args: | ||
| required (dict): The keys and value which MUST be in 'actual'. | ||
| actual (dict): The test result. Extra keys will not be checked. | ||
| required: The keys and value which MUST be in 'actual'. | ||
| actual: The test result. Extra keys will not be checked. | ||
| """ | ||
| for key in required: | ||
| self.assertEqual( | ||
| required[key], actual[key], msg="%s mismatch. %s" % (key, actual) | ||
| ) | ||
|
|
||
|
|
||
| def DEBUG(target): | ||
| def DEBUG(target: TV) -> TV: | ||
| """A decorator to set the .loglevel attribute to logging.DEBUG. | ||
| Can apply to either a TestCase or an individual test method.""" | ||
| target.loglevel = logging.DEBUG | ||
| target.loglevel = logging.DEBUG # type: ignore[attr-defined] | ||
| return target | ||
|
|
||
|
|
||
| def INFO(target): | ||
| def INFO(target: TV) -> TV: | ||
| """A decorator to set the .loglevel attribute to logging.INFO. | ||
| Can apply to either a TestCase or an individual test method.""" | ||
| target.loglevel = logging.INFO | ||
| target.loglevel = logging.INFO # type: ignore[attr-defined] | ||
| return target | ||
|
|
||
|
|
||
| def logcontext_clean(target): | ||
| def logcontext_clean(target: TV) -> TV: | ||
| """A decorator which marks the TestCase or method as 'logcontext_clean' | ||
|
|
||
| ... ie, any logcontext errors should cause a test failure | ||
| """ | ||
|
|
||
| def logcontext_error(msg): | ||
| def logcontext_error(msg: str) -> NoReturn: | ||
| raise AssertionError("logcontext error: %s" % (msg)) | ||
|
|
||
| patcher = patch("synapse.logging.context.logcontext_error", new=logcontext_error) | ||
| return patcher(target) | ||
| return patcher(target) # type: ignore[call-overload] | ||
|
|
||
|
|
||
| class HomeserverTestCase(TestCase): | ||
|
|
@@ -255,7 +259,7 @@ def __init__(self, methodName: str): | |
| method = getattr(self, methodName) | ||
| self._extra_config = getattr(method, "_extra_config", None) | ||
|
|
||
| def setUp(self): | ||
| def setUp(self) -> None: | ||
| """ | ||
| Set up the TestCase by calling the homeserver constructor, optionally | ||
| hijacking the authentication system to return a fixed user, and then | ||
|
|
@@ -306,15 +310,21 @@ def setUp(self): | |
| ) | ||
| ) | ||
|
|
||
| async def get_user_by_access_token(token=None, allow_guest=False): | ||
| async def get_user_by_access_token( | ||
| token: Optional[str] = None, allow_guest: bool = False | ||
| ) -> JsonDict: | ||
| assert self.helper.auth_user_id is not None | ||
| return { | ||
| "user": UserID.from_string(self.helper.auth_user_id), | ||
| "token_id": token_id, | ||
| "is_guest": False, | ||
| } | ||
|
|
||
| async def get_user_by_req(request, allow_guest=False): | ||
| async def get_user_by_req( | ||
| request: SynapseRequest, | ||
| allow_guest: bool = False, | ||
| allow_expired: bool = False, | ||
DMRobertson marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ) -> Requester: | ||
| assert self.helper.auth_user_id is not None | ||
| return create_requester( | ||
| UserID.from_string(self.helper.auth_user_id), | ||
|
|
@@ -339,11 +349,11 @@ async def get_user_by_req(request, allow_guest=False): | |
| if hasattr(self, "prepare"): | ||
| self.prepare(self.reactor, self.clock, self.hs) | ||
|
|
||
| def tearDown(self): | ||
| def tearDown(self) -> None: | ||
| # Reset to not use frozen dicts. | ||
| events.USE_FROZEN_DICTS = False | ||
|
|
||
| def wait_on_thread(self, deferred, timeout=10): | ||
| def wait_on_thread(self, deferred: Deferred, timeout: int = 10) -> None: | ||
| """ | ||
| Wait until a Deferred is done, where it's waiting on a real thread. | ||
| """ | ||
|
|
@@ -374,7 +384,7 @@ def make_homeserver(self, reactor, clock): | |
| clock (synapse.util.Clock): The Clock, associated with the reactor. | ||
|
|
||
| Returns: | ||
| A homeserver (synapse.server.HomeServer) suitable for testing. | ||
| A homeserver suitable for testing. | ||
|
|
||
| Function to be overridden in subclasses. | ||
| """ | ||
|
|
@@ -408,7 +418,7 @@ def create_resource_dict(self) -> Dict[str, Resource]: | |
| "/_synapse/admin": servlet_resource, | ||
| } | ||
|
|
||
| def default_config(self): | ||
| def default_config(self) -> JsonDict: | ||
| """ | ||
| Get a default HomeServer config dict. | ||
| """ | ||
|
|
@@ -421,7 +431,9 @@ def default_config(self): | |
|
|
||
| return config | ||
|
|
||
| def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer): | ||
| def prepare( | ||
| self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer | ||
| ) -> None: | ||
| """ | ||
| Prepare for the test. This involves things like mocking out parts of | ||
| the homeserver, or building test data common across the whole test | ||
|
|
@@ -519,7 +531,7 @@ def setup_test_homeserver(self, *args: Any, **kwargs: Any) -> HomeServer: | |
| config_obj.parse_config_dict(config, "", "") | ||
| kwargs["config"] = config_obj | ||
|
|
||
| async def run_bg_updates(): | ||
| async def run_bg_updates() -> None: | ||
| with LoggingContext("run_bg_updates"): | ||
| self.get_success(stor.db_pool.updates.run_background_updates(False)) | ||
|
|
||
|
|
@@ -538,11 +550,7 @@ def pump(self, by: float = 0.0) -> None: | |
| """ | ||
| self.reactor.pump([by] * 100) | ||
|
|
||
| def get_success( | ||
| self, | ||
| d: Awaitable[TV], | ||
| by: float = 0.0, | ||
| ) -> TV: | ||
| def get_success(self, d: Awaitable[TV], by: float = 0.0) -> TV: | ||
| deferred: Deferred[TV] = ensureDeferred(d) # type: ignore[arg-type] | ||
| self.pump(by=by) | ||
| return self.successResultOf(deferred) | ||
|
|
@@ -755,7 +763,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase): | |
| OTHER_SERVER_NAME = "other.example.com" | ||
| OTHER_SERVER_SIGNATURE_KEY = signedjson.key.generate_signing_key("test") | ||
|
|
||
| def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): | ||
| def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | ||
| super().prepare(reactor, clock, hs) | ||
|
|
||
| # poke the other server's signing key into the key store, so that we don't | ||
|
|
@@ -879,7 +887,7 @@ def _auth_header_for_request( | |
| ) | ||
|
|
||
|
|
||
| def override_config(extra_config): | ||
| def override_config(extra_config: JsonDict) -> Callable[[TV], TV]: | ||
| """A decorator which can be applied to test functions to give additional HS config | ||
|
|
||
| For use | ||
|
|
@@ -892,12 +900,13 @@ def test_foo(self): | |
| ... | ||
|
|
||
| Args: | ||
| extra_config(dict): Additional config settings to be merged into the default | ||
| extra_config: Additional config settings to be merged into the default | ||
| config dict before instantiating the test homeserver. | ||
| """ | ||
|
|
||
| def decorator(func): | ||
| func._extra_config = extra_config | ||
| def decorator(func: TV) -> TV: | ||
| # This attribute is being defined. | ||
| func._extra_config = extra_config # type: ignore[attr-defined] | ||
| return func | ||
|
|
||
| return decorator | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.