Skip to content

Commit

Permalink
tests: unit tests for benchmark.utils and additional for data.loader
Browse files Browse the repository at this point in the history
  • Loading branch information
dPys committed Dec 4, 2024
1 parent 32d2e4a commit 7e4041d
Show file tree
Hide file tree
Showing 3 changed files with 221 additions and 8 deletions.
101 changes: 101 additions & 0 deletions nxbench/benchmarks/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import os
from unittest.mock import patch

import pytest

from nxbench.benchmarks.utils import (
get_available_backends,
get_benchmark_config,
get_python_version,
is_cugraph_available,
is_graphblas_available,
is_nx_parallel_available,
)


@pytest.fixture(autouse=True)
def _reset_benchmark_config():
import nxbench.benchmarks.utils

original_config = nxbench.benchmarks.utils._BENCHMARK_CONFIG
nxbench.benchmarks.utils._BENCHMARK_CONFIG = None
yield
nxbench.benchmarks.utils._BENCHMARK_CONFIG = original_config


def test_backend_availability():
"""Test backend availability detection."""
with patch("importlib.util.find_spec") as mock_find_spec:
# test when backends are available
mock_find_spec.return_value = True
assert is_cugraph_available() is True
assert is_graphblas_available() is True
assert is_nx_parallel_available() is True

# test when backends are not available
mock_find_spec.return_value = None
assert is_cugraph_available() is False
assert is_graphblas_available() is False
assert is_nx_parallel_available() is False


def test_get_available_backends():
"""Test getting list of available backends."""
with (
patch("nxbench.benchmarks.utils.is_cugraph_available") as mock_cugraph,
patch("nxbench.benchmarks.utils.is_graphblas_available") as mock_graphblas,
patch("nxbench.benchmarks.utils.is_nx_parallel_available") as mock_parallel,
):

mock_cugraph.return_value = True
mock_graphblas.return_value = True
mock_parallel.return_value = True

backends = get_available_backends()
assert "networkx" in backends
assert "cugraph" in backends
assert "graphblas" in backends
assert "parallel" in backends

mock_cugraph.return_value = False
mock_graphblas.return_value = False
mock_parallel.return_value = False

backends = get_available_backends()
assert backends == ["networkx"]


def test_get_python_version():
"""Test Python version string formatting."""
version = get_python_version()
assert len(version.split(".")) == 3
for part in version.split("."):
assert part.isdigit()


def test_configure_benchmarks_env_vars():
"""Test configuration with environment variables."""
with patch.dict(os.environ, {"NXBENCH_CONFIG_FILE": "test_config.yaml"}):
with patch("pathlib.Path.exists") as mock_exists:
mock_exists.return_value = False
with pytest.raises(FileNotFoundError):
get_benchmark_config()


def test_memory_tracking():
"""Test memory usage tracking context manager."""
from nxbench.benchmarks.utils import memory_tracker

def allocate_memory():
return [0] * 1000000

with memory_tracker() as mem:
data = allocate_memory()
# ensure data exists to prevent premature garbage collection
assert len(data) == 1000000

assert "current" in mem
assert "peak" in mem
assert isinstance(mem["current"], int)
assert isinstance(mem["peak"], int)
assert mem["peak"] >= mem["current"]
25 changes: 17 additions & 8 deletions nxbench/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def get_metadata(self, name: str) -> dict[str, Any]:
return network.iloc[0].to_dict()

async def load_network(
self, config: DatasetConfig
self, config: DatasetConfig, session: aiohttp.ClientSession | None = None
) -> tuple[nx.Graph | nx.DiGraph, dict[str, Any]]:
"""Load or generate a network based on config."""
source_lower = config.source.lower()
Expand Down Expand Up @@ -89,7 +89,7 @@ async def load_network(

source_lower = config.source.lower()
if source_lower == "networkrepository":
graph, metadata = await self._load_nr_graph(config.name, metadata)
graph, metadata = await self._load_nr_graph(config.name, metadata, session)
elif source_lower == "local":
graph, metadata = self._load_local_graph(config)
elif source_lower == "generator":
Expand Down Expand Up @@ -241,7 +241,10 @@ def edge_parser():
return graph

async def _load_nr_graph(
self, name: str, metadata: dict[str, Any]
self,
name: str,
metadata: dict[str, Any],
session: aiohttp.ClientSession | None = None,
) -> nx.Graph | nx.DiGraph:
for ext in self.SUPPORTED_FORMATS:
graph_file = self.data_dir / f"{name}{ext}"
Expand All @@ -256,7 +259,7 @@ async def _load_nr_graph(
f"Network '{name}' not found in local cache. Attempting to download from "
f"repository."
)
await self._download_and_extract_network(name, url)
await self._download_and_extract_network(name, url, session)

for ext in self.SUPPORTED_FORMATS:
graph_file = self.data_dir / f"{name}{ext}"
Expand All @@ -269,13 +272,15 @@ async def _load_nr_graph(
f"download was successful and the graph file exists."
)

async def _download_and_extract_network(self, name: str, url: str):
async def _download_and_extract_network(
self, name: str, url: str, session: aiohttp.ClientSession | None = None
):
zip_path = self.data_dir / f"{name}.zip"
extracted_folder = self.data_dir / f"{name}_extracted"

if not zip_path.exists():
logger.info(f"Downloading network '{name}' from {url}")
await self._download_file(url, zip_path)
await self._download_file(url, zip_path, session)
logger.info(f"Downloaded network '{name}' to {zip_path}")

if not extracted_folder.exists():
Expand Down Expand Up @@ -309,8 +314,12 @@ async def _download_and_extract_network(self, name: str, url: str):
)
raise

async def _download_file(self, url: str, dest: Path):
async with aiohttp.ClientSession() as session:
async def _download_file(
self, url: str, dest: Path, session: aiohttp.ClientSession | None = None
):
if session is None:
session = aiohttp.ClientSession()
async with session:
async with session.get(url) as response:
if response.status != 200:
logger.error(
Expand Down
103 changes: 103 additions & 0 deletions nxbench/data/tests/test_loader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import warnings
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch

import aiohttp
import networkx as nx
import pytest

Expand Down Expand Up @@ -419,3 +421,104 @@ async def test_load_edge_list_with_extra_columns(data_manager, create_edge_file)
assert data["weight"] == expected_weights.get(
(u, v), expected_weights.get((v, u))
), f"Incorrect weight for edge ({u}, {v})"


def test_load_metadata(data_manager):
"""Test that metadata is loaded correctly."""
metadata = data_manager._metadata_df
assert not metadata.empty, "Metadata DataFrame should not be empty"
expected_names = [
"jazz",
"08blocks",
"patentcite",
"imdb",
"citeseer",
"mixed_delimiters",
"invalid_weights",
"self_loops_duplicates",
"non_sequential_ids",
"example",
"extra_columns",
"twitter",
"invalid_example",
]
assert set(metadata["name"]) == set(
expected_names
), "Metadata names do not match expected names"


def test_get_metadata(data_manager):
"""Test retrieving metadata for a network."""
metadata = data_manager.get_metadata("jazz")
assert metadata["name"] == "jazz", "Metadata 'name' should be 'jazz'"
assert metadata["directed"] is False, "Metadata 'directed' should be False"
assert metadata["weighted"] is True, "Metadata 'weighted' should be True"


@patch("nxbench.data.loader.zipfile.ZipFile")
@pytest.mark.asyncio
async def test_load_network_retry(mock_zipfile_class, data_manager):
"""Test network loading retry behavior."""
data_manager.get_metadata = MagicMock(
return_value={
"download_url": "http://example.com/test.zip",
"directed": False,
"weighted": True,
}
)

config = DatasetConfig(name="test", source="networkrepository", params={})

mock_response = AsyncMock()
mock_response.status = 200
mock_response.content.read = AsyncMock(side_effect=[b"data", b""])

mock_session = AsyncMock(spec=aiohttp.ClientSession)
mock_session.get.return_value.__aenter__.return_value = mock_response

mock_zipfile = MagicMock()
mock_zipfile.__enter__.return_value.extractall = MagicMock()
mock_zipfile_class.return_value = mock_zipfile

with pytest.raises(FileNotFoundError):
await data_manager.load_network(config, session=mock_session)

mock_session.get.assert_called_once_with("http://example.com/test.zip")


def test_load_weighted_graph(data_manager, tmp_path):
"""Test loading weighted graph formats."""
content = """# Test weighted graph
1 2 1.5
2 3 2.5
3 1 3.5
"""
test_file = tmp_path / "test.edges"
test_file.write_text(content)

graph = data_manager._load_graph_file(
test_file, {"directed": False, "weighted": True}
)

assert isinstance(graph, nx.Graph)
assert graph.number_of_nodes() == 3
assert graph.number_of_edges() == 3
for _, _, data in graph.edges(data=True):
assert "weight" in data
assert isinstance(data["weight"], float)


def test_normalize_graph(data_manager, tmp_path):
"""Test graph normalization and cleanup."""
content = """1 1
1 2
2 3"""
test_file = tmp_path / "test.edges"
test_file.write_text(content)

normalized = data_manager._load_graph_file(test_file, {"directed": False})

assert len(list(nx.selfloop_edges(normalized))) == 0
assert all(isinstance(n, str) for n in normalized.nodes())
assert normalized.number_of_nodes() == 3
assert normalized.number_of_edges() == 2

0 comments on commit 7e4041d

Please sign in to comment.