diff --git a/nxbench/data/loader.py b/nxbench/data/loader.py index 0b0349d..3bab221 100644 --- a/nxbench/data/loader.py +++ b/nxbench/data/loader.py @@ -1,4 +1,3 @@ -import importlib import importlib.resources as importlib_resources import logging import os @@ -14,6 +13,7 @@ from scipy.io import mmread from nxbench.benchmarks.config import DatasetConfig +from nxbench.data.synthesize import generate_graph warnings.filterwarnings("ignore") @@ -390,37 +390,25 @@ def _load_local_graph( def _generate_graph( self, config: DatasetConfig ) -> tuple[nx.Graph | nx.DiGraph, dict[str, Any]]: - """Generate a synthetic network using networkx generator functions.""" + """Generate a synthetic network using a generator function.""" generator_name = config.params.get("generator") if not generator_name: raise ValueError("Generator name must be specified in params.") - try: - module_path, func_name = generator_name.rsplit(".", 1) - module = importlib.import_module(module_path) - generator = getattr(module, func_name) - except Exception: - raise ValueError(f"Invalid generator {generator_name}") - gen_params = config.params.copy() gen_params.pop("generator", None) + directed = config.metadata.get("directed", False) + try: - graph = generator(**gen_params) + graph = generate_graph(generator_name, gen_params, directed) except Exception: - raise ValueError( - f"Failed to generate graph with {generator_name} and params " - f"{gen_params}" + logger.exception( + f"Failed to generate graph with generator '{generator_name}'" ) - - directed = config.metadata.get("directed", False) - if directed and not graph.is_directed(): - graph = graph.to_directed() - elif not directed and graph.is_directed(): - graph = graph.to_undirected() + raise graph.graph.update(config.metadata) - return graph, config.metadata def load_network_sync( diff --git a/nxbench/data/synthesize.py b/nxbench/data/synthesize.py new file mode 100644 index 0000000..92dc901 --- /dev/null +++ b/nxbench/data/synthesize.py @@ -0,0 +1,38 @@ +import importlib +import logging +import warnings + +import networkx as nx + +warnings.filterwarnings("ignore") + +logger = logging.getLogger("nxbench") + + +def generate_graph( + generator_name: str, gen_params: dict, directed: bool = False +) -> nx.Graph: + """Generate a synthetic network using networkx generator functions.""" + if not generator_name: + raise ValueError("Generator name must be specified.") + + try: + module_path, func_name = generator_name.rsplit(".", 1) + module = importlib.import_module(module_path) + generator = getattr(module, func_name) + except Exception as e: + raise ValueError(f"Invalid generator {generator_name}") from e + + try: + graph = generator(**gen_params) + except Exception as e: + raise ValueError( + f"Failed to generate graph with {generator_name} and params {gen_params}" + ) from e + + if directed and not graph.is_directed(): + graph = graph.to_directed() + elif not directed and graph.is_directed(): + graph = graph.to_undirected() + + return graph