From bc3d3376dffbf65a5f0fef7324289dc4aac19e75 Mon Sep 17 00:00:00 2001 From: Jim Albright <37345113+albrja@users.noreply.github.com> Date: Fri, 6 Dec 2024 14:27:50 -0800 Subject: [PATCH] Albrja/mic 5609/mortality component (#16) Albrja/mic 5609/mortality component Implement cause of death feature of Mortality component - *Category*: Feature - *JIRA issue*: https://jira.ihme.washington.edu/browse/MIC-5609 - *Research reference*: https://github.com/ihmeuw/vivarium_research/pull/1566/files Changes and notes -create custom mortality component -determine cause of death based on case fatality rates for maternal disorders for each simulant Verification and Testing Ran simulation and simulants died because of expected causes. --- README.rst | 2 +- python_versions.json | 2 +- .../components/__init__.py | 1 + .../components/antenatal_care.py | 6 +- .../components/mortality.py | 167 ++++++++++++++++++ .../constants/data_values.py | 5 +- src/vivarium_gates_mncnh/data/loader.py | 1 + .../model_specifications/model_spec.yaml | 2 + tests/test_mortality.py | 27 +++ 9 files changed, 206 insertions(+), 7 deletions(-) create mode 100644 src/vivarium_gates_mncnh/components/mortality.py create mode 100644 tests/test_mortality.py diff --git a/README.rst b/README.rst index 08a92a2..c23dbc8 100644 --- a/README.rst +++ b/README.rst @@ -27,7 +27,7 @@ all necessary requirements as follows:: (vivarium_gates_mncnh) :~$ pip install -e . ...pip will install vivarium and other requirements... -Supported Python versions: 3.9, 3.10, 3.11 +Supported Python versions: 3.10, 3.11 Note the ``-e`` flag that follows pip install. This will install the python package in-place, which is important for making the model specifications later. diff --git a/python_versions.json b/python_versions.json index a32f85f..8b914b0 100644 --- a/python_versions.json +++ b/python_versions.json @@ -1 +1 @@ -["3.9", "3.10", "3.11"] +["3.10", "3.11"] diff --git a/src/vivarium_gates_mncnh/components/__init__.py b/src/vivarium_gates_mncnh/components/__init__.py index 79cd6bd..19d91cd 100644 --- a/src/vivarium_gates_mncnh/components/__init__.py +++ b/src/vivarium_gates_mncnh/components/__init__.py @@ -1,5 +1,6 @@ from vivarium_gates_mncnh.components.antenatal_care import AntenatalCare from vivarium_gates_mncnh.components.maternal_disorders import MaternalDisorder +from vivarium_gates_mncnh.components.mortality import MortalityDueToMaternalDisorders from vivarium_gates_mncnh.components.observers import ( ANCObserver, BirthObserver, diff --git a/src/vivarium_gates_mncnh/components/antenatal_care.py b/src/vivarium_gates_mncnh/components/antenatal_care.py index a2e8966..8ebe050 100644 --- a/src/vivarium_gates_mncnh/components/antenatal_care.py +++ b/src/vivarium_gates_mncnh/components/antenatal_care.py @@ -11,6 +11,7 @@ from vivarium.types import ClockTime from vivarium_gates_mncnh.components.tree import DecisionTreeState, TreeMachine +from vivarium_gates_mncnh.constants.data_keys import ANC from vivarium_gates_mncnh.constants.data_values import ( ANC_RATES, COLUMNS, @@ -68,10 +69,7 @@ def setup(self, builder: Builder): self.location = get_location(builder) def build_all_lookup_tables(self, builder: Builder) -> None: - # TODO: update data key to constant - anc_attendance_probability = builder.data.load( - "covariate.antenatal_care_1_visit_coverage_proportion.estimate" - ) + anc_attendance_probability = builder.data.load(ANC.ESTIMATE) self.lookup_tables["anc_attendance_probability"] = self.build_lookup_table( builder=builder, data_source=anc_attendance_probability, diff --git a/src/vivarium_gates_mncnh/components/mortality.py b/src/vivarium_gates_mncnh/components/mortality.py new file mode 100644 index 0000000..d8d3bbf --- /dev/null +++ b/src/vivarium_gates_mncnh/components/mortality.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +from functools import partial +from typing import Any + +import pandas as pd +from vivarium import Component +from vivarium.framework.engine import Builder +from vivarium.framework.event import Event +from vivarium.framework.population import SimulantData + +from vivarium_gates_mncnh.constants.data_values import COLUMNS, SIMULATION_EVENT_NAMES +from vivarium_gates_mncnh.constants.metadata import ARTIFACT_INDEX_COLUMNS +from vivarium_gates_mncnh.utilities import get_location + + +class MortalityDueToMaternalDisorders(Component): + """A component to handle mortality caused by the modeled maternal disorders.""" + + ############## + # Properties # + ############## + + @property + def configuration_defaults(self) -> dict[str, Any]: + return { + self.name: { + "data_sources": { + **{ + "life_expectancy": "population.theoretical_minimum_risk_life_expectancy" + }, + **{ + f"{cause}_case_fatality_rate": partial( + self.load_cfr_data, key_name=cause + ) + for cause in self.maternal_disorders + }, + }, + }, + } + + @property + def columns_created(self) -> list[str]: + return [COLUMNS.CAUSE_OF_DEATH, COLUMNS.YEARS_OF_LIFE_LOST] + + @property + def columns_required(self) -> list[str]: + return [ + COLUMNS.ALIVE, + COLUMNS.EXIT_TIME, + COLUMNS.AGE, + COLUMNS.SEX, + ] + self.maternal_disorders + + ##################### + # Lifecycle methods # + ##################### + + def __init__(self) -> None: + super().__init__() + # TODO: update list of maternal disorders when implemented + self.maternal_disorders = [ + COLUMNS.OBSTRUCTED_LABOR, + COLUMNS.MATERNAL_HEMORRHAGE, + COLUMNS.MATERNAL_SEPSIS, + ] + + def setup(self, builder: Builder) -> None: + self._sim_step_name = builder.time.simulation_event_name() + self.randomness = builder.randomness.get_stream(self.name) + self.location = get_location(builder) + + ######################## + # Event-driven methods # + ######################## + + def on_initialize_simulants(self, pop_data: SimulantData) -> None: + pop_update = pd.DataFrame( + { + COLUMNS.CAUSE_OF_DEATH: "not_dead", + COLUMNS.YEARS_OF_LIFE_LOST: 0.0, + }, + index=pop_data.index, + ) + self.population_view.update(pop_update) + + def on_time_step(self, event) -> None: + if self._sim_step_name() != SIMULATION_EVENT_NAMES.MORTALITY: + return + + pop = self.population_view.get(event.index) + has_maternal_disorders = pop[self.maternal_disorders] + has_maternal_disorders = has_maternal_disorders.loc[ + has_maternal_disorders.any(axis=1) + ] + + # Get raw and conditional case fatality rates for each simulant + choice_data = has_maternal_disorders.copy() + choice_data = self.calculate_case_fatality_rates(choice_data) + + # Decide what simulants die from what maternal disorders + dead_idx = self.randomness.filter_for_probability( + choice_data.index, + choice_data["mortality_probability"], + "mortality_choice", + ) + + # Update metadata for simulants that died + if not dead_idx.empty: + pop.loc[dead_idx, COLUMNS.ALIVE] = "dead" + + # Get maternal disorders each simulant is affect by + cause_of_death = self.randomness.choice( + index=dead_idx, + choices=self.maternal_disorders, + p=choice_data.loc[ + dead_idx, + [f"{disorder}_proportional_cfr" for disorder in self.maternal_disorders], + ], + additional_key="cause_of_death", + ) + pop.loc[dead_idx, COLUMNS.CAUSE_OF_DEATH] = cause_of_death + pop.loc[dead_idx, COLUMNS.YEARS_OF_LIFE_LOST] = self.lookup_tables[ + "life_expectancy" + ](dead_idx) + + self.population_view.update(pop) + + ################## + # Helper methods # + ################## + + def load_cfr_data(self, builder: Builder, key_name: str) -> pd.DataFrame: + """Load case fatality rate data for maternal disorders.""" + maternal_disorder = key_name.split("_case_fatality_rate")[0] + incidence_rate = builder.data.load( + f"cause.{maternal_disorder}.incidence_rate" + ).set_index(ARTIFACT_INDEX_COLUMNS) + csmr = builder.data.load( + f"cause.{maternal_disorder}.cause_specific_mortality_rate" + ).set_index(ARTIFACT_INDEX_COLUMNS) + cfr = (csmr / incidence_rate).fillna(0).reset_index() + + return cfr + + def calculate_case_fatality_rates(self, simulants: pd.DataFrame) -> pd.DataFrame: + """Calculate the total and proportional case fatality rate for each simulant.""" + + # Simulants is a boolean dataframe of whether or not a simulant has each maternal disorder. + for disorder in self.maternal_disorders: + simulants[disorder] = simulants[disorder] * self.lookup_tables[ + f"{disorder}_case_fatality_rate" + ](simulants.index) + simulants["mortality_probability"] = simulants[self.maternal_disorders].sum(axis=1) + cfr_data = self.get_proportional_case_fatality_rates(simulants) + + return cfr_data + + def get_proportional_case_fatality_rates(self, simulants: pd.DataFrame) -> pd.DataFrame: + """Calculate the proportional case fatality rates for each maternal disorder.""" + + for disorder in self.maternal_disorders: + simulants[f"{disorder}_proportional_cfr"] = ( + simulants[disorder] / simulants["mortality_probability"] + ) + + return simulants diff --git a/src/vivarium_gates_mncnh/constants/data_values.py b/src/vivarium_gates_mncnh/constants/data_values.py index 4c9fe83..d486045 100644 --- a/src/vivarium_gates_mncnh/constants/data_values.py +++ b/src/vivarium_gates_mncnh/constants/data_values.py @@ -83,6 +83,7 @@ class _SimulationEventNames(NamedTuple): MATERNAL_HEMORRHAGE = "maternal_hemorrhage" OBSTRUCTED_LABOR = "maternal_obstructed_labor_and_uterine_rupture" NEONATAL = "neonatal" + MORTALITY = "mortality" SIMULATION_EVENT_NAMES = _SimulationEventNames() @@ -139,9 +140,12 @@ class __ANCRates(NamedTuple): class __Columns(NamedTuple): TRACKED = "tracked" + EXIT_TIME = "exit_time" SEX = "sex" ALIVE = "alive" AGE = "age" + CAUSE_OF_DEATH = "cause_of_death" + YEARS_OF_LIFE_LOST = "years_of_life_lost" LOCATION = "location" PREGNANCY_OUTCOME = "pregnancy_outcome" PREGNANCY_DURATION = "pregnancy_duration" @@ -154,7 +158,6 @@ class __Columns(NamedTuple): STATED_GESTATIONAL_AGE = "stated_gestational_age" SUCCESSFUL_LBW_IDENTIFICATION = "successful_lbw_identification" ANC_STATE = "anc_state" - SEPSIS_STATE = "sepsis_state" MATERNAL_SEPSIS = "maternal_sepsis_and_other_maternal_infections" MATERNAL_HEMORRHAGE = "maternal_hemorrhage" OBSTRUCTED_LABOR = "maternal_obstructed_labor_and_uterine_rupture" diff --git a/src/vivarium_gates_mncnh/data/loader.py b/src/vivarium_gates_mncnh/data/loader.py index 653146a..7c3b8fa 100644 --- a/src/vivarium_gates_mncnh/data/loader.py +++ b/src/vivarium_gates_mncnh/data/loader.py @@ -11,6 +11,7 @@ No logging is done here. Logging is done in vivarium inputs itself and forwarded. """ + from typing import List, Optional, Union import numpy as np diff --git a/src/vivarium_gates_mncnh/model_specifications/model_spec.yaml b/src/vivarium_gates_mncnh/model_specifications/model_spec.yaml index 08e228e..23ab3ab 100644 --- a/src/vivarium_gates_mncnh/model_specifications/model_spec.yaml +++ b/src/vivarium_gates_mncnh/model_specifications/model_spec.yaml @@ -15,6 +15,7 @@ components: - ResultsStratifier() - BirthObserver() - AntenatalCare() + - MortalityDueToMaternalDisorders() - MaternalDisorder("maternal_obstructed_labor_and_uterine_rupture") - MaternalDisorder("maternal_hemorrhage") - MaternalDisorder("maternal_sepsis_and_other_maternal_infections") @@ -43,6 +44,7 @@ configuration: - 'maternal_obstructed_labor_and_uterine_rupture' - 'maternal_hemorrhage' - 'maternal_sepsis_and_other_maternal_infections' + - 'mortality' - 'neonatal' population: diff --git a/tests/test_mortality.py b/tests/test_mortality.py new file mode 100644 index 0000000..96f9b95 --- /dev/null +++ b/tests/test_mortality.py @@ -0,0 +1,27 @@ +import pandas as pd +import pytest + +from vivarium_gates_mncnh.components.mortality import MortalityDueToMaternalDisorders +from vivarium_gates_mncnh.constants.data_values import COLUMNS + + +def test_get_proportional_case_fatality_rates(): + """Tests that proportional case fatality rates sum to 1.""" + + # Make case fatality data + simulant_idx = pd.Index(list(range(10))) + data_vals = [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5] + choice_data = pd.DataFrame(index=simulant_idx) + + mortality = MortalityDueToMaternalDisorders() + for disoder in mortality.maternal_disorders: + choice_data[disoder] = data_vals + # Get total case fatality rates + choice_data["mortality_probability"] = choice_data.sum(axis=1) + + proportional_cfr_data = mortality.get_proportional_case_fatality_rates(choice_data) + + proportional_cfr_cols = [ + col for col in proportional_cfr_data.columns if "proportional_cfr" in col + ] + assert proportional_cfr_data[proportional_cfr_cols].sum(axis=1).all() == 1.0