diff --git a/src/vivarium_gates_mncnh/components/__init__.py b/src/vivarium_gates_mncnh/components/__init__.py index 19d91cd..de897da 100644 --- a/src/vivarium_gates_mncnh/components/__init__.py +++ b/src/vivarium_gates_mncnh/components/__init__.py @@ -1,6 +1,6 @@ from vivarium_gates_mncnh.components.antenatal_care import AntenatalCare from vivarium_gates_mncnh.components.maternal_disorders import MaternalDisorder -from vivarium_gates_mncnh.components.mortality import MortalityDueToMaternalDisorders +from vivarium_gates_mncnh.components.mortality import MaternalDisordersBurden from vivarium_gates_mncnh.components.observers import ( ANCObserver, BirthObserver, diff --git a/src/vivarium_gates_mncnh/components/mortality.py b/src/vivarium_gates_mncnh/components/mortality.py index d8d3bbf..8d09068 100644 --- a/src/vivarium_gates_mncnh/components/mortality.py +++ b/src/vivarium_gates_mncnh/components/mortality.py @@ -14,8 +14,8 @@ from vivarium_gates_mncnh.utilities import get_location -class MortalityDueToMaternalDisorders(Component): - """A component to handle mortality caused by the modeled maternal disorders.""" +class MaternalDisordersBurden(Component): + """A component to handle morbidity and mortality caused by the modeled maternal disorders.""" ############## # Properties # @@ -35,13 +35,19 @@ def configuration_defaults(self) -> dict[str, Any]: ) for cause in self.maternal_disorders }, + **{ + f"{cause}_yld_rate": f"cause.{cause}.yld_rate" + for cause in self.maternal_disorders + }, }, }, } @property def columns_created(self) -> list[str]: - return [COLUMNS.CAUSE_OF_DEATH, COLUMNS.YEARS_OF_LIFE_LOST] + return [COLUMNS.CAUSE_OF_DEATH, COLUMNS.YEARS_OF_LIFE_LOST] + [ + f"{cause}_ylds" for cause in self.maternal_disorders + ] @property def columns_required(self) -> list[str]: @@ -77,8 +83,11 @@ def setup(self, builder: Builder) -> None: def on_initialize_simulants(self, pop_data: SimulantData) -> None: pop_update = pd.DataFrame( { - COLUMNS.CAUSE_OF_DEATH: "not_dead", - COLUMNS.YEARS_OF_LIFE_LOST: 0.0, + **{ + COLUMNS.CAUSE_OF_DEATH: "not_dead", + COLUMNS.YEARS_OF_LIFE_LOST: 0.0, + }, + **{f"{cause}_ylds": 0.0 for cause in self.maternal_disorders}, }, index=pop_data.index, ) @@ -124,6 +133,13 @@ def on_time_step(self, event) -> None: "life_expectancy" ](dead_idx) + # Update YLDs for each maternal disorder + yld_idx = has_maternal_disorders.index.difference(dead_idx) + for cause in self.maternal_disorders: + pop.loc[yld_idx, f"{cause}_ylds"] = self.lookup_tables[f"{cause}_yld_rate"]( + yld_idx + ) + self.population_view.update(pop) ################## @@ -147,9 +163,9 @@ def calculate_case_fatality_rates(self, simulants: pd.DataFrame) -> pd.DataFrame """Calculate the total and proportional case fatality rate for each simulant.""" # Simulants is a boolean dataframe of whether or not a simulant has each maternal disorder. - for disorder in self.maternal_disorders: - simulants[disorder] = simulants[disorder] * self.lookup_tables[ - f"{disorder}_case_fatality_rate" + for cause in self.maternal_disorders: + simulants[cause] = simulants[cause] * self.lookup_tables[ + f"{cause}_case_fatality_rate" ](simulants.index) simulants["mortality_probability"] = simulants[self.maternal_disorders].sum(axis=1) cfr_data = self.get_proportional_case_fatality_rates(simulants) @@ -159,9 +175,9 @@ def calculate_case_fatality_rates(self, simulants: pd.DataFrame) -> pd.DataFrame def get_proportional_case_fatality_rates(self, simulants: pd.DataFrame) -> pd.DataFrame: """Calculate the proportional case fatality rates for each maternal disorder.""" - for disorder in self.maternal_disorders: - simulants[f"{disorder}_proportional_cfr"] = ( - simulants[disorder] / simulants["mortality_probability"] + for cause in self.maternal_disorders: + simulants[f"{cause}_proportional_cfr"] = ( + simulants[cause] / simulants["mortality_probability"] ) return simulants diff --git a/src/vivarium_gates_mncnh/constants/data_keys.py b/src/vivarium_gates_mncnh/constants/data_keys.py index e25b809..cf9f5af 100644 --- a/src/vivarium_gates_mncnh/constants/data_keys.py +++ b/src/vivarium_gates_mncnh/constants/data_keys.py @@ -91,6 +91,7 @@ class __MaternalSepsis(NamedTuple): CSMR: str = ( "cause.maternal_sepsis_and_other_maternal_infections.cause_specific_mortality_rate" ) + YLD_RATE: str = "cause.maternal_sepsis_and_other_maternal_infections.yld_rate" @property def name(self): @@ -108,6 +109,7 @@ class __MaternalHemorrhage(NamedTuple): # Keys that will be loaded into the artifact. must have a colon type declaration RAW_INCIDENCE_RATE: str = "cause.maternal_hemorrhage.incidence_rate" CSMR: str = "cause.maternal_hemorrhage.cause_specific_mortality_rate" + YLD_RATE: str = "cause.maternal_hemorrhage.yld_rate" @property def name(self): @@ -129,6 +131,7 @@ class __ObstructedLabor(NamedTuple): CSMR: str = ( "cause.maternal_obstructed_labor_and_uterine_rupture.cause_specific_mortality_rate" ) + YLD_RATE: str = "cause.maternal_obstructed_labor_and_uterine_rupture.yld_rate" @property def name(self): diff --git a/src/vivarium_gates_mncnh/data/extra_gbd.py b/src/vivarium_gates_mncnh/data/extra_gbd.py index ac98cc7..a194383 100644 --- a/src/vivarium_gates_mncnh/data/extra_gbd.py +++ b/src/vivarium_gates_mncnh/data/extra_gbd.py @@ -1,5 +1,7 @@ +import pandas as pd from vivarium_gbd_access import constants as gbd_constants from vivarium_gbd_access import utilities as vi_utils +from vivarium_inputs import globals as vi_globals from vivarium_inputs import utility_data from vivarium_gates_mncnh.constants import data_keys @@ -21,3 +23,20 @@ def load_lbwsg_exposure(location: str): release_id=gbd_constants.RELEASE_IDS.GBD_2021, ) return data + + +@vi_utils.cache +def get_maternal_disorder_yld_rate(key: str, location: str) -> pd.DataFrame: + entity = utilities.get_entity(key) + location_id = utility_data.get_location_id(location) + data = vi_utils.get_draws( + "cause_id", + entity.gbd_id, + source=gbd_constants.SOURCES.COMO, + location_id=location_id, + year_id=2021, + release_id=gbd_constants.RELEASE_IDS.GBD_2021, + measure_id=vi_globals.MEASURES["YLDs"], + metric_id=vi_globals.METRICS["Rate"], + ) + return data diff --git a/src/vivarium_gates_mncnh/data/loader.py b/src/vivarium_gates_mncnh/data/loader.py index 7c3b8fa..fbb761d 100644 --- a/src/vivarium_gates_mncnh/data/loader.py +++ b/src/vivarium_gates_mncnh/data/loader.py @@ -68,10 +68,13 @@ def get_data( data_keys.ANC.ESTIMATE: load_anc_proportion, data_keys.MATERNAL_SEPSIS.RAW_INCIDENCE_RATE: load_standard_data, data_keys.MATERNAL_SEPSIS.CSMR: load_standard_data, + data_keys.MATERNAL_SEPSIS.YLD_RATE: load_maternal_disorder_yld_rate, data_keys.MATERNAL_HEMORRHAGE.RAW_INCIDENCE_RATE: load_standard_data, data_keys.MATERNAL_HEMORRHAGE.CSMR: load_standard_data, + data_keys.MATERNAL_HEMORRHAGE.YLD_RATE: load_maternal_disorder_yld_rate, data_keys.OBSTRUCTED_LABOR.RAW_INCIDENCE_RATE: load_standard_data, data_keys.OBSTRUCTED_LABOR.CSMR: load_standard_data, + data_keys.OBSTRUCTED_LABOR.YLD_RATE: load_maternal_disorder_yld_rate, } return mapping[lookup_key](lookup_key, location, years) @@ -295,6 +298,19 @@ def load_anc_proportion( return anc_proportion_draws_df.set_index(["year_start", "year_end"]) +def load_maternal_disorder_yld_rate( + key: str, location: str, years: Optional[Union[int, str, list[int]]] = None +) -> pd.DataFrame: + + groupby_cols = ["age_group_id", "sex_id", "year_id"] + draw_cols = vi_globals.DRAW_COLUMNS + yld_rate = extra_gbd.get_maternal_disorder_yld_rate(key, location) + yld_rate = yld_rate[groupby_cols + draw_cols] + yld_rate = reshape_to_vivarium_format(yld_rate, location) + + return yld_rate + + def reshape_to_vivarium_format(df, location): df = vi_utils.reshape(df, value_cols=vi_globals.DRAW_COLUMNS) df = vi_utils.scrub_gbd_conventions(df, location) diff --git a/src/vivarium_gates_mncnh/model_specifications/model_spec.yaml b/src/vivarium_gates_mncnh/model_specifications/model_spec.yaml index 23ab3ab..ab2d2f2 100644 --- a/src/vivarium_gates_mncnh/model_specifications/model_spec.yaml +++ b/src/vivarium_gates_mncnh/model_specifications/model_spec.yaml @@ -15,7 +15,7 @@ components: - ResultsStratifier() - BirthObserver() - AntenatalCare() - - MortalityDueToMaternalDisorders() + - MaternalDisordersBurden() - MaternalDisorder("maternal_obstructed_labor_and_uterine_rupture") - MaternalDisorder("maternal_hemorrhage") - MaternalDisorder("maternal_sepsis_and_other_maternal_infections") @@ -25,7 +25,7 @@ components: configuration: input_data: input_draw_number: 0 - artifact_path: "/mnt/team/simulation_science/pub/models/vivarium_gates_mncnh/artifacts/hemorrhage/ethiopia.hdf" + artifact_path: "/mnt/team/simulation_science/pub/models/vivarium_gates_mncnh/artifacts/mortality/ethiopia.hdf" interpolation: order: 0 extrapolate: True diff --git a/tests/test_mortality.py b/tests/test_mortality.py index 96f9b95..d363458 100644 --- a/tests/test_mortality.py +++ b/tests/test_mortality.py @@ -1,7 +1,7 @@ import pandas as pd import pytest -from vivarium_gates_mncnh.components.mortality import MortalityDueToMaternalDisorders +from vivarium_gates_mncnh.components.mortality import MaternalDisordersBurden from vivarium_gates_mncnh.constants.data_values import COLUMNS @@ -13,7 +13,7 @@ def test_get_proportional_case_fatality_rates(): data_vals = [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5] choice_data = pd.DataFrame(index=simulant_idx) - mortality = MortalityDueToMaternalDisorders() + mortality = MaternalDisordersBurden() for disoder in mortality.maternal_disorders: choice_data[disoder] = data_vals # Get total case fatality rates