Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expand unit testing #2

Merged
merged 2 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions nxbench/benchmarks/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import os
from unittest.mock import patch

import pytest

from nxbench.benchmarks.utils import (
get_available_backends,
get_benchmark_config,
get_python_version,
is_cugraph_available,
is_graphblas_available,
is_nx_parallel_available,
)


@pytest.fixture(autouse=True)
def _reset_benchmark_config():
import nxbench.benchmarks.utils

original_config = nxbench.benchmarks.utils._BENCHMARK_CONFIG
nxbench.benchmarks.utils._BENCHMARK_CONFIG = None
yield
nxbench.benchmarks.utils._BENCHMARK_CONFIG = original_config


def test_backend_availability():
"""Test backend availability detection."""
with patch("importlib.util.find_spec") as mock_find_spec:
# test when backends are available
mock_find_spec.return_value = True
assert is_cugraph_available() is True
assert is_graphblas_available() is True
assert is_nx_parallel_available() is True

# test when backends are not available
mock_find_spec.return_value = None
assert is_cugraph_available() is False
assert is_graphblas_available() is False
assert is_nx_parallel_available() is False


def test_get_available_backends():
"""Test getting list of available backends."""
with (
patch("nxbench.benchmarks.utils.is_cugraph_available") as mock_cugraph,
patch("nxbench.benchmarks.utils.is_graphblas_available") as mock_graphblas,
patch("nxbench.benchmarks.utils.is_nx_parallel_available") as mock_parallel,
):

mock_cugraph.return_value = True
mock_graphblas.return_value = True
mock_parallel.return_value = True

backends = get_available_backends()
assert "networkx" in backends
assert "cugraph" in backends
assert "graphblas" in backends
assert "parallel" in backends

mock_cugraph.return_value = False
mock_graphblas.return_value = False
mock_parallel.return_value = False

backends = get_available_backends()
assert backends == ["networkx"]


def test_get_python_version():
"""Test Python version string formatting."""
version = get_python_version()
assert len(version.split(".")) == 3
for part in version.split("."):
assert part.isdigit()


def test_configure_benchmarks_env_vars():
"""Test configuration with environment variables."""
with patch.dict(os.environ, {"NXBENCH_CONFIG_FILE": "test_config.yaml"}):
with patch("pathlib.Path.exists") as mock_exists:
mock_exists.return_value = False
with pytest.raises(FileNotFoundError):
get_benchmark_config()


def test_memory_tracking():
"""Test memory usage tracking context manager."""
from nxbench.benchmarks.utils import memory_tracker

def allocate_memory():
return [0] * 1000000

with memory_tracker() as mem:
data = allocate_memory()
# ensure data exists to prevent premature garbage collection
assert len(data) == 1000000

assert "current" in mem
assert "peak" in mem
assert isinstance(mem["current"], int)
assert isinstance(mem["peak"], int)
assert mem["peak"] >= mem["current"]
53 changes: 25 additions & 28 deletions nxbench/data/loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import importlib
import importlib.resources as importlib_resources
import logging
import os
Expand All @@ -14,6 +13,7 @@
from scipy.io import mmread

from nxbench.benchmarks.config import DatasetConfig
from nxbench.data.synthesize import generate_graph

warnings.filterwarnings("ignore")

Expand Down Expand Up @@ -59,7 +59,7 @@ def get_metadata(self, name: str) -> dict[str, Any]:
return network.iloc[0].to_dict()

async def load_network(
self, config: DatasetConfig
self, config: DatasetConfig, session: aiohttp.ClientSession | None = None
) -> tuple[nx.Graph | nx.DiGraph, dict[str, Any]]:
"""Load or generate a network based on config."""
source_lower = config.source.lower()
Expand Down Expand Up @@ -89,7 +89,7 @@ async def load_network(

source_lower = config.source.lower()
if source_lower == "networkrepository":
graph, metadata = await self._load_nr_graph(config.name, metadata)
graph, metadata = await self._load_nr_graph(config.name, metadata, session)
elif source_lower == "local":
graph, metadata = self._load_local_graph(config)
elif source_lower == "generator":
Expand Down Expand Up @@ -241,7 +241,10 @@ def edge_parser():
return graph

async def _load_nr_graph(
self, name: str, metadata: dict[str, Any]
self,
name: str,
metadata: dict[str, Any],
session: aiohttp.ClientSession | None = None,
) -> nx.Graph | nx.DiGraph:
for ext in self.SUPPORTED_FORMATS:
graph_file = self.data_dir / f"{name}{ext}"
Expand All @@ -256,7 +259,7 @@ async def _load_nr_graph(
f"Network '{name}' not found in local cache. Attempting to download from "
f"repository."
)
await self._download_and_extract_network(name, url)
await self._download_and_extract_network(name, url, session)

for ext in self.SUPPORTED_FORMATS:
graph_file = self.data_dir / f"{name}{ext}"
Expand All @@ -269,13 +272,15 @@ async def _load_nr_graph(
f"download was successful and the graph file exists."
)

async def _download_and_extract_network(self, name: str, url: str):
async def _download_and_extract_network(
self, name: str, url: str, session: aiohttp.ClientSession | None = None
):
zip_path = self.data_dir / f"{name}.zip"
extracted_folder = self.data_dir / f"{name}_extracted"

if not zip_path.exists():
logger.info(f"Downloading network '{name}' from {url}")
await self._download_file(url, zip_path)
await self._download_file(url, zip_path, session)
logger.info(f"Downloaded network '{name}' to {zip_path}")

if not extracted_folder.exists():
Expand Down Expand Up @@ -309,8 +314,12 @@ async def _download_and_extract_network(self, name: str, url: str):
)
raise

async def _download_file(self, url: str, dest: Path):
async with aiohttp.ClientSession() as session:
async def _download_file(
self, url: str, dest: Path, session: aiohttp.ClientSession | None = None
):
if session is None:
session = aiohttp.ClientSession()
async with session:
async with session.get(url) as response:
if response.status != 200:
logger.error(
Expand Down Expand Up @@ -381,37 +390,25 @@ def _load_local_graph(
def _generate_graph(
self, config: DatasetConfig
) -> tuple[nx.Graph | nx.DiGraph, dict[str, Any]]:
"""Generate a synthetic network using networkx generator functions."""
"""Generate a synthetic network using a generator function."""
generator_name = config.params.get("generator")
if not generator_name:
raise ValueError("Generator name must be specified in params.")

try:
module_path, func_name = generator_name.rsplit(".", 1)
module = importlib.import_module(module_path)
generator = getattr(module, func_name)
except Exception:
raise ValueError(f"Invalid generator {generator_name}")

gen_params = config.params.copy()
gen_params.pop("generator", None)

directed = config.metadata.get("directed", False)

try:
graph = generator(**gen_params)
graph = generate_graph(generator_name, gen_params, directed)
except Exception:
raise ValueError(
f"Failed to generate graph with {generator_name} and params "
f"{gen_params}"
logger.exception(
f"Failed to generate graph with generator '{generator_name}'"
)

directed = config.metadata.get("directed", False)
if directed and not graph.is_directed():
graph = graph.to_directed()
elif not directed and graph.is_directed():
graph = graph.to_undirected()
raise

graph.graph.update(config.metadata)

return graph, config.metadata

def load_network_sync(
Expand Down
38 changes: 38 additions & 0 deletions nxbench/data/synthesize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import importlib
import logging
import warnings

import networkx as nx

warnings.filterwarnings("ignore")

logger = logging.getLogger("nxbench")


def generate_graph(
generator_name: str, gen_params: dict, directed: bool = False
) -> nx.Graph:
"""Generate a synthetic network using networkx generator functions."""
if not generator_name:
raise ValueError("Generator name must be specified.")

try:
module_path, func_name = generator_name.rsplit(".", 1)
module = importlib.import_module(module_path)
generator = getattr(module, func_name)
except Exception as e:
raise ValueError(f"Invalid generator {generator_name}") from e

try:
graph = generator(**gen_params)
except Exception as e:
raise ValueError(
f"Failed to generate graph with {generator_name} and params {gen_params}"
) from e

if directed and not graph.is_directed():
graph = graph.to_directed()
elif not directed and graph.is_directed():
graph = graph.to_undirected()

return graph
103 changes: 103 additions & 0 deletions nxbench/data/tests/test_loader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import warnings
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch

import aiohttp
import networkx as nx
import pytest

Expand Down Expand Up @@ -419,3 +421,104 @@ async def test_load_edge_list_with_extra_columns(data_manager, create_edge_file)
assert data["weight"] == expected_weights.get(
(u, v), expected_weights.get((v, u))
), f"Incorrect weight for edge ({u}, {v})"


def test_load_metadata(data_manager):
"""Test that metadata is loaded correctly."""
metadata = data_manager._metadata_df
assert not metadata.empty, "Metadata DataFrame should not be empty"
expected_names = [
"jazz",
"08blocks",
"patentcite",
"imdb",
"citeseer",
"mixed_delimiters",
"invalid_weights",
"self_loops_duplicates",
"non_sequential_ids",
"example",
"extra_columns",
"twitter",
"invalid_example",
]
assert set(metadata["name"]) == set(
expected_names
), "Metadata names do not match expected names"


def test_get_metadata(data_manager):
"""Test retrieving metadata for a network."""
metadata = data_manager.get_metadata("jazz")
assert metadata["name"] == "jazz", "Metadata 'name' should be 'jazz'"
assert metadata["directed"] is False, "Metadata 'directed' should be False"
assert metadata["weighted"] is True, "Metadata 'weighted' should be True"


@patch("nxbench.data.loader.zipfile.ZipFile")
@pytest.mark.asyncio
async def test_load_network_retry(mock_zipfile_class, data_manager):
"""Test network loading retry behavior."""
data_manager.get_metadata = MagicMock(
return_value={
"download_url": "http://example.com/test.zip",
"directed": False,
"weighted": True,
}
)

config = DatasetConfig(name="test", source="networkrepository", params={})

mock_response = AsyncMock()
mock_response.status = 200
mock_response.content.read = AsyncMock(side_effect=[b"data", b""])

mock_session = AsyncMock(spec=aiohttp.ClientSession)
mock_session.get.return_value.__aenter__.return_value = mock_response

mock_zipfile = MagicMock()
mock_zipfile.__enter__.return_value.extractall = MagicMock()
mock_zipfile_class.return_value = mock_zipfile

with pytest.raises(FileNotFoundError):
await data_manager.load_network(config, session=mock_session)

mock_session.get.assert_called_once_with("http://example.com/test.zip")


def test_load_weighted_graph(data_manager, tmp_path):
"""Test loading weighted graph formats."""
content = """# Test weighted graph
1 2 1.5
2 3 2.5
3 1 3.5
"""
test_file = tmp_path / "test.edges"
test_file.write_text(content)

graph = data_manager._load_graph_file(
test_file, {"directed": False, "weighted": True}
)

assert isinstance(graph, nx.Graph)
assert graph.number_of_nodes() == 3
assert graph.number_of_edges() == 3
for _, _, data in graph.edges(data=True):
assert "weight" in data
assert isinstance(data["weight"], float)


def test_normalize_graph(data_manager, tmp_path):
"""Test graph normalization and cleanup."""
content = """1 1
1 2
2 3"""
test_file = tmp_path / "test.edges"
test_file.write_text(content)

normalized = data_manager._load_graph_file(test_file, {"directed": False})

assert len(list(nx.selfloop_edges(normalized))) == 0
assert all(isinstance(n, str) for n in normalized.nodes())
assert normalized.number_of_nodes() == 3
assert normalized.number_of_edges() == 2
Loading