-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathsparse_to_csr.m
76 lines (71 loc) · 2.24 KB
/
sparse_to_csr.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
function [rp ci ai ncol]=sparse_to_csr(A,varargin)
% SPARSE_TO_CSR Convert a sparse matrix into compressed row storage arrays
%
% [rp ci ai] = sparse_to_csr(A) returns the row pointer (rp), column index
% (ci) and value index (ai) arrays of a compressed sparse representation of
% the matrix A.
%
% [rp ci ai] = sparse_to_csr(i,j,v,n) returns a csr representation of the
% index sets i,j,v with n rows.
%
% Example:
% A=sparse(6,6); A(1,1)=5; A(1,5)=2; A(2,3)=-1; A(4,1)=1; A(5,6)=1;
% [rp ci ai]=sparse_to_csr(A)
%
% See also CSR_TO_SPARSE, SPARSE
% David F. Gleich
% Copyright, Stanford University, 2008-2009
% History
% 2008-04-07: Initial version
% 2008-04-24: Added triple array input
% 2009-05-01: Added ncol output
% 2009-05-15: Fixed triplet input
error(nargchk(1, 5, nargin, 'struct'))
retc = nargout>1; reta = nargout>2;
if nargin>1
if nargin>4, ncol = varargin{4}; end
nzi = A; nzj = varargin{1};
if reta && length(varargin) > 2, nzv = varargin{2}; end
if nargin<4, n=max(nzi); else n=varargin{3}; end
nz = length(A);
if length(nzi) ~= length(nzj), error('gaimc:invalidInput',...
'length of nzi (%i) not equal to length of nzj (%i)', nz, ...
length(nzj));
end
if reta && length(varargin) < 3, error('gaimc:invalidInput',...
'no value array passed for triplet input, see usage');
end
if ~isscalar(n), error('gaimc:invalidInput',...
['the 4th input to sparse_to_csr with triple input was not ' ...
'a scalar']);
end
if nargin < 5, ncol = max(nzj);
elseif ~isscalar(ncol), error('gaimc:invalidInput',...
['the 5th input to sparse_to_csr with triple input was not ' ...
'a scalar']);
end
else
n = size(A,1); nz = nnz(A); ncol = size(A,2);
retc = nargout>1; reta = nargout>2;
if reta, [nzi nzj nzv] = find(A);
else [nzi nzj] = find(A);
end
end
if retc, ci = zeros(nz,1); end
if reta, ai = zeros(nz,1); end
rp = zeros(n+1,1);
for i=1:nz
rp(nzi(i)+1)=rp(nzi(i)+1)+1;
end
rp=cumsum(rp);
if ~retc && ~reta, rp=rp+1; return; end
for i=1:nz
if reta, ai(rp(nzi(i))+1)=nzv(i); end
ci(rp(nzi(i))+1)=nzj(i);
rp(nzi(i))=rp(nzi(i))+1;
end
for i=n:-1:1
rp(i+1)=rp(i);
end
rp(1)=0;
rp=rp+1;