Skip to content

Commit

Permalink
test: expand unit testing
Browse files Browse the repository at this point in the history
  • Loading branch information
dPys committed Dec 28, 2024
1 parent 0483083 commit 715aa6a
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 10 deletions.
78 changes: 78 additions & 0 deletions nxbench/data/tests/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,81 @@ def test_delete_results_no_match(benchmark_db, sample_benchmark_result):

remaining = benchmark_db.get_results(as_pandas=False)
assert len(remaining) == 1


def test_save_results_with_error_and_parameters(benchmark_db, sample_benchmark_result):
sample_benchmark_result.error = "Test error message"
benchmark_db.save_results(sample_benchmark_result)

results = benchmark_db.get_results(as_pandas=False)
assert len(results) == 1
result = results[0]
assert result["error"] == "Test error message"


def test_filter_results_by_start_and_end_date(benchmark_db, sample_benchmark_result):
"""
Test that results can be filtered correctly using both start_date and end_date.
We manually update timestamps in the DB to simulate older/newer records.
"""
import sqlite3
from datetime import datetime, timedelta, timezone

benchmark_db.save_results(sample_benchmark_result)
old_timestamp = (datetime.now(timezone.utc) - timedelta(days=2)).isoformat()
with sqlite3.connect(benchmark_db.db_path) as conn:
conn.execute("UPDATE benchmarks SET timestamp=? WHERE id=1", (old_timestamp,))
conn.commit()

benchmark_db.save_results(sample_benchmark_result)
new_timestamp = datetime.now(timezone.utc).isoformat()

results = benchmark_db.get_results(
start_date=old_timestamp, end_date=new_timestamp, as_pandas=False
)
assert (
len(results) == 2
), "Both old and new records should be included in the date range."

middle_timestamp = (datetime.now(timezone.utc) - timedelta(days=1)).isoformat()
results = benchmark_db.get_results(
start_date=middle_timestamp, end_date=new_timestamp, as_pandas=False
)
assert (
len(results) == 1
), "Only the newer record should match after filtering by a middle start date."


def test_directed_and_weighted_flags_are_integers(
benchmark_db, sample_benchmark_result
):
"""Ensure that 'directed' and 'weighted' fields are correctly stored as integers
(0 or 1).
"""
sample_benchmark_result.is_directed = True
sample_benchmark_result.is_weighted = True
benchmark_db.save_results(sample_benchmark_result)

results = benchmark_db.get_results(as_pandas=False)
assert len(results) == 1
result = results[0]
assert (
result["directed"] == 1
), "'is_directed=True' should be stored as integer '1'."
assert (
result["weighted"] == 1
), "'is_weighted=True' should be stored as integer '1'."

sample_benchmark_result.is_directed = False
sample_benchmark_result.is_weighted = False
benchmark_db.save_results(sample_benchmark_result)

results = benchmark_db.get_results(as_pandas=False)
assert len(results) == 2
new_record = results[1]
assert (
new_record["directed"] == 0
), "'is_directed=False' should be stored as integer '0'."
assert (
new_record["weighted"] == 0
), "'is_weighted=False' should be stored as integer '0'."
87 changes: 77 additions & 10 deletions nxbench/viz/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,7 @@ def test_load_data(csv_file_path):
"""Test that load_data reads a CSV file and returns a DataFrame."""
df = load_data(csv_file_path)
assert not df.empty
# Check that data types are not coerced here, just loaded
# e.g. 'execution_time' might still be object before preprocess_data
assert df["execution_time"].dtype in (object, "O")
# Check that we "hacked" pd.DataFrame.iteritems -> .items
assert hasattr(pd.DataFrame, "iteritems")
assert df.shape[0] == 4

Expand Down Expand Up @@ -139,17 +136,12 @@ def test_load_and_prepare_data(mock_load_data, raw_df):
load_and_prepare_data("fake/path.csv", logger)
)

# Expect that the function used load_data internally
mock_load_data.assert_called_once_with("fake/path.csv")

# Check that data was cleaned
# We already tested correctness in test_preprocess_data, so minimal checks here
assert not cleaned_df.empty
assert "execution_time_with_preloading" in cleaned_df.columns

# Check aggregated data shape
assert not df_agg.empty
# The group columns should include "algorithm" and possibly "backend_full" etc.
assert "algorithm" in group_columns
assert len(available_parcats_columns) >= 1

Expand All @@ -172,13 +164,88 @@ def test_load_and_prepare_data_with_different_logger_levels(logger_level, raw_df
logger = logging.getLogger("test_logger")
logger.setLevel(logger_level)

# We patch load_data to return our fixture
with patch("nxbench.viz.utils.load_data", return_value=raw_df):
cleaned_df, df_agg, group_columns, available_parcats_columns = (
load_and_prepare_data("any_path.csv", logger)
)
# Basic sanity checks
assert not cleaned_df.empty
assert not df_agg.empty
assert len(group_columns) > 0
assert isinstance(available_parcats_columns, list)


def test_preprocess_data_no_preloading_col():
"""
Ensures coverage for the branch that adds 'execution_time_with_preloading'
if it's missing from the DataFrame.
"""
df = pd.DataFrame(
{
"algorithm": ["bfs"],
"execution_time": ["1.5"],
"memory_used": [128],
"num_nodes": [100],
"num_edges": [500],
"num_thread": ["4"],
}
)
cleaned_df = preprocess_data(df)
assert "execution_time_with_preloading" in cleaned_df.columns
assert cleaned_df["execution_time_with_preloading"].iloc[0] == 1.5


def test_preprocess_data_single_node_edge():
"""
Ensures coverage for the else-branches that skip binning
when there's only 1 unique num_nodes or num_edges.
"""
df = pd.DataFrame(
{
"algorithm": ["bfs", "bfs"],
"execution_time": ["0.5", "0.6"],
"execution_time_with_preloading": [None, None],
"memory_used": [100.0, 110.0],
"num_nodes": [10, 10],
"num_edges": [50, 50],
"num_thread": ["2", "2"],
}
)

cleaned_df = preprocess_data(df)
assert all(cleaned_df["num_nodes_bin"] == 10)
assert all(cleaned_df["num_edges_bin"] == 50)


def test_aggregate_data_no_backend_version():
"""
Ensures coverage for the else-branch that warns when 'backend_version'
is missing from the DataFrame.
"""
df = pd.DataFrame(
{
"algorithm": ["bfs", "dfs"],
"dataset": ["ds1", "ds2"],
"backend": ["networkx", "gunrock"],
"num_nodes_bin": ["10 <= x < 50", "50 <= x < 100"],
"num_edges_bin": ["50 <= x < 250", "250 <= x < 500"],
"is_directed": [False, True],
"is_weighted": [False, True],
"python_version": ["3.8", "3.9"],
"cpu": ["intel", "amd"],
"os": ["linux", "windows"],
"num_thread": [4, 8],
"execution_time": [0.5, 1.2],
"execution_time_with_preloading": [0.5, 1.1],
"memory_used": [100.0, 200.0],
}
)

logger = logging.getLogger("nxbench")
with patch.object(logger, "warning") as mock_warn:
df_agg, group_cols, _ = aggregate_data(df)
mock_warn.assert_called_once_with(
"No 'backend_version' column found in the dataframe."
)

assert "backend_full" not in group_cols
assert "backend" in group_cols

0 comments on commit 715aa6a

Please sign in to comment.