GNU Octave  4.0.0
A high-level interpreted language, primarily intended for numerical computations, mostly compatible with Matlab
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Pages
sqrtm.cc
Go to the documentation of this file.
1 /*
2 
3 Copyright (C) 2001-2015 Ross Lippert and Paul Kienzle
4 Copyright (C) 2010 VZLU Prague
5 
6 This file is part of Octave.
7 
8 Octave is free software; you can redistribute it and/or modify it
9 under the terms of the GNU General Public License as published by the
10 Free Software Foundation; either version 3 of the License, or (at your
11 option) any later version.
12 
13 Octave is distributed in the hope that it will be useful, but WITHOUT
14 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
15 FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
16 for more details.
17 
18 You should have received a copy of the GNU General Public License
19 along with Octave; see the file COPYING. If not, see
20 <http://www.gnu.org/licenses/>.
21 
22 */
23 
24 #ifdef HAVE_CONFIG_H
25 #include <config.h>
26 #endif
27 
28 #include <float.h>
29 
30 #include "CmplxSCHUR.h"
31 #include "fCmplxSCHUR.h"
32 #include "lo-ieee.h"
33 #include "lo-mappers.h"
34 #include "oct-norm.h"
35 
36 #include "defun.h"
37 #include "error.h"
38 #include "gripes.h"
39 #include "utils.h"
40 #include "xnorm.h"
41 
42 template <class Matrix>
43 static void
45 {
46  typedef typename Matrix::element_type element_type;
47 
48  const element_type zero = element_type ();
49 
50  bool singular = false;
51 
52  // The following code is equivalent to this triple loop:
53  //
54  // n = rows (T);
55  // for j = 1:n
56  // T(j,j) = sqrt (T(j,j));
57  // for i = j-1:-1:1
58  // T(i,j) /= (T(i,i) + T(j,j));
59  // k = 1:i-1;
60  // T(k,j) -= T(k,i) * T(i,j);
61  // endfor
62  // endfor
63  //
64  // this is an in-place, cache-aligned variant of the code
65  // given in Higham's paper.
66 
67  const octave_idx_type n = T.rows ();
68  element_type *Tp = T.fortran_vec ();
69  for (octave_idx_type j = 0; j < n; j++)
70  {
71  element_type *colj = Tp + n*j;
72  if (colj[j] != zero)
73  colj[j] = sqrt (colj[j]);
74  else
75  singular = true;
76 
77  for (octave_idx_type i = j-1; i >= 0; i--)
78  {
79  const element_type *coli = Tp + n*i;
80  const element_type colji = colj[i] /= (coli[i] + colj[j]);
81  for (octave_idx_type k = 0; k < i; k++)
82  colj[k] -= coli[k] * colji;
83  }
84  }
85 
86  if (singular)
87  warning_with_id ("Octave:sqrtm:SingularMatrix",
88  "sqrtm: matrix is singular, may not have a square root");
89 }
90 
91 template <class Matrix, class ComplexMatrix, class ComplexSCHUR>
92 static octave_value
94 {
95 
96  octave_value retval;
97 
98  MatrixType mt = arg.matrix_type ();
99 
100  bool iscomplex = arg.is_complex_type ();
101 
102  typedef typename Matrix::element_type real_type;
103 
104  real_type cutoff = 0;
105  real_type one = 1;
106  real_type eps = std::numeric_limits<real_type>::epsilon ();
107 
108  if (! iscomplex)
109  {
111 
112  if (mt.is_unknown ()) // if type is not known, compute it now.
113  arg.matrix_type (mt = MatrixType (x));
114 
115  switch (mt.type ())
116  {
117  case MatrixType::Upper:
119  if (! x.diag ().any_element_is_negative ())
120  {
121  // Do it in real arithmetic.
122  sqrtm_utri_inplace (x);
123  retval = x;
124  retval.matrix_type (mt);
125  }
126  else
127  iscomplex = true;
128  break;
129 
130  case MatrixType::Lower:
131  if (! x.diag ().any_element_is_negative ())
132  {
133  x = x.transpose ();
134  sqrtm_utri_inplace (x);
135  retval = x.transpose ();
136  retval.matrix_type (mt);
137  }
138  else
139  iscomplex = true;
140  break;
141 
142  default:
143  iscomplex = true;
144  break;
145  }
146 
147  if (iscomplex)
148  cutoff = 10 * x.rows () * eps * xnorm (x, one);
149  }
150 
151  if (iscomplex)
152  {
154 
155  if (mt.is_unknown ()) // if type is not known, compute it now.
156  arg.matrix_type (mt = MatrixType (x));
157 
158  switch (mt.type ())
159  {
160  case MatrixType::Upper:
162  sqrtm_utri_inplace (x);
163  retval = x;
164  retval.matrix_type (mt);
165  break;
166 
167  case MatrixType::Lower:
168  x = x.transpose ();
169  sqrtm_utri_inplace (x);
170  retval = x.transpose ();
171  retval.matrix_type (mt);
172  break;
173 
174  default:
175  {
176  ComplexMatrix u;
177 
178  do
179  {
180  ComplexSCHUR schur (x, std::string (), true);
181  x = schur.schur_matrix ();
182  u = schur.unitary_matrix ();
183  }
184  while (0); // schur no longer needed.
185 
186  sqrtm_utri_inplace (x);
187 
188  x = u * x; // original x no longer needed.
190 
191  if (cutoff > 0 && xnorm (imag (res), one) <= cutoff)
192  retval = real (res);
193  else
194  retval = res;
195  }
196  break;
197  }
198  }
199 
200  return retval;
201 }
202 
203 DEFUN (sqrtm, args, nargout,
204  "-*- texinfo -*-\n\
205 @deftypefn {Built-in Function} {@var{s} =} sqrtm (@var{A})\n\
206 @deftypefnx {Built-in Function} {[@var{s}, @var{error_estimate}] =} sqrtm (@var{A})\n\
207 Compute the matrix square root of the square matrix @var{A}.\n\
208 \n\
209 Ref: @nospell{N.J. Higham}. @cite{A New sqrtm for @sc{matlab}}. Numerical\n\
210 Analysis Report No. 336, Manchester @nospell{Centre} for Computational\n\
211 Mathematics, Manchester, England, January 1999.\n\
212 @seealso{expm, logm}\n\
213 @end deftypefn")
214 {
215  octave_value_list retval;
216 
217  int nargin = args.length ();
218 
219  if (nargin != 1)
220  {
221  print_usage ();
222  return retval;
223  }
224 
225  octave_value arg = args(0);
226 
227  octave_idx_type n = arg.rows ();
228  octave_idx_type nc = arg.columns ();
229 
230  if (n != nc || arg.ndims () > 2)
231  {
233  return retval;
234  }
235 
236  if (nargout > 1)
237  {
238  retval.resize (1, 2);
239  retval(2) = -1.0;
240  }
241 
242  if (arg.is_diag_matrix ())
243  // sqrtm of a diagonal matrix is just sqrt.
244  retval(0) = arg.sqrt ();
245  else if (arg.is_single_type ())
246  retval(0) = do_sqrtm<FloatMatrix, FloatComplexMatrix, FloatComplexSCHUR>
247  (arg);
248  else if (arg.is_numeric_type ())
249  retval(0) = do_sqrtm<Matrix, ComplexMatrix, ComplexSCHUR> (arg);
250 
251  if (nargout > 1 && ! error_state)
252  {
253  // This corresponds to generic code
254  //
255  // norm (s*s - x, "fro") / norm (x, "fro");
256 
257  octave_value s = retval(0);
258  retval(1) = xfrobnorm (s*s - arg) / xfrobnorm (arg);
259  }
260 
261  return retval;
262 }
263 
264 /*
265 %!assert (sqrtm (2*ones (2)), ones (2), 3*eps)
266 
267 ## The following two tests are from the reference in the docstring above.
268 %!test
269 %! x = [0 1; 0 0];
270 %! assert (any (isnan (sqrtm (x))(:)));
271 
272 %!test
273 %! x = eye (4); x(2,2) = x(3,3) = 2^-26; x(1,4) = 1;
274 %! z = eye (4); z(2,2) = z(3,3) = 2^-13; z(1,4) = 0.5;
275 %! [y, err] = sqrtm (x);
276 %! assert (y, z);
277 %! assert (err, 0); # Yes, this one has to hold exactly
278 */
Matrix diag(octave_idx_type k=0) const
Definition: dMatrix.cc:2712
void warning_with_id(const char *id, const char *fmt,...)
Definition: error.cc:696
bool any_element_is_negative(bool=false) const
Definition: dNDArray.cc:550
int ndims(void) const
Definition: ov.h:479
octave_idx_type rows(void) const
Definition: ov.h:473
static octave_value do_sqrtm(const octave_value &arg)
Definition: sqrtm.cc:93
bool is_unknown(void) const
Definition: MatrixType.h:132
OCTINTERP_API void print_usage(void)
Definition: defun.cc:51
octave_idx_type length(void) const
Definition: oct-obj.h:89
bool is_numeric_type(void) const
Definition: ov.h:663
#define DEFUN(name, args_name, nargout_name, doc)
Definition: defun.h:44
ComplexMatrix schur_matrix(void) const
Definition: CmplxSCHUR.h:75
void gripe_square_matrix_required(const char *name)
Definition: gripes.cc:81
octave_idx_type rows(void) const
Definition: Array.h:313
octave_idx_type columns(void) const
Definition: ov.h:475
int type(bool quiet=true)
Definition: MatrixType.cc:963
int error_state
Definition: error.cc:101
bool is_complex_type(void) const
Definition: ov.h:654
Matrix transpose(void) const
Definition: dMatrix.h:114
OCTAVE_API double xnorm(const ColumnVector &x, double p)
Definition: oct-norm.cc:536
Definition: dMatrix.h:35
ComplexMatrix transpose(void) const
Definition: CMatrix.h:151
ComplexMatrix unitary_matrix(void) const
Definition: CmplxSCHUR.h:77
#define eps(C)
double arg(double x)
Definition: lo-mappers.h:37
MatrixType matrix_type(void) const
Definition: ov.h:510
ComplexMatrix xgemm(const ComplexMatrix &a, const ComplexMatrix &b, blas_trans_type transa, blas_trans_type transb)
Definition: CMatrix.cc:3686
OCTAVE_API double xfrobnorm(const Matrix &x)
Definition: oct-norm.cc:536
void resize(octave_idx_type n, const octave_value &rfv=octave_value())
Definition: oct-obj.h:93
ColumnVector imag(const ComplexColumnVector &a)
Definition: dColVector.cc:162
static void sqrtm_utri_inplace(Matrix &T)
Definition: sqrtm.cc:44
ComplexMatrix octave_value_extract< ComplexMatrix >(const octave_value &v)
Definition: ov.h:1413
const T * fortran_vec(void) const
Definition: Array.h:481
bool is_single_type(void) const
Definition: ov.h:611
ColumnVector real(const ComplexColumnVector &a)
Definition: dColVector.cc:156
Matrix octave_value_extract< Matrix >(const octave_value &v)
Definition: ov.h:1411
octave_value sqrt(void) const
Definition: ov.h:1200
bool is_diag_matrix(void) const
Definition: ov.h:556
F77_RET_T const double * x