Skip to content

Commit

Permalink
resolves #639, credit due to Dong Liu (#722)
Browse files Browse the repository at this point in the history
* resolves #639, credit due to Dong Liu
* avoid importing *
* add optional tf pytest marker, option test_dragonnet
* lint
* pass on tensorflow ImportError
* update CONTRIBUTING, docs/installation.rst

---------

Co-authored-by: Roland Stevenson <[email protected]>
  • Loading branch information
ras44 and rolandrmgservices authored Jan 29, 2024
1 parent 9e1f892 commit 1d8e095
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 1 deletion.
7 changes: 7 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,19 @@ Before submitting a PR, make sure the change to pass all tests and test coverage
$ pytest -vs tests/ --cov causalml/
```

To run tests that require tensorflow (i.e. DragonNet), make sure tensorflow is installed and include the `--runtf` option with the `pytest` command. For example:

```bash
$ pytest --runtf -vs tests/test_dragonnet.py
```

You can also run tests via make:
```bash
$ make test
```



## Submission :tada:

In your PR, please include:
Expand Down
34 changes: 34 additions & 0 deletions causalml/inference/tf/dragonnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from tensorflow.keras.layers import Dense, Concatenate
from tensorflow.keras.optimizers import SGD, Adam
from tensorflow.keras.regularizers import l2
from tensorflow.keras.models import load_model

from causalml.inference.tf.utils import (
dragonnet_loss_binarycross,
Expand Down Expand Up @@ -290,3 +291,36 @@ def fit_predict(self, X, treatment, y, p=None, return_components=False):
"""
self.fit(X, treatment, y)
return self.predict_tau(X)

def save(self, h5_filepath):
"""
Save the dragonnet model as a H5 file.
Args:
h5_filepath (H5 file path): H5 file path
"""
self.dragonnet.save(h5_filepath)

def load(self, h5_filepath, ratio=1.0, dragonnet_loss=dragonnet_loss_binarycross):
"""
Load the dragonnet model from a H5 file.
Args:
h5_filepath (H5 file path): H5 file path
ratio (float): weight assigned to the targeted regularization loss component
dragonnet_loss (function): a loss function
"""
self.dragonnet = load_model(
h5_filepath,
custom_objects={
"EpsilonLayer": EpsilonLayer,
"dragonnet_loss_binarycross": dragonnet_loss_binarycross,
"tarreg_ATE_unbounded_domain_loss": make_tarreg_loss(
ratio=ratio, dragonnet_loss=dragonnet_loss
),
"regression_loss": regression_loss,
"binary_classification_loss": binary_classification_loss,
"treatment_accuracy": treatment_accuracy,
"track_epsilon": track_epsilon,
},
)
11 changes: 10 additions & 1 deletion docs/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ with ``tensorflow``:
pip install .[tf]
=======

Windows
-------
Expand All @@ -106,11 +107,19 @@ See content in https://github.com/uber/causalml/issues/678
Running Tests
-------------

Make sure pytest is installed before attempting to run tests.

Run all tests with:

.. code-block:: bash
pytest -vs tests/ --cov causalml/
Add ``--runtf`` to run optional tensorflow tests which will be skipped by default.

You can also run tests via make:

.. code-block:: bash
make test
18 changes: 18 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,21 @@ def _generate_data():
return data

yield _generate_data


def pytest_addoption(parser):
parser.addoption("--runtf", action="store_true", default=False, help="run tf tests")


def pytest_configure(config):
config.addinivalue_line("markers", "tf: mark test as tf to run")


def pytest_collection_modifyitems(config, items):
if config.getoption("--runtf"):
# --runtf given in cli: do not skip tf tests
return
skip_tf = pytest.mark.skip(reason="need --runtf option to run")
for item in items:
if "tf" in item.keywords:
item.add_marker(skip_tf)
23 changes: 23 additions & 0 deletions tests/test_dragonnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
try:
from causalml.inference.tf import DragonNet
except ImportError:
pass
from causalml.dataset.regression import simulate_nuisance_and_easy_treatment
import shutil
import pytest


@pytest.mark.tf
def test_save_load_dragonnet():
y, X, w, tau, b, e = simulate_nuisance_and_easy_treatment(n=1000)

dragon = DragonNet(neurons_per_layer=200, targeted_reg=True, verbose=False)
dragon_ite = dragon.fit_predict(X, w, y, return_components=False)
dragon_ate = dragon_ite.mean()
dragon.save("smaug")

smaug = DragonNet()
smaug.load("smaug")
shutil.rmtree("smaug")

assert smaug.predict_tau(X).mean() == dragon_ate

0 comments on commit 1d8e095

Please sign in to comment.