Skip to content

Commit

Permalink
Albrja/mic 5609/mortality component (#16)
Browse files Browse the repository at this point in the history
Albrja/mic 5609/mortality component

Implement cause of death feature of Mortality component
- *Category*: Feature
- *JIRA issue*: https://jira.ihme.washington.edu/browse/MIC-5609
- *Research reference*: https://github.com/ihmeuw/vivarium_research/pull/1566/files

Changes and notes
-create custom mortality component
-determine cause of death based on case fatality rates for maternal disorders for each simulant

Verification and Testing
Ran simulation and simulants died because of expected causes.
  • Loading branch information
albrja authored Dec 6, 2024
1 parent 7aa907e commit bc3d337
Show file tree
Hide file tree
Showing 9 changed files with 206 additions and 7 deletions.
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ all necessary requirements as follows::
(vivarium_gates_mncnh) :~$ pip install -e .
...pip will install vivarium and other requirements...

Supported Python versions: 3.9, 3.10, 3.11
Supported Python versions: 3.10, 3.11

Note the ``-e`` flag that follows pip install. This will install the python
package in-place, which is important for making the model specifications later.
Expand Down
2 changes: 1 addition & 1 deletion python_versions.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
["3.9", "3.10", "3.11"]
["3.10", "3.11"]
1 change: 1 addition & 0 deletions src/vivarium_gates_mncnh/components/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from vivarium_gates_mncnh.components.antenatal_care import AntenatalCare
from vivarium_gates_mncnh.components.maternal_disorders import MaternalDisorder
from vivarium_gates_mncnh.components.mortality import MortalityDueToMaternalDisorders
from vivarium_gates_mncnh.components.observers import (
ANCObserver,
BirthObserver,
Expand Down
6 changes: 2 additions & 4 deletions src/vivarium_gates_mncnh/components/antenatal_care.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from vivarium.types import ClockTime

from vivarium_gates_mncnh.components.tree import DecisionTreeState, TreeMachine
from vivarium_gates_mncnh.constants.data_keys import ANC
from vivarium_gates_mncnh.constants.data_values import (
ANC_RATES,
COLUMNS,
Expand Down Expand Up @@ -68,10 +69,7 @@ def setup(self, builder: Builder):
self.location = get_location(builder)

def build_all_lookup_tables(self, builder: Builder) -> None:
# TODO: update data key to constant
anc_attendance_probability = builder.data.load(
"covariate.antenatal_care_1_visit_coverage_proportion.estimate"
)
anc_attendance_probability = builder.data.load(ANC.ESTIMATE)
self.lookup_tables["anc_attendance_probability"] = self.build_lookup_table(
builder=builder,
data_source=anc_attendance_probability,
Expand Down
167 changes: 167 additions & 0 deletions src/vivarium_gates_mncnh/components/mortality.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
from __future__ import annotations

from functools import partial
from typing import Any

import pandas as pd
from vivarium import Component
from vivarium.framework.engine import Builder
from vivarium.framework.event import Event
from vivarium.framework.population import SimulantData

from vivarium_gates_mncnh.constants.data_values import COLUMNS, SIMULATION_EVENT_NAMES
from vivarium_gates_mncnh.constants.metadata import ARTIFACT_INDEX_COLUMNS
from vivarium_gates_mncnh.utilities import get_location


class MortalityDueToMaternalDisorders(Component):
"""A component to handle mortality caused by the modeled maternal disorders."""

##############
# Properties #
##############

@property
def configuration_defaults(self) -> dict[str, Any]:
return {
self.name: {
"data_sources": {
**{
"life_expectancy": "population.theoretical_minimum_risk_life_expectancy"
},
**{
f"{cause}_case_fatality_rate": partial(
self.load_cfr_data, key_name=cause
)
for cause in self.maternal_disorders
},
},
},
}

@property
def columns_created(self) -> list[str]:
return [COLUMNS.CAUSE_OF_DEATH, COLUMNS.YEARS_OF_LIFE_LOST]

@property
def columns_required(self) -> list[str]:
return [
COLUMNS.ALIVE,
COLUMNS.EXIT_TIME,
COLUMNS.AGE,
COLUMNS.SEX,
] + self.maternal_disorders

#####################
# Lifecycle methods #
#####################

def __init__(self) -> None:
super().__init__()
# TODO: update list of maternal disorders when implemented
self.maternal_disorders = [
COLUMNS.OBSTRUCTED_LABOR,
COLUMNS.MATERNAL_HEMORRHAGE,
COLUMNS.MATERNAL_SEPSIS,
]

def setup(self, builder: Builder) -> None:
self._sim_step_name = builder.time.simulation_event_name()
self.randomness = builder.randomness.get_stream(self.name)
self.location = get_location(builder)

########################
# Event-driven methods #
########################

def on_initialize_simulants(self, pop_data: SimulantData) -> None:
pop_update = pd.DataFrame(
{
COLUMNS.CAUSE_OF_DEATH: "not_dead",
COLUMNS.YEARS_OF_LIFE_LOST: 0.0,
},
index=pop_data.index,
)
self.population_view.update(pop_update)

def on_time_step(self, event) -> None:
if self._sim_step_name() != SIMULATION_EVENT_NAMES.MORTALITY:
return

pop = self.population_view.get(event.index)
has_maternal_disorders = pop[self.maternal_disorders]
has_maternal_disorders = has_maternal_disorders.loc[
has_maternal_disorders.any(axis=1)
]

# Get raw and conditional case fatality rates for each simulant
choice_data = has_maternal_disorders.copy()
choice_data = self.calculate_case_fatality_rates(choice_data)

# Decide what simulants die from what maternal disorders
dead_idx = self.randomness.filter_for_probability(
choice_data.index,
choice_data["mortality_probability"],
"mortality_choice",
)

# Update metadata for simulants that died
if not dead_idx.empty:
pop.loc[dead_idx, COLUMNS.ALIVE] = "dead"

# Get maternal disorders each simulant is affect by
cause_of_death = self.randomness.choice(
index=dead_idx,
choices=self.maternal_disorders,
p=choice_data.loc[
dead_idx,
[f"{disorder}_proportional_cfr" for disorder in self.maternal_disorders],
],
additional_key="cause_of_death",
)
pop.loc[dead_idx, COLUMNS.CAUSE_OF_DEATH] = cause_of_death
pop.loc[dead_idx, COLUMNS.YEARS_OF_LIFE_LOST] = self.lookup_tables[
"life_expectancy"
](dead_idx)

self.population_view.update(pop)

##################
# Helper methods #
##################

def load_cfr_data(self, builder: Builder, key_name: str) -> pd.DataFrame:
"""Load case fatality rate data for maternal disorders."""
maternal_disorder = key_name.split("_case_fatality_rate")[0]
incidence_rate = builder.data.load(
f"cause.{maternal_disorder}.incidence_rate"
).set_index(ARTIFACT_INDEX_COLUMNS)
csmr = builder.data.load(
f"cause.{maternal_disorder}.cause_specific_mortality_rate"
).set_index(ARTIFACT_INDEX_COLUMNS)
cfr = (csmr / incidence_rate).fillna(0).reset_index()

return cfr

def calculate_case_fatality_rates(self, simulants: pd.DataFrame) -> pd.DataFrame:
"""Calculate the total and proportional case fatality rate for each simulant."""

# Simulants is a boolean dataframe of whether or not a simulant has each maternal disorder.
for disorder in self.maternal_disorders:
simulants[disorder] = simulants[disorder] * self.lookup_tables[
f"{disorder}_case_fatality_rate"
](simulants.index)
simulants["mortality_probability"] = simulants[self.maternal_disorders].sum(axis=1)
cfr_data = self.get_proportional_case_fatality_rates(simulants)

return cfr_data

def get_proportional_case_fatality_rates(self, simulants: pd.DataFrame) -> pd.DataFrame:
"""Calculate the proportional case fatality rates for each maternal disorder."""

for disorder in self.maternal_disorders:
simulants[f"{disorder}_proportional_cfr"] = (
simulants[disorder] / simulants["mortality_probability"]
)

return simulants
5 changes: 4 additions & 1 deletion src/vivarium_gates_mncnh/constants/data_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class _SimulationEventNames(NamedTuple):
MATERNAL_HEMORRHAGE = "maternal_hemorrhage"
OBSTRUCTED_LABOR = "maternal_obstructed_labor_and_uterine_rupture"
NEONATAL = "neonatal"
MORTALITY = "mortality"


SIMULATION_EVENT_NAMES = _SimulationEventNames()
Expand Down Expand Up @@ -139,9 +140,12 @@ class __ANCRates(NamedTuple):

class __Columns(NamedTuple):
TRACKED = "tracked"
EXIT_TIME = "exit_time"
SEX = "sex"
ALIVE = "alive"
AGE = "age"
CAUSE_OF_DEATH = "cause_of_death"
YEARS_OF_LIFE_LOST = "years_of_life_lost"
LOCATION = "location"
PREGNANCY_OUTCOME = "pregnancy_outcome"
PREGNANCY_DURATION = "pregnancy_duration"
Expand All @@ -154,7 +158,6 @@ class __Columns(NamedTuple):
STATED_GESTATIONAL_AGE = "stated_gestational_age"
SUCCESSFUL_LBW_IDENTIFICATION = "successful_lbw_identification"
ANC_STATE = "anc_state"
SEPSIS_STATE = "sepsis_state"
MATERNAL_SEPSIS = "maternal_sepsis_and_other_maternal_infections"
MATERNAL_HEMORRHAGE = "maternal_hemorrhage"
OBSTRUCTED_LABOR = "maternal_obstructed_labor_and_uterine_rupture"
Expand Down
1 change: 1 addition & 0 deletions src/vivarium_gates_mncnh/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
No logging is done here. Logging is done in vivarium inputs itself and forwarded.
"""

from typing import List, Optional, Union

import numpy as np
Expand Down
2 changes: 2 additions & 0 deletions src/vivarium_gates_mncnh/model_specifications/model_spec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ components:
- ResultsStratifier()
- BirthObserver()
- AntenatalCare()
- MortalityDueToMaternalDisorders()
- MaternalDisorder("maternal_obstructed_labor_and_uterine_rupture")
- MaternalDisorder("maternal_hemorrhage")
- MaternalDisorder("maternal_sepsis_and_other_maternal_infections")
Expand Down Expand Up @@ -43,6 +44,7 @@ configuration:
- 'maternal_obstructed_labor_and_uterine_rupture'
- 'maternal_hemorrhage'
- 'maternal_sepsis_and_other_maternal_infections'
- 'mortality'
- 'neonatal'

population:
Expand Down
27 changes: 27 additions & 0 deletions tests/test_mortality.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pandas as pd
import pytest

from vivarium_gates_mncnh.components.mortality import MortalityDueToMaternalDisorders
from vivarium_gates_mncnh.constants.data_values import COLUMNS


def test_get_proportional_case_fatality_rates():
"""Tests that proportional case fatality rates sum to 1."""

# Make case fatality data
simulant_idx = pd.Index(list(range(10)))
data_vals = [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5]
choice_data = pd.DataFrame(index=simulant_idx)

mortality = MortalityDueToMaternalDisorders()
for disoder in mortality.maternal_disorders:
choice_data[disoder] = data_vals
# Get total case fatality rates
choice_data["mortality_probability"] = choice_data.sum(axis=1)

proportional_cfr_data = mortality.get_proportional_case_fatality_rates(choice_data)

proportional_cfr_cols = [
col for col in proportional_cfr_data.columns if "proportional_cfr" in col
]
assert proportional_cfr_data[proportional_cfr_cols].sum(axis=1).all() == 1.0

0 comments on commit bc3d337

Please sign in to comment.