Skip to content

Commit

Permalink
Add RR and RR interpolation for LBWSG. Test load standard data for pafs
Browse files Browse the repository at this point in the history
  • Loading branch information
albrja committed Jan 3, 2025
1 parent 89887e6 commit a796f62
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/vivarium_gates_mncnh/constants/data_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ class __LowBirthWeightShortGestation(NamedTuple):
EXPOSURE: str = "risk_factor.low_birth_weight_and_short_gestation.exposure"
DISTRIBUTION: str = "risk_factor.low_birth_weight_and_short_gestation.distribution"
CATEGORIES: str = "risk_factor.low_birth_weight_and_short_gestation.categories"
RELATIVE_RISK: str = "risk_factor.low_birth_weight_and_short_gestation.relative_risk"
RELATIVE_RISK_INTERPOLATOR: str = (
"risk_factor.low_birth_weight_and_short_gestation.relative_risk_interpolator"
)
PAF: str = (
"risk_factor.low_birth_weight_and_short_gestation.population_attributable_fraction"
)

@property
def name(self):
Expand Down
74 changes: 74 additions & 0 deletions src/vivarium_gates_mncnh/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
No logging is done here. Logging is done in vivarium inputs itself and forwarded.
"""

import pickle
from typing import List, Optional, Union

import numpy as np
import pandas as pd
import vivarium_inputs.validation.sim as validation
from gbd_mapping import causes, covariates, risk_factors
from scipy.interpolate import RectBivariateSpline, griddata
from vivarium.framework.artifact import EntityKey
from vivarium_gbd_access import gbd
from vivarium_inputs import core as vi_core
Expand Down Expand Up @@ -65,6 +67,9 @@ def get_data(
data_keys.LBWSG.DISTRIBUTION: load_metadata,
data_keys.LBWSG.CATEGORIES: load_metadata,
data_keys.LBWSG.EXPOSURE: load_lbwsg_exposure,
data_keys.LBWSG.RELATIVE_RISK: load_lbwsg_rr,
data_keys.LBWSG.RELATIVE_RISK_INTERPOLATOR: load_lbwsg_interpolated_rr,
data_keys.LBWSG.PAF: load_standard_data,
data_keys.ANC.ESTIMATE: load_anc_proportion,
data_keys.MATERNAL_SEPSIS.RAW_INCIDENCE_RATE: load_standard_data,
data_keys.MATERNAL_SEPSIS.CSMR: load_standard_data,
Expand Down Expand Up @@ -314,6 +319,75 @@ def load_maternal_disorder_yld_rate(
return yld_rate


def load_lbwsg_rr(
key: str, location: str, years: Optional[Union[int, str, list[int]]] = None
) -> pd.DataFrame:
if key != data_keys.LBWSG.RELATIVE_RISK:
raise ValueError(f"Unrecognized key {key}")

data = load_standard_data(key, location, years)
data = data.query("year_start == 2021").droplevel(["affected_entity", "affected_measure"])
data = data[~data.index.duplicated()]
return data


def load_lbwsg_interpolated_rr(key: str, location: str) -> pd.DataFrame:
if key != data_keys.LBWSG.RELATIVE_RISK_INTERPOLATOR:
raise ValueError(f"Unrecognized key {key}")

rr = get_data(data_keys.LBWSG.RELATIVE_RISK, location).reset_index()
rr["parameter"] = pd.Categorical(
rr["parameter"], [f"cat{i}" for i in range(metadata.DRAW_COUNT)]
)
rr = (
rr.sort_values("parameter")
.set_index(metadata.ARTIFACT_INDEX_COLUMNS + ["parameter"])
.stack()
.unstack("parameter")
.apply(np.log)
)

# get category midpoints
def get_category_midpoints(lbwsg_type: str) -> pd.Series:
categories = get_data(f"risk_factor.{data_keys.LBWSG.name}.categories", location)
return utilities.get_intervals_from_categories(lbwsg_type, categories).apply(
lambda x: x.mid
)

gestational_age_midpoints = get_category_midpoints("short_gestation")
birth_weight_midpoints = get_category_midpoints("low_birth_weight")

# build grid of gestational age and birth weight
def get_grid(midpoints: pd.Series, endpoints: tuple[float, float]) -> np.array:
grid = np.append(np.unique(midpoints), endpoints)
grid.sort()
return grid

gestational_age_grid = get_grid(gestational_age_midpoints, (0.0, 42.0))
birth_weight_grid = get_grid(birth_weight_midpoints, (0.0, 4500.0))

def make_interpolator(log_rr_for_age_sex_draw: pd.Series) -> RectBivariateSpline:
# Use scipy.interpolate.griddata to extrapolate to grid using nearest neighbor interpolation
log_rr_grid_nearest = griddata(
(gestational_age_midpoints, birth_weight_midpoints),
log_rr_for_age_sex_draw,
(gestational_age_grid[:, None], birth_weight_grid[None, :]),
method="nearest",
rescale=True,
)
# return a RectBivariateSpline object from the extrapolated values on grid
return RectBivariateSpline(
gestational_age_grid, birth_weight_grid, log_rr_grid_nearest, kx=1, ky=1
)

log_rr_interpolator = (
rr.apply(make_interpolator, axis="columns")
.apply(lambda x: pickle.dumps(x).hex())
.unstack()
)
return log_rr_interpolator


def reshape_to_vivarium_format(df, location):
df = vi_utils.reshape(df, value_cols=vi_globals.DRAW_COLUMNS)
df = vi_utils.scrub_gbd_conventions(df, location)
Expand Down
51 changes: 51 additions & 0 deletions src/vivarium_gates_mncnh/data/utilities.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Union

import pandas as pd
from gbd_mapping import ModelableEntity, causes, covariates, risk_factors
from vivarium.framework.artifact import EntityKey
from vivarium_inputs.mapping_extension import alternative_risk_factors
Expand All @@ -15,3 +16,53 @@ def get_entity(key: Union[str, EntityKey]):
}
key = EntityKey(key)
return type_map[key.type][key.name]


def get_intervals_from_categories(lbwsg_type: str, categories: Dict[str, str]) -> pd.Series:
if lbwsg_type == "low_birth_weight":
category_endpoints = pd.Series(
{
cat: parse_low_birth_weight_description(description)
for cat, description in categories.items()
},
name=f"{lbwsg_type}.endpoints",
)
elif lbwsg_type == "short_gestation":
category_endpoints = pd.Series(
{
cat: parse_short_gestation_description(description)
for cat, description in categories.items()
},
name=f"{lbwsg_type}.endpoints",
)
else:
raise ValueError(
f"Unrecognized risk type {lbwsg_type}. Risk type must be low_birth_weight or short_gestation"
)
category_endpoints.index.name = "parameter"

return category_endpoints


def parse_low_birth_weight_description(description: str) -> pd.Interval:
# descriptions look like this: 'Birth prevalence - [34, 36) wks, [2000, 2500) g'

endpoints = pd.Interval(
*[
float(val)
for val in description.split(", [")[1].split(")")[0].split("]")[0].split(", ")
]
)
return endpoints


def parse_short_gestation_description(description: str) -> pd.Interval:
# descriptions look like this: 'Birth prevalence - [34, 36) wks, [2000, 2500) g'

endpoints = pd.Interval(
*[
float(val)
for val in description.split("- [")[1].split(")")[0].split("+")[0].split(", ")
]
)
return endpoints

0 comments on commit a796f62

Please sign in to comment.