Skip to content

Commit

Permalink
chore: add configuration manager
Browse files Browse the repository at this point in the history
  • Loading branch information
Arnault Chazareix committed Jan 4, 2024
1 parent 3c78845 commit f771ca7
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 0 deletions.
40 changes: 40 additions & 0 deletions stqdm/configuration_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from abc import ABCMeta
from contextlib import contextmanager
from typing import Any, Generic, Iterator, Mapping, Optional, TypeVar, cast


class ScopeError(RuntimeError):
pass


ScopeConfig = TypeVar("ScopeConfig", bound=Mapping[str, Any])


class ScopeManager(Generic[ScopeConfig], metaclass=ABCMeta): # pylint: disable=invalid-name,inconsistent-mro
def __init__(self, stack: Optional[list[ScopeConfig]] = None) -> None:
if stack is None:
stack = []
self._scope_stack: list[ScopeConfig] = stack

def set_default_kwargs(self, scope_config: ScopeConfig) -> None:
if not self._scope_stack:
self._scope_stack.append(scope_config)
else:
self._scope_stack[0] = scope_config

@contextmanager
def scope(self, scope_config: ScopeConfig) -> Iterator[ScopeConfig]:
self._scope_stack.append(scope_config)
try:
yield scope_config
finally:
self._scope_stack.pop()

def get_current_defaults(self) -> ScopeConfig:
if self._scope_stack:
return self._scope_stack[-1]
raise ScopeError("No default scope set. You may have messed up with the scope stack. Please report this issue.")

def use_current_default_if_config_not_provided(self, config: ScopeConfig) -> ScopeConfig:
# This typing is probably wrong, but in our case, we will be using a TypedDict and it should be ok
return cast(ScopeConfig, {**self.get_current_defaults(), **config})
35 changes: 35 additions & 0 deletions tests/test_configuration_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from stqdm.configuration_manager import ScopeManager


def test_scope_manager__get_current_defaults():
scope_manager = ScopeManager([{}])
assert scope_manager.get_current_defaults() == {}
scope_manager = ScopeManager([{"foo": "bar"}])
assert scope_manager.get_current_defaults() == {"foo": "bar"}


def test_scope_manager__set_default_kwargs():
scope_manager = ScopeManager([{}])
assert scope_manager.get_current_defaults() == {}
scope_manager.set_default_kwargs({"foo": "bar"})
assert scope_manager.get_current_defaults() == {"foo": "bar"}
scope_manager.set_default_kwargs({"foo": "baz", "bar": "foo"})
assert scope_manager.get_current_defaults() == {"foo": "baz", "bar": "foo"}


def test_scope_manager__use_current_default_if_config_not_provided():
scope_manager = ScopeManager([{"foo": "bar"}])
assert scope_manager.use_current_default_if_config_not_provided({}) == {"foo": "bar"}
assert scope_manager.use_current_default_if_config_not_provided({"foo": "baz"}) == {"foo": "baz"}
assert scope_manager.use_current_default_if_config_not_provided({"bar": "foo"}) == {"foo": "bar", "bar": "foo"}


def test_scope_manager__scope():
scope_manager = ScopeManager([{}])
assert scope_manager.get_current_defaults() == {}
with scope_manager.scope({"foo": "bar"}):
assert scope_manager.get_current_defaults() == {"foo": "bar"}
with scope_manager.scope({"foo": "baz"}):
assert scope_manager.get_current_defaults() == {"foo": "baz"}
assert scope_manager.get_current_defaults() == {"foo": "bar"}
assert scope_manager.get_current_defaults() == {}

0 comments on commit f771ca7

Please sign in to comment.