Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 22 additions & 22 deletions homeassistant/components/automation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Allow to set up simple automation rules via the config file."""
import asyncio
import importlib
import logging
from typing import Any, Awaitable, Callable, List, Optional, Set
Expand Down Expand Up @@ -127,13 +128,11 @@ def automations_with_entity(hass: HomeAssistant, entity_id: str) -> List[str]:

component = hass.data[DOMAIN]

results = []

for automation_entity in component.entities:
if entity_id in automation_entity.referenced_entities:
results.append(automation_entity.entity_id)

return results
return [
automation_entity.entity_id
for automation_entity in component.entities
if entity_id in automation_entity.referenced_entities
]


@callback
Expand All @@ -160,13 +159,11 @@ def automations_with_device(hass: HomeAssistant, device_id: str) -> List[str]:

component = hass.data[DOMAIN]

results = []

for automation_entity in component.entities:
if device_id in automation_entity.referenced_devices:
results.append(automation_entity.entity_id)

return results
return [
automation_entity.entity_id
for automation_entity in component.entities
if device_id in automation_entity.referenced_devices
]


@callback
Expand Down Expand Up @@ -443,26 +440,29 @@ async def _async_attach_triggers(
self, home_assistant_start: bool
) -> Optional[Callable[[], None]]:
"""Set up the triggers."""
removes = []
info = {"name": self._name, "home_assistant_start": home_assistant_start}

triggers = []
for conf in self._trigger_config:
platform = importlib.import_module(f".{conf[CONF_PLATFORM]}", __name__)

remove = await platform.async_attach_trigger( # type: ignore
self.hass, conf, self.async_trigger, info
triggers.append(
platform.async_attach_trigger( # type: ignore
self.hass, conf, self.async_trigger, info
)
)

if not remove:
_LOGGER.error("Error setting up trigger %s", self._name)
continue
results = await asyncio.gather(*triggers)

_LOGGER.info("Initialized trigger %s", self._name)
removes.append(remove)
if None in results:
_LOGGER.error("Error setting up trigger %s", self._name)

removes = [remove for remove in results if remove is not None]
if not removes:
return None

_LOGGER.info("Initialized trigger %s", self._name)

@callback
def remove_triggers():
"""Remove attached triggers."""
Expand Down
34 changes: 19 additions & 15 deletions homeassistant/components/automation/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,19 @@ async def async_validate_config_item(hass, config, full_config=None):
config[CONF_TRIGGER] = triggers

if CONF_CONDITION in config:
conditions = []
for cond in config[CONF_CONDITION]:
cond = await condition.async_validate_condition_config(hass, cond)
conditions.append(cond)
config[CONF_CONDITION] = conditions

actions = []
for action in config[CONF_ACTION]:
action = await script.async_validate_action_config(hass, action)
actions.append(action)
config[CONF_ACTION] = actions
config[CONF_CONDITION] = await asyncio.gather(
*[
condition.async_validate_condition_config(hass, cond)
for cond in config[CONF_CONDITION]
]
)

config[CONF_ACTION] = await asyncio.gather(
*[
script.async_validate_action_config(hass, action)
for action in config[CONF_ACTION]
]
)

return config

Expand All @@ -69,16 +71,18 @@ async def _try_async_validate_config_item(hass, config, full_config=None):

async def async_validate_config(hass, config):
"""Validate config."""
automations = []
validated_automations = await asyncio.gather(
*(
_try_async_validate_config_item(hass, p_config, config)
for _, p_config in config_per_platform(config, DOMAIN)
)
)
for validated_automation in validated_automations:
if validated_automation is not None:
automations.append(validated_automation)

automations = [
validated_automation
for validated_automation in validated_automations
if validated_automation is not None
]

# Create a copy of the configuration with all config for current
# component removed and add validated config back in.
Expand Down
10 changes: 7 additions & 3 deletions homeassistant/components/automation/litejet.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,13 @@ def released():
cancel_pressed_more_than()
cancel_pressed_more_than = None
held_time = dt_util.utcnow() - pressed_time
if held_less_than is not None and held_time < held_less_than:
if held_more_than is None or held_time > held_more_than:
hass.add_job(call_action)

if (
held_less_than is not None
and held_time < held_less_than
and (held_more_than is None or held_time > held_more_than)
):
hass.add_job(call_action)

hass.data["litejet_system"].on_switch_pressed(number, pressed)
hass.data["litejet_system"].on_switch_released(number, released)
Expand Down
5 changes: 1 addition & 4 deletions homeassistant/components/automation/zone.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,7 @@ def zone_automation_listener(entity, from_s, to_s):
return

zone_state = hass.states.get(zone_entity_id)
if from_s:
from_match = condition.zone(hass, zone_state, from_s)
else:
from_match = False
from_match = condition.zone(hass, zone_state, from_s) if from_s else False
to_match = condition.zone(hass, zone_state, to_s)

if (
Expand Down
14 changes: 7 additions & 7 deletions tests/components/automation/test_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ async def test_if_fires_on_event(hass, calls):

hass.bus.async_fire("test_event")
await hass.async_block_till_done()
assert 1 == len(calls)
assert len(calls) == 1


async def test_if_fires_on_event_extra_data(hass, calls):
Expand All @@ -64,14 +64,14 @@ async def test_if_fires_on_event_extra_data(hass, calls):

hass.bus.async_fire("test_event", {"extra_key": "extra_data"})
await hass.async_block_till_done()
assert 1 == len(calls)
assert len(calls) == 1

await common.async_turn_off(hass)
await hass.async_block_till_done()

hass.bus.async_fire("test_event")
await hass.async_block_till_done()
assert 1 == len(calls)
assert len(calls) == 1


async def test_if_fires_on_event_with_data(hass, calls):
Expand All @@ -93,7 +93,7 @@ async def test_if_fires_on_event_with_data(hass, calls):

hass.bus.async_fire("test_event", {"some_attr": "some_value", "another": "value"})
await hass.async_block_till_done()
assert 1 == len(calls)
assert len(calls) == 1


async def test_if_fires_on_event_with_empty_data_config(hass, calls):
Expand All @@ -119,7 +119,7 @@ async def test_if_fires_on_event_with_empty_data_config(hass, calls):

hass.bus.async_fire("test_event", {"some_attr": "some_value", "another": "value"})
await hass.async_block_till_done()
assert 1 == len(calls)
assert len(calls) == 1


async def test_if_fires_on_event_with_nested_data(hass, calls):
Expand All @@ -143,7 +143,7 @@ async def test_if_fires_on_event_with_nested_data(hass, calls):
"test_event", {"parent_attr": {"some_attr": "some_value", "another": "value"}}
)
await hass.async_block_till_done()
assert 1 == len(calls)
assert len(calls) == 1


async def test_if_not_fires_if_event_data_not_matches(hass, calls):
Expand All @@ -165,4 +165,4 @@ async def test_if_not_fires_if_event_data_not_matches(hass, calls):

hass.bus.async_fire("test_event", {"some_attr": "some_other_value"})
await hass.async_block_till_done()
assert 0 == len(calls)
assert len(calls) == 0
22 changes: 11 additions & 11 deletions tests/components/automation/test_geo_location.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,11 @@ async def test_if_fires_on_zone_enter(hass, calls):
)
await hass.async_block_till_done()

assert 1 == len(calls)
assert len(calls) == 1
assert calls[0].context.parent_id == context.id
assert (
"geo_location - geo_location.entity - hello - hello - test"
== calls[0].data["some"]
calls[0].data["some"]
== "geo_location - geo_location.entity - hello - hello - test"
)

# Set out of zone again so we can trigger call
Expand All @@ -108,7 +108,7 @@ async def test_if_fires_on_zone_enter(hass, calls):
)
await hass.async_block_till_done()

assert 1 == len(calls)
assert len(calls) == 1


async def test_if_not_fires_for_enter_on_zone_leave(hass, calls):
Expand Down Expand Up @@ -143,7 +143,7 @@ async def test_if_not_fires_for_enter_on_zone_leave(hass, calls):
)
await hass.async_block_till_done()

assert 0 == len(calls)
assert len(calls) == 0


async def test_if_fires_on_zone_leave(hass, calls):
Expand Down Expand Up @@ -178,7 +178,7 @@ async def test_if_fires_on_zone_leave(hass, calls):
)
await hass.async_block_till_done()

assert 1 == len(calls)
assert len(calls) == 1


async def test_if_not_fires_for_leave_on_zone_enter(hass, calls):
Expand Down Expand Up @@ -213,7 +213,7 @@ async def test_if_not_fires_for_leave_on_zone_enter(hass, calls):
)
await hass.async_block_till_done()

assert 0 == len(calls)
assert len(calls) == 0


async def test_if_fires_on_zone_appear(hass, calls):
Expand Down Expand Up @@ -258,10 +258,10 @@ async def test_if_fires_on_zone_appear(hass, calls):
)
await hass.async_block_till_done()

assert 1 == len(calls)
assert len(calls) == 1
assert calls[0].context.parent_id == context.id
assert (
"geo_location - geo_location.entity - - hello - test" == calls[0].data["some"]
calls[0].data["some"] == "geo_location - geo_location.entity - - hello - test"
)


Expand Down Expand Up @@ -308,7 +308,7 @@ async def test_if_fires_on_zone_disappear(hass, calls):
hass.states.async_remove("geo_location.entity")
await hass.async_block_till_done()

assert 1 == len(calls)
assert len(calls) == 1
assert (
"geo_location - geo_location.entity - hello - - test" == calls[0].data["some"]
calls[0].data["some"] == "geo_location - geo_location.entity - hello - - test"
)
24 changes: 12 additions & 12 deletions tests/components/automation/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ async def test_service_specify_entity_id(hass, calls):

hass.bus.async_fire("test_event")
await hass.async_block_till_done()
assert 1 == len(calls)
assert len(calls) == 1
assert ["hello.world"] == calls[0].data.get(ATTR_ENTITY_ID)


Expand All @@ -170,7 +170,7 @@ async def test_service_specify_entity_id_list(hass, calls):

hass.bus.async_fire("test_event")
await hass.async_block_till_done()
assert 1 == len(calls)
assert len(calls) == 1
assert ["hello.world", "hello.world2"] == calls[0].data.get(ATTR_ENTITY_ID)


Expand All @@ -192,10 +192,10 @@ async def test_two_triggers(hass, calls):

hass.bus.async_fire("test_event")
await hass.async_block_till_done()
assert 1 == len(calls)
assert len(calls) == 1
hass.states.async_set("test.entity", "hello")
await hass.async_block_till_done()
assert 2 == len(calls)
assert len(calls) == 2


async def test_trigger_service_ignoring_condition(hass, calls):
Expand Down Expand Up @@ -268,17 +268,17 @@ async def test_two_conditions_with_and(hass, calls):
hass.states.async_set(entity_id, 100)
hass.bus.async_fire("test_event")
await hass.async_block_till_done()
assert 1 == len(calls)
assert len(calls) == 1

hass.states.async_set(entity_id, 101)
hass.bus.async_fire("test_event")
await hass.async_block_till_done()
assert 1 == len(calls)
assert len(calls) == 1

hass.states.async_set(entity_id, 151)
hass.bus.async_fire("test_event")
await hass.async_block_till_done()
assert 1 == len(calls)
assert len(calls) == 1


async def test_automation_list_setting(hass, calls):
Expand All @@ -302,11 +302,11 @@ async def test_automation_list_setting(hass, calls):

hass.bus.async_fire("test_event")
await hass.async_block_till_done()
assert 1 == len(calls)
assert len(calls) == 1

hass.bus.async_fire("test_event_2")
await hass.async_block_till_done()
assert 2 == len(calls)
assert len(calls) == 2


async def test_automation_calling_two_actions(hass, calls):
Expand Down Expand Up @@ -368,15 +368,15 @@ async def test_shared_context(hass, calls):
assert event_mock.call_count == 2

# Verify automation triggered evenet for 'hello' automation
args, kwargs = event_mock.call_args_list[0]
args, _ = event_mock.call_args_list[0]
first_trigger_context = args[0].context
assert first_trigger_context.parent_id == context.id
# Ensure event data has all attributes set
assert args[0].data.get(ATTR_NAME) is not None
assert args[0].data.get(ATTR_ENTITY_ID) is not None

# Ensure context set correctly for event fired by 'hello' automation
args, kwargs = first_automation_listener.call_args
args, _ = first_automation_listener.call_args
assert args[0].context is first_trigger_context

# Ensure the 'hello' automation state has the right context
Expand All @@ -385,7 +385,7 @@ async def test_shared_context(hass, calls):
assert state.context is first_trigger_context

# Verify automation triggered evenet for 'bye' automation
args, kwargs = event_mock.call_args_list[1]
args, _ = event_mock.call_args_list[1]
second_trigger_context = args[0].context
assert second_trigger_context.parent_id == first_trigger_context.id
# Ensure event data has all attributes set
Expand Down
Loading