diff --git a/fakeredis/commands_mixins/generic_mixin.py b/fakeredis/commands_mixins/generic_mixin.py index 7200f154..3268287f 100644 --- a/fakeredis/commands_mixins/generic_mixin.py +++ b/fakeredis/commands_mixins/generic_mixin.py @@ -118,9 +118,9 @@ def expire(self, key: CommandItem, seconds: int, *args: bytes) -> int: res = self._expireat(key, self._db.time + seconds, *args) return res - @command(name="EXPIREAT", fixed=(Key(), Int)) - def expireat(self, key: CommandItem, timestamp: int) -> int: - return self._expireat(key, float(timestamp)) + @command(name="EXPIREAT", fixed=(Key(), Int), repeat=(bytes,)) + def expireat(self, key: CommandItem, timestamp: int, *args: bytes) -> int: + return self._expireat(key, float(timestamp), *args) @command(name="KEYS", fixed=(bytes,)) def keys(self, pattern: bytes) -> List[bytes]: @@ -148,13 +148,13 @@ def persist(self, key: CommandItem) -> int: key.expireat = None return 1 - @command(name="PEXPIRE", fixed=(Key(), Int)) - def pexpire(self, key: CommandItem, ms: int) -> int: - return self._expireat(key, self._db.time + ms / 1000.0) + @command(name="PEXPIRE", fixed=(Key(), Int), repeat=(bytes,)) + def pexpire(self, key: CommandItem, ms: int, *args: bytes) -> int: + return self._expireat(key, self._db.time + ms / 1000.0, *args) - @command(name="PEXPIREAT", fixed=(Key(), Int)) - def pexpireat(self, key: CommandItem, ms_timestamp: int) -> int: - return self._expireat(key, ms_timestamp / 1000.0) + @command(name="PEXPIREAT", fixed=(Key(), Int), repeat=(bytes,)) + def pexpireat(self, key: CommandItem, ms_timestamp: int, *args: bytes) -> int: + return self._expireat(key, ms_timestamp / 1000.0, *args) @command(name="PTTL", fixed=(Key(),)) def pttl(self, key: CommandItem) -> int: diff --git a/test/test_mixins/test_generic_commands.py b/test/test_mixins/test_generic_commands.py index 1f6d86ae..ab1301aa 100644 --- a/test/test_mixins/test_generic_commands.py +++ b/test/test_mixins/test_generic_commands.py @@ -39,6 +39,14 @@ def test_expireat_should_return_false_for_missing_key(r: redis.Redis): assert r.expireat("missing", int(time() + 1)) is False +@pytest.mark.min_server("7") +def test_expireat_should_not_expire_when_expire_is_set(r: redis.Redis): + r.set("foo", "bar") + assert r.get("foo") == b"bar" + assert r.expireat("foo", int(time() + 100), nx=True) == 1 + assert r.expireat("foo", int(time() + 200), nx=True) == 0 + + def test_del_operator(r: redis.Redis): r["foo"] = "bar" del r["foo"]