Skip to content

Commit

Permalink
Make torch optional. Update DragonNet w/ latest TF APIs (#790)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeongyoonlee authored Sep 14, 2024
1 parent 129b5d9 commit e1c6c31
Show file tree
Hide file tree
Showing 11 changed files with 66 additions and 34 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ _build/
.coverage*
*.html
*.prof
.venv/
12 changes: 10 additions & 2 deletions causalml/inference/tf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,11 @@ class EpsilonLayer(Layer):
Custom keras layer to allow epsilon to be learned during training process.
"""

def __init__(self):
def __init__(self, **kwargs):
"""
Inherits keras' Layer object.
"""
super(EpsilonLayer, self).__init__()
super(EpsilonLayer, self).__init__(**kwargs)

def build(self, input_shape):
"""
Expand All @@ -162,3 +162,11 @@ def build(self, input_shape):

def call(self, inputs, **kwargs):
return self.epsilon * tf.ones_like(inputs)[:, 0:1]

def get_config(self):
config = super().get_config()
return config

@classmethod
def from_config(cls, config):
return cls(**config)
File renamed without changes.
File renamed without changes.
10 changes: 5 additions & 5 deletions docs/examples/cevae_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
"import logging\n",
"\n",
"from causalml.inference.meta import BaseXRegressor, BaseRRegressor, BaseSRegressor, BaseTRegressor\n",
"from causalml.inference.nn import CEVAE\n",
"from causalml.inference.torch import CEVAE\n",
"from causalml.propensity import ElasticNetPropensityModel\n",
"from causalml.metrics import *\n",
"from causalml.dataset import simulate_hidden_confounder\n",
Expand Down Expand Up @@ -342,8 +342,8 @@
"treatment = df['treatment'].values\n",
"y = df['y_factual'].values\n",
"y_cf = df['y_cfactual'].values\n",
"tau = df.apply(lambda d: d['y_factual'] - d['y_cfactual'] if d['treatment']==1 \n",
" else d['y_cfactual'] - d['y_factual'], \n",
"tau = df.apply(lambda d: d['y_factual'] - d['y_cfactual'] if d['treatment']==1\n",
" else d['y_cfactual'] - d['y_factual'],\n",
" axis=1)\n",
"mu_0 = df['mu0'].values\n",
"mu_1 = df['mu1'].values"
Expand Down Expand Up @@ -5045,7 +5045,7 @@
" preds_dict_train['X Learner (LR)'].ravel(),\n",
" preds_dict_train['X Learner (XGB)'].ravel(),\n",
" preds_dict_train['R Learner (LR)'].ravel(),\n",
" preds_dict_train['R Learner (XGB)'].ravel(), \n",
" preds_dict_train['R Learner (XGB)'].ravel(),\n",
" preds_dict_train['CEVAE'].ravel(),\n",
" preds_dict_train['generated_data']['tau'].ravel(),\n",
" preds_dict_train['generated_data']['w'].ravel(),\n",
Expand Down Expand Up @@ -5077,7 +5077,7 @@
" preds_dict_valid['X Learner (LR)'].ravel(),\n",
" preds_dict_valid['X Learner (XGB)'].ravel(),\n",
" preds_dict_valid['R Learner (LR)'].ravel(),\n",
" preds_dict_valid['R Learner (XGB)'].ravel(), \n",
" preds_dict_valid['R Learner (XGB)'].ravel(),\n",
" preds_dict_valid['CEVAE'].ravel(),\n",
" preds_dict_valid['generated_data']['tau'].ravel(),\n",
" preds_dict_valid['generated_data']['w'].ravel(),\n",
Expand Down
30 changes: 21 additions & 9 deletions docs/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
Installation
============

Installation with ``conda`` or ``pip`` is recommended. Developers can follow the **Install from source** instructions below. If building from source, consider doing so within a conda environment and then exporting the environment for reproducibility.
Installation with ``conda`` or ``pip`` is recommended. Developers can follow the **Install from source** instructions below. If building from source, consider doing so within a conda environment and then exporting the environment for reproducibility.

To use models under the ``inference.tf`` module (e.g. ``DragonNet``), additional dependency of ``tensorflow`` is required. For detailed instructions, see below.
To use models under the ``inference.tf`` or ``inference.torch`` module (e.g. ``DragonNet`` or ``CEVAE``), additional dependency of ``tensorflow`` or ``torch`` is required. For detailed instructions, see below.

Install using ``conda``
-----------------------
Expand All @@ -13,7 +13,7 @@ Install ``conda``
^^^^^^^^^^^^^^^^^

.. code-block:: bash
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
bash Miniconda3-latest-Linux-x86_64.sh -b
source miniconda3/bin/activate
Expand All @@ -36,14 +36,21 @@ Install from ``PyPI``
pip install causalml
Install ``causalml`` with ``tensorflow`` from ``PyPI``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Install ``causalml`` with ``tensorflow`` for ``DragonNet`` from ``PyPI``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. code-block:: bash
pip install causalml[tf]
pip install -U numpy # this step is necessary to fix [#338](https://github.com/uber/causalml/issues/338)
Install ``causalml`` with ``torch`` for ``CEVAE`` from ``PyPI``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. code-block:: bash
pip install causalml[torch]
Install from source
-------------------
Expand All @@ -68,12 +75,18 @@ Then:
pip install .
python setup.py build_ext --inplace
with ``tensorflow``:
with ``tensorflow`` for ``DragonNet``:

.. code-block:: bash
pip install .[tf]
with ``torch`` for ``CEVAE``:

.. code-block:: bash
pip install .[torch]
=======

Windows
Expand All @@ -93,11 +106,10 @@ Run all tests with:
pytest -vs tests/ --cov causalml/
Add ``--runtf`` to run optional tensorflow tests which will be skipped by default.
Add ``--runtf`` and/or ``--runtorch`` to run optional tensorflow/torch tests which will be skipped by default.

You can also run tests via make:

.. code-block:: bash
make test
3 changes: 1 addition & 2 deletions docs/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ Interpretation
Please see :ref:`Interpretable Causal ML` section

Validation
----------------
----------

Please see :ref:`validation` section

Expand Down Expand Up @@ -306,4 +306,3 @@ For more details, please refer to the `feature_selection.ipynb notebook <https:/
treatment_group = 'treatment1',
n_bins=10)
print(kl_imp)
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ dependencies = [
"lightgbm",
"pygam",
"packaging",
"torch",
"pyro-ppl",
"graphviz",
]

Expand All @@ -55,7 +53,10 @@ test = [
tf = [
"tensorflow>=2.4.0"
]

torch = [
"torch",
"pyro-ppl"
]

[build-system]
requires = [
Expand Down
18 changes: 12 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,23 @@ def _generate_data():

def pytest_addoption(parser):
parser.addoption("--runtf", action="store_true", default=False, help="run tf tests")
parser.addoption(
"--runtorch", action="store_true", default=False, help="run torch tests"
)


def pytest_configure(config):
config.addinivalue_line("markers", "tf: mark test as tf to run")
config.addinivalue_line("markers", "torch: mark test as torch to run")


def pytest_collection_modifyitems(config, items):
if config.getoption("--runtf"):
# --runtf given in cli: do not skip tf tests
return
skip_tf = pytest.mark.skip(reason="need --runtf option to run")

skip_tf = False if config.getoption("--runtf") else True
skip_torch = False if config.getoption("--runtorch") else True

for item in items:
if "tf" in item.keywords:
item.add_marker(skip_tf)
if "tf" in item.keywords and skip_tf:
item.add_marker(pytest.mark.skip(reason="need --runtf option to run"))
if "torch" in item.keywords and skip_torch:
item.add_marker(pytest.mark.skip(reason="need --runtorch option to run"))
9 changes: 7 additions & 2 deletions tests/test_cevae.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import pandas as pd
import torch
import pytest

from causalml.inference.nn import CEVAE
try:
import torch
from causalml.inference.torch import CEVAE
except ImportError:
pass
from causalml.dataset import simulate_hidden_confounder
from causalml.metrics import get_cumgain


@pytest.mark.torch
def test_CEVAE():
y, X, treatment, tau, b, e = simulate_hidden_confounder(
n=10000, p=5, sigma=1.0, adj=0.0
Expand Down
10 changes: 5 additions & 5 deletions tests/test_dragonnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,21 @@
except ImportError:
pass
from causalml.dataset.regression import simulate_nuisance_and_easy_treatment
import shutil
import pytest


@pytest.mark.tf
def test_save_load_dragonnet():
def test_save_load_dragonnet(tmp_path):
y, X, w, tau, b, e = simulate_nuisance_and_easy_treatment(n=1000)

dragon = DragonNet(neurons_per_layer=200, targeted_reg=True, verbose=False)
dragon_ite = dragon.fit_predict(X, w, y, return_components=False)
dragon_ate = dragon_ite.mean()
dragon.save("smaug")

model_file = tmp_path / "smaug.h5"
dragon.save(model_file)

smaug = DragonNet()
smaug.load("smaug")
shutil.rmtree("smaug")
smaug.load(model_file)

assert smaug.predict_tau(X).mean() == dragon_ate

0 comments on commit e1c6c31

Please sign in to comment.