Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: use numpyro implementation of dismod at to derive a consistent remission rate #1

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,4 @@ dmypy.json

# Version file
src/*/_version.py
*~
6 changes: 5 additions & 1 deletion artifact_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,8 @@ vivarium_cluster_tools>=2.0.0
black==22.3.0
isort
jupyterlab
matplotlib
matplotlib
jax
numpyro
diffrax
interpax
Empty file removed isort
Empty file.
8 changes: 7 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,13 @@

setup_requires = ["setuptools_scm"]

data_requirements = ["vivarium_inputs[data]>=5.0.7"]
data_requirements = [
"vivarium_inputs[data]>=5.0.7",
"jax",
"numpyro",
"diffrax",
"interpax",
]
cluster_requirements = ["vivarium_cluster_tools>=2.0.3"]
test_requirements = ["pytest"]
lint_requirements = ["black", "isort"]
Expand Down
96 changes: 94 additions & 2 deletions src/vivarium_nih_moud/data/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
from pathlib import Path
from typing import Optional

import numpy as np
import pandas as pd
from loguru import logger
from vivarium.framework.artifact import Artifact, EntityKey

from vivarium_nih_moud.constants import data_keys
from vivarium_nih_moud.data import loader
from ..constants import data_keys
from ..data import dismod_at, loader


def open_artifact(output_path: Path, location: str) -> Artifact:
Expand Down Expand Up @@ -127,3 +128,94 @@ def write_data_by_draw(artifact: Artifact, key: str, data: pd.DataFrame):
data = data.reset_index(drop=True)
for c in data.columns:
store.put(f"{key.path}/{c}", data[c])


def generate_consistent_moud_rates(art: Artifact, location: str, years: Optional[str]):
"""Generates consistent rates for MOUD data.

Parameters
----------
art
The artifact to read from and write to.
location
The location associated with the data to load and the artifact to
write to.
years
The years to load data for.

"""
# TODO: check if the consistent rates are already in the artifact, and if so, skip rest of this function

# copy metadata
for key in [
"cause.opioid_use_disorders.restrictions",
"cause.opioid_use_disorders.disability_weight",
]:
data = art.load(key)
write_or_replace(art, key.replace("opioid_use_disorders", "oud_consistent"), data)

ages = np.arange(0, 96, 5)
years = np.array([2020, 2025])
sexes = ["Male", "Female"]
key = {
"i": "cause.opioid_use_disorders.incidence_rate",
"p": "cause.opioid_use_disorders.prevalence",
"f": "cause.opioid_use_disorders.excess_mortality_rate",
"m_all": "cause.all_causes.cause_specific_mortality_rate",
"csmr_with": "cause.opioid_use_disorders.cause_specific_mortality_rate",
"pop": "population.structure",
}

def oud_data(sex):
df_data = pd.concat(
[
dismod_at.transform_to_data("p", art.load(key["p"]), sex, ages, [2021]),
dismod_at.transform_to_data("i", art.load(key["i"]), sex, ages, [2021]),
dismod_at.transform_to_data("f", art.load(key["f"]), sex, ages, [2021]),
dismod_at.transform_to_data(
"m",
art.load(key["m_all"]) - art.load(key["csmr_with"]),
sex,
ages,
[2021],
),
]
)
return df_data

def get_rates(model_dict, rate_type, year):
df_out = []
for model in model_dict.values():
df_out.append(model.get_rate(rate_type, year))
df_out = pd.concat(df_out)
return df_out

# fit model separately for Male and Female
m = {}
for sex in sexes:
m[sex] = dismod_at.ConsistentModel(sex, ages, years)
m[sex].fit(oud_data(sex))

# store consistent rates in artifact
for rate_type in "ipfr":
# generate data for k
df_out = get_rates(m, rate_type, 2020)
# store generated data in artifact
if rate_type != "r":
rate_name = key[rate_type]
else:
rate_name = "cause.opioid_use_disorders.remission_rate"
rate_name = rate_name.replace("opioid_use_disorders", "oud_consistent")
write_or_replace(art, rate_name, df_out)

# then do cause specific mortality rate
df_out = get_rates(m, "p", 2020) * get_rates(m, "f", 2020)
rate_name = "cause.oud_consistent.cause_specific_mortality_rate"
write_or_replace(art, rate_name, df_out)


def write_or_replace(art, key, data):
if key in art.keys:
art.replace(key, data)
else:
art.write(key, data)
Loading
Loading