diff --git a/stqdm/configuration_manager.py b/stqdm/configuration_manager.py new file mode 100644 index 0000000..371a17e --- /dev/null +++ b/stqdm/configuration_manager.py @@ -0,0 +1,40 @@ +from abc import ABCMeta +from contextlib import contextmanager +from typing import Any, Generic, Iterator, Mapping, Optional, TypeVar, cast + + +class ScopeError(RuntimeError): + pass + + +ScopeConfig = TypeVar("ScopeConfig", bound=Mapping[str, Any]) + + +class ScopeManager(Generic[ScopeConfig], metaclass=ABCMeta): # pylint: disable=invalid-name,inconsistent-mro + def __init__(self, stack: Optional[list[ScopeConfig]] = None) -> None: + if stack is None: + stack = [] + self._scope_stack: list[ScopeConfig] = stack + + def set_default_kwargs(self, scope_config: ScopeConfig) -> None: + if not self._scope_stack: + self._scope_stack.append(scope_config) + else: + self._scope_stack[0] = scope_config + + @contextmanager + def scope(self, scope_config: ScopeConfig) -> Iterator[ScopeConfig]: + self._scope_stack.append(scope_config) + try: + yield scope_config + finally: + self._scope_stack.pop() + + def get_current_defaults(self) -> ScopeConfig: + if self._scope_stack: + return self._scope_stack[-1] + raise ScopeError("No default scope set. You may have messed up with the scope stack. Please report this issue.") + + def use_current_default_if_config_not_provided(self, config: ScopeConfig) -> ScopeConfig: + # This typing is probably wrong, but in our case, we will be using a TypedDict and it should be ok + return cast(ScopeConfig, {**self.get_current_defaults(), **config}) diff --git a/tests/test_configuration_manager.py b/tests/test_configuration_manager.py new file mode 100644 index 0000000..6072d0f --- /dev/null +++ b/tests/test_configuration_manager.py @@ -0,0 +1,35 @@ +from stqdm.configuration_manager import ScopeManager + + +def test_scope_manager__get_current_defaults(): + scope_manager = ScopeManager([{}]) + assert scope_manager.get_current_defaults() == {} + scope_manager = ScopeManager([{"foo": "bar"}]) + assert scope_manager.get_current_defaults() == {"foo": "bar"} + + +def test_scope_manager__set_default_kwargs(): + scope_manager = ScopeManager([{}]) + assert scope_manager.get_current_defaults() == {} + scope_manager.set_default_kwargs({"foo": "bar"}) + assert scope_manager.get_current_defaults() == {"foo": "bar"} + scope_manager.set_default_kwargs({"foo": "baz", "bar": "foo"}) + assert scope_manager.get_current_defaults() == {"foo": "baz", "bar": "foo"} + + +def test_scope_manager__use_current_default_if_config_not_provided(): + scope_manager = ScopeManager([{"foo": "bar"}]) + assert scope_manager.use_current_default_if_config_not_provided({}) == {"foo": "bar"} + assert scope_manager.use_current_default_if_config_not_provided({"foo": "baz"}) == {"foo": "baz"} + assert scope_manager.use_current_default_if_config_not_provided({"bar": "foo"}) == {"foo": "bar", "bar": "foo"} + + +def test_scope_manager__scope(): + scope_manager = ScopeManager([{}]) + assert scope_manager.get_current_defaults() == {} + with scope_manager.scope({"foo": "bar"}): + assert scope_manager.get_current_defaults() == {"foo": "bar"} + with scope_manager.scope({"foo": "baz"}): + assert scope_manager.get_current_defaults() == {"foo": "baz"} + assert scope_manager.get_current_defaults() == {"foo": "bar"} + assert scope_manager.get_current_defaults() == {}