Skip to content
  •  
  •  
  •  
9 changes: 3 additions & 6 deletions homeassistant/components/hassio/addon_panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,9 @@ async def get_panels(self):
return {}


def _register_panel(hass, addon, data):
"""Init coroutine to register the panel.

Return coroutine.
"""
return hass.components.panel_custom.async_register_panel(
async def _register_panel(hass, addon, data):
"""Init coroutine to register the panel."""
await hass.components.panel_custom.async_register_panel(
frontend_url_path=addon,
webcomponent_name="hassio-main",
sidebar_title=data[ATTR_TITLE],
Expand Down
2 changes: 1 addition & 1 deletion homeassistant/components/homekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def _start(self, bridged_states):
)

_LOGGER.debug("Driver start")
self.hass.async_add_executor_job(self.driver.start)
self.hass.add_job(self.driver.start)
Copy link
Member Author

@balloob balloob Apr 30, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bdraco, I found this bug -> calling an async function from sync context

self.status = STATUS_RUNNING

async def async_stop(self, *args):
Expand Down
2 changes: 1 addition & 1 deletion homeassistant/components/homekit/accessories.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ async def run_handler(self):
Run inside the Home Assistant event loop.
"""
state = self.hass.states.get(self.entity_id)
self.hass.async_add_executor_job(self.update_state_callback, None, None, state)
self.hass.async_add_job(self.update_state_callback, None, None, state)
Copy link
Member Author

@balloob balloob Apr 30, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bdraco, I found this bug -> passing a callback to the executor. Does this have to run in parallel or could the method just be called too?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe make update_state_callback a coroutine instead? It needs to schedule work on the executor. Maybe it would be good to be able to await that work?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

async_track_state_change(self.hass, self.entity_id, self.update_state_callback)

@MartinHjelmare Would that be safe to do since its being called from async_track_state_change as well?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It would be run with hass.async_run_job.

core/homeassistant/core.py

Lines 370 to 388 in d43617c

@callback
def async_run_job(
self, target: Callable[..., Union[None, Awaitable]], *args: Any
) -> None:
"""Run a job from within the event loop.
This method must be run in the event loop.
target: target to call.
args: parameters for method to call.
"""
if (
not asyncio.iscoroutine(target)
and not asyncio.iscoroutinefunction(target)
and is_callback(target)
):
target(*args)
else:
self.async_add_job(target, *args)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adjusted here a04a8fa

I'll cherry pick it out and make another PR after some more testing

async_track_state_change(self.hass, self.entity_id, self.update_state_callback)

battery_charging_state = None
Expand Down
19 changes: 7 additions & 12 deletions homeassistant/components/smhi/weather.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,27 +99,22 @@ def unique_id(self) -> str:
@Throttle(MIN_TIME_BETWEEN_UPDATES)
async def async_update(self) -> None:
"""Refresh the forecast data from SMHI weather API."""

def fail():
"""Postpone updates."""
self._fail_count += 1
if self._fail_count < 3:
self.hass.helpers.event.async_call_later(
RETRY_TIMEOUT, self.retry_update()
)

try:
with async_timeout.timeout(10):
self._forecasts = await self.get_weather_forecast()
self._fail_count = 0

except (asyncio.TimeoutError, SmhiForecastException):
_LOGGER.error("Failed to connect to SMHI API, retry in 5 minutes")
fail()
self._fail_count += 1
if self._fail_count < 3:
self.hass.helpers.event.async_call_later(
RETRY_TIMEOUT, self.retry_update
)

async def retry_update(self):
async def retry_update(self, _):
"""Retry refresh weather forecast."""
self.async_update()
await self.async_update()

async def get_weather_forecast(self) -> []:
"""Return the current forecasts from SMHI API."""
Expand Down
3 changes: 1 addition & 2 deletions homeassistant/components/vera/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,7 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b
controller.start()

hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_STOP,
lambda event: hass.async_add_executor_job(controller.stop),
EVENT_HOMEASSISTANT_STOP, lambda event: controller.stop()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vangorra, lambdas are not seen as an async function, regardless of what they return and so are already executed in the executor. So we were calling an async function from a sync context. Now it's fixed :)

)

try:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Test the NEW_NAME config flow."""
from asynctest import patch

from homeassistant import config_entries, setup
from homeassistant.components.NEW_DOMAIN.config_flow import CannotConnect, InvalidAuth
from homeassistant.components.NEW_DOMAIN.const import DOMAIN

from tests.async_mock import patch


async def test_form(hass):
"""Test we get the form."""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""Test the NEW_NAME config flow."""
from asynctest import patch

from homeassistant import config_entries, setup
from homeassistant.components.NEW_DOMAIN.const import (
DOMAIN,
Expand All @@ -9,6 +7,8 @@
)
from homeassistant.helpers import config_entry_oauth2_flow

from tests.async_mock import patch

CLIENT_ID = "1234"
CLIENT_SECRET = "5678"

Expand Down
8 changes: 8 additions & 0 deletions tests/async_mock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Mock utilities that are async aware."""
import sys

if sys.version_info[:2] < (3, 8):
from asynctest.mock import * # noqa
from asynctest.mock import CoroutineMock as AsyncMock # noqa
else:
from unittest.mock import * # noqa
5 changes: 2 additions & 3 deletions tests/auth/providers/test_command_line.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Tests for the command_line auth provider."""

import os
from unittest.mock import Mock
import uuid

import pytest
Expand All @@ -11,7 +10,7 @@
from homeassistant.auth.providers import command_line
from homeassistant.const import CONF_TYPE

from tests.common import mock_coro
from tests.async_mock import AsyncMock


@pytest.fixture
Expand Down Expand Up @@ -63,7 +62,7 @@ async def test_match_existing_credentials(store, provider):
data={"username": "good-user"},
is_new=False,
)
provider.async_credentials = Mock(return_value=mock_coro([existing]))
provider.async_credentials = AsyncMock(return_value=[existing])
credentials = await provider.async_get_or_create_credentials(
{"username": "good-user", "password": "irrelevant"}
)
Expand Down
3 changes: 2 additions & 1 deletion tests/auth/providers/test_homeassistant.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Test the Home Assistant local auth provider."""
import asyncio

from asynctest import Mock, patch
import pytest
import voluptuous as vol

Expand All @@ -12,6 +11,8 @@
homeassistant as hass_auth,
)

from tests.async_mock import Mock, patch


@pytest.fixture
def data(hass):
Expand Down
5 changes: 2 additions & 3 deletions tests/auth/providers/test_insecure_example.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
"""Tests for the insecure example auth provider."""
from unittest.mock import Mock
import uuid

import pytest

from homeassistant.auth import AuthManager, auth_store, models as auth_models
from homeassistant.auth.providers import insecure_example

from tests.common import mock_coro
from tests.async_mock import AsyncMock


@pytest.fixture
Expand Down Expand Up @@ -63,7 +62,7 @@ async def test_match_existing_credentials(store, provider):
data={"username": "user-test"},
is_new=False,
)
provider.async_credentials = Mock(return_value=mock_coro([existing]))
provider.async_credentials = AsyncMock(return_value=[existing])
credentials = await provider.async_get_or_create_credentials(
{"username": "user-test", "password": "password-test"}
)
Expand Down
12 changes: 6 additions & 6 deletions tests/auth/test_auth_store.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Tests for the auth store."""
import asyncio

import asynctest

from homeassistant.auth import auth_store

from tests.async_mock import patch


async def test_loading_no_group_data_format(hass, hass_storage):
"""Test we correctly load old data without any groups."""
Expand Down Expand Up @@ -229,12 +229,12 @@ async def test_system_groups_store_id_and_name(hass, hass_storage):
async def test_loading_race_condition(hass):
"""Test only one storage load called when concurrent loading occurred ."""
store = auth_store.AuthStore(hass)
with asynctest.patch(
with patch(
"homeassistant.helpers.entity_registry.async_get_registry"
) as mock_ent_registry, asynctest.patch(
) as mock_ent_registry, patch(
"homeassistant.helpers.device_registry.async_get_registry"
) as mock_dev_registry, asynctest.patch(
"homeassistant.helpers.storage.Store.async_load"
) as mock_dev_registry, patch(
"homeassistant.helpers.storage.Store.async_load", return_value=None
) as mock_load:
results = await asyncio.gather(store.async_get_users(), store.async_get_users())

Expand Down
59 changes: 37 additions & 22 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import uuid

from aiohttp.test_utils import unused_port as get_test_instance_port # noqa
from asynctest import MagicMock, Mock, patch

from homeassistant import auth, config_entries, core as ha, loader
from homeassistant.auth import (
Expand Down Expand Up @@ -60,6 +59,8 @@
from homeassistant.util.unit_system import METRIC_SYSTEM
import homeassistant.util.yaml.loader as yaml_loader

from tests.async_mock import AsyncMock, MagicMock, Mock, patch

_LOGGER = logging.getLogger(__name__)
INSTANCES = []
CLIENT_ID = "https://example.com/app"
Expand Down Expand Up @@ -159,20 +160,37 @@ async def async_test_home_assistant(loop):

def async_add_job(target, *args):
"""Add job."""
if isinstance(target, Mock):
return mock_coro(target(*args))
check_target = target
while isinstance(check_target, ft.partial):
check_target = check_target.func

if isinstance(check_target, Mock) and not isinstance(target, AsyncMock):
fut = asyncio.Future()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is how I caught the incorrect async/synv usage in homekit. Now this will raise if it is called from the executor during a test because there is no event loop!

fut.set_result(target(*args))
return fut

return orig_async_add_job(target, *args)

def async_add_executor_job(target, *args):
"""Add executor job."""
if isinstance(target, Mock):
return mock_coro(target(*args))
check_target = target
while isinstance(check_target, ft.partial):
check_target = check_target.func

if isinstance(check_target, Mock):
fut = asyncio.Future()
fut.set_result(target(*args))
return fut

return orig_async_add_executor_job(target, *args)

def async_create_task(coroutine):
"""Create task."""
if isinstance(coroutine, Mock):
return mock_coro()
if isinstance(coroutine, Mock) and not isinstance(coroutine, AsyncMock):
fut = asyncio.Future()
fut.set_result(None)
return fut

return orig_async_create_task(coroutine)

hass.async_add_job = async_add_job
Expand Down Expand Up @@ -311,15 +329,16 @@ async def async_mock_mqtt_component(hass, config=None):
if config is None:
config = {mqtt.CONF_BROKER: "mock-broker"}

async def _async_fire_mqtt_message(topic, payload, qos, retain):
@ha.callback
def _async_fire_mqtt_message(topic, payload, qos, retain):
async_fire_mqtt_message(hass, topic, payload, qos, retain)

with patch("paho.mqtt.client.Client") as mock_client:
mock_client().connect.return_value = 0
mock_client().subscribe.return_value = (0, 0)
mock_client().unsubscribe.return_value = (0, 0)
mock_client().publish.return_value = (0, 0)
mock_client().publish.side_effect = _async_fire_mqtt_message
mock_client = mock_client.return_value
mock_client.connect.return_value = 0
mock_client.subscribe.return_value = (0, 0)
mock_client.unsubscribe.return_value = (0, 0)
mock_client.publish.side_effect = _async_fire_mqtt_message

result = await async_setup_component(hass, mqtt.DOMAIN, {mqtt.DOMAIN: config})
assert result
Expand Down Expand Up @@ -503,7 +522,7 @@ def __init__(
self.async_setup = async_setup

if setup is None and async_setup is None:
self.async_setup = mock_coro_func(True)
self.async_setup = AsyncMock(return_value=True)

if async_setup_entry is not None:
self.async_setup_entry = async_setup_entry
Expand Down Expand Up @@ -561,7 +580,7 @@ def __init__(
self.async_setup_entry = async_setup_entry

if setup_platform is None and async_setup_platform is None:
self.async_setup_platform = mock_coro_func()
self.async_setup_platform = AsyncMock(return_value=None)


class MockEntityPlatform(entity_platform.EntityPlatform):
Expand Down Expand Up @@ -731,14 +750,10 @@ def mock_coro(return_value=None, exception=None):
def mock_coro_func(return_value=None, exception=None):
"""Return a method to create a coro function that returns a value."""

@asyncio.coroutine
def coro(*args, **kwargs):
"""Fake coroutine."""
if exception:
raise exception
return return_value
if exception:
return AsyncMock(side_effect=exception)

return coro
return AsyncMock(return_value=return_value)


@contextmanager
Expand Down
16 changes: 5 additions & 11 deletions tests/components/adguard/test_config_flow.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Tests for the AdGuard Home config flow."""
from unittest.mock import patch

import aiohttp

Expand All @@ -15,7 +14,8 @@
CONF_VERIFY_SSL,
)

from tests.common import MockConfigEntry, mock_coro
from tests.async_mock import patch
from tests.common import MockConfigEntry

FIXTURE_USER_INPUT = {
CONF_HOST: "127.0.0.1",
Expand Down Expand Up @@ -156,22 +156,16 @@ async def test_hassio_update_instance_running(hass, aioclient_mock):
entry.add_to_hass(hass)

with patch.object(
hass.config_entries,
"async_forward_entry_setup",
side_effect=lambda *_: mock_coro(True),
hass.config_entries, "async_forward_entry_setup", return_value=True,
) as mock_load:
assert await hass.config_entries.async_setup(entry.entry_id)
assert entry.state == config_entries.ENTRY_STATE_LOADED
assert len(mock_load.mock_calls) == 2

with patch.object(
hass.config_entries,
"async_forward_entry_unload",
side_effect=lambda *_: mock_coro(True),
hass.config_entries, "async_forward_entry_unload", return_value=True,
) as mock_unload, patch.object(
hass.config_entries,
"async_forward_entry_setup",
side_effect=lambda *_: mock_coro(True),
hass.config_entries, "async_forward_entry_setup", return_value=True,
) as mock_load:
result = await hass.config_entries.flow.async_init(
"adguard",
Expand Down
2 changes: 1 addition & 1 deletion tests/components/airly/test_config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import json

from airly.exceptions import AirlyError
from asynctest import patch

from homeassistant import data_entry_flow
from homeassistant.components.airly.const import DOMAIN
Expand All @@ -15,6 +14,7 @@
HTTP_FORBIDDEN,
)

from tests.async_mock import patch
from tests.common import MockConfigEntry, load_fixture

CONFIG = {
Expand Down
Loading