Skip to content

Commit 3b2aae5

Browse files
jbouwhemontnemery
andauthored
Refactor MQTT discovery (#67966)
* Proof of concept * remove notify platform * remove loose test * Add rework from #67912 (#1) * Move notify serviceupdater to Mixins * Move tag discovery handler to Mixins * fix tests * Add typing for async_load_platform_helper * Add add entry unload support for notify platform * Simplify discovery updates * Remove not needed extra logic * Cleanup inrelevant or duplicate code * reuse update_device and move to mixins * Remove notify platform * revert changes to notify platform * Rename update class * unify tag entry setup * Use shared code for device_trigger `update_device` * PoC shared dispatcher for device_trigger * Fix bugs * Improve typing - remove async_update * Unload config_entry and tests * Release dispatcher after setup and deduplicate * closures to methods, revert `in` to `=`, updates * Re-add update support for tag platform * Re-add update support for device-trigger platform * Cleanup rediscovery code revert related changes * Undo discovery code shift * Update homeassistant/components/mqtt/mixins.py Co-authored-by: Erik Montnemery <[email protected]> * Update homeassistant/components/mqtt/device_trigger.py Co-authored-by: Erik Montnemery <[email protected]> * Update homeassistant/components/mqtt/mixins.py Co-authored-by: Erik Montnemery <[email protected]> * revert doc string changes * move conditions * typing and check config_entry_id * Update homeassistant/components/mqtt/mixins.py Co-authored-by: Erik Montnemery <[email protected]> * cleanup not used attribute * Remove entry_unload code and tests * update comment * add second comment Co-authored-by: Erik Montnemery <[email protected]>
1 parent c932407 commit 3b2aae5

File tree

5 files changed

+364
-286
lines changed

5 files changed

+364
-286
lines changed

homeassistant/components/mqtt/device_trigger.py

Lines changed: 104 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from collections.abc import Callable
55
import logging
6-
from typing import Any
6+
from typing import Any, cast
77

88
import attr
99
import voluptuous as vol
@@ -13,6 +13,7 @@
1313
AutomationTriggerInfo,
1414
)
1515
from homeassistant.components.device_automation import DEVICE_TRIGGER_BASE_SCHEMA
16+
from homeassistant.config_entries import ConfigEntry
1617
from homeassistant.const import (
1718
CONF_DEVICE,
1819
CONF_DEVICE_ID,
@@ -23,30 +24,19 @@
2324
)
2425
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
2526
from homeassistant.exceptions import HomeAssistantError
26-
from homeassistant.helpers import config_validation as cv, device_registry as dr
27-
from homeassistant.helpers.dispatcher import (
28-
async_dispatcher_connect,
29-
async_dispatcher_send,
30-
)
27+
from homeassistant.helpers import config_validation as cv
28+
from homeassistant.helpers.dispatcher import async_dispatcher_send
3129
from homeassistant.helpers.typing import ConfigType
3230

3331
from . import debug_info, trigger as mqtt_trigger
3432
from .. import mqtt
35-
from .const import (
36-
ATTR_DISCOVERY_HASH,
37-
ATTR_DISCOVERY_TOPIC,
38-
CONF_PAYLOAD,
39-
CONF_QOS,
40-
CONF_TOPIC,
41-
DOMAIN,
42-
)
43-
from .discovery import MQTT_DISCOVERY_DONE, MQTT_DISCOVERY_UPDATED, clear_discovery_hash
33+
from .const import ATTR_DISCOVERY_HASH, CONF_PAYLOAD, CONF_QOS, CONF_TOPIC, DOMAIN
34+
from .discovery import MQTT_DISCOVERY_DONE
4435
from .mixins import (
45-
CONF_CONNECTIONS,
46-
CONF_IDENTIFIERS,
4736
MQTT_ENTITY_DEVICE_INFO_SCHEMA,
48-
cleanup_device_registry,
49-
device_info_from_config,
37+
MqttDiscoveryDeviceUpdate,
38+
send_discovery_done,
39+
update_device,
5040
)
5141

5242
_LOGGER = logging.getLogger(__name__)
@@ -89,6 +79,8 @@
8979

9080
DEVICE_TRIGGERS = "mqtt_device_triggers"
9181

82+
LOG_NAME = "Device trigger"
83+
9284

9385
@attr.s(slots=True)
9486
class TriggerInstance:
@@ -99,7 +91,7 @@ class TriggerInstance:
9991
trigger: Trigger = attr.ib()
10092
remove: CALLBACK_TYPE | None = attr.ib(default=None)
10193

102-
async def async_attach_trigger(self):
94+
async def async_attach_trigger(self) -> None:
10395
"""Attach MQTT trigger."""
10496
mqtt_config = {
10597
mqtt_trigger.CONF_PLATFORM: mqtt.DOMAIN,
@@ -132,14 +124,15 @@ class Trigger:
132124
hass: HomeAssistant = attr.ib()
133125
payload: str | None = attr.ib()
134126
qos: int | None = attr.ib()
135-
remove_signal: Callable[[], None] | None = attr.ib()
136127
subtype: str = attr.ib()
137128
topic: str | None = attr.ib()
138129
type: str = attr.ib()
139130
value_template: str | None = attr.ib()
140131
trigger_instances: list[TriggerInstance] = attr.ib(factory=list)
141132

142-
async def add_trigger(self, action, automation_info):
133+
async def add_trigger(
134+
self, action: AutomationActionType, automation_info: AutomationTriggerInfo
135+
) -> Callable:
143136
"""Add MQTT trigger."""
144137
instance = TriggerInstance(action, automation_info, self)
145138
self.trigger_instances.append(instance)
@@ -160,9 +153,8 @@ def async_remove() -> None:
160153

161154
return async_remove
162155

163-
async def update_trigger(self, config, discovery_hash, remove_signal):
156+
async def update_trigger(self, config: ConfigType) -> None:
164157
"""Update MQTT device trigger."""
165-
self.remove_signal = remove_signal
166158
self.type = config[CONF_TYPE]
167159
self.subtype = config[CONF_SUBTYPE]
168160
self.payload = config[CONF_PAYLOAD]
@@ -178,7 +170,7 @@ async def update_trigger(self, config, discovery_hash, remove_signal):
178170
for trig in self.trigger_instances:
179171
await trig.async_attach_trigger()
180172

181-
def detach_trigger(self):
173+
def detach_trigger(self) -> None:
182174
"""Remove MQTT device trigger."""
183175
# Mark trigger as unknown
184176
self.topic = None
@@ -190,110 +182,110 @@ def detach_trigger(self):
190182
trig.remove = None
191183

192184

193-
def _update_device(hass, config_entry, config):
194-
"""Update device registry."""
195-
device_registry = dr.async_get(hass)
196-
config_entry_id = config_entry.entry_id
197-
device_info = device_info_from_config(config[CONF_DEVICE])
198-
199-
if config_entry_id is not None and device_info is not None:
200-
device_info["config_entry_id"] = config_entry_id
201-
device_registry.async_get_or_create(**device_info)
202-
185+
class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate):
186+
"""Setup a MQTT device trigger with auto discovery."""
187+
188+
def __init__(
189+
self,
190+
hass: HomeAssistant,
191+
config: ConfigType,
192+
device_id: str,
193+
discovery_data: dict,
194+
config_entry: ConfigEntry,
195+
) -> None:
196+
"""Initialize."""
197+
self._config = config
198+
self._config_entry = config_entry
199+
self.device_id = device_id
200+
self.discovery_data = discovery_data
201+
self.hass = hass
202+
203+
MqttDiscoveryDeviceUpdate.__init__(
204+
self,
205+
hass,
206+
discovery_data,
207+
device_id,
208+
config_entry,
209+
LOG_NAME,
210+
)
203211

204-
async def async_setup_trigger(hass, config, config_entry, discovery_data):
205-
"""Set up the MQTT device trigger."""
206-
config = TRIGGER_DISCOVERY_SCHEMA(config)
207-
discovery_hash = discovery_data[ATTR_DISCOVERY_HASH]
208-
discovery_id = discovery_hash[1]
209-
remove_signal = None
212+
async def async_setup(self) -> None:
213+
"""Initialize the device trigger."""
214+
discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH]
215+
discovery_id = discovery_hash[1]
216+
if discovery_id not in self.hass.data.setdefault(DEVICE_TRIGGERS, {}):
217+
self.hass.data[DEVICE_TRIGGERS][discovery_id] = Trigger(
218+
hass=self.hass,
219+
device_id=self.device_id,
220+
discovery_data=self.discovery_data,
221+
type=self._config[CONF_TYPE],
222+
subtype=self._config[CONF_SUBTYPE],
223+
topic=self._config[CONF_TOPIC],
224+
payload=self._config[CONF_PAYLOAD],
225+
qos=self._config[CONF_QOS],
226+
value_template=self._config[CONF_VALUE_TEMPLATE],
227+
)
228+
else:
229+
await self.hass.data[DEVICE_TRIGGERS][discovery_id].update_trigger(
230+
self._config
231+
)
232+
debug_info.add_trigger_discovery_data(
233+
self.hass, discovery_hash, self.discovery_data, self.device_id
234+
)
210235

211-
async def discovery_update(payload):
212-
"""Handle discovery update."""
213-
_LOGGER.info(
214-
"Got update for trigger with hash: %s '%s'", discovery_hash, payload
236+
async def async_update(self, discovery_data: dict) -> None:
237+
"""Handle MQTT device trigger discovery updates."""
238+
discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH]
239+
discovery_id = discovery_hash[1]
240+
debug_info.update_trigger_discovery_data(
241+
self.hass, discovery_hash, discovery_data
215242
)
216-
if not payload:
217-
# Empty payload: Remove trigger
243+
config = TRIGGER_DISCOVERY_SCHEMA(discovery_data)
244+
update_device(self.hass, self._config_entry, config)
245+
device_trigger: Trigger = self.hass.data[DEVICE_TRIGGERS][discovery_id]
246+
await device_trigger.update_trigger(config)
247+
248+
async def async_tear_down(self) -> None:
249+
"""Cleanup device trigger."""
250+
discovery_hash = self.discovery_data[ATTR_DISCOVERY_HASH]
251+
discovery_id = discovery_hash[1]
252+
if discovery_id in self.hass.data[DEVICE_TRIGGERS]:
218253
_LOGGER.info("Removing trigger: %s", discovery_hash)
219-
debug_info.remove_trigger_discovery_data(hass, discovery_hash)
220-
if discovery_id in hass.data[DEVICE_TRIGGERS]:
221-
device_trigger = hass.data[DEVICE_TRIGGERS][discovery_id]
222-
device_trigger.detach_trigger()
223-
clear_discovery_hash(hass, discovery_hash)
224-
remove_signal()
225-
await cleanup_device_registry(hass, device.id, config_entry.entry_id)
226-
else:
227-
# Non-empty payload: Update trigger
228-
_LOGGER.info("Updating trigger: %s", discovery_hash)
229-
debug_info.update_trigger_discovery_data(hass, discovery_hash, payload)
230-
config = TRIGGER_DISCOVERY_SCHEMA(payload)
231-
_update_device(hass, config_entry, config)
232-
device_trigger = hass.data[DEVICE_TRIGGERS][discovery_id]
233-
await device_trigger.update_trigger(config, discovery_hash, remove_signal)
234-
async_dispatcher_send(hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None)
235-
236-
remove_signal = async_dispatcher_connect(
237-
hass, MQTT_DISCOVERY_UPDATED.format(discovery_hash), discovery_update
238-
)
254+
trigger: Trigger = self.hass.data[DEVICE_TRIGGERS][discovery_id]
255+
trigger.detach_trigger()
256+
debug_info.remove_trigger_discovery_data(self.hass, discovery_hash)
239257

240-
_update_device(hass, config_entry, config)
241258

242-
device_registry = dr.async_get(hass)
243-
device = device_registry.async_get_device(
244-
{(DOMAIN, id_) for id_ in config[CONF_DEVICE][CONF_IDENTIFIERS]},
245-
{tuple(x) for x in config[CONF_DEVICE][CONF_CONNECTIONS]},
246-
)
259+
async def async_setup_trigger(
260+
hass, config: ConfigType, config_entry: ConfigEntry, discovery_data: dict
261+
) -> None:
262+
"""Set up the MQTT device trigger."""
263+
config = TRIGGER_DISCOVERY_SCHEMA(config)
264+
discovery_hash = discovery_data[ATTR_DISCOVERY_HASH]
247265

248-
if device is None:
266+
if (device_id := update_device(hass, config_entry, config)) is None:
249267
async_dispatcher_send(hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None)
250268
return
251269

252-
if DEVICE_TRIGGERS not in hass.data:
253-
hass.data[DEVICE_TRIGGERS] = {}
254-
if discovery_id not in hass.data[DEVICE_TRIGGERS]:
255-
hass.data[DEVICE_TRIGGERS][discovery_id] = Trigger(
256-
hass=hass,
257-
device_id=device.id,
258-
discovery_data=discovery_data,
259-
type=config[CONF_TYPE],
260-
subtype=config[CONF_SUBTYPE],
261-
topic=config[CONF_TOPIC],
262-
payload=config[CONF_PAYLOAD],
263-
qos=config[CONF_QOS],
264-
remove_signal=remove_signal,
265-
value_template=config[CONF_VALUE_TEMPLATE],
266-
)
267-
else:
268-
await hass.data[DEVICE_TRIGGERS][discovery_id].update_trigger(
269-
config, discovery_hash, remove_signal
270-
)
271-
debug_info.add_trigger_discovery_data(
272-
hass, discovery_hash, discovery_data, device.id
270+
mqtt_device_trigger = MqttDeviceTrigger(
271+
hass, config, device_id, discovery_data, config_entry
273272
)
274-
275-
async_dispatcher_send(hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None)
273+
await mqtt_device_trigger.async_setup()
274+
send_discovery_done(hass, discovery_data)
276275

277276

278-
async def async_removed_from_device(hass: HomeAssistant, device_id: str):
277+
async def async_removed_from_device(hass: HomeAssistant, device_id: str) -> None:
279278
"""Handle Mqtt removed from a device."""
280279
triggers = await async_get_triggers(hass, device_id)
281280
for trig in triggers:
282-
device_trigger = hass.data[DEVICE_TRIGGERS].pop(trig[CONF_DISCOVERY_ID])
281+
device_trigger: Trigger = hass.data[DEVICE_TRIGGERS].pop(
282+
trig[CONF_DISCOVERY_ID]
283+
)
283284
if device_trigger:
284-
discovery_hash = device_trigger.discovery_data[ATTR_DISCOVERY_HASH]
285-
discovery_topic = device_trigger.discovery_data[ATTR_DISCOVERY_TOPIC]
286-
287-
debug_info.remove_trigger_discovery_data(hass, discovery_hash)
288285
device_trigger.detach_trigger()
289-
clear_discovery_hash(hass, discovery_hash)
290-
device_trigger.remove_signal()
291-
mqtt.publish(
292-
hass,
293-
discovery_topic,
294-
"",
295-
retain=True,
296-
)
286+
discovery_data = cast(dict, device_trigger.discovery_data)
287+
discovery_hash = discovery_data[ATTR_DISCOVERY_HASH]
288+
debug_info.remove_trigger_discovery_data(hass, discovery_hash)
297289

298290

299291
async def async_get_triggers(
@@ -328,8 +320,7 @@ async def async_attach_trigger(
328320
automation_info: AutomationTriggerInfo,
329321
) -> CALLBACK_TYPE:
330322
"""Attach a trigger."""
331-
if DEVICE_TRIGGERS not in hass.data:
332-
hass.data[DEVICE_TRIGGERS] = {}
323+
hass.data.setdefault(DEVICE_TRIGGERS, {})
333324
device_id = config[CONF_DEVICE_ID]
334325
discovery_id = config[CONF_DISCOVERY_ID]
335326

@@ -338,7 +329,6 @@ async def async_attach_trigger(
338329
hass=hass,
339330
device_id=device_id,
340331
discovery_data=None,
341-
remove_signal=None,
342332
type=config[CONF_TYPE],
343333
subtype=config[CONF_SUBTYPE],
344334
topic=None,

homeassistant/components/mqtt/discovery.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
"""Support for MQTT discovery."""
2+
from __future__ import annotations
3+
24
import asyncio
35
from collections import deque
46
import functools
@@ -73,20 +75,22 @@
7375
TOPIC_BASE = "~"
7476

7577

76-
def clear_discovery_hash(hass, discovery_hash):
78+
class MQTTConfig(dict):
79+
"""Dummy class to allow adding attributes."""
80+
81+
discovery_data: dict
82+
83+
84+
def clear_discovery_hash(hass: HomeAssistant, discovery_hash: tuple) -> None:
7785
"""Clear entry in ALREADY_DISCOVERED list."""
7886
del hass.data[ALREADY_DISCOVERED][discovery_hash]
7987

8088

81-
def set_discovery_hash(hass, discovery_hash):
89+
def set_discovery_hash(hass: HomeAssistant, discovery_hash: tuple):
8290
"""Clear entry in ALREADY_DISCOVERED list."""
8391
hass.data[ALREADY_DISCOVERED][discovery_hash] = {}
8492

8593

86-
class MQTTConfig(dict):
87-
"""Dummy class to allow adding attributes."""
88-
89-
9094
async def async_start( # noqa: C901
9195
hass: HomeAssistant, discovery_topic, config_entry=None
9296
) -> None:
@@ -181,6 +185,7 @@ async def async_discovery_message_received(msg):
181185
await async_process_discovery_payload(component, discovery_id, payload)
182186

183187
async def async_process_discovery_payload(component, discovery_id, payload):
188+
"""Process the payload of a new discovery."""
184189

185190
_LOGGER.debug("Process discovery payload %s", payload)
186191
discovery_hash = (component, discovery_id)

0 commit comments

Comments
 (0)