Skip to content

Commit 9eb554f

Browse files
Merge pull request #441 from Tal-Golan/row_col_indicator_g_sparse
A sparse version of row_col_indicator_g
2 parents d673b50 + 315f740 commit 9eb554f

File tree

3 files changed

+53
-2
lines changed

3 files changed

+53
-2
lines changed

src/rsatoolbox/rdm/compare.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from rsatoolbox.util.matrix import pairwise_contrast
1414
from rsatoolbox.util.rdm_utils import _get_n_from_reduced_vectors
1515
from rsatoolbox.util.rdm_utils import _get_n_from_length
16-
from rsatoolbox.util.matrix import row_col_indicator_g
16+
from rsatoolbox.util.matrix import row_col_indicator_g, row_col_indicator_g_sparse
1717
from rsatoolbox.util.rdm_utils import batch_to_matrices
1818

1919

@@ -446,7 +446,11 @@ def _cov_weighting(vector, nan_idx, sigma_k=None):
446446
N, n_dist = vector.shape
447447
n_cond = _get_n_from_length(nan_idx.shape[0])
448448
vector_w = -0.5 * np.c_[vector, np.zeros((N, n_cond))]
449-
rowI, colI = row_col_indicator_g(n_cond)
449+
SPARSE_THRESHOLD = 100 # threshold for switching to sparse matrices
450+
if n_cond >= SPARSE_THRESHOLD:
451+
rowI, colI = row_col_indicator_g_sparse(n_cond) # use sparse indicator matrices
452+
else:
453+
rowI, colI = row_col_indicator_g(n_cond) # use dense indicator matrices
450454
sumI = rowI + colI
451455
if np.all(nan_idx):
452456
# column and row means

src/rsatoolbox/util/matrix.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,46 @@ def row_col_indicator_g(n_cond):
159159
np.fill_diagonal(col_i[-n_cond:, :], 1)
160160
return (row_i, col_i)
161161

162+
def row_col_indicator_g_sparse(n_cond, dtype=np.int8, index_dtype=np.int64):
163+
"""
164+
Generates row and column indicator matrices as sparse matrices (CSR format)
165+
for a vectorized second moment matrix.
166+
The vectorized version has the off-diagonal elements first, then appends the diagonal.
167+
168+
Args:
169+
n_cond (int): Number of conditions underlying the second moment.
170+
dtype (numpy.dtype, optional): Data type for the non-zero elements
171+
of the sparse matrices. Defaults to np.int8.
172+
index_dtype (numpy.dtype, optional): Data type for the indices in the sparse matrix.
173+
Defaults to np.int64.
174+
175+
Returns:
176+
tuple: (row_indicator_sparse, col_indicator_sparse)
177+
row_indicator_sparse (scipy.sparse.csr_matrix): Sparse indicator matrix for rows.
178+
col_indicator_sparse (scipy.sparse.csr_matrix): Sparse indicator matrix for columns.
179+
"""
180+
num_off_diag = n_cond * (n_cond - 1) // 2
181+
n_elem = num_off_diag + n_cond
182+
shape = (n_elem, n_cond)
183+
if n_elem == 0:
184+
return (csr_matrix(shape, dtype=dtype),
185+
csr_matrix(shape, dtype=dtype))
186+
sparse_data = np.ones(n_elem, dtype=dtype)
187+
188+
sparse_row_coords = np.arange(n_elem, dtype=index_dtype)
189+
row_i_col_coords = np.empty(n_elem, dtype=index_dtype)
190+
col_i_col_coords = np.empty(n_elem, dtype=index_dtype)
191+
if num_off_diag > 0:
192+
upper_p, upper_q = np.triu_indices(n_cond, k=1)
193+
row_i_col_coords[:num_off_diag] = upper_p.astype(index_dtype)
194+
col_i_col_coords[:num_off_diag] = upper_q.astype(index_dtype)
195+
if n_cond > 0:
196+
diag_indices_in_matrix = np.arange(n_cond, dtype=index_dtype)
197+
row_i_col_coords[num_off_diag:] = diag_indices_in_matrix
198+
col_i_col_coords[num_off_diag:] = diag_indices_in_matrix
199+
row_i_sparse = csr_matrix((sparse_data, (sparse_row_coords, row_i_col_coords)), shape=shape)
200+
col_i_sparse = csr_matrix((sparse_data, (sparse_row_coords, col_i_col_coords)), shape=shape)
201+
return (row_i_sparse, col_i_sparse)
162202

163203
def run() -> spmatrix:
164204
a = csr_matrix(np.eye(3))

tests/test_util_matrix.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@ def test_centering(self):
4545
self.assertEqual(n_row, 10)
4646
self.assertEqual(n_col, 10)
4747

48+
def test_row_col_indicator_g(self):
49+
""" Tests the equivalence of the sparse and dense versions of row_col_indicator_g """
50+
for n_cond in [0, 1, 2, 3, 4, 5, 10, 20, 50, 100]:
51+
rowI_sparse, colI_sparse = rsu.matrix.row_col_indicator_g_sparse(n_cond)
52+
rowI_dense, colI_dense = rsu.matrix.row_col_indicator_g(n_cond)
53+
self.assertTrue(np.array_equal(rowI_sparse.toarray(), rowI_dense))
54+
self.assertTrue(np.array_equal(colI_sparse.toarray(), colI_dense))
4855

4956
if __name__ == '__main__':
5057
unittest.main()

0 commit comments

Comments
 (0)