Skip to content

Commit

Permalink
feat: abstract out synthesize functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
dPys committed Dec 4, 2024
1 parent 7e4041d commit f2757b7
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 20 deletions.
28 changes: 8 additions & 20 deletions nxbench/data/loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import importlib
import importlib.resources as importlib_resources
import logging
import os
Expand All @@ -14,6 +13,7 @@
from scipy.io import mmread

from nxbench.benchmarks.config import DatasetConfig
from nxbench.data.synthesize import generate_graph

warnings.filterwarnings("ignore")

Expand Down Expand Up @@ -390,37 +390,25 @@ def _load_local_graph(
def _generate_graph(
self, config: DatasetConfig
) -> tuple[nx.Graph | nx.DiGraph, dict[str, Any]]:
"""Generate a synthetic network using networkx generator functions."""
"""Generate a synthetic network using a generator function."""
generator_name = config.params.get("generator")
if not generator_name:
raise ValueError("Generator name must be specified in params.")

try:
module_path, func_name = generator_name.rsplit(".", 1)
module = importlib.import_module(module_path)
generator = getattr(module, func_name)
except Exception:
raise ValueError(f"Invalid generator {generator_name}")

gen_params = config.params.copy()
gen_params.pop("generator", None)

directed = config.metadata.get("directed", False)

try:
graph = generator(**gen_params)
graph = generate_graph(generator_name, gen_params, directed)
except Exception:
raise ValueError(
f"Failed to generate graph with {generator_name} and params "
f"{gen_params}"
logger.exception(
f"Failed to generate graph with generator '{generator_name}'"
)

directed = config.metadata.get("directed", False)
if directed and not graph.is_directed():
graph = graph.to_directed()
elif not directed and graph.is_directed():
graph = graph.to_undirected()
raise

graph.graph.update(config.metadata)

return graph, config.metadata

def load_network_sync(
Expand Down
38 changes: 38 additions & 0 deletions nxbench/data/synthesize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import importlib
import logging
import warnings

import networkx as nx

warnings.filterwarnings("ignore")

logger = logging.getLogger("nxbench")


def generate_graph(
generator_name: str, gen_params: dict, directed: bool = False
) -> nx.Graph:
"""Generate a synthetic network using networkx generator functions."""
if not generator_name:
raise ValueError("Generator name must be specified.")

try:
module_path, func_name = generator_name.rsplit(".", 1)
module = importlib.import_module(module_path)
generator = getattr(module, func_name)
except Exception as e:
raise ValueError(f"Invalid generator {generator_name}") from e

try:
graph = generator(**gen_params)
except Exception as e:
raise ValueError(
f"Failed to generate graph with {generator_name} and params {gen_params}"
) from e

if directed and not graph.is_directed():
graph = graph.to_directed()
elif not directed and graph.is_directed():
graph = graph.to_undirected()

return graph

0 comments on commit f2757b7

Please sign in to comment.