Skip to content

Commit

Permalink
Fix running async functions without event loop during testing (#36859)
Browse files Browse the repository at this point in the history
* Keep device commissioning method in CommissionDeviceTest class

* Improve readability

* Run matter testing on a single even loop context

* Update all run_tests_no_exit() usages
  • Loading branch information
arkq authored Jan 17, 2025
1 parent 594ffe2 commit 4fd7215
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -972,18 +972,6 @@ def __init__(self, *args):
# The named pipe name must be set in the derived classes
self.app_pipe = None

async def commission_devices(self) -> bool:
conf = self.matter_test_config

for commission_idx, node_id in enumerate(conf.dut_node_ids):
logging.info(
f"Starting commissioning for root index {conf.root_of_trust_index}, fabric ID 0x{conf.fabric_id:016X}, node ID 0x{node_id:016X}")
logging.info(f"Commissioning method: {conf.commissioning_method}")

await CommissionDeviceTest.commission_device(self, commission_idx)

return True

def get_test_steps(self, test: str) -> list[TestStep]:
''' Retrieves the test step list for the given test
Expand Down Expand Up @@ -1168,17 +1156,14 @@ def setup_test(self):
self.step(1)

def teardown_class(self):
"""Final teardown after all tests: log all problems"""
if len(self.problems) == 0:
return

logging.info("###########################################################")
logging.info("Problems found:")
logging.info("===============")
for problem in self.problems:
logging.info(str(problem))
logging.info("###########################################################")

"""Final teardown after all tests: log all problems."""
if len(self.problems) > 0:
logging.info("###########################################################")
logging.info("Problems found:")
logging.info("===============")
for problem in self.problems:
logging.info(str(problem))
logging.info("###########################################################")
super().teardown_class()

def check_pics(self, pics_key: str) -> bool:
Expand Down Expand Up @@ -2107,8 +2092,7 @@ def parse_matter_test_args(argv: Optional[List[str]] = None) -> MatterTestConfig

def _async_runner(body, self: MatterBaseTest, *args, **kwargs):
timeout = self.matter_test_config.timeout if self.matter_test_config.timeout is not None else self.default_timeout
runner_with_timeout = asyncio.wait_for(body(self, *args, **kwargs), timeout=timeout)
return asyncio.run(runner_with_timeout)
return self.event_loop.run_until_complete(asyncio.wait_for(body(self, *args, **kwargs), timeout=timeout))


def async_test_body(body):
Expand Down Expand Up @@ -2301,7 +2285,7 @@ def run_on_singleton_matching_endpoint(accept_function: EndpointCheckFunction):
def run_on_singleton_matching_endpoint_internal(body):
def matching_runner(self: MatterBaseTest, *args, **kwargs):
runner_with_timeout = asyncio.wait_for(_get_all_matching_endpoints(self, accept_function), timeout=30)
matching = asyncio.run(runner_with_timeout)
matching = self.event_loop.run_until_complete(runner_with_timeout)
asserts.assert_less_equal(len(matching), 1, "More than one matching endpoint found for singleton test.")
if not matching:
logging.info("Test is not applicable to any endpoint - skipping test")
Expand Down Expand Up @@ -2348,7 +2332,7 @@ def run_if_endpoint_matches(accept_function: EndpointCheckFunction):
def run_if_endpoint_matches_internal(body):
def per_endpoint_runner(self: MatterBaseTest, *args, **kwargs):
runner_with_timeout = asyncio.wait_for(should_run_test_on_endpoint(self, accept_function), timeout=60)
should_run_test = asyncio.run(runner_with_timeout)
should_run_test = self.event_loop.run_until_complete(runner_with_timeout)
if not should_run_test:
logging.info("Test is not applicable to this endpoint - skipping test")
asserts.skip('Endpoint does not match test requirements')
Expand All @@ -2367,14 +2351,25 @@ def __init__(self, *args):
self.is_commissioning = True

def test_run_commissioning(self):
if not asyncio.run(self.commission_devices()):
raise signals.TestAbortAll("Failed to commission node")
if not self.event_loop.run_until_complete(self.commission_devices()):
raise signals.TestAbortAll("Failed to commission node(s)")

async def commission_devices(self) -> bool:
conf = self.matter_test_config

async def commission_device(instance: MatterBaseTest, i) -> bool:
dev_ctrl = instance.default_controller
conf = instance.matter_test_config
commissioned = []
setup_payloads = self.get_setup_payload_info()
for node_id, setup_payload in zip(conf.dut_node_ids, setup_payloads):
logging.info(f"Starting commissioning for root index {conf.root_of_trust_index}, "
f"fabric ID 0x{conf.fabric_id:016X}, node ID 0x{node_id:016X}")
logging.info(f"Commissioning method: {conf.commissioning_method}")
commissioned.append(await self.commission_device(node_id, setup_payload))

info = instance.get_setup_payload_info()[i]
return all(commissioned)

async def commission_device(self, node_id: int, info: SetupPayloadInfo) -> bool:
dev_ctrl = self.default_controller
conf = self.matter_test_config

if conf.tc_version_to_simulate is not None and conf.tc_user_response_to_simulate is not None:
logging.debug(
Expand All @@ -2384,7 +2379,7 @@ async def commission_device(instance: MatterBaseTest, i) -> bool:
if conf.commissioning_method == "on-network":
try:
await dev_ctrl.CommissionOnNetwork(
nodeId=conf.dut_node_ids[i],
nodeId=node_id,
setupPinCode=info.passcode,
filterType=info.filter_type,
filter=info.filter_value
Expand All @@ -2398,7 +2393,7 @@ async def commission_device(instance: MatterBaseTest, i) -> bool:
await dev_ctrl.CommissionWiFi(
info.filter_value,
info.passcode,
conf.dut_node_ids[i],
node_id,
conf.wifi_ssid,
conf.wifi_passphrase,
isShortDiscriminator=(info.filter_type == DiscoveryFilterType.SHORT_DISCRIMINATOR)
Expand All @@ -2412,7 +2407,7 @@ async def commission_device(instance: MatterBaseTest, i) -> bool:
await dev_ctrl.CommissionThread(
info.filter_value,
info.passcode,
conf.dut_node_ids[i],
node_id,
conf.thread_operational_dataset,
isShortDiscriminator=(info.filter_type == DiscoveryFilterType.SHORT_DISCRIMINATOR)
)
Expand All @@ -2425,7 +2420,8 @@ async def commission_device(instance: MatterBaseTest, i) -> bool:
logging.warning("==== USING A DIRECT IP COMMISSIONING METHOD NOT SUPPORTED IN THE LONG TERM ====")
await dev_ctrl.CommissionIP(
ipaddr=conf.commissionee_ip_address_just_for_testing,
setupPinCode=info.passcode, nodeid=conf.dut_node_ids[i]
setupPinCode=info.passcode,
nodeid=node_id,
)
return True
except ChipStackError as e:
Expand All @@ -2441,10 +2437,10 @@ def default_matter_test_main():
In this case, only one test class in a test script is allowed.
To make your test script executable, add the following to your file:
.. code-block:: python
from chip.testing.matter_testing.py import default_matter_test_main
from chip.testing.matter_testing import default_matter_test_main
...
if __name__ == '__main__':
default_matter_test_main.main()
default_matter_test_main()
"""

matter_test_config = parse_matter_test_args()
Expand Down Expand Up @@ -2473,7 +2469,15 @@ def get_test_info(test_class: MatterBaseTest, matter_test_config: MatterTestConf
return info


def run_tests_no_exit(test_class: MatterBaseTest, matter_test_config: MatterTestConfig, hooks: TestRunnerHooks, default_controller=None, external_stack=None) -> bool:
def run_tests_no_exit(test_class: MatterBaseTest, matter_test_config: MatterTestConfig,
event_loop: asyncio.AbstractEventLoop, hooks: TestRunnerHooks,
default_controller=None, external_stack=None) -> bool:

# NOTE: It's not possible to pass event loop via Mobly TestRunConfig user params, because the
# Mobly deep copies the user params before passing them to the test class and the event
# loop is not serializable. So, we are setting the event loop as a test class member.
CommissionDeviceTest.event_loop = event_loop
test_class.event_loop = event_loop

get_test_info(test_class, matter_test_config)

Expand Down Expand Up @@ -2553,9 +2557,13 @@ def run_tests_no_exit(test_class: MatterBaseTest, matter_test_config: MatterTest
duration = (datetime.now(timezone.utc) - runner_start_time) / timedelta(microseconds=1)
hooks.stop(duration=duration)

# Shutdown the stack when all done
if not external_stack:
stack.Shutdown()
async def shutdown():
stack.Shutdown()
# Shutdown the stack when all done. Use the async runner to ensure that
# during the shutdown callbacks can use tha same async context which was used
# during the initialization.
event_loop.run_until_complete(shutdown())

if ok:
logging.info("Final result: PASS !")
Expand All @@ -2564,6 +2572,9 @@ def run_tests_no_exit(test_class: MatterBaseTest, matter_test_config: MatterTest
return ok


def run_tests(test_class: MatterBaseTest, matter_test_config: MatterTestConfig, hooks: TestRunnerHooks, default_controller=None, external_stack=None) -> None:
if not run_tests_no_exit(test_class, matter_test_config, hooks, default_controller, external_stack):
sys.exit(1)
def run_tests(test_class: MatterBaseTest, matter_test_config: MatterTestConfig,
hooks: TestRunnerHooks, default_controller=None, external_stack=None) -> None:
with asyncio.Runner() as runner:
if not run_tests_no_exit(test_class, matter_test_config, runner.get_loop(),
hooks, default_controller, external_stack):
sys.exit(1)
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
# pip install opencv-python requests click_option_group
# python src/python_testing/post_certification_tests/production_device_checks.py

import asyncio
import base64
import hashlib
import importlib
Expand Down Expand Up @@ -390,9 +391,9 @@ def run_test(test_class: MatterBaseTest, tests: typing.List[str], test_config: T
stack = test_config.get_stack()
controller = test_config.get_controller()
matter_config = test_config.get_config(tests)
ok = run_tests_no_exit(test_class, matter_config, hooks, controller, stack)
if not ok:
print(f"Test failure. Failed on step: {hooks.get_failures()}")
with asyncio.Runner() as runner:
if not run_tests_no_exit(test_class, matter_config, runner.get_loop(), hooks, controller, stack):
print(f"Test failure. Failed on step: {hooks.get_failures()}")
return hooks.get_failures()


Expand Down
5 changes: 4 additions & 1 deletion src/python_testing/test_testing/MockTestRunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.
#

import asyncio
import importlib
import os
import sys
Expand Down Expand Up @@ -75,4 +76,6 @@ def run_test_with_mock_read(self, read_cache: Attribute.AsyncReadTransaction.Re
self.default_controller.Read = AsyncMock(return_value=read_cache)
# This doesn't need to do anything since we are overriding the read anyway
self.default_controller.FindOrEstablishPASESession = AsyncMock(return_value=None)
return run_tests_no_exit(self.test_class, self.config, hooks, self.default_controller, self.stack)
with asyncio.Runner() as runner:
return run_tests_no_exit(self.test_class, self.config, runner.get_loop(),
hooks, self.default_controller, self.stack)
5 changes: 4 additions & 1 deletion src/python_testing/test_testing/test_TC_CCNTL_2_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# limitations under the License.
#

import asyncio
import base64
import os
import pathlib
Expand Down Expand Up @@ -166,7 +167,9 @@ def run_test_with_mock(self, dynamic_invoke_return: typing.Callable, dynamic_eve
self.default_controller.FindOrEstablishPASESession = AsyncMock(return_value=None)
self.default_controller.ReadEvent = AsyncMock(return_value=[], side_effect=dynamic_event_return)

return run_tests_no_exit(self.test_class, self.config, hooks, self.default_controller, self.stack)
with asyncio.Runner() as runner:
return run_tests_no_exit(self.test_class, self.config, runner.get_loop(),
hooks, self.default_controller, self.stack)


@click.command()
Expand Down
5 changes: 4 additions & 1 deletion src/python_testing/test_testing/test_TC_MCORE_FS_1_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# limitations under the License.
#

import asyncio
import base64
import os
import pathlib
Expand Down Expand Up @@ -137,7 +138,9 @@ def run_test_with_mock(self, dynamic_invoke_return: typing.Callable, dynamic_eve
self.default_controller.FindOrEstablishPASESession = AsyncMock(return_value=None)
self.default_controller.ReadEvent = AsyncMock(return_value=[], side_effect=dynamic_event_return)

return run_tests_no_exit(self.test_class, self.config, hooks, self.default_controller, self.stack)
with asyncio.Runner() as runner:
return run_tests_no_exit(self.test_class, self.config, runner.get_loop(),
hooks, self.default_controller, self.stack)


@click.command()
Expand Down

0 comments on commit 4fd7215

Please sign in to comment.