Skip to content

Commit e54eace

Browse files
authored
Merge branch 'main' into main
2 parents edd6014 + 94cd878 commit e54eace

17 files changed

+202
-112
lines changed

arq/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ async def watch_reload(path: str, worker_settings: 'WorkerSettingsType') -> None
6060
except ImportError as e: # pragma: no cover
6161
raise ImportError('watchfiles not installed, use `pip install watchfiles`') from e
6262

63-
loop = asyncio.get_event_loop()
63+
loop = asyncio.get_running_loop()
6464
stop_event = asyncio.Event()
6565

6666
def worker_on_stop(s: Signals) -> None:

arq/connections.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ class RedisSettings:
5555
@classmethod
5656
def from_dsn(cls, dsn: str) -> 'RedisSettings':
5757
conf = urlparse(dsn)
58-
assert conf.scheme in {'redis', 'rediss', 'unix'}, 'invalid DSN scheme'
58+
if conf.scheme not in {'redis', 'rediss', 'unix'}:
59+
raise RuntimeError('invalid DSN scheme')
5960
query_db = parse_qs(conf.query).get('db')
6061
if query_db:
6162
# e.g. redis://localhost:6379?db=1
@@ -143,7 +144,8 @@ async def enqueue_job(
143144
_queue_name = self.default_queue_name
144145
job_id = _job_id or uuid4().hex
145146
job_key = job_key_prefix + job_id
146-
assert not (_defer_until and _defer_by), "use either 'defer_until' or 'defer_by' or neither, not both"
147+
if _defer_until and _defer_by:
148+
raise RuntimeError("use either 'defer_until' or 'defer_by' or neither, not both")
147149

148150
defer_by_ms = to_ms(_defer_by)
149151
expires_ms = to_ms(_expires)
@@ -195,9 +197,11 @@ async def all_job_results(self) -> List[JobResult]:
195197
async def _get_job_def(self, job_id: bytes, score: int) -> JobDef:
196198
key = job_key_prefix + job_id.decode()
197199
v = await self.get(key)
198-
assert v is not None, f'job "{key}" not found'
200+
if v is None:
201+
raise RuntimeError(f'job "{key}" not found')
199202
jd = deserialize_job(v, deserializer=self.job_deserializer)
200203
jd.score = score
204+
jd.job_id = job_id.decode()
201205
return jd
202206

203207
async def queued_jobs(self, *, queue_name: Optional[str] = None) -> List[JobDef]:
@@ -226,9 +230,8 @@ async def create_pool(
226230
"""
227231
settings: RedisSettings = RedisSettings() if settings_ is None else settings_
228232

229-
assert not (
230-
type(settings.host) is str and settings.sentinel
231-
), "str provided for 'host' but 'sentinel' is true; list of sentinels expected"
233+
if isinstance(settings.host, str) and settings.sentinel:
234+
raise RuntimeError("str provided for 'host' but 'sentinel' is true; list of sentinels expected")
232235

233236
if settings.sentinel:
234237

arq/cron.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,18 @@ def _get_next_dt(dt_: datetime, options: Options) -> Optional[datetime]: # noqa
5858
next_v = getattr(dt_, field)
5959
if isinstance(v, int):
6060
mismatch = next_v != v
61-
else:
62-
assert isinstance(v, (set, list, tuple)), v
61+
elif isinstance(v, (set, list, tuple)):
6362
mismatch = next_v not in v
63+
else:
64+
raise RuntimeError(v)
6465
# print(field, v, next_v, mismatch)
6566
if mismatch:
6667
micro = max(dt_.microsecond - options.microsecond, 0)
6768
if field == 'month':
6869
if dt_.month == 12:
69-
return datetime(dt_.year + 1, 1, 1)
70+
return datetime(dt_.year + 1, 1, 1, tzinfo=dt_.tzinfo)
7071
else:
71-
return datetime(dt_.year, dt_.month + 1, 1)
72+
return datetime(dt_.year, dt_.month + 1, 1, tzinfo=dt_.tzinfo)
7273
elif field in ('day', 'weekday'):
7374
return (
7475
dt_
@@ -82,7 +83,8 @@ def _get_next_dt(dt_: datetime, options: Options) -> Optional[datetime]: # noqa
8283
elif field == 'second':
8384
return dt_ + timedelta(seconds=1) - timedelta(microseconds=micro)
8485
else:
85-
assert field == 'microsecond', field
86+
if field != 'microsecond':
87+
raise RuntimeError(field)
8688
return dt_ + timedelta(microseconds=options.microsecond - dt_.microsecond)
8789
return None
8890

@@ -173,7 +175,8 @@ def cron(
173175
else:
174176
coroutine_ = coroutine
175177

176-
assert asyncio.iscoroutinefunction(coroutine_), f'{coroutine_} is not a coroutine function'
178+
if not asyncio.iscoroutinefunction(coroutine_):
179+
raise RuntimeError(f'{coroutine_} is not a coroutine function')
177180
timeout = to_seconds(timeout)
178181
keep_result = to_seconds(keep_result)
179182

arq/jobs.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class JobDef:
4747
job_try: int
4848
enqueue_time: datetime
4949
score: Optional[int]
50+
job_id: Optional[str]
5051

5152
def __post_init__(self) -> None:
5253
if isinstance(self.score, float):
@@ -60,7 +61,6 @@ class JobResult(JobDef):
6061
start_time: datetime
6162
finish_time: datetime
6263
queue_name: str
63-
job_id: Optional[str] = None
6464

6565

6666
class Job:
@@ -238,6 +238,7 @@ def serialize_result(
238238
finished_ms: int,
239239
ref: str,
240240
queue_name: str,
241+
job_id: str,
241242
*,
242243
serializer: Optional[Serializer] = None,
243244
) -> Optional[bytes]:
@@ -252,6 +253,7 @@ def serialize_result(
252253
'st': start_ms,
253254
'ft': finished_ms,
254255
'q': queue_name,
256+
'id': job_id,
255257
}
256258
if serializer is None:
257259
serializer = pickle.dumps
@@ -281,6 +283,7 @@ def deserialize_job(r: bytes, *, deserializer: Optional[Deserializer] = None) ->
281283
job_try=d['t'],
282284
enqueue_time=ms_to_datetime(d['et']),
283285
score=None,
286+
job_id=None,
284287
)
285288
except Exception as e:
286289
raise DeserializationError('unable to deserialize job') from e
@@ -315,6 +318,7 @@ def deserialize_result(r: bytes, *, deserializer: Optional[Deserializer] = None)
315318
start_time=ms_to_datetime(d['st']),
316319
finish_time=ms_to_datetime(d['ft']),
317320
queue_name=d.get('q', '<unknown>'),
321+
job_id=d.get('id', '<unknown>'),
318322
)
319323
except Exception as e:
320324
raise DeserializationError('unable to deserialize job result') from e

arq/worker.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ def func(
8585
else:
8686
coroutine_ = coroutine
8787

88-
assert asyncio.iscoroutinefunction(coroutine_), f'{coroutine_} is not a coroutine function'
88+
if not asyncio.iscoroutinefunction(coroutine_):
89+
raise RuntimeError(f'{coroutine_} is not a coroutine function')
8990
timeout = to_seconds(timeout)
9091
keep_result = to_seconds(keep_result)
9192

@@ -226,17 +227,23 @@ def __init__(
226227
self.queue_name = queue_name
227228
self.cron_jobs: List[CronJob] = []
228229
if cron_jobs is not None:
229-
assert all(isinstance(cj, CronJob) for cj in cron_jobs), 'cron_jobs, must be instances of CronJob'
230+
if not all(isinstance(cj, CronJob) for cj in cron_jobs):
231+
raise RuntimeError('cron_jobs, must be instances of CronJob')
230232
self.cron_jobs = list(cron_jobs)
231233
self.functions.update({cj.name: cj for cj in self.cron_jobs})
232-
assert len(self.functions) > 0, 'at least one function or cron_job must be registered'
234+
if len(self.functions) == 0:
235+
raise RuntimeError('at least one function or cron_job must be registered')
233236
self.burst = burst
234237
self.on_startup = on_startup
235238
self.on_shutdown = on_shutdown
236239
self.on_job_start = on_job_start
237240
self.on_job_end = on_job_end
238241
self.after_job_end = after_job_end
239-
self.sem = asyncio.BoundedSemaphore(max_jobs)
242+
243+
self.max_jobs = max_jobs
244+
self.sem = asyncio.BoundedSemaphore(max_jobs + 1)
245+
self.job_counter: int = 0
246+
240247
self.job_timeout_s = to_seconds(job_timeout)
241248
self.keep_result_s = to_seconds(keep_result)
242249
self.keep_result_forever = keep_result_forever
@@ -374,13 +381,13 @@ async def _poll_iteration(self) -> None:
374381
return
375382
count = min(burst_jobs_remaining, count)
376383
if self.allow_pick_jobs:
377-
async with self.sem: # don't bother with zrangebyscore until we have "space" to run the jobs
384+
if self.job_counter < self.max_jobs:
378385
now = timestamp_ms()
379386
job_ids = await self.pool.zrangebyscore(
380387
self.queue_name, min=float('-inf'), start=self._queue_read_offset, num=count, max=now
381388
)
382389

383-
await self.start_jobs(job_ids)
390+
await self.start_jobs(job_ids)
384391

385392
if self.allow_abort_jobs:
386393
await self._cancel_aborted_jobs()
@@ -419,12 +426,23 @@ async def _cancel_aborted_jobs(self) -> None:
419426
self.aborting_tasks.update(aborted)
420427
await self.pool.zrem(abort_jobs_ss, *aborted)
421428

429+
def _release_sem_dec_counter_on_complete(self) -> None:
430+
self.job_counter = self.job_counter - 1
431+
self.sem.release()
432+
422433
async def start_jobs(self, job_ids: List[bytes]) -> None:
423434
"""
424435
For each job id, get the job definition, check it's not running and start it in a task
425436
"""
426437
for job_id_b in job_ids:
427438
await self.sem.acquire()
439+
440+
if self.job_counter >= self.max_jobs:
441+
self.sem.release()
442+
return None
443+
444+
self.job_counter = self.job_counter + 1
445+
428446
job_id = job_id_b.decode()
429447
in_progress_key = in_progress_key_prefix + job_id
430448
async with self.pool.pipeline(transaction=True) as pipe:
@@ -433,6 +451,7 @@ async def start_jobs(self, job_ids: List[bytes]) -> None:
433451
score = await pipe.zscore(self.queue_name, job_id)
434452
if ongoing_exists or not score:
435453
# job already started elsewhere, or already finished and removed from queue
454+
self.job_counter = self.job_counter - 1
436455
self.sem.release()
437456
logger.debug('job %s already running elsewhere', job_id)
438457
continue
@@ -445,11 +464,12 @@ async def start_jobs(self, job_ids: List[bytes]) -> None:
445464
await pipe.execute()
446465
except (ResponseError, WatchError):
447466
# job already started elsewhere since we got 'existing'
467+
self.job_counter = self.job_counter - 1
448468
self.sem.release()
449469
logger.debug('multi-exec error, job %s already started elsewhere', job_id)
450470
else:
451471
t = self.loop.create_task(self.run_job(job_id, int(score)))
452-
t.add_done_callback(lambda _: self.sem.release())
472+
t.add_done_callback(lambda _: self._release_sem_dec_counter_on_complete())
453473
self.tasks[job_id] = t
454474

455475
async def run_job(self, job_id: str, score: int) -> None: # noqa: C901
@@ -484,6 +504,7 @@ async def job_failed(exc: BaseException) -> None:
484504
ref=f'{job_id}:{function_name}',
485505
serializer=self.job_serializer,
486506
queue_name=self.queue_name,
507+
job_id=job_id,
487508
)
488509
await asyncio.shield(self.finish_failed_job(job_id, result_data_))
489510

@@ -539,6 +560,7 @@ async def job_failed(exc: BaseException) -> None:
539560
timestamp_ms(),
540561
ref,
541562
self.queue_name,
563+
job_id=job_id,
542564
serializer=self.job_serializer,
543565
)
544566
return await asyncio.shield(self.finish_failed_job(job_id, result_data))
@@ -632,6 +654,7 @@ async def job_failed(exc: BaseException) -> None:
632654
finished_ms,
633655
ref,
634656
self.queue_name,
657+
job_id=job_id,
635658
serializer=self.job_serializer,
636659
)
637660

docs/examples/job_ids.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from arq import create_pool
44
from arq.connections import RedisSettings
5+
from arq.jobs import Job
6+
57

68
async def the_task(ctx):
79
print('running the task with id', ctx['job_id'])
@@ -37,6 +39,14 @@ async def main():
3739
> None
3840
"""
3941

42+
# you can retrieve jobs by using arq.jobs.Job
43+
await redis.enqueue_job('the_task', _job_id='my_job')
44+
job5 = Job(job_id='my_job', redis=redis)
45+
print(job5)
46+
"""
47+
<arq job my_job>
48+
"""
49+
4050
class WorkerSettings:
4151
functions = [the_task]
4252

docs/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ Sometimes you want a job to only be run once at a time (eg. a backup) or once fo
136136
invoices for a particular company).
137137

138138
*arq* supports this via custom job ids, see :func:`arq.connections.ArqRedis.enqueue_job`. It guarantees
139-
that a job with a particular ID cannot be enqueued again until its execution has finished.
139+
that a job with a particular ID cannot be enqueued again until its execution has finished and its result has cleared. To control when a finished job's result clears, you can use the `keep_result` setting on your worker, see :func:`arq.worker.func`.
140140

141141
.. literalinclude:: examples/job_ids.py
142142

requirements/docs.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#
2-
# This file is autogenerated by pip-compile with python 3.9
2+
# This file is autogenerated by pip-compile with python 3.11
33
# To update, run:
44
#
55
# pip-compile --output-file=requirements/docs.txt requirements/docs.in
@@ -35,7 +35,7 @@ requests==2.28.1
3535
snowballstemmer==2.2.0
3636
# via sphinx
3737
sphinx==5.1.1
38-
# via -r docs.in
38+
# via -r requirements/docs.in
3939
sphinxcontrib-applehelp==1.0.2
4040
# via sphinx
4141
sphinxcontrib-devhelp==1.0.2

requirements/linting.txt

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#
2-
# This file is autogenerated by pip-compile with python 3.9
2+
# This file is autogenerated by pip-compile with python 3.11
33
# To update, run:
44
#
55
# pip-compile --output-file=requirements/linting.txt requirements/linting.in
@@ -34,15 +34,9 @@ pycodestyle==2.9.1
3434
# via flake8
3535
pyflakes==2.5.0
3636
# via flake8
37-
tomli==2.0.1
38-
# via
39-
# black
40-
# mypy
4137
types-pytz==2022.2.1.0
4238
# via -r requirements/linting.in
4339
types-redis==4.2.8
4440
# via -r requirements/linting.in
4541
typing-extensions==4.3.0
46-
# via
47-
# black
48-
# mypy
42+
# via mypy

requirements/pyproject.txt

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#
2-
# This file is autogenerated by pip-compile with python 3.9
2+
# This file is autogenerated by pip-compile with python 3.11
33
# To update, run:
44
#
55
# pip-compile --extra=watch --output-file=requirements/pyproject.txt pyproject.toml
@@ -10,23 +10,15 @@ async-timeout==4.0.2
1010
# via redis
1111
click==8.1.3
1212
# via arq (pyproject.toml)
13-
deprecated==1.2.13
14-
# via redis
15-
hiredis==2.0.0
13+
hiredis==2.1.0
1614
# via redis
1715
idna==3.3
1816
# via anyio
19-
packaging==21.3
20-
# via redis
21-
pyparsing==3.0.9
22-
# via packaging
23-
redis[hiredis]==4.3.4
17+
redis[hiredis]==4.4.0
2418
# via arq (pyproject.toml)
2519
sniffio==1.2.0
2620
# via anyio
2721
typing-extensions==4.3.0
2822
# via arq (pyproject.toml)
2923
watchfiles==0.16.1
3024
# via arq (pyproject.toml)
31-
wrapt==1.14.1
32-
# via deprecated

requirements/testing.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ dirty-equals>=0.4,<1
33
msgpack>=1,<2
44
pydantic>=1.9.2,<2
55
pytest>=7,<8
6-
pytest-asyncio>=0.19,<0.20
6+
pytest-asyncio>=0.20.3
77
pytest-mock>=3,<4
88
pytest-sugar>=0.9,<1
99
pytest-timeout>=2,<3

0 commit comments

Comments
 (0)