Skip to content

Commit

Permalink
Albrja/mic-5609/ylds maternal disorders (#17)
Browse files Browse the repository at this point in the history
Albrja/mic-5609/ylds maternal disorders

Add ylds to mortality
- *Category*: Implementation
- *JIRA issue*: https://jira.ihme.washington.edu/browse/MIC-5609
- *Research reference*: <!--Link to research documentation for code -->

Changes and notes
-add yld_rate to artifact for each maternal disorder
-record yld_rate in state table

### Verification and Testing
<!--
Details on how code was verified. Consider: plots, images, (small) csv files.
-->
  • Loading branch information
albrja authored Dec 9, 2024
1 parent bc3d337 commit 64f9590
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 16 deletions.
2 changes: 1 addition & 1 deletion src/vivarium_gates_mncnh/components/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from vivarium_gates_mncnh.components.antenatal_care import AntenatalCare
from vivarium_gates_mncnh.components.maternal_disorders import MaternalDisorder
from vivarium_gates_mncnh.components.mortality import MortalityDueToMaternalDisorders
from vivarium_gates_mncnh.components.mortality import MaternalDisordersBurden
from vivarium_gates_mncnh.components.observers import (
ANCObserver,
BirthObserver,
Expand Down
38 changes: 27 additions & 11 deletions src/vivarium_gates_mncnh/components/mortality.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from vivarium_gates_mncnh.utilities import get_location


class MortalityDueToMaternalDisorders(Component):
"""A component to handle mortality caused by the modeled maternal disorders."""
class MaternalDisordersBurden(Component):
"""A component to handle morbidity and mortality caused by the modeled maternal disorders."""

##############
# Properties #
Expand All @@ -35,13 +35,19 @@ def configuration_defaults(self) -> dict[str, Any]:
)
for cause in self.maternal_disorders
},
**{
f"{cause}_yld_rate": f"cause.{cause}.yld_rate"
for cause in self.maternal_disorders
},
},
},
}

@property
def columns_created(self) -> list[str]:
return [COLUMNS.CAUSE_OF_DEATH, COLUMNS.YEARS_OF_LIFE_LOST]
return [COLUMNS.CAUSE_OF_DEATH, COLUMNS.YEARS_OF_LIFE_LOST] + [
f"{cause}_ylds" for cause in self.maternal_disorders
]

@property
def columns_required(self) -> list[str]:
Expand Down Expand Up @@ -77,8 +83,11 @@ def setup(self, builder: Builder) -> None:
def on_initialize_simulants(self, pop_data: SimulantData) -> None:
pop_update = pd.DataFrame(
{
COLUMNS.CAUSE_OF_DEATH: "not_dead",
COLUMNS.YEARS_OF_LIFE_LOST: 0.0,
**{
COLUMNS.CAUSE_OF_DEATH: "not_dead",
COLUMNS.YEARS_OF_LIFE_LOST: 0.0,
},
**{f"{cause}_ylds": 0.0 for cause in self.maternal_disorders},
},
index=pop_data.index,
)
Expand Down Expand Up @@ -124,6 +133,13 @@ def on_time_step(self, event) -> None:
"life_expectancy"
](dead_idx)

# Update YLDs for each maternal disorder
yld_idx = has_maternal_disorders.index.difference(dead_idx)
for cause in self.maternal_disorders:
pop.loc[yld_idx, f"{cause}_ylds"] = self.lookup_tables[f"{cause}_yld_rate"](
yld_idx
)

self.population_view.update(pop)

##################
Expand All @@ -147,9 +163,9 @@ def calculate_case_fatality_rates(self, simulants: pd.DataFrame) -> pd.DataFrame
"""Calculate the total and proportional case fatality rate for each simulant."""

# Simulants is a boolean dataframe of whether or not a simulant has each maternal disorder.
for disorder in self.maternal_disorders:
simulants[disorder] = simulants[disorder] * self.lookup_tables[
f"{disorder}_case_fatality_rate"
for cause in self.maternal_disorders:
simulants[cause] = simulants[cause] * self.lookup_tables[
f"{cause}_case_fatality_rate"
](simulants.index)
simulants["mortality_probability"] = simulants[self.maternal_disorders].sum(axis=1)
cfr_data = self.get_proportional_case_fatality_rates(simulants)
Expand All @@ -159,9 +175,9 @@ def calculate_case_fatality_rates(self, simulants: pd.DataFrame) -> pd.DataFrame
def get_proportional_case_fatality_rates(self, simulants: pd.DataFrame) -> pd.DataFrame:
"""Calculate the proportional case fatality rates for each maternal disorder."""

for disorder in self.maternal_disorders:
simulants[f"{disorder}_proportional_cfr"] = (
simulants[disorder] / simulants["mortality_probability"]
for cause in self.maternal_disorders:
simulants[f"{cause}_proportional_cfr"] = (
simulants[cause] / simulants["mortality_probability"]
)

return simulants
3 changes: 3 additions & 0 deletions src/vivarium_gates_mncnh/constants/data_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class __MaternalSepsis(NamedTuple):
CSMR: str = (
"cause.maternal_sepsis_and_other_maternal_infections.cause_specific_mortality_rate"
)
YLD_RATE: str = "cause.maternal_sepsis_and_other_maternal_infections.yld_rate"

@property
def name(self):
Expand All @@ -108,6 +109,7 @@ class __MaternalHemorrhage(NamedTuple):
# Keys that will be loaded into the artifact. must have a colon type declaration
RAW_INCIDENCE_RATE: str = "cause.maternal_hemorrhage.incidence_rate"
CSMR: str = "cause.maternal_hemorrhage.cause_specific_mortality_rate"
YLD_RATE: str = "cause.maternal_hemorrhage.yld_rate"

@property
def name(self):
Expand All @@ -129,6 +131,7 @@ class __ObstructedLabor(NamedTuple):
CSMR: str = (
"cause.maternal_obstructed_labor_and_uterine_rupture.cause_specific_mortality_rate"
)
YLD_RATE: str = "cause.maternal_obstructed_labor_and_uterine_rupture.yld_rate"

@property
def name(self):
Expand Down
19 changes: 19 additions & 0 deletions src/vivarium_gates_mncnh/data/extra_gbd.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pandas as pd
from vivarium_gbd_access import constants as gbd_constants
from vivarium_gbd_access import utilities as vi_utils
from vivarium_inputs import globals as vi_globals
from vivarium_inputs import utility_data

from vivarium_gates_mncnh.constants import data_keys
Expand All @@ -21,3 +23,20 @@ def load_lbwsg_exposure(location: str):
release_id=gbd_constants.RELEASE_IDS.GBD_2021,
)
return data


@vi_utils.cache
def get_maternal_disorder_yld_rate(key: str, location: str) -> pd.DataFrame:
entity = utilities.get_entity(key)
location_id = utility_data.get_location_id(location)
data = vi_utils.get_draws(
"cause_id",
entity.gbd_id,
source=gbd_constants.SOURCES.COMO,
location_id=location_id,
year_id=2021,
release_id=gbd_constants.RELEASE_IDS.GBD_2021,
measure_id=vi_globals.MEASURES["YLDs"],
metric_id=vi_globals.METRICS["Rate"],
)
return data
16 changes: 16 additions & 0 deletions src/vivarium_gates_mncnh/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,13 @@ def get_data(
data_keys.ANC.ESTIMATE: load_anc_proportion,
data_keys.MATERNAL_SEPSIS.RAW_INCIDENCE_RATE: load_standard_data,
data_keys.MATERNAL_SEPSIS.CSMR: load_standard_data,
data_keys.MATERNAL_SEPSIS.YLD_RATE: load_maternal_disorder_yld_rate,
data_keys.MATERNAL_HEMORRHAGE.RAW_INCIDENCE_RATE: load_standard_data,
data_keys.MATERNAL_HEMORRHAGE.CSMR: load_standard_data,
data_keys.MATERNAL_HEMORRHAGE.YLD_RATE: load_maternal_disorder_yld_rate,
data_keys.OBSTRUCTED_LABOR.RAW_INCIDENCE_RATE: load_standard_data,
data_keys.OBSTRUCTED_LABOR.CSMR: load_standard_data,
data_keys.OBSTRUCTED_LABOR.YLD_RATE: load_maternal_disorder_yld_rate,
}
return mapping[lookup_key](lookup_key, location, years)

Expand Down Expand Up @@ -295,6 +298,19 @@ def load_anc_proportion(
return anc_proportion_draws_df.set_index(["year_start", "year_end"])


def load_maternal_disorder_yld_rate(
key: str, location: str, years: Optional[Union[int, str, list[int]]] = None
) -> pd.DataFrame:

groupby_cols = ["age_group_id", "sex_id", "year_id"]
draw_cols = vi_globals.DRAW_COLUMNS
yld_rate = extra_gbd.get_maternal_disorder_yld_rate(key, location)
yld_rate = yld_rate[groupby_cols + draw_cols]
yld_rate = reshape_to_vivarium_format(yld_rate, location)

return yld_rate


def reshape_to_vivarium_format(df, location):
df = vi_utils.reshape(df, value_cols=vi_globals.DRAW_COLUMNS)
df = vi_utils.scrub_gbd_conventions(df, location)
Expand Down
4 changes: 2 additions & 2 deletions src/vivarium_gates_mncnh/model_specifications/model_spec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ components:
- ResultsStratifier()
- BirthObserver()
- AntenatalCare()
- MortalityDueToMaternalDisorders()
- MaternalDisordersBurden()
- MaternalDisorder("maternal_obstructed_labor_and_uterine_rupture")
- MaternalDisorder("maternal_hemorrhage")
- MaternalDisorder("maternal_sepsis_and_other_maternal_infections")
Expand All @@ -25,7 +25,7 @@ components:
configuration:
input_data:
input_draw_number: 0
artifact_path: "/mnt/team/simulation_science/pub/models/vivarium_gates_mncnh/artifacts/hemorrhage/ethiopia.hdf"
artifact_path: "/mnt/team/simulation_science/pub/models/vivarium_gates_mncnh/artifacts/mortality/ethiopia.hdf"
interpolation:
order: 0
extrapolate: True
Expand Down
4 changes: 2 additions & 2 deletions tests/test_mortality.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pandas as pd
import pytest

from vivarium_gates_mncnh.components.mortality import MortalityDueToMaternalDisorders
from vivarium_gates_mncnh.components.mortality import MaternalDisordersBurden
from vivarium_gates_mncnh.constants.data_values import COLUMNS


Expand All @@ -13,7 +13,7 @@ def test_get_proportional_case_fatality_rates():
data_vals = [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5]
choice_data = pd.DataFrame(index=simulant_idx)

mortality = MortalityDueToMaternalDisorders()
mortality = MaternalDisordersBurden()
for disoder in mortality.maternal_disorders:
choice_data[disoder] = data_vals
# Get total case fatality rates
Expand Down

0 comments on commit 64f9590

Please sign in to comment.