Skip to content

Commit 0656b65

Browse files
authored
Ensure we hold strong references to tasks (#382)
1 parent f99db35 commit 0656b65

File tree

2 files changed

+31
-3
lines changed

2 files changed

+31
-3
lines changed

aioesphomeapi/connection.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def __init__(
9696
) -> None:
9797
self._params = params
9898
self.on_stop: Optional[Callable[[], Coroutine[Any, Any, None]]] = on_stop
99+
self._on_stop_task: Optional[asyncio.Task[None]] = None
99100
self._socket: Optional[socket.socket] = None
100101
self._frame_helper: Optional[APIFrameHelper] = None
101102
self._api_version: Optional[APIVersion] = None
@@ -117,6 +118,7 @@ def __init__(
117118
self._ping_stop_event = asyncio.Event()
118119

119120
self._connect_task: Optional[asyncio.Task[None]] = None
121+
self._keep_alive_task: Optional[asyncio.Task[None]] = None
120122
self._fatal_exception: Optional[Exception] = None
121123
self._expected_disconnect = False
122124

@@ -142,6 +144,10 @@ def _cleanup(self) -> None:
142144
self._connect_task.cancel()
143145
self._connect_task = None
144146

147+
if self._keep_alive_task is not None:
148+
self._keep_alive_task.cancel()
149+
self._keep_alive_task = None
150+
145151
if self._frame_helper is not None:
146152
self._frame_helper.close()
147153
self._frame_helper = None
@@ -151,8 +157,19 @@ def _cleanup(self) -> None:
151157
self._socket = None
152158

153159
if self.on_stop and self._connect_complete:
160+
161+
def _remove_on_stop_task(_fut: asyncio.Future[None]) -> None:
162+
"""Remove the stop task.
163+
164+
We need to do this because the asyncio does not hold
165+
a strong reference to the task, so it can be garbage
166+
collected unexpectedly.
167+
"""
168+
self._on_stop_task = None
169+
154170
# Ensure on_stop is called only once
155-
asyncio.create_task(self.on_stop())
171+
self._on_stop_task = asyncio.create_task(self.on_stop())
172+
self._on_stop_task.add_done_callback(_remove_on_stop_task)
156173
self.on_stop = None
157174

158175
# Note: we don't explicitly cancel the ping/read task here
@@ -318,7 +335,7 @@ async def _keep_alive_loop() -> None:
318335
self._report_fatal_error(err)
319336
return
320337

321-
asyncio.create_task(_keep_alive_loop())
338+
self._keep_alive_task = asyncio.create_task(_keep_alive_loop())
322339

323340
async def connect(self, *, login: bool) -> None:
324341
if self._connection_state != ConnectionState.INITIALIZED:

aioesphomeapi/reconnect_logic.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def __init__(
5959
self._wait_task_lock = asyncio.Lock()
6060
# Event for tracking when logic should stop
6161
self._stop_event = asyncio.Event()
62+
self._stop_task: Optional[asyncio.Task[None]] = None
6263

6364
@property
6465
def _is_stopped(self) -> bool:
@@ -200,7 +201,17 @@ async def stop(self) -> None:
200201
await self._stop_zc_listen()
201202

202203
def stop_callback(self) -> None:
203-
asyncio.create_task(self.stop())
204+
def _remove_stop_task(_fut: asyncio.Future[None]) -> None:
205+
"""Remove the stop task from the reconnect loop.
206+
207+
We need to do this because the asyncio does not hold
208+
a strong reference to the task, so it can be garbage
209+
collected unexpectedly.
210+
"""
211+
self._stop_task = None
212+
213+
self._stop_task = asyncio.create_task(self.stop())
214+
self._stop_task.add_done_callback(_remove_stop_task)
204215

205216
async def _start_zc_listen(self) -> None:
206217
"""Listen for mDNS records.

0 commit comments

Comments
 (0)