From cfb9eac6237b40b07fe2a8d533fe60126dfde788 Mon Sep 17 00:00:00 2001 From: Jayaram Kancherla Date: Sun, 12 Jan 2025 17:06:37 -0800 Subject: [PATCH] Fix accessing dimnames on matrices --- src/rds2py/read_matrix.py | 4 ++-- tests/data/generate_files.R | 2 ++ tests/test_matrices.py | 15 +++++++++++++++ 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/src/rds2py/read_matrix.py b/src/rds2py/read_matrix.py index 79d6e5b..af378b8 100644 --- a/src/rds2py/read_matrix.py +++ b/src/rds2py/read_matrix.py @@ -93,8 +93,8 @@ def _as_sparse_matrix(robject: dict, **kwargs) -> spmatrix: ) names = None - if "dimnames" in robject["attributes"]: - names = _dispatcher(robject["attributes"]["dimnames"], **kwargs) + if "Dimnames" in robject["attributes"]: + names = _dispatcher(robject["attributes"]["Dimnames"], **kwargs) if names is not None and len(names) > 0: return MatrixWrapper(mat, names) diff --git a/tests/data/generate_files.R b/tests/data/generate_files.R index 592e2c2..2c38c9a 100644 --- a/tests/data/generate_files.R +++ b/tests/data/generate_files.R @@ -97,6 +97,8 @@ saveRDS(df, file="lists_df_rownames.rds") y <- Matrix::rsparsematrix(100, 10, 0.05) saveRDS(y, file="s4_matrix.rds") +rownames(y) <- paste("row", 1:nrow(y), sep="_") + setClass("FOO", slots=c(bar="integer")) y <- new("FOO", bar=2L) saveRDS(y, file="s4_class.rds") diff --git a/tests/test_matrices.py b/tests/test_matrices.py index 52d59b9..1f24339 100644 --- a/tests/test_matrices.py +++ b/tests/test_matrices.py @@ -17,6 +17,21 @@ def test_read_s4_matrix_dgc(): assert array is not None assert isinstance(array, sp.spmatrix) +def test_read_s4_matrix_dgc_with_rownames(): + array = read_rds("tests/data/matrix_with_row_names.rds") + + assert array is not None + assert isinstance(array, MatrixWrapper) + assert len(array.dimnames[0]) = 100 + + +def test_read_s4_matrix_dgc_with_bothnames(): + array = read_rds("tests/data/matrix_with_dim_names.rds") + + assert array is not None + assert isinstance(array, MatrixWrapper) + assert len(array.dimnames[0]) = 100 + assert len(array.dimnames[0]) = 10 def test_read_s4_matrix_dgt(): array = read_rds("tests/data/s4_matrix_dgt.rds")