@@ -159,6 +159,46 @@ def row_col_indicator_g(n_cond):
159159 np .fill_diagonal (col_i [- n_cond :, :], 1 )
160160 return (row_i , col_i )
161161
162+ def row_col_indicator_g_sparse (n_cond , dtype = np .int8 , index_dtype = np .int64 ):
163+ """
164+ Generates row and column indicator matrices as sparse matrices (CSR format)
165+ for a vectorized second moment matrix.
166+ The vectorized version has the off-diagonal elements first, then appends the diagonal.
167+
168+ Args:
169+ n_cond (int): Number of conditions underlying the second moment.
170+ dtype (numpy.dtype, optional): Data type for the non-zero elements
171+ of the sparse matrices. Defaults to np.int8.
172+ index_dtype (numpy.dtype, optional): Data type for the indices in the sparse matrix.
173+ Defaults to np.int64.
174+
175+ Returns:
176+ tuple: (row_indicator_sparse, col_indicator_sparse)
177+ row_indicator_sparse (scipy.sparse.csr_matrix): Sparse indicator matrix for rows.
178+ col_indicator_sparse (scipy.sparse.csr_matrix): Sparse indicator matrix for columns.
179+ """
180+ num_off_diag = n_cond * (n_cond - 1 ) // 2
181+ n_elem = num_off_diag + n_cond
182+ shape = (n_elem , n_cond )
183+ if n_elem == 0 :
184+ return (csr_matrix (shape , dtype = dtype ),
185+ csr_matrix (shape , dtype = dtype ))
186+ sparse_data = np .ones (n_elem , dtype = dtype )
187+
188+ sparse_row_coords = np .arange (n_elem , dtype = index_dtype )
189+ row_i_col_coords = np .empty (n_elem , dtype = index_dtype )
190+ col_i_col_coords = np .empty (n_elem , dtype = index_dtype )
191+ if num_off_diag > 0 :
192+ upper_p , upper_q = np .triu_indices (n_cond , k = 1 )
193+ row_i_col_coords [:num_off_diag ] = upper_p .astype (index_dtype )
194+ col_i_col_coords [:num_off_diag ] = upper_q .astype (index_dtype )
195+ if n_cond > 0 :
196+ diag_indices_in_matrix = np .arange (n_cond , dtype = index_dtype )
197+ row_i_col_coords [num_off_diag :] = diag_indices_in_matrix
198+ col_i_col_coords [num_off_diag :] = diag_indices_in_matrix
199+ row_i_sparse = csr_matrix ((sparse_data , (sparse_row_coords , row_i_col_coords )), shape = shape )
200+ col_i_sparse = csr_matrix ((sparse_data , (sparse_row_coords , col_i_col_coords )), shape = shape )
201+ return (row_i_sparse , col_i_sparse )
162202
163203def run () -> spmatrix :
164204 a = csr_matrix (np .eye (3 ))
0 commit comments