Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Albrja/mic 5371/maternal sepsis tree #11

Merged
merged 4 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/vivarium_gates_mncnh/components/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from vivarium_gates_mncnh.components.antenatal_care import AntenatalCare
from vivarium_gates_mncnh.components.maternal_sepsis import MaternalSepsis
from vivarium_gates_mncnh.components.observers import (
ANCObserver,
BirthObserver,
Expand Down
39 changes: 8 additions & 31 deletions src/vivarium_gates_mncnh/components/antenatal_care.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
from __future__ import annotations

from typing import Callable

import numpy as np
import pandas as pd
import scipy.stats as stats
from vivarium import Component
from vivarium.framework.engine import Builder
from vivarium.framework.event import Event
from vivarium.framework.population import SimulantData
from vivarium.framework.state_machine import Machine, State, TransientState
from vivarium.framework.state_machine import State, TransientState
from vivarium.types import ClockTime

from vivarium_gates_mncnh.constants import data_keys
from vivarium_gates_mncnh.components.tree import DecisionTreeState, TreeMachine
from vivarium_gates_mncnh.constants.data_values import (
ANC_RATES,
COLUMNS,
Expand All @@ -23,30 +21,6 @@
from vivarium_gates_mncnh.utilities import get_location


class TreeMachine(Machine):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These classes got moved to tree.py and were made more configurable for future decision trees.

def setup(self, builder: Builder) -> None:
super().setup(builder)
self._sim_step_name = builder.time.simulation_event_name()

def on_time_step(self, event: Event) -> None:
if self._sim_step_name() == SIMULATION_EVENT_NAMES.PREGNANCY:
super().on_time_step(event)


class ANCState(TransientState):
def __init__(self) -> None:
super().__init__("attended_antental_care")

@property
def columns_required(self) -> list[str]:
return [COLUMNS.ATTENDED_CARE_FACILITY]

def transition_side_effect(self, index: pd.Index, _event_time: ClockTime) -> None:
pop = self.population_view.get(index)
pop[COLUMNS.ATTENDED_CARE_FACILITY] = True
self.population_view.update(pop)


class UltrasoundState(TransientState):
def __init__(self, ultrasound_type: str) -> None:
super().__init__(f"{ultrasound_type}_ultrasound")
Expand Down Expand Up @@ -176,9 +150,11 @@ def _determine_lbw_identification(self, pop: pd.DataFrame) -> pd.Series:
identification[lbw_index] = draws < identification_rates
return identification

def create_anc_decision_tree(self) -> Machine:
def create_anc_decision_tree(self) -> TreeMachine:
initial_state = State("initial")
attended_antental_care = ANCState()
attended_antental_care = DecisionTreeState(
"attended_antental_care", COLUMNS.ATTENDED_CARE_FACILITY, True
)
gets_ultrasound = TransientState("gets_ultrasound")
standard_ultasound = UltrasoundState(ULTRASOUND_TYPES.STANDARD)
ai_assisted_ultrasound = UltrasoundState(ULTRASOUND_TYPES.AI_ASSISTED)
Expand Down Expand Up @@ -224,7 +200,7 @@ def create_anc_decision_tree(self) -> Machine:
ai_assisted_ultrasound.add_transition(output_state=end_state)

return TreeMachine(
"anc_state",
COLUMNS.ANC_STATE,
[
initial_state,
attended_antental_care,
Expand All @@ -234,4 +210,5 @@ def create_anc_decision_tree(self) -> Machine:
end_state,
],
initial_state,
time_step_name=SIMULATION_EVENT_NAMES.PREGNANCY,
)
79 changes: 79 additions & 0 deletions src/vivarium_gates_mncnh/components/maternal_sepsis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from __future__ import annotations

import pandas as pd
from vivarium import Component
from vivarium.framework.engine import Builder
from vivarium.framework.event import Event
from vivarium.framework.population import SimulantData

from vivarium_gates_mncnh.constants import data_keys
from vivarium_gates_mncnh.constants.data_values import (
COLUMNS,
PREGNANCY_OUTCOMES,
SIMULATION_EVENT_NAMES,
)
from vivarium_gates_mncnh.constants.metadata import ARTIFACT_INDEX_COLUMNS
from vivarium_gates_mncnh.utilities import get_location


class MaternalSepsis(Component):
@property
def configuration_defaults(self) -> dict:
return {self.name: {"data_sources": {"incidence_risk": self.load_incidence_risk}}}

@property
def columns_created(self):
return [COLUMNS.MATERNAL_SEPSIS]

@property
def columns_required(self):
return [COLUMNS.PREGNANCY_OUTCOME]

def setup(self, builder: Builder) -> None:
self._sim_step_name = builder.time.simulation_event_name()
self.randomness = builder.randomness.get_stream(self.name)
self.location = get_location(builder)

def on_initialize_simulants(self, pop_data: SimulantData) -> None:
anc_data = pd.DataFrame(
{
COLUMNS.MATERNAL_SEPSIS: False,
},
index=pop_data.index,
)

self.population_view.update(anc_data)

def on_time_step(self, event: Event) -> None:
if self._sim_step_name() != SIMULATION_EVENT_NAMES.MATERNAL_SEPSIS:
return

pop = self.population_view.get(event.index)
full_term = pop.loc[
pop[COLUMNS.PREGNANCY_OUTCOME].isin(
[PREGNANCY_OUTCOMES.STILLBIRTH_OUTCOME, PREGNANCY_OUTCOMES.LIVE_BIRTH_OUTCOME]
)
]
sepsis_risk = self.lookup_tables["incidence_risk"](full_term.index)
got_sepsis = self.randomness.filter_for_probability(
full_term.index,
sepsis_risk,
"got_sepsis_choice",
)
pop.loc[got_sepsis, COLUMNS.MATERNAL_SEPSIS] = True
self.population_view.update(pop)

def load_incidence_risk(self, builder: Builder) -> pd.DataFrame:
raw_incidence = builder.data.load(
data_keys.MATERNAL_SEPSIS.RAW_INCIDENCE_RATE
).set_index(ARTIFACT_INDEX_COLUMNS)
asfr = builder.data.load(data_keys.PREGNANCY.ASFR).set_index(ARTIFACT_INDEX_COLUMNS)
sbr = (
builder.data.load(data_keys.PREGNANCY.SBR)
.set_index("year_start")
.drop(columns=["year_end"])
.reindex(asfr.index, level="year_start")
)
birth_rate = (sbr + 1) * asfr
incidence_risk = (raw_incidence / birth_rate).fillna(0.0)
return incidence_risk.reset_index()
1 change: 0 additions & 1 deletion src/vivarium_gates_mncnh/components/pregnancy.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def setup(self, builder: Builder):
Interface to several simulation tools.
"""
self.time_step = builder.time.step_size()
self._sim_step_name = builder.time.simulation_event_name()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This just wasn't being used for the pregnancy component.

self.randomness = builder.randomness.get_stream(self.name)
self.birth_outcome_probabilities = builder.value.register_value_producer(
"birth_outcome_probabilities",
Expand Down
51 changes: 51 additions & 0 deletions src/vivarium_gates_mncnh/components/tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from __future__ import annotations

import pandas as pd
from vivarium.framework.engine import Builder
from vivarium.framework.event import Event
from vivarium.framework.state_machine import Machine, State, TransientState
from vivarium.types import ClockTime

from vivarium_gates_mncnh.constants.data_values import SIMULATION_EVENT_NAMES


class TreeMachine(Machine):
def __init__(
self,
state_column: str,
states: list[State],
initial_state=None,
time_step_name: str = "",
):
super().__init__(state_column, states, initial_state)
# Time step name where the simulants will go through the decision tree
self._time_step_trigger = time_step_name

def setup(self, builder: Builder) -> None:
super().setup(builder)
self._sim_step_name = builder.time.simulation_event_name()

def on_time_step(self, event: Event) -> None:
if self._sim_step_name() == self._time_step_trigger:
super().on_time_step(event)


class DecisionTreeState(TransientState):
def __init__(
self,
state_id: str,
update_col: str,
update_value: str | bool,
) -> None:
super().__init__(state_id)
self.update_column = update_col
self.update_value = update_value

@property
def columns_required(self) -> list[str]:
return [self.update_column]

def transition_side_effect(self, index: pd.Index, _event_time: ClockTime) -> None:
pop = self.population_view.get(index)
pop[self.update_column] = self.update_value
self.population_view.update(pop)
4 changes: 4 additions & 0 deletions src/vivarium_gates_mncnh/constants/data_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class _Durations(NamedTuple):
class _SimulationEventNames(NamedTuple):
PREGNANCY = "pregnancy"
INTRAPARTRUM = "intrapartum"
MATERNAL_SEPSIS = "maternal_sepsis"
NEONATAL = "neonatal"


Expand Down Expand Up @@ -149,6 +150,9 @@ class __Columns(NamedTuple):
ULTRASOUND_TYPE = "ultrasound_type"
STATED_GESTATIONAL_AGE = "stated_gestational_age"
SUCCESSFUL_LBW_IDENTIFICATION = "successful_lbw_identification"
ANC_STATE = "anc_state"
SEPSIS_STATE = "sepsis_state"
MATERNAL_SEPSIS = "maternal_sepsis"


COLUMNS = __Columns()
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ components:
- ResultsStratifier()
- BirthObserver()
- AntenatalCare()
- MaternalSepsis()
# Add model observers below here
- ANCObserver()

configuration:
input_data:
input_draw_number: 0
artifact_path: "/mnt/team/simulation_science/pub/models/vivarium_gates_mncnh/artifacts/anc/pakistan.hdf"
artifact_path: "/mnt/team/simulation_science/pub/models/vivarium_gates_mncnh/artifacts/sepis/ethiopia.hdf"
interpolation:
order: 0
extrapolate: True
Expand All @@ -35,7 +37,8 @@ configuration:
day: 1
simulation_events:
- 'pregnancy'
- 'intrapartum'
- 'intrapartum'
- "maternal_sepsis"
- 'neonatal'

population:
Expand Down
Loading