Skip to content

Commit

Permalink
Merge pull request #19 from dPys/more-env-precautions
Browse files Browse the repository at this point in the history
More env precautions
  • Loading branch information
dPys authored Dec 16, 2024
2 parents 2b92f90 + 073952b commit 2e476e3
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions nxbench/benchmarks/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from nxbench.data.loader import BenchmarkDataManager
from nxbench.validation.registry import BenchmarkValidator

nx.config.warnings_to_ignore.add("cache")
warnings.filterwarnings("ignore")

logger = logging.getLogger("nxbench")
Expand Down Expand Up @@ -163,6 +164,13 @@ def prepare_benchmark(
f"{original_graph.number_of_edges()} edges"
)

# initially clear env completely
if hasattr(nx.config.backends, "parallel"):
if hasattr(nx.config.backends.parallel, "active"):
nx.config.backends.parallel.active = False
nx.config.backends.parallel.n_jobs = 1
os.environ["NX_CUGRAPH_AUTOCONFIG"] = "False"

for var_name in [
"NUM_THREAD",
"OMP_NUM_THREADS",
Expand All @@ -180,10 +188,13 @@ def prepare_benchmark(
except ImportError:
logger.exception("nx-parallel backend not available")
return None
nx.config.backends.parallel.active = True
nx.config.backends.parallel.n_jobs = num_thread
return nxp.ParallelGraph(original_graph)

if "cugraph" in backend and is_nx_cugraph_available():
try:
os.environ["NX_CUGRAPH_AUTOCONFIG"] = "True"
cugraph = import_module("nx_cugraph")
except ImportError:
logger.exception("cugraph backend not available")
Expand Down Expand Up @@ -290,10 +301,13 @@ def teardown_specific(self, backend: str, num_thread: int = 1):
nx.config.backends.parallel.active = False
nx.config.backends.parallel.n_jobs = 1

os.environ["NUM_THREAD"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
if "cugraph" in backend:
os.environ["NX_CUGRAPH_AUTOCONFIG"] = "False"

os.environ["NUM_THREAD"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"

def teardown(self):
"""ASV teardown method. Called after all benchmarks are run."""
Expand Down

0 comments on commit 2e476e3

Please sign in to comment.