Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Albrja/mic 5748/lbwsg #24

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/vivarium_gates_mncnh/components/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from vivarium_gates_mncnh.components.antenatal_care import AntenatalCare
from vivarium_gates_mncnh.components.lbwsg import LBWSGRiskEffect
from vivarium_gates_mncnh.components.maternal_disorders import MaternalDisorder
from vivarium_gates_mncnh.components.mortality import (
MaternalDisordersBurden,
Expand Down
30 changes: 30 additions & 0 deletions src/vivarium_gates_mncnh/components/lbwsg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from vivarium.framework.engine import Builder
from vivarium_public_health.risks.implementations.low_birth_weight_and_short_gestation import (
LBWSGRiskEffect as LBWSGRiskEffect_,
)
from vivarium_public_health.utilities import get_lookup_columns


class LBWSGRiskEffect(LBWSGRiskEffect_):
"""Subclass of LBWSGRiskEffect to expose the PAF pipeline to be accessable by other components."""

def setup(self, builder: Builder) -> None:
# Paf pipeline needs to be registered before the super setup is called
self.paf = builder.value.register_value_producer(
"paf",
source=self.lookup_tables["population_attributable_fraction"],
component=self,
required_resources=get_lookup_columns(
[self.lookup_tables["population_attributable_fraction"]]
),
)
super().setup(builder)

# NOTE: We will be manually handling the paf effect so the target_paf_pipeline
# has not been created and will throw a warning
def register_paf_modifier(self, builder: Builder) -> None:
builder.value.register_value_modifier(
self.target_paf_pipeline_name,
modifier=self.paf,
component=self,
)
13 changes: 10 additions & 3 deletions src/vivarium_gates_mncnh/components/mortality.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,13 +207,20 @@ def setup(self, builder: Builder) -> None:
self.location = get_location(builder)

self.all_cause_mortality_rate = builder.value.register_value_producer(
"all_cause_mortality_rate",
"all_causes.cause_specific_mortality_rate",
source=self.lookup_tables["all_cause_mortality_rate"],
component=self,
requires_columns=get_lookup_columns(
required_resources=get_lookup_columns(
[self.lookup_tables["all_cause_mortality_rate"]]
),
)
# Modify ACMR pipeline with CSMR for neonatal causes
self.death_in_age_group = builder.value.register_value_producer(
"death_in_age_group_probability",
source=self.all_cause_mortality_rate,
component=self,
required_resources=[self.all_cause_mortality_rate],
)

def on_initialize_simulants(self, pop_data: SimulantData) -> None:
pop_update = pd.DataFrame(
Expand All @@ -234,7 +241,7 @@ def on_time_step(self, event: Event) -> None:

pop = self.population_view.get(event.index)
alive_children = pop.loc[pop[COLUMNS.CHILD_ALIVE] == "alive"]
mortality_rates = self.all_cause_mortality_rate(alive_children.index)
mortality_rates = self.death_in_age_group(alive_children.index)
# Convert to rates to probability
if self._sim_step_name() == SIMULATION_EVENT_NAMES.EARLY_NEONATAL_MORTALITY:
duration = 7 / 365.0
Expand Down
17 changes: 0 additions & 17 deletions src/vivarium_gates_mncnh/data/extra_gbd.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,6 @@
from vivarium_gates_mncnh.data import utilities


@vi_utils.cache
def load_lbwsg_exposure(location: str):
entity = utilities.get_entity(data_keys.LBWSG.EXPOSURE)
location_id = utility_data.get_location_id(location)
data = vi_utils.get_draws(
gbd_id_type="rei_id",
gbd_id=entity.gbd_id,
source=gbd_constants.SOURCES.EXPOSURE,
location_id=location_id,
year_id=2021,
sex_id=gbd_constants.SEX.MALE + gbd_constants.SEX.FEMALE,
age_group_id=164, # Birth prevalence
release_id=gbd_constants.RELEASE_IDS.GBD_2021,
)
return data


@vi_utils.cache
def get_maternal_disorder_yld_rate(key: str, location: str) -> pd.DataFrame:
entity = utilities.get_entity(key)
Expand Down
24 changes: 1 addition & 23 deletions src/vivarium_gates_mncnh/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def get_data(
data_keys.PREGNANCY.RAW_INCIDENCE_RATE_ECTOPIC: load_raw_incidence_data,
data_keys.LBWSG.DISTRIBUTION: load_metadata,
data_keys.LBWSG.CATEGORIES: load_metadata,
data_keys.LBWSG.EXPOSURE: load_lbwsg_exposure,
data_keys.LBWSG.EXPOSURE: load_standard_data,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was producing an incorrect index and load_standard_data works now.

data_keys.LBWSG.RELATIVE_RISK: load_lbwsg_rr,
data_keys.LBWSG.RELATIVE_RISK_INTERPOLATOR: load_lbwsg_interpolated_rr,
data_keys.LBWSG.PAF: load_standard_data,
Expand Down Expand Up @@ -257,28 +257,6 @@ def load_scaling_factor(
return preg_inc


def load_lbwsg_exposure(
key: str, location: str, years: Optional[Union[int, str, list[int]]] = None
) -> pd.DataFrame:
entity = utilities.get_entity(data_keys.LBWSG.EXPOSURE)
data = extra_gbd.load_lbwsg_exposure(location)
# This category was a mistake in GBD 2019, so drop.
extra_residual_category = vi_globals.EXTRA_RESIDUAL_CATEGORY[entity.name]
data = data.loc[data["parameter"] != extra_residual_category]
idx_cols = ["location_id", "sex_id", "parameter"]
data = data.set_index(idx_cols)[vi_globals.DRAW_COLUMNS]

# Sometimes there are data values on the order of 10e-300 that cause
# floating point headaches, so clip everything to reasonable values
data = data.clip(lower=vi_globals.MINIMUM_EXPOSURE_VALUE)

# normalize so all categories sum to 1
total_exposure = data.groupby(["location_id", "sex_id"]).transform("sum")
data = (data / total_exposure).reset_index()
data = reshape_to_vivarium_format(data, location)
return data


def load_anc_proportion(
key: str, location: str, years: Optional[Union[int, str, list[int]]] = None
) -> pd.DataFrame:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ plugins:
builder_interface: vivarium_gates_mncnh.plugins.time.TimeInterface

components:
vivarium_public_health:
risks:
- LBWSGRisk()

vivarium_gates_mncnh:
components:
- AgelessPopulation("population.scaling_factor")
Expand All @@ -17,14 +21,15 @@ components:
- MaternalDisorder("maternal_sepsis_and_other_maternal_infections")
- MaternalDisordersBurden()
- NeonatalMortality()
- LBWSGRiskEffect('cause.all_causes.cause_specific_mortality_rate')
# Add model observers below here
- ANCObserver()
- MaternalDisordersBurdenObserver()

configuration:
input_data:
input_draw_number: 0
artifact_path: "/mnt/team/simulation_science/pub/models/vivarium_gates_mncnh/artifacts/mortality/ethiopia.hdf"
artifact_path: "/mnt/team/simulation_science/pub/models/vivarium_gates_mncnh/artifacts/neonatal/ethiopia.hdf"
interpolation:
order: 0
extrapolate: True
Expand Down
Loading