GNU Octave  4.0.0
A high-level interpreted language, primarily intended for numerical computations, mostly compatible with Matlab
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Pages
sparse-dmsolve.cc
Go to the documentation of this file.
1 /*
2 
3 Copyright (C) 2006-2015 David Bateman
4 
5 This file is part of Octave.
6 
7 Octave is free software; you can redistribute it and/or modify it
8 under the terms of the GNU General Public License as published by the
9 Free Software Foundation; either version 3 of the License, or (at your
10 option) any later version.
11 
12 Octave is distributed in the hope that it will be useful, but WITHOUT
13 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
14 FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
15 for more details.
16 
17 You should have received a copy of the GNU General Public License
18 along with Octave; see the file COPYING. If not, see
19 <http://www.gnu.org/licenses/>.
20 
21 */
22 
23 #ifdef HAVE_CONFIG_H
24 #include <config.h>
25 #endif
26 
27 #include <vector>
28 
29 #include "MArray.h"
30 #include "MSparse.h"
31 #include "SparseQR.h"
32 #include "SparseCmplxQR.h"
33 #include "MatrixType.h"
34 #include "oct-sort.h"
35 #include "oct-locbuf.h"
36 #include "oct-inttypes.h"
37 
38 template <class T>
39 static MSparse<T>
41  const octave_idx_type *Q, octave_idx_type rst,
43  octave_idx_type cend, octave_idx_type maxnz = -1,
44  bool lazy = false)
45 {
46  octave_idx_type nr = rend - rst;
47  octave_idx_type nc = cend - cst;
48  maxnz = (maxnz < 0 ? A.nnz () : maxnz);
49  octave_idx_type nz;
50 
51  // Cast to uint64 to handle overflow in this multiplication
52  if (octave_uint64 (nr)*octave_uint64 (nc) < octave_uint64 (maxnz))
53  nz = nr*nc;
54  else
55  nz = maxnz;
56 
57  MSparse<T> B (nr, nc, (nz < maxnz ? nz : maxnz));
58  // Some sparse functions can support lazy indexing (where elements
59  // in the row are in no particular order), even though octave in
60  // general can't. For those functions that can using it is a big
61  // win here in terms of speed.
62  if (lazy)
63  {
64  nz = 0;
65  for (octave_idx_type j = cst ; j < cend ; j++)
66  {
67  octave_idx_type qq = (Q ? Q[j] : j);
68  B.xcidx (j - cst) = nz;
69  for (octave_idx_type p = A.cidx (qq) ; p < A.cidx (qq+1) ; p++)
70  {
71  octave_quit ();
72  octave_idx_type r = (Pinv ? Pinv[A.ridx (p)] : A.ridx (p));
73  if (r >= rst && r < rend)
74  {
75  B.xdata (nz) = A.data (p);
76  B.xridx (nz++) = r - rst ;
77  }
78  }
79  }
80  B.xcidx (cend - cst) = nz ;
81  }
82  else
83  {
84  OCTAVE_LOCAL_BUFFER (T, X, rend - rst);
86  octave_idx_type *ri = B.xridx ();
87  nz = 0;
88  for (octave_idx_type j = cst ; j < cend ; j++)
89  {
90  octave_idx_type qq = (Q ? Q[j] : j);
91  B.xcidx (j - cst) = nz;
92  for (octave_idx_type p = A.cidx (qq) ; p < A.cidx (qq+1) ; p++)
93  {
94  octave_quit ();
95  octave_idx_type r = (Pinv ? Pinv[A.ridx (p)] : A.ridx (p));
96  if (r >= rst && r < rend)
97  {
98  X[r-rst] = A.data (p);
99  B.xridx (nz++) = r - rst ;
100  }
101  }
102  sort.sort (ri + B.xcidx (j - cst), nz - B.xcidx (j - cst));
103  for (octave_idx_type p = B.cidx (j - cst); p < nz; p++)
104  B.xdata (p) = X[B.xridx (p)];
105  }
106  B.xcidx (cend - cst) = nz ;
107  }
108 
109  return B;
110 }
111 
112 #if !defined (CXX_NEW_FRIEND_TEMPLATE_DECL)
113 static MSparse<double>
114 dmsolve_extract (const MSparse<double> &A, const octave_idx_type *Pinv,
115  const octave_idx_type *Q, octave_idx_type rst,
117  octave_idx_type cend, octave_idx_type maxnz,
118  bool lazy);
119 
120 static MSparse<Complex>
122  const octave_idx_type *Q, octave_idx_type rst,
124  octave_idx_type cend, octave_idx_type maxnz,
125  bool lazy);
126 #endif
127 
128 template <class T>
129 static MArray<T>
133  octave_idx_type c2)
134 {
135  r2 -= 1;
136  c2 -= 1;
137  if (r1 > r2) { std::swap (r1, r2); }
138  if (c1 > c2) { std::swap (c1, c2); }
139 
140  octave_idx_type new_r = r2 - r1 + 1;
141  octave_idx_type new_c = c2 - c1 + 1;
142 
143  MArray<T> result (dim_vector (new_r, new_c));
144 
145  for (octave_idx_type j = 0; j < new_c; j++)
146  for (octave_idx_type i = 0; i < new_r; i++)
147  result.xelem (i, j) = m.elem (r1+i, c1+j);
148 
149  return result;
150 }
151 
152 #if !defined (CXX_NEW_FRIEND_TEMPLATE_DECL)
153 static MArray<double>
157  octave_idx_type c2)
158 
159 static MArray<Complex>
163  octave_idx_type c2)
164 #endif
165 
166 template <class T>
167 static void
170 {
171  T *ax = a.fortran_vec ();
172  const T *bx = b.fortran_vec ();
173  octave_idx_type anr = a.rows ();
174  octave_idx_type nr = b.rows ();
175  octave_idx_type nc = b.cols ();
176  for (octave_idx_type j = 0; j < nc; j++)
177  {
178  octave_idx_type aoff = (c + j) * anr;
179  octave_idx_type boff = j * nr;
180  for (octave_idx_type i = 0; i < nr; i++)
181  {
182  octave_quit ();
183  ax[Q[r + i] + aoff] = bx[i + boff];
184  }
185  }
186 }
187 
188 #if !defined (CXX_NEW_FRIEND_TEMPLATE_DECL)
189 static void
192 
193 static void
196 #endif
197 
198 template <class T>
199 static void
202 {
203  octave_idx_type b_rows = b.rows ();
204  octave_idx_type b_cols = b.cols ();
205  octave_idx_type nr = a.rows ();
206  octave_idx_type nc = a.cols ();
207 
209  for (octave_idx_type i = 0; i < nr; i++)
210  Qinv[Q[i]] = i;
211 
212  // First count the number of elements in the final array
213  octave_idx_type nel = a.xcidx (c) + b.nnz ();
214 
215  if (c + b_cols < nc)
216  nel += a.xcidx (nc) - a.xcidx (c + b_cols);
217 
218  for (octave_idx_type i = c; i < c + b_cols; i++)
219  for (octave_idx_type j = a.xcidx (i); j < a.xcidx (i+1); j++)
220  if (Qinv[a.xridx (j)] < r || Qinv[a.xridx (j)] >= r + b_rows)
221  nel++;
222 
223  OCTAVE_LOCAL_BUFFER (T, X, nr);
225  MSparse<T> tmp (a);
226  a = MSparse<T> (nr, nc, nel);
227  octave_idx_type *ri = a.xridx ();
228 
229  for (octave_idx_type i = 0; i < tmp.cidx (c); i++)
230  {
231  a.xdata (i) = tmp.xdata (i);
232  a.xridx (i) = tmp.xridx (i);
233  }
234  for (octave_idx_type i = 0; i < c + 1; i++)
235  a.xcidx (i) = tmp.xcidx (i);
236 
237  octave_idx_type ii = a.xcidx (c);
238 
239  for (octave_idx_type i = c; i < c + b_cols; i++)
240  {
241  octave_quit ();
242 
243  for (octave_idx_type j = tmp.xcidx (i); j < tmp.xcidx (i+1); j++)
244  if (Qinv[tmp.xridx (j)] < r || Qinv[tmp.xridx (j)] >= r + b_rows)
245  {
246  X[tmp.xridx (j)] = tmp.xdata (j);
247  a.xridx (ii++) = tmp.xridx (j);
248  }
249 
250  octave_quit ();
251 
252  for (octave_idx_type j = b.cidx (i-c); j < b.cidx (i-c+1); j++)
253  {
254  X[Q[r + b.ridx (j)]] = b.data (j);
255  a.xridx (ii++) = Q[r + b.ridx (j)];
256  }
257 
258  sort.sort (ri + a.xcidx (i), ii - a.xcidx (i));
259  for (octave_idx_type p = a.xcidx (i); p < ii; p++)
260  a.xdata (p) = X[a.xridx (p)];
261  a.xcidx (i+1) = ii;
262  }
263 
264  for (octave_idx_type i = c + b_cols; i < nc; i++)
265  {
266  for (octave_idx_type j = tmp.xcidx (i); j < tmp.cidx (i+1); j++)
267  {
268  a.xdata (ii) = tmp.xdata (j);
269  a.xridx (ii++) = tmp.xridx (j);
270  }
271  a.xcidx (i+1) = ii;
272  }
273 }
274 
275 #if !defined (CXX_NEW_FRIEND_TEMPLATE_DECL)
276 static void
279 
280 static void
283 #endif
284 
285 template <class T, class RT>
286 static void
288 {
289  octave_idx_type b_nr = b.rows ();
290  octave_idx_type b_nc = b.cols ();
291  const T *Bx = b.fortran_vec ();
292  a.resize (dim_vector (b_nr, b_nc));
293  RT *Btx = a.fortran_vec ();
294  for (octave_idx_type j = 0; j < b_nc; j++)
295  {
296  octave_idx_type off = j * b_nr;
297  for (octave_idx_type i = 0; i < b_nr; i++)
298  {
299  octave_quit ();
300  Btx[p[i] + off] = Bx[ i + off];
301  }
302  }
303 }
304 
305 #if !defined (CXX_NEW_FRIEND_TEMPLATE_DECL)
306 static void
308  const octave_idx_type *p);
309 
310 static void
312  const octave_idx_type *p);
313 
314 static void
316  const octave_idx_type *p);
317 #endif
318 
319 template <class T, class RT>
320 static void
322 {
323  octave_idx_type b_nr = b.rows ();
324  octave_idx_type b_nc = b.cols ();
325  octave_idx_type b_nz = b.nnz ();
326  octave_idx_type nz = 0;
327  a = MSparse<RT> (b_nr, b_nc, b_nz);
329  octave_idx_type *ri = a.xridx ();
330  OCTAVE_LOCAL_BUFFER (RT, X, b_nr);
331  a.xcidx (0) = 0;
332  for (octave_idx_type j = 0; j < b_nc; j++)
333  {
334  for (octave_idx_type i = b.cidx (j); i < b.cidx (j+1); i++)
335  {
336  octave_quit ();
337  octave_idx_type r = p[b.ridx (i)];
338  X[r] = b.data (i);
339  a.xridx (nz++) = p[b.ridx (i)];
340  }
341  sort.sort (ri + a.xcidx (j), nz - a.xcidx (j));
342  for (octave_idx_type i = a.cidx (j); i < nz; i++)
343  {
344  octave_quit ();
345  a.xdata (i) = X[a.xridx (i)];
346  }
347  a.xcidx (j+1) = nz;
348  }
349 }
350 
351 #if !defined (CXX_NEW_FRIEND_TEMPLATE_DECL)
352 static void
354  const octave_idx_type *p);
355 
356 static void
358  const octave_idx_type *p);
359 
360 static void
362  const octave_idx_type *p);
363 #endif
364 
365 static void
367 {
368  // Dummy singularity handler so that LU solver doesn't flag
369  // an error for numerically rank defficient matrices
370 }
371 
372 template <class RT, class ST, class T>
373 RT
374 dmsolve (const ST &a, const T &b, octave_idx_type &info)
375 {
376 #ifdef HAVE_CXSPARSE
377  octave_idx_type nr = a.rows ();
378  octave_idx_type nc = a.cols ();
379  octave_idx_type b_nr = b.rows ();
380  octave_idx_type b_nc = b.cols ();
381  RT retval;
382 
383  if (nr < 0 || nc < 0 || nr != b_nr)
384  (*current_liboctave_error_handler)
385  ("matrix dimension mismatch in solution of minimum norm problem");
386  else if (nr == 0 || nc == 0 || b_nc == 0)
387  retval = RT (nc, b_nc, 0.0);
388  else
389  {
390  octave_idx_type nnz_remaining = a.nnz ();
391  CXSPARSE_DNAME () csm;
392  csm.m = nr;
393  csm.n = nc;
394  csm.x = 0;
395  csm.nz = -1;
396  csm.nzmax = a.nnz ();
397  // Cast away const on A, with full knowledge that CSparse won't touch it.
398  // Prevents the methods below making a copy of the data.
399  csm.p = const_cast<octave_idx_type *>(a.cidx ());
400  csm.i = const_cast<octave_idx_type *>(a.ridx ());
401 
402 #if defined (CS_VER) && (CS_VER >= 2)
403  CXSPARSE_DNAME (d) *dm = CXSPARSE_DNAME(_dmperm) (&csm, 0);
404  octave_idx_type *p = dm->p;
405  octave_idx_type *q = dm->q;
406 #else
407  CXSPARSE_DNAME (d) *dm = CXSPARSE_DNAME(_dmperm) (&csm);
408  octave_idx_type *p = dm->P;
409  octave_idx_type *q = dm->Q;
410 #endif
412  for (octave_idx_type i = 0; i < nr; i++)
413  pinv[p[i]] = i;
414  RT btmp;
415  dmsolve_permute (btmp, b, pinv);
416  info = 0;
417  retval.resize (nc, b_nc);
418 
419  // Leading over-determined block
420  if (dm->rr[2] < nr && dm->cc[3] < nc)
421  {
422  ST m = dmsolve_extract (a, pinv, q, dm->rr[2], nr, dm->cc[3], nc,
423  nnz_remaining, true);
424  nnz_remaining -= m.nnz ();
425  RT mtmp =
426  qrsolve (m, dmsolve_extract (btmp, 0, 0, dm->rr[2], b_nr, 0,
427  b_nc), info);
428  dmsolve_insert (retval, mtmp, q, dm->cc[3], 0);
429  if (dm->rr[2] > 0 && !info)
430  {
431  m = dmsolve_extract (a, pinv, q, 0, dm->rr[2],
432  dm->cc[3], nc, nnz_remaining, true);
433  nnz_remaining -= m.nnz ();
434  RT ctmp = dmsolve_extract (btmp, 0, 0, 0,
435  dm->rr[2], 0, b_nc);
436  btmp.insert (ctmp - m * mtmp, 0, 0);
437  }
438  }
439 
440  // Structurally non-singular blocks
441  // FIXME: Should use fine Dulmange-Mendelsohn decomposition here.
442  if (dm->rr[1] < dm->rr[2] && dm->cc[2] < dm->cc[3] && !info)
443  {
444  ST m = dmsolve_extract (a, pinv, q, dm->rr[1], dm->rr[2],
445  dm->cc[2], dm->cc[3], nnz_remaining, false);
446  nnz_remaining -= m.nnz ();
447  RT btmp2 = dmsolve_extract (btmp, 0, 0, dm->rr[1], dm->rr[2],
448  0, b_nc);
449  double rcond = 0.0;
451  RT mtmp = m.solve (mtyp, btmp2, info, rcond,
453  if (info != 0)
454  {
455  info = 0;
456  mtmp = qrsolve (m, btmp2, info);
457  }
458 
459  dmsolve_insert (retval, mtmp, q, dm->cc[2], 0);
460  if (dm->rr[1] > 0 && !info)
461  {
462  m = dmsolve_extract (a, pinv, q, 0, dm->rr[1], dm->cc[2],
463  dm->cc[3], nnz_remaining, true);
464  nnz_remaining -= m.nnz ();
465  RT ctmp = dmsolve_extract (btmp, 0, 0, 0,
466  dm->rr[1], 0, b_nc);
467  btmp.insert (ctmp - m * mtmp, 0, 0);
468  }
469  }
470 
471  // Trailing under-determined block
472  if (dm->rr[1] > 0 && dm->cc[2] > 0 && !info)
473  {
474  ST m = dmsolve_extract (a, pinv, q, 0, dm->rr[1], 0,
475  dm->cc[2], nnz_remaining, true);
476  RT mtmp =
477  qrsolve (m, dmsolve_extract (btmp, 0, 0, 0, dm->rr[1] , 0,
478  b_nc), info);
479  dmsolve_insert (retval, mtmp, q, 0, 0);
480  }
481 
482  CXSPARSE_DNAME (_dfree) (dm);
483  }
484  return retval;
485 #else
486  (*current_liboctave_error_handler)
487  ("CXSPARSE unavailable; cannot solve minimum norm problem");
488  return RT ();
489 #endif
490 }
491 
492 #if !defined (CXX_NEW_FRIEND_TEMPLATE_DECL)
493 extern Matrix
494 dmsolve (const SparseMatrix &a, const Matrix &b,
495  octave_idx_type &info);
496 
497 extern ComplexMatrix
498 dmsolve (const SparseMatrix &a, const ComplexMatrix &b,
499  octave_idx_type &info);
500 
501 extern ComplexMatrix
502 dmsolve (const SparseComplexMatrix &a, const Matrix &b,
503  octave_idx_type &info);
504 
505 extern ComplexMatrix
506 dmsolve (const SparseComplexMatrix &a, const ComplexMatrix &b,
507  octave_idx_type &info);
508 
509 extern SparseMatrix
510 dmsolve (const SparseMatrix &a, const SparseMatrix &b,
511  octave_idx_type &info);
512 
513 extern SparseComplexMatrix
514 dmsolve (const SparseMatrix &a, const SparseComplexMatrix &b,
515  octave_idx_type &info);
516 
517 extern SparseComplexMatrix
518 dmsolve (const SparseComplexMatrix &a, const SparseMatrix &b,
519  octave_idx_type &info);
520 
521 extern SparseComplexMatrix
523  octave_idx_type &info);
524 #endif
octave_idx_type * xridx(void)
Definition: Sparse.h:524
octave_idx_type cols(void) const
Definition: Sparse.h:264
T * data(void)
Definition: Sparse.h:509
octave_idx_type rows(void) const
Definition: Sparse.h:263
static void solve_singularity_warning(double)
F77_RET_T const octave_idx_type Complex * A
Definition: CmplxGEPBAL.cc:39
octave_idx_type * xcidx(void)
Definition: Sparse.h:537
octave_idx_type * cidx(void)
Definition: Sparse.h:531
T & elem(octave_idx_type n)
Definition: Array.h:380
Definition: MArray.h:36
octave_idx_type rows(void) const
Definition: Array.h:313
F77_RET_T const double const double double * d
octave_idx_type nnz(void) const
Definition: Sparse.h:248
static MArray< double > const octave_idx_type const octave_idx_type octave_idx_type octave_idx_type r2
#define CXSPARSE_DNAME(name)
Definition: SparseQR.h:37
Sparse< T > sort(octave_idx_type dim=0, sortmode mode=ASCENDING) const
Definition: Sparse.cc:2240
void resize(const dim_vector &dv, const T &rfv)
Definition: Array.cc:1033
octave_int< uint64_t > octave_uint64
Definition: dMatrix.h:35
static void dmsolve_permute(MArray< RT > &a, const MArray< T > &b, const octave_idx_type *p)
T & xelem(octave_idx_type n)
Definition: Array.h:353
void sort(T *data, octave_idx_type nel)
Definition: oct-sort.cc:1514
octave_idx_type * ridx(void)
Definition: Sparse.h:518
F77_RET_T const octave_idx_type Complex const octave_idx_type Complex * B
Definition: CmplxGEPBAL.cc:39
ComplexMatrix qrsolve(const SparseComplexMatrix &a, const Matrix &b, octave_idx_type &info)
T * xdata(void)
Definition: Sparse.h:511
#define OCTAVE_LOCAL_BUFFER(T, buf, size)
Definition: oct-locbuf.h:197
F77_RET_T const octave_idx_type const octave_idx_type const octave_idx_type double const octave_idx_type double const octave_idx_type double * Q
Definition: qz.cc:114
static MSparse< T > dmsolve_extract(const MSparse< T > &A, const octave_idx_type *Pinv, const octave_idx_type *Q, octave_idx_type rst, octave_idx_type rend, octave_idx_type cst, octave_idx_type cend, octave_idx_type maxnz=-1, bool lazy=false)
static MArray< double > const octave_idx_type const octave_idx_type octave_idx_type octave_idx_type octave_idx_type c1
static MArray< double > const octave_idx_type const octave_idx_type octave_idx_type r1
const T * fortran_vec(void) const
Definition: Array.h:481
octave_idx_type cols(void) const
Definition: Array.h:321
static MArray< double > const octave_idx_type const octave_idx_type octave_idx_type octave_idx_type octave_idx_type octave_idx_type static c2 void dmsolve_insert(MArray< T > &a, const MArray< T > &b, const octave_idx_type *Q, octave_idx_type r, octave_idx_type c)
RT dmsolve(const ST &a, const T &b, octave_idx_type &info)