diff --git a/asyncua/common/connection.py b/asyncua/common/connection.py index e45230de6..f61485938 100644 --- a/asyncua/common/connection.py +++ b/asyncua/common/connection.py @@ -3,6 +3,7 @@ from datetime import datetime, timedelta import logging import copy +from typing_extensions import Literal from asyncua import ua from asyncua.ua.uaerrors import UaInvalidParameterError @@ -51,17 +52,30 @@ def is_chunk_count_within_limit(self, sz: int) -> bool: _logger.error("Number of message chunks: %s is > configured max chunk count: %s", sz, self.max_chunk_count) return within_limit + def _set_new_limits_from_ack(self, ack: ua.Acknowledge, role: Literal["client", "server"]) -> None: + max_recv_buffer = ack.ReceiveBufferSize if role == "client" else ack.SendBufferSize + max_send_buffer = ack.SendBufferSize if role == "client" else ack.ReceiveBufferSize + + new_limits = TransportLimits( + max_chunk_count=ack.MaxChunkCount, + max_recv_buffer=max_recv_buffer, + max_send_buffer=max_send_buffer, + max_message_size=ack.MaxMessageSize, + ) + if new_limits != self: + self.max_chunk_count = new_limits.max_chunk_count + self.max_recv_buffer = new_limits.max_recv_buffer + self.max_send_buffer = new_limits.max_send_buffer + self.max_message_size = new_limits.max_message_size + _logger.info("updating %s limits to: %s", role, self) + def create_acknowledge_and_set_limits(self, msg: ua.Hello) -> ua.Acknowledge: ack = ua.Acknowledge() ack.ReceiveBufferSize = min(msg.ReceiveBufferSize, self.max_send_buffer) ack.SendBufferSize = min(msg.SendBufferSize, self.max_recv_buffer) ack.MaxChunkCount = self._select_limit(msg.MaxChunkCount, self.max_chunk_count) ack.MaxMessageSize = self._select_limit(msg.MaxMessageSize, self.max_message_size) - self.max_chunk_count = ack.MaxChunkCount - self.max_recv_buffer = ack.SendBufferSize - self.max_send_buffer = ack.ReceiveBufferSize - self.max_message_size = ack.MaxMessageSize - _logger.info("updating server limits to: %s", self) + self._set_new_limits_from_ack(ack, "server") return ack def create_hello_limits(self, msg: ua.Hello) -> ua.Hello: @@ -72,11 +86,7 @@ def create_hello_limits(self, msg: ua.Hello) -> ua.Hello: return msg def update_client_limits(self, msg: ua.Acknowledge) -> None: - self.max_chunk_count = msg.MaxChunkCount - self.max_recv_buffer = msg.ReceiveBufferSize - self.max_send_buffer = msg.SendBufferSize - self.max_message_size = msg.MaxMessageSize - _logger.info("updating client limits to: %s", self) + self._set_new_limits_from_ack(msg, "client") class MessageChunk: diff --git a/tests/common/test_connection.py b/tests/common/test_connection.py new file mode 100644 index 000000000..3d9cb04d9 --- /dev/null +++ b/tests/common/test_connection.py @@ -0,0 +1,22 @@ +import pytest + +from asyncua.common.connection import TransportLimits, ua + +@pytest.mark.parametrize('transport_limits, ua_hello, expected_ack, expected_transport_limits', [ + (TransportLimits(), ua.Hello(), ua.Acknowledge(0, 65535, 65535, 104857600, 1601), TransportLimits()), + (TransportLimits(), ua.Hello(0, 100, 200, 1000, 10), ua.Acknowledge(0, 100, 200, 1000, 10), TransportLimits(200, 100, 10, 1000)), +]) +def test_create_acknowledge_and_set_limits(transport_limits, ua_hello, expected_ack, expected_transport_limits): + ua_ack = transport_limits.create_acknowledge_and_set_limits(ua_hello) + assert ua_ack == expected_ack + assert transport_limits == expected_transport_limits + + +@pytest.mark.parametrize('ua_ack, expected_transport_limits', [ + (ua.Acknowledge(0, 65535, 65535, 104857600, 1601), TransportLimits()), + (ua.Acknowledge(0, 100, 200, 1000, 10), TransportLimits(100, 200, 10, 1000)), +]) +def test_update_client_limits(ua_ack, expected_transport_limits): + transport_limits = TransportLimits() + transport_limits.update_client_limits(ua_ack) + assert transport_limits == expected_transport_limits