Skip to content

Commit

Permalink
Refactor machine and state components
Browse files Browse the repository at this point in the history
  • Loading branch information
albrja committed Nov 14, 2024
1 parent 705bfaa commit f277568
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 63 deletions.
90 changes: 28 additions & 62 deletions src/vivarium_gates_mncnh/components/antenatal_care.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from vivarium.framework.engine import Builder
from vivarium.framework.event import Event
from vivarium.framework.population import SimulantData
from vivarium.framework.state_machine import Machine, State, Transient, Transition
from vivarium.framework.state_machine import Machine, TransientState
from vivarium.types import ClockTime

from vivarium_gates_mncnh.constants import data_keys
Expand All @@ -24,51 +24,17 @@


class TreeMachine(Machine):
@property
def columns_created(self) -> list[str]:
return [self.state_column]

@property
def columns_required(self) -> list[str] | None:
return None

def __init__(self, state_column: str, states: list[State], initial_state: State) -> None:
super().__init__(state_column, states)
self.initial_state = initial_state

def setup(self, builder: Builder) -> None:
super().setup(builder)
self._sim_step_name = builder.time.simulation_event_name()

def on_initialize_simulants(self, pop_data: SimulantData) -> None:
self.population_view.update(
pd.Series(
self.initial_state.state_id, index=pop_data.index, name=self.state_column
),
)

def on_time_step(self, event: Event) -> None:
if self._sim_step_name() == SIMULATION_EVENT_NAMES.PREGNANCY:
self.transition(event.index, event.time)


class DecisionTreeState(State):
def add_decision(
self,
output_state: State,
decision_function: Callable[[pd.Index], pd.Series] = lambda index: pd.Series(
1.0, index=index
),
) -> None:
transition = Transition(self, output_state, decision_function)
self.add_transition(transition)


class TransientDecisionTreeState(DecisionTreeState, Transient):
pass
super().on_time_step(event)


class ANCState(TransientDecisionTreeState):
class ANCState(TransientState):
def __init__(self) -> None:
super().__init__("attended_antental_care")

Expand All @@ -82,7 +48,7 @@ def transition_side_effect(self, index: pd.Index, _event_time: ClockTime) -> Non
self.population_view.update(pop)


class UltrasoundState(TransientDecisionTreeState):
class UltrasoundState(TransientState):
def __init__(self, ultrasound_type: str) -> None:
super().__init__(f"{ultrasound_type}_ultrasound")
self.ultrasound_type = ultrasound_type
Expand Down Expand Up @@ -212,53 +178,53 @@ def _determine_lbw_identification(self, pop: pd.DataFrame) -> pd.Series:
return identification

def create_anc_decision_tree(self) -> Machine:
initial_state = DecisionTreeState("initial")
initial_state = TransientState("initial")
attended_antental_care = ANCState()
gets_ultrasound = TransientDecisionTreeState("gets_ultrasound")
gets_ultrasound = TransientState("gets_ultrasound")
standard_ultasound = UltrasoundState(ULTRASOUND_TYPES.STANDARD)
ai_assisted_ultrasound = UltrasoundState(ULTRASOUND_TYPES.AI_ASSISTED)
end_state = DecisionTreeState("end")
end_state = TransientState("end")

# Decisions
initial_state.add_decision(
attended_antental_care,
# TODO: this data will need to be updated when the ANC artifact is finished
self.get_anc_attendance_rate,
initial_state.add_transition(
output_state=attended_antental_care,
probability_function=self.get_anc_attendance_rate,
)
initial_state.add_decision(
end_state, lambda index: 1 - self.get_anc_attendance_rate(index)
initial_state.add_transition(
output_state=end_state,
probability_function=lambda index: 1 - self.get_anc_attendance_rate(index)
)
attended_antental_care.add_decision(
gets_ultrasound,
lambda index: pd.Series(
attended_antental_care.add_transition(
output_state=gets_ultrasound,
probability_function=lambda index: pd.Series(
ANC_RATES.RECEIVED_ULTRASOUND[self.location], index=index
),
)
attended_antental_care.add_decision(
end_state,
lambda index: pd.Series(
attended_antental_care.add_transition(
output_state=end_state,
probability_function=lambda index: pd.Series(
1 - ANC_RATES.RECEIVED_ULTRASOUND[self.location],
index=index,
),
)
gets_ultrasound.add_decision(
standard_ultasound,
lambda index: pd.Series(
gets_ultrasound.add_transition(
output_state=standard_ultasound,
probability_function=lambda index: pd.Series(
ANC_RATES.ULTRASOUND_TYPE[self.location][ULTRASOUND_TYPES.STANDARD],
index=index,
),
)
gets_ultrasound.add_decision(
ai_assisted_ultrasound,
lambda index: pd.Series(
gets_ultrasound.add_transition(
output_state=ai_assisted_ultrasound,
probability_function=lambda index: pd.Series(
ANC_RATES.ULTRASOUND_TYPE[self.location][ULTRASOUND_TYPES.AI_ASSISTED],
index=index,
),
)
standard_ultasound.add_decision(end_state)
ai_assisted_ultrasound.add_decision(end_state)
standard_ultasound.add_transition(output_state=end_state)
ai_assisted_ultrasound.add_transition(output_state=end_state)

return TreeMachine(
return Machine(
"anc_state",
[
initial_state,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ components:
configuration:
input_data:
input_draw_number: 0
artifact_path: "/mnt/team/simulation_science/pub/models/vivarium_gates_mncnh/artifacts/anc/ethiopia.hdf"
# artifact_path: "/mnt/team/simulation_science/pub/models/vivarium_gates_mncnh/artifacts/anc/ethiopia.hdf"
artifact_path: "/home/albrja/scratch/artifacts/ethiopia.hdf"
interpolation:
order: 0
extrapolate: True
Expand Down

0 comments on commit f277568

Please sign in to comment.