floatSVD.cc

Go to the documentation of this file.
00001 /*
00002 
00003 Copyright (C) 1994-2012 John W. Eaton
00004 
00005 This file is part of Octave.
00006 
00007 Octave is free software; you can redistribute it and/or modify it
00008 under the terms of the GNU General Public License as published by the
00009 Free Software Foundation; either version 3 of the License, or (at your
00010 option) any later version.
00011 
00012 Octave is distributed in the hope that it will be useful, but WITHOUT
00013 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
00014 FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
00015 for more details.
00016 
00017 You should have received a copy of the GNU General Public License
00018 along with Octave; see the file COPYING.  If not, see
00019 <http://www.gnu.org/licenses/>.
00020 
00021 */
00022 
00023 #ifdef HAVE_CONFIG_H
00024 #include <config.h>
00025 #endif
00026 
00027 #include <iostream>
00028 
00029 #include "floatSVD.h"
00030 #include "f77-fcn.h"
00031 #include "oct-locbuf.h"
00032 
00033 extern "C"
00034 {
00035   F77_RET_T
00036   F77_FUNC (sgesvd, SGESVD) (F77_CONST_CHAR_ARG_DECL,
00037                              F77_CONST_CHAR_ARG_DECL,
00038                              const octave_idx_type&, const octave_idx_type&,
00039                              float*, const octave_idx_type&, float*,
00040                              float*, const octave_idx_type&, float*,
00041                              const octave_idx_type&, float*,
00042                              const octave_idx_type&, octave_idx_type&
00043                              F77_CHAR_ARG_LEN_DECL
00044                              F77_CHAR_ARG_LEN_DECL);
00045 
00046   F77_RET_T
00047   F77_FUNC (sgesdd, SGESDD) (F77_CONST_CHAR_ARG_DECL,
00048                              const octave_idx_type&, const octave_idx_type&,
00049                              float*, const octave_idx_type&, float*,
00050                              float*, const octave_idx_type&, float*,
00051                              const octave_idx_type&, float*,
00052                              const octave_idx_type&, octave_idx_type *,
00053                              octave_idx_type&
00054                              F77_CHAR_ARG_LEN_DECL);
00055 }
00056 
00057 FloatMatrix
00058 FloatSVD::left_singular_matrix (void) const
00059 {
00060   if (type_computed == SVD::sigma_only)
00061     {
00062       (*current_liboctave_error_handler)
00063         ("FloatSVD: U not computed because type == SVD::sigma_only");
00064       return FloatMatrix ();
00065     }
00066   else
00067     return left_sm;
00068 }
00069 
00070 FloatMatrix
00071 FloatSVD::right_singular_matrix (void) const
00072 {
00073   if (type_computed == SVD::sigma_only)
00074     {
00075       (*current_liboctave_error_handler)
00076         ("FloatSVD: V not computed because type == SVD::sigma_only");
00077       return FloatMatrix ();
00078     }
00079   else
00080     return right_sm;
00081 }
00082 
00083 octave_idx_type
00084 FloatSVD::init (const FloatMatrix& a, SVD::type svd_type, SVD::driver svd_driver)
00085 {
00086   octave_idx_type info;
00087 
00088   octave_idx_type m = a.rows ();
00089   octave_idx_type n = a.cols ();
00090 
00091   FloatMatrix atmp = a;
00092   float *tmp_data = atmp.fortran_vec ();
00093 
00094   octave_idx_type min_mn = m < n ? m : n;
00095 
00096   char jobu = 'A';
00097   char jobv = 'A';
00098 
00099   octave_idx_type ncol_u = m;
00100   octave_idx_type nrow_vt = n;
00101   octave_idx_type nrow_s = m;
00102   octave_idx_type ncol_s = n;
00103 
00104   switch (svd_type)
00105     {
00106     case SVD::economy:
00107       jobu = jobv = 'S';
00108       ncol_u = nrow_vt = nrow_s = ncol_s = min_mn;
00109       break;
00110 
00111     case SVD::sigma_only:
00112 
00113       // Note:  for this case, both jobu and jobv should be 'N', but
00114       // there seems to be a bug in dgesvd from Lapack V2.0.  To
00115       // demonstrate the bug, set both jobu and jobv to 'N' and find
00116       // the singular values of [eye(3), eye(3)].  The result is
00117       // [-sqrt(2), -sqrt(2), -sqrt(2)].
00118       //
00119       // For Lapack 3.0, this problem seems to be fixed.
00120 
00121       jobu = jobv = 'N';
00122       ncol_u = nrow_vt = 1;
00123       break;
00124 
00125     default:
00126       break;
00127     }
00128 
00129   type_computed = svd_type;
00130 
00131   if (! (jobu == 'N' || jobu == 'O'))
00132     left_sm.resize (m, ncol_u);
00133 
00134   float *u = left_sm.fortran_vec ();
00135 
00136   sigma.resize (nrow_s, ncol_s);
00137   float *s_vec  = sigma.fortran_vec ();
00138 
00139   if (! (jobv == 'N' || jobv == 'O'))
00140     right_sm.resize (nrow_vt, n);
00141 
00142   float *vt = right_sm.fortran_vec ();
00143 
00144   // Query SGESVD for the correct dimension of WORK.
00145 
00146   octave_idx_type lwork = -1;
00147 
00148   Array<float> work (dim_vector (1, 1));
00149 
00150   octave_idx_type one = 1;
00151   octave_idx_type m1 = std::max (m, one);
00152   octave_idx_type nrow_vt1 = std::max (nrow_vt, one);
00153 
00154   if (svd_driver == SVD::GESVD)
00155     {
00156       F77_XFCN (sgesvd, SGESVD, (F77_CONST_CHAR_ARG2 (&jobu, 1),
00157                                  F77_CONST_CHAR_ARG2 (&jobv, 1),
00158                                  m, n, tmp_data, m1, s_vec, u, m1, vt,
00159                                  nrow_vt1, work.fortran_vec (), lwork, info
00160                                  F77_CHAR_ARG_LEN (1)
00161                                  F77_CHAR_ARG_LEN (1)));
00162 
00163       lwork = static_cast<octave_idx_type> (work(0));
00164       work.resize (dim_vector (lwork, 1));
00165 
00166       F77_XFCN (sgesvd, SGESVD, (F77_CONST_CHAR_ARG2 (&jobu, 1),
00167                                  F77_CONST_CHAR_ARG2 (&jobv, 1),
00168                                  m, n, tmp_data, m1, s_vec, u, m1, vt,
00169                                  nrow_vt1, work.fortran_vec (), lwork, info
00170                                  F77_CHAR_ARG_LEN (1)
00171                                  F77_CHAR_ARG_LEN (1)));
00172 
00173     }
00174   else if (svd_driver == SVD::GESDD)
00175     {
00176       assert (jobu == jobv);
00177       char jobz = jobu;
00178       OCTAVE_LOCAL_BUFFER (octave_idx_type, iwork, 8*min_mn);
00179 
00180       F77_XFCN (sgesdd, SGESDD, (F77_CONST_CHAR_ARG2 (&jobz, 1),
00181                                  m, n, tmp_data, m1, s_vec, u, m1, vt,
00182                                  nrow_vt1, work.fortran_vec (), lwork, iwork, info
00183                                  F77_CHAR_ARG_LEN (1)));
00184 
00185       lwork = static_cast<octave_idx_type> (work(0));
00186       work.resize (dim_vector (lwork, 1));
00187 
00188       F77_XFCN (sgesdd, SGESDD, (F77_CONST_CHAR_ARG2 (&jobz, 1),
00189                                  m, n, tmp_data, m1, s_vec, u, m1, vt,
00190                                  nrow_vt1, work.fortran_vec (), lwork, iwork, info
00191                                  F77_CHAR_ARG_LEN (1)));
00192 
00193     }
00194   else
00195     assert (0); // impossible
00196 
00197   if (! (jobv == 'N' || jobv == 'O'))
00198     right_sm = right_sm.transpose ();
00199 
00200   return info;
00201 }
00202 
00203 std::ostream&
00204 operator << (std::ostream& os, const FloatSVD& a)
00205 {
00206   os << a.left_singular_matrix () << "\n";
00207   os << a.singular_values () << "\n";
00208   os << a.right_singular_matrix () << "\n";
00209 
00210   return os;
00211 }
 All Classes Files Functions Variables Typedefs Enumerations Enumerator Friends Defines