GNU Octave  3.8.0
A high-level interpreted language, primarily intended for numerical computations, mostly compatible with Matlab
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Pages
Sparse-diag-op-defs.h
Go to the documentation of this file.
1 /* -*- C++ -*-
2 
3 Copyright (C) 2009-2013 Jason Riedy, Jaroslav Hajek
4 
5 This file is part of Octave.
6 
7 Octave is free software; you can redistribute it and/or modify it
8 under the terms of the GNU General Public License as published by the
9 Free Software Foundation; either version 3 of the License, or (at your
10 option) any later version.
11 
12 Octave is distributed in the hope that it will be useful, but WITHOUT
13 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
14 FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
15 for more details.
16 
17 You should have received a copy of the GNU General Public License
18 along with Octave; see the file COPYING. If not, see
19 <http://www.gnu.org/licenses/>.
20 
21 */
22 
23 #if !defined (octave_Sparse_diag_op_defs_h)
24 #define octave_Sparse_diag_op_defs_h 1
25 
26 // Matrix multiplication
27 
28 template <typename RT, typename DM, typename SM>
29 RT do_mul_dm_sm (const DM& d, const SM& a)
30 {
31  const octave_idx_type nr = d.rows ();
32  const octave_idx_type nc = d.cols ();
33 
34  const octave_idx_type a_nr = a.rows ();
35  const octave_idx_type a_nc = a.cols ();
36 
37  if (nc != a_nr)
38  {
39  gripe_nonconformant ("operator *", nr, nc, a_nr, a_nc);
40  return RT ();
41  }
42  else
43  {
44  RT r (nr, a_nc, a.nnz ());
45 
46  octave_idx_type l = 0;
47 
48  for (octave_idx_type j = 0; j < a_nc; j++)
49  {
50  r.xcidx (j) = l;
51  const octave_idx_type colend = a.cidx (j+1);
52  for (octave_idx_type k = a.cidx (j); k < colend; k++)
53  {
54  const octave_idx_type i = a.ridx (k);
55  if (i >= nr) break;
56  r.xdata (l) = d.dgelem (i) * a.data (k);
57  r.xridx (l) = i;
58  l++;
59  }
60  }
61 
62  r.xcidx (a_nc) = l;
63 
64  r.maybe_compress (true);
65  return r;
66  }
67 }
68 
69 template <typename RT, typename SM, typename DM>
70 RT do_mul_sm_dm (const SM& a, const DM& d)
71 {
72  const octave_idx_type nr = d.rows ();
73  const octave_idx_type nc = d.cols ();
74 
75  const octave_idx_type a_nr = a.rows ();
76  const octave_idx_type a_nc = a.cols ();
77 
78  if (nr != a_nc)
79  {
80  gripe_nonconformant ("operator *", a_nr, a_nc, nr, nc);
81  return RT ();
82  }
83  else
84  {
85 
86  const octave_idx_type mnc = nc < a_nc ? nc: a_nc;
87  RT r (a_nr, nc, a.cidx (mnc));
88 
89  for (octave_idx_type j = 0; j < mnc; ++j)
90  {
91  const typename DM::element_type s = d.dgelem (j);
92  const octave_idx_type colend = a.cidx (j+1);
93  r.xcidx (j) = a.cidx (j);
94  for (octave_idx_type k = a.cidx (j); k < colend; ++k)
95  {
96  r.xdata (k) = s * a.data (k);
97  r.xridx (k) = a.ridx (k);
98  }
99  }
100  for (octave_idx_type j = mnc; j <= nc; ++j)
101  r.xcidx (j) = a.cidx (mnc);
102 
103  r.maybe_compress (true);
104  return r;
105  }
106 }
107 
108 // FIXME: functors such as this should be gathered somewhere
109 template <typename T>
111  : public std::unary_function <T, T>
112 {
113  T operator () (const T x) { return x; }
114 };
115 
116 // Matrix addition
117 
118 template <typename RT, typename SM, typename DM, typename OpA, typename OpD>
119 RT inner_do_add_sm_dm (const SM& a, const DM& d, OpA opa, OpD opd)
120 {
121  using std::min;
122  const octave_idx_type nr = d.rows ();
123  const octave_idx_type nc = d.cols ();
124  const octave_idx_type n = min (nr, nc);
125 
126  const octave_idx_type a_nr = a.rows ();
127  const octave_idx_type a_nc = a.cols ();
128 
129  const octave_idx_type nz = a.nnz ();
130  RT r (a_nr, a_nc, nz + n);
131  octave_idx_type k = 0;
132 
133  for (octave_idx_type j = 0; j < nc; ++j)
134  {
135  octave_quit ();
136  const octave_idx_type colend = a.cidx (j+1);
137  r.xcidx (j) = k;
138  octave_idx_type k_src = a.cidx (j), k_split;
139 
140  for (k_split = k_src; k_split < colend; k_split++)
141  if (a.ridx (k_split) >= j)
142  break;
143 
144  for (; k_src < k_split; k_src++, k++)
145  {
146  r.xridx (k) = a.ridx (k_src);
147  r.xdata (k) = opa (a.data (k_src));
148  }
149 
150  if (k_src < colend && a.ridx (k_src) == j)
151  {
152  r.xridx (k) = j;
153  r.xdata (k) = opa (a.data (k_src)) + opd (d.dgelem (j));
154  k++; k_src++;
155  }
156  else
157  {
158  r.xridx (k) = j;
159  r.xdata (k) = opd (d.dgelem (j));
160  k++;
161  }
162 
163  for (; k_src < colend; k_src++, k++)
164  {
165  r.xridx (k) = a.ridx (k_src);
166  r.xdata (k) = opa (a.data (k_src));
167  }
168 
169  }
170  r.xcidx (nc) = k;
171 
172  r.maybe_compress (true);
173  return r;
174 }
175 
176 template <typename RT, typename DM, typename SM>
177 RT do_commutative_add_dm_sm (const DM& d, const SM& a)
178 {
179  // Extra function to ensure this is only emitted once.
180  return inner_do_add_sm_dm<RT> (a, d,
183 }
184 
185 template <typename RT, typename DM, typename SM>
186 RT do_add_dm_sm (const DM& d, const SM& a)
187 {
188  if (a.rows () != d.rows () || a.cols () != d.cols ())
189  {
190  gripe_nonconformant ("operator +", d.rows (), d.cols (), a.rows (), a.cols ());
191  return RT ();
192  }
193  else
194  return do_commutative_add_dm_sm<RT> (d, a);
195 }
196 
197 template <typename RT, typename DM, typename SM>
198 RT do_sub_dm_sm (const DM& d, const SM& a)
199 {
200  if (a.rows () != d.rows () || a.cols () != d.cols ())
201  {
202  gripe_nonconformant ("operator -", d.rows (), d.cols (), a.rows (), a.cols ());
203  return RT ();
204  }
205  else
206  return inner_do_add_sm_dm<RT> (a, d, std::negate<typename SM::element_type> (),
208 }
209 
210 template <typename RT, typename SM, typename DM>
211 RT do_add_sm_dm (const SM& a, const DM& d)
212 {
213  if (a.rows () != d.rows () || a.cols () != d.cols ())
214  {
215  gripe_nonconformant ("operator +", a.rows (), a.cols (), d.rows (), d.cols ());
216  return RT ();
217  }
218  else
219  return do_commutative_add_dm_sm<RT> (d, a);
220 }
221 
222 template <typename RT, typename SM, typename DM>
223 RT do_sub_sm_dm (const SM& a, const DM& d)
224 {
225  if (a.rows () != d.rows () || a.cols () != d.cols ())
226  {
227  gripe_nonconformant ("operator -", a.rows (), a.cols (), d.rows (), d.cols ());
228  return RT ();
229  }
230  else
231  return inner_do_add_sm_dm<RT> (a, d,
233  std::negate<typename DM::element_type> ());
234 }
235 
236 #endif // octave_Sparse_diag_op_defs_h