diff --git a/tf_quant_finance/models/generic_ito_process_test.py b/tf_quant_finance/models/generic_ito_process_test.py index db058632a..82ba92ce1 100644 --- a/tf_quant_finance/models/generic_ito_process_test.py +++ b/tf_quant_finance/models/generic_ito_process_test.py @@ -13,6 +13,7 @@ # limitations under the License. """Tests for `sample_paths` of `ItoProcess`.""" +import os from unittest import mock # pylint: disable=g-importing-member from absl.testing import parameterized import numpy as np @@ -748,4 +749,5 @@ def _gaussian(xs, variance): if __name__ == '__main__': + os.environ['TF_XLA_FLAGS'] = '--tf_mlir_enable_mlir_bridge=true' tf.test.main()