Skip to content

Commit

Permalink
fix: Modify the integration to fully use HA async paradigm. (#175)
Browse files Browse the repository at this point in the history
  • Loading branch information
cayossarian authored Jan 2, 2025
1 parent 6232651 commit 31d3dbc
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 117 deletions.
185 changes: 132 additions & 53 deletions custom_components/stateful_scenes/StatefulScenes.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Stateful Scenes for Home Assistant."""

import asyncio
import logging
from typing import Any

from homeassistant.core import Event, EventStateChangedData, HomeAssistant
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.event import async_call_later
from homeassistant.helpers.template import area_id, area_name

Expand Down Expand Up @@ -57,7 +57,7 @@ def __init__(

async def async_start(self, callback) -> None:
"""Start a new timer if we have a duration."""
await self._hass.async_add_executor_job(self.cancel_if_active)
await self.async_cancel_if_active()
if self.transition_time > 0 and self._hass is not None:
_LOGGER.debug(
"Starting scene evaluation timer for %s seconds",
Expand Down Expand Up @@ -88,14 +88,10 @@ def set_debounce_time(self, time: float) -> None:
"""Set the timer duration."""
self._debounce_time = time or 0.0

def set(self, cancel_callback) -> None:
"""Store new timer's cancel callback."""
self._cancel_callback = cancel_callback

def cancel_if_active(self) -> None:
async def async_cancel_if_active(self) -> None:
"""Cancel current timer if active."""
if self._cancel_callback:
_LOGGER.debug("Cancelling active scene evaluation timer")
_LOGGER.debug("Async cancelling active scene evaluation timer")
self._cancel_callback()
self._cancel_callback = None

Expand Down Expand Up @@ -144,6 +140,8 @@ def __init__(self, hass: HomeAssistant, scene_conf: dict) -> None:
if self._entity_id is None:
self._entity_id = get_entity_id_from_id(self.hass, self._id)

hass.async_create_task(self.async_initialize())

@property
def attributes(self) -> SceneStateAttributes:
"""Return scene attributes matching SceneStateProtocol."""
Expand Down Expand Up @@ -178,7 +176,7 @@ def area_id(self) -> str:
"""Return the area_id of the scene."""
return self._area_id

def turn_on(self):
async def async_turn_on(self):
"""Turn on the scene."""
if self._entity_id is None:
raise StatefulScenesYamlInvalid(
Expand All @@ -187,16 +185,13 @@ def turn_on(self):

# Store the current state of the entities
for entity_id in self.entities:
self.store_entity_state(entity_id)
await self.async_store_entity_state(entity_id)

asyncio.run_coroutine_threadsafe(
self._scene_evaluation_timer.async_start(
self.async_timer_evaluate_scene_state
),
self.hass.loop,
).result()
await self._scene_evaluation_timer.async_start(
self.async_timer_evaluate_scene_state
)

self.hass.services.call(
await self.hass.services.async_call(
domain="scene",
service="turn_on",
target={"entity_id": self._entity_id},
Expand All @@ -215,29 +210,30 @@ def set_off_scene(self, entity_id: str | None) -> None:
if entity_id:
self._restore_on_deactivate = False

def turn_off(self):
async def async_set_off_scene(self, entity_id: str | None) -> None:
"""Set the off scene entity_id asynchronously."""
self.set_off_scene(entity_id)

async def async_turn_off(self):
"""Turn off all entities in the scene."""
if not self._is_on: # already off
return

if self._off_scene_entity_id:
self._scene_evaluation_timer.cancel_if_active()
self.hass.services.call(
await self._scene_evaluation_timer.async_cancel_if_active()
await self.hass.services.async_call(
domain="scene",
service="turn_on",
target={"entity_id": self._off_scene_entity_id},
service_data={"transition": self._transition_time},
)
elif self.restore_on_deactivate:
asyncio.run_coroutine_threadsafe(
self._scene_evaluation_timer.async_start(
self.async_timer_evaluate_scene_state
),
self.hass.loop,
).result()
self.restore()
await self._scene_evaluation_timer.async_start(
self.async_timer_evaluate_scene_state
)
await self.async_restore()
else:
self.hass.services.call(
await self.hass.services.async_call(
domain="homeassistant",
service="turn_off",
target={"entity_id": list(self.entities.keys())},
Expand Down Expand Up @@ -292,41 +288,70 @@ def set_ignore_unavailable(self, ignore_unavailable):
"""Set the ignore unavailable flag."""
self._ignore_unavailable = ignore_unavailable

def register_callback(self):
async def async_initialize(self) -> None:
"""Initialize the scene and evaluate its initial state."""
_LOGGER.debug("Initializing scene: %s", self.name)
await self.async_check_all_states()
_LOGGER.debug(
"Initial state for scene %s: %s", self.name, "on" if self._is_on else "off"
)

async def async_register_callback(self):
"""Register callback."""
schedule_update_func = self.callback_funcs.get("schedule_update_func", None)
state_change_func = self.callback_funcs.get("state_change_func", None)
if schedule_update_func is None or state_change_func is None:
raise ValueError("No callback functions provided for scene.")

self.schedule_update = schedule_update_func

# Register state change callback for all entities in the scene
entity_ids = list(self.entities.keys())
_LOGGER.debug(
"Registering callbacks for entities: %s in scene: %s",
entity_ids,
self.name,
)

# Set up state change tracking
self.callback = state_change_func(
self.hass, self.entities.keys(), self.update_callback
self.hass, entity_ids, self.async_update_callback
)

def unregister_callback(self):
async def async_unregister_callback(self):
"""Unregister callbacks."""
if self.callback is not None:
self.callback()
self.callback = None

def update_callback(self, event: Event[EventStateChangedData]):
async def async_update_callback(self, event: Event[EventStateChangedData]):
"""Update the scene when a tracked entity changes state."""
entity_id = event.data.get("entity_id")
new_state = event.data.get("new_state")
old_state = event.data.get("old_state")

self.store_entity_state(entity_id, old_state)
_LOGGER.debug(
"State change callback for %s in scene %s: old=%s new=%s",
entity_id,
self.name,
old_state.state if old_state else None,
new_state.state if new_state else None,
)

# Store the old state
await self.async_store_entity_state(entity_id, old_state)

# Check if this update is interesting
if self.is_interesting_update(old_state, new_state):
if not self._scene_evaluation_timer.is_active():
asyncio.run_coroutine_threadsafe(
self.async_evaluate_scene_state(), self.hass.loop
).result()
await self.async_evaluate_scene_state()

async def async_evaluate_scene_state(self):
"""Evaluate scene state immediately."""
await self.hass.async_add_executor_job(self.check_all_states)
_LOGGER.debug("[Scene: %s] Starting scene evaluation", self.name)
await self.async_check_all_states()
if self.schedule_update:
await self.hass.async_add_executor_job(self.schedule_update, True)
self.schedule_update(True)

async def async_timer_evaluate_scene_state(self, _now):
"""Handle Callback from HA after expiration of SceneEvaluationTimer."""
Expand All @@ -337,6 +362,8 @@ async def async_timer_evaluate_scene_state(self, _now):
def is_interesting_update(self, old_state, new_state):
"""Check if the state change is interesting."""
if old_state is None:
if new_state is None:
_LOGGER.warning("New State is None and Old State is None")
return True
if not self.compare_values(old_state.state, new_state.state):
return True
Expand All @@ -354,11 +381,31 @@ def is_interesting_update(self, old_state, new_state):
return True
return False

def check_state(self, entity_id, new_state):
async def async_check_state(self, entity_id, new_state):
"""Check if entity's current state matches the scene's defined state."""
if new_state is None:
_LOGGER.warning("Entity not found: %s", entity_id)
return False
# Check if entity exists in registry
# Get entity registry directly
registry = er.async_get(self.hass)
entry = registry.async_get(entity_id)

if entry is None:
_LOGGER.debug(
"[Scene: %s] Entity %s not found in registry.",
self.name,
entity_id,
)
return False

# Check if entity exists in state
new_state = self.hass.states.get(entity_id)
if new_state is None:
_LOGGER.debug(
"[Scene: %s] Entity %s not found in state.",
self.name,
entity_id,
)
return False

if self.ignore_unavailable and new_state.state == "unavailable":
return None
Expand Down Expand Up @@ -407,7 +454,7 @@ def check_state(self, entity_id, new_state):
)
return True

def check_all_states(self):
async def async_check_all_states(self):
"""Check the state of the scene.
If all entities are in the desired state, the scene is on. If any entity is not
Expand All @@ -416,24 +463,56 @@ def check_all_states(self):
"""
for entity_id in self.entities:
state = self.hass.states.get(entity_id)
self.states[entity_id] = self.check_state(entity_id, state)
self.states[entity_id] = await self.async_check_state(entity_id, state)

states = [state for state in self.states.values() if state is not None]
result = all(states) if states else False
self._is_on = result

if not states:
self._is_on = False
else:
self._is_on = all(states)

def store_entity_state(self, entity_id, state=None):
"""Store the state of an entity.
If the state is not provided, the current state of the entity is used.
"""
async def async_store_entity_state(self, entity_id, state=None):
"""Store the state of an entity."""
if state is None:
state = self.hass.states.get(entity_id)
self.restore_states[entity_id] = state

async def async_restore(self):
"""Restore the state entities."""
entities = {}
for entity_id, state in self.restore_states.items():
if state is None:
continue

# restore state
entities[entity_id] = {"state": state.state}

# do not restore attributes if the entity is off
if state.state == "off":
continue

# restore attributes
if state.domain in ATTRIBUTES_TO_CHECK:
entity_attrs = state.attributes
for attribute in ATTRIBUTES_TO_CHECK.get(state.domain):
if attribute not in entity_attrs:
continue
entities[entity_id][attribute] = entity_attrs[attribute]

service_data = {"entities": entities}
if self._transition_time is not None:
service_data["transition"] = self._transition_time
await self.hass.services.async_call(
domain="scene", service="apply", service_data=service_data
)

# def store_entity_state(self, entity_id, state=None):
# """Store the state of an entity.
#
# If the state is not provided, the current state of the entity is used.
# """
# if state is None:
# state = self.hass.states.get(entity_id)
# self.restore_states[entity_id] = state

def restore(self):
"""Restore the state entities."""
entities = {}
Expand Down
3 changes: 2 additions & 1 deletion custom_components/stateful_scenes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Platform.SWITCH,
]


# https://developers.home-assistant.io/docs/config_entries_index/#setting-up-an-entry
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up this integration using UI."""
Expand All @@ -54,7 +55,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:

if is_hub and entry.data.get(CONF_ENABLE_DISCOVERY, False):
discovery_manager = DiscoveryManager(hass, entry)
await discovery_manager.start_discovery()
await discovery_manager.async_start_discovery()

await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)

Expand Down
2 changes: 1 addition & 1 deletion custom_components/stateful_scenes/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, hass: HomeAssistant, ha_config: ConfigType) -> None:
self.hass = hass
self.ha_config = ha_config

async def start_discovery(self) -> None:
async def async_start_discovery(self) -> None:
"""Start the discovery procedure."""
_LOGGER.debug("Start auto discovering devices")
entity_registry = er.async_get(self.hass)
Expand Down
6 changes: 3 additions & 3 deletions custom_components/stateful_scenes/number.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def device_info(self) -> DeviceInfo | None:
suggested_area=self._scene.area_id,
)

def set_native_value(self, value: float) -> None:
async def async_set_native_value(self, value: float) -> None:
"""Update the current value."""
self._scene.set_transition_time(value)

Expand Down Expand Up @@ -159,7 +159,7 @@ def device_info(self) -> DeviceInfo | None:
manufacturer=DEVICE_INFO_MANUFACTURER,
)

def set_native_value(self, value: float) -> None:
async def async_set_native_value(self, value: float) -> None:
"""Update the current value."""
self._scene.set_debounce_time(value)

Expand Down Expand Up @@ -220,7 +220,7 @@ def device_info(self) -> DeviceInfo | None:
manufacturer=DEVICE_INFO_MANUFACTURER,
)

def set_native_value(self, value: int) -> None:
async def async_set_native_value(self, value: int) -> None:
"""Update the current value."""
self._scene.set_number_tolerance(value)

Expand Down
Loading

0 comments on commit 31d3dbc

Please sign in to comment.