tril.cc

Go to the documentation of this file.
00001 /*
00002 
00003 Copyright (C) 2004-2012 David Bateman
00004 Copyright (C) 2009 VZLU Prague
00005 
00006 This file is part of Octave.
00007 
00008 Octave is free software; you can redistribute it and/or modify it
00009 under the terms of the GNU General Public License as published by the
00010 Free Software Foundation; either version 3 of the License, or (at your
00011 option) any later version.
00012 
00013 Octave is distributed in the hope that it will be useful, but WITHOUT
00014 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
00015 FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
00016 for more details.
00017 
00018 You should have received a copy of the GNU General Public License
00019 along with Octave; see the file COPYING.  If not, see
00020 <http://www.gnu.org/licenses/>.
00021 
00022 */
00023 
00024 #ifdef HAVE_CONFIG_H
00025 #include <config.h>
00026 #endif
00027 
00028 #include <algorithm>
00029 #include "Array.h"
00030 #include "Sparse.h"
00031 #include "mx-base.h"
00032 
00033 #include "ov.h"
00034 #include "Cell.h"
00035 
00036 #include "defun-dld.h"
00037 #include "error.h"
00038 #include "oct-obj.h"
00039 
00040 // The bulk of the work.
00041 template <class T>
00042 static Array<T>
00043 do_tril (const Array<T>& a, octave_idx_type k, bool pack)
00044 {
00045   octave_idx_type nr = a.rows (), nc = a.columns ();
00046   const T *avec = a.fortran_vec ();
00047   octave_idx_type zero = 0;
00048 
00049   if (pack)
00050     {
00051       octave_idx_type j1 = std::min (std::max (zero, k), nc);
00052       octave_idx_type j2 = std::min (std::max (zero, nr + k), nc);
00053       octave_idx_type n = j1 * nr + ((j2 - j1) * (nr-(j1-k) + nr-(j2-1-k))) / 2;
00054       Array<T> r (dim_vector (n, 1));
00055       T *rvec = r.fortran_vec ();
00056       for (octave_idx_type j = 0; j < nc; j++)
00057         {
00058           octave_idx_type ii = std::min (std::max (zero, j - k), nr);
00059           rvec = std::copy (avec + ii, avec + nr, rvec);
00060           avec += nr;
00061         }
00062 
00063       return r;
00064     }
00065   else
00066     {
00067       Array<T> r (a.dims ());
00068       T *rvec = r.fortran_vec ();
00069       for (octave_idx_type j = 0; j < nc; j++)
00070         {
00071           octave_idx_type ii = std::min (std::max (zero, j - k), nr);
00072           std::fill (rvec, rvec + ii, T());
00073           std::copy (avec + ii, avec + nr, rvec + ii);
00074           avec += nr;
00075           rvec += nr;
00076         }
00077 
00078       return r;
00079     }
00080 }
00081 
00082 template <class T>
00083 static Array<T>
00084 do_triu (const Array<T>& a, octave_idx_type k, bool pack)
00085 {
00086   octave_idx_type nr = a.rows (), nc = a.columns ();
00087   const T *avec = a.fortran_vec ();
00088   octave_idx_type zero = 0;
00089 
00090   if (pack)
00091     {
00092       octave_idx_type j1 = std::min (std::max (zero, k), nc);
00093       octave_idx_type j2 = std::min (std::max (zero, nr + k), nc);
00094       octave_idx_type n = ((j2 - j1) * ((j1+1-k) + (j2-k))) / 2 + (nc - j2) * nr;
00095       Array<T> r (dim_vector (n, 1));
00096       T *rvec = r.fortran_vec ();
00097       for (octave_idx_type j = 0; j < nc; j++)
00098         {
00099           octave_idx_type ii = std::min (std::max (zero, j + 1 - k), nr);
00100           rvec = std::copy (avec, avec + ii, rvec);
00101           avec += nr;
00102         }
00103 
00104       return r;
00105     }
00106   else
00107     {
00108       NoAlias<Array<T> > r (a.dims ());
00109       T *rvec = r.fortran_vec ();
00110       for (octave_idx_type j = 0; j < nc; j++)
00111         {
00112           octave_idx_type ii = std::min (std::max (zero, j + 1 - k), nr);
00113           std::copy (avec, avec + ii, rvec);
00114           std::fill (rvec + ii, rvec + nr, T());
00115           avec += nr;
00116           rvec += nr;
00117         }
00118 
00119       return r;
00120     }
00121 }
00122 
00123 // These two are by David Bateman.
00124 // FIXME: optimizations possible. "pack" support missing.
00125 
00126 template <class T>
00127 static Sparse<T>
00128 do_tril (const Sparse<T>& a, octave_idx_type k, bool pack)
00129 {
00130   if (pack) // FIXME
00131     {
00132       error ("tril: \"pack\" not implemented for sparse matrices");
00133       return Sparse<T> ();
00134     }
00135 
00136   Sparse<T> m = a;
00137   octave_idx_type nc = m.cols();
00138 
00139   for (octave_idx_type j = 0; j < nc; j++)
00140     for (octave_idx_type i = m.cidx(j); i < m.cidx(j+1); i++)
00141       if (m.ridx(i) < j-k)
00142         m.data(i) = 0.;
00143 
00144   m.maybe_compress (true);
00145   return m;
00146 }
00147 
00148 template <class T>
00149 static Sparse<T>
00150 do_triu (const Sparse<T>& a, octave_idx_type k, bool pack)
00151 {
00152   if (pack) // FIXME
00153     {
00154       error ("triu: \"pack\" not implemented for sparse matrices");
00155       return Sparse<T> ();
00156     }
00157 
00158   Sparse<T> m = a;
00159   octave_idx_type nc = m.cols();
00160 
00161   for (octave_idx_type j = 0; j < nc; j++)
00162     for (octave_idx_type i = m.cidx(j); i < m.cidx(j+1); i++)
00163       if (m.ridx(i) > j-k)
00164         m.data(i) = 0.;
00165 
00166   m.maybe_compress (true);
00167   return m;
00168 }
00169 
00170 // Convenience dispatchers.
00171 template <class T>
00172 static Array<T>
00173 do_trilu (const Array<T>& a, octave_idx_type k, bool lower, bool pack)
00174 {
00175   return lower ? do_tril (a, k, pack) : do_triu (a, k, pack);
00176 }
00177 
00178 template <class T>
00179 static Sparse<T>
00180 do_trilu (const Sparse<T>& a, octave_idx_type k, bool lower, bool pack)
00181 {
00182   return lower ? do_tril (a, k, pack) : do_triu (a, k, pack);
00183 }
00184 
00185 static octave_value
00186 do_trilu (const std::string& name,
00187           const octave_value_list& args)
00188 {
00189   bool lower = name == "tril";
00190 
00191   octave_value retval;
00192   int nargin = args.length ();
00193   octave_idx_type k = 0;
00194   bool pack = false;
00195   if (nargin >= 2 && args(nargin-1).is_string ())
00196     {
00197       pack = args(nargin-1).string_value () == "pack";
00198       nargin--;
00199     }
00200 
00201   if (nargin == 2)
00202     {
00203       k = args(1).int_value (true);
00204 
00205       if (error_state)
00206         return retval;
00207     }
00208 
00209   if (nargin < 1 || nargin > 2)
00210     print_usage ();
00211   else
00212     {
00213       octave_value arg = args (0);
00214 
00215       dim_vector dims = arg.dims ();
00216       if (dims.length () != 2)
00217         error ("%s: need a 2-D matrix", name.c_str ());
00218       else if (k < -dims (0) || k > dims(1))
00219         error ("%s: requested diagonal out of range", name.c_str ());
00220       else
00221         {
00222           switch (arg.builtin_type ())
00223             {
00224             case btyp_double:
00225               if (arg.is_sparse_type ())
00226                 retval = do_trilu (arg.sparse_matrix_value (), k, lower, pack);
00227               else
00228                 retval = do_trilu (arg.array_value (), k, lower, pack);
00229               break;
00230             case btyp_complex:
00231               if (arg.is_sparse_type ())
00232                 retval = do_trilu (arg.sparse_complex_matrix_value (), k, lower, pack);
00233               else
00234                 retval = do_trilu (arg.complex_array_value (), k, lower, pack);
00235               break;
00236             case btyp_bool:
00237               if (arg.is_sparse_type ())
00238                 retval = do_trilu (arg.sparse_bool_matrix_value (), k, lower, pack);
00239               else
00240                 retval = do_trilu (arg.bool_array_value (), k, lower, pack);
00241               break;
00242 #define ARRAYCASE(TYP) \
00243             case btyp_ ## TYP: \
00244               retval = do_trilu (arg.TYP ## _array_value (), k, lower, pack); \
00245               break
00246             ARRAYCASE (float);
00247             ARRAYCASE (float_complex);
00248             ARRAYCASE (int8);
00249             ARRAYCASE (int16);
00250             ARRAYCASE (int32);
00251             ARRAYCASE (int64);
00252             ARRAYCASE (uint8);
00253             ARRAYCASE (uint16);
00254             ARRAYCASE (uint32);
00255             ARRAYCASE (uint64);
00256             ARRAYCASE (char);
00257 #undef ARRAYCASE
00258             default:
00259               {
00260                 // Generic code that works on octave-values, that is slow
00261                 // but will also work on arbitrary user types
00262 
00263                 if (pack) // FIXME
00264                   {
00265                     error ("%s: \"pack\" not implemented for class %s",
00266                            name.c_str (), arg.class_name ().c_str ());
00267                     return octave_value ();
00268                   }
00269 
00270                 octave_value tmp = arg;
00271                 if (arg.numel () == 0)
00272                   return arg;
00273 
00274                 octave_idx_type nr = dims(0), nc = dims (1);
00275 
00276                 // The sole purpose of the below is to force the correct
00277                 // matrix size. This would not be necessary if the
00278                 // octave_value resize function allowed a fill_value.
00279                 // It also allows odd attributes in some user types
00280                 // to be handled. With a fill_value ot should be replaced
00281                 // with
00282                 //
00283                 // octave_value_list ov_idx;
00284                 // tmp = tmp.resize(dim_vector (0,0)).resize (dims, fill_value);
00285 
00286                 octave_value_list ov_idx;
00287                 std::list<octave_value_list> idx_tmp;
00288                 ov_idx(1) = static_cast<double> (nc+1);
00289                 ov_idx(0) = Range (1, nr);
00290                 idx_tmp.push_back (ov_idx);
00291                 ov_idx(1) = static_cast<double> (nc);
00292                 tmp = tmp.resize (dim_vector (0,0));
00293                 tmp = tmp.subsasgn("(",idx_tmp, arg.do_index_op (ov_idx));
00294                 tmp = tmp.resize(dims);
00295 
00296                 if (lower)
00297                   {
00298                     octave_idx_type st = nc < nr + k ? nc : nr + k;
00299 
00300                     for (octave_idx_type j = 1; j <= st; j++)
00301                       {
00302                         octave_idx_type nr_limit = 1 > j - k ? 1 : j - k;
00303                         ov_idx(1) = static_cast<double> (j);
00304                         ov_idx(0) = Range (nr_limit, nr);
00305                         std::list<octave_value_list> idx;
00306                         idx.push_back (ov_idx);
00307 
00308                         tmp = tmp.subsasgn ("(", idx, arg.do_index_op(ov_idx));
00309 
00310                         if (error_state)
00311                           return retval;
00312                       }
00313                   }
00314                 else
00315                   {
00316                     octave_idx_type st = k + 1 > 1 ? k + 1 : 1;
00317 
00318                     for (octave_idx_type j = st; j <= nc; j++)
00319                       {
00320                         octave_idx_type nr_limit = nr < j - k ? nr : j - k;
00321                         ov_idx(1) = static_cast<double> (j);
00322                         ov_idx(0) = Range (1, nr_limit);
00323                         std::list<octave_value_list> idx;
00324                         idx.push_back (ov_idx);
00325 
00326                         tmp = tmp.subsasgn ("(", idx, arg.do_index_op(ov_idx));
00327 
00328                         if (error_state)
00329                           return retval;
00330                       }
00331                   }
00332 
00333                 retval = tmp;
00334               }
00335             }
00336         }
00337     }
00338 
00339   return retval;
00340 }
00341 
00342 DEFUN_DLD (tril, args, ,
00343   "-*- texinfo -*-\n\
00344 @deftypefn  {Function File} {} tril (@var{A})\n\
00345 @deftypefnx {Function File} {} tril (@var{A}, @var{k})\n\
00346 @deftypefnx {Function File} {} tril (@var{A}, @var{k}, @var{pack})\n\
00347 @deftypefnx {Function File} {} triu (@var{A})\n\
00348 @deftypefnx {Function File} {} triu (@var{A}, @var{k})\n\
00349 @deftypefnx {Function File} {} triu (@var{A}, @var{k}, @var{pack})\n\
00350 Return a new matrix formed by extracting the lower (@code{tril})\n\
00351 or upper (@code{triu}) triangular part of the matrix @var{A}, and\n\
00352 setting all other elements to zero.  The second argument is optional,\n\
00353 and specifies how many diagonals above or below the main diagonal should\n\
00354 also be set to zero.\n\
00355 \n\
00356 The default value of @var{k} is zero, so that @code{triu} and\n\
00357 @code{tril} normally include the main diagonal as part of the result.\n\
00358 \n\
00359 If the value of @var{k} is nonzero integer, the selection of elements\
00360 starts at an offset of @var{k} diagonals above or below the main\
00361 diagonal; above for positive @var{k} and below for negative @var{k}.\
00362 \n\
00363 The absolute value of @var{k} must not be greater than the number of\n\
00364 sub-diagonals or super-diagonals.\n\
00365 \n\
00366 For example:\n\
00367 \n\
00368 @example\n\
00369 @group\n\
00370 tril (ones (3), -1)\n\
00371      @result{}  0  0  0\n\
00372          1  0  0\n\
00373          1  1  0\n\
00374 @end group\n\
00375 @end example\n\
00376 \n\
00377 @noindent\n\
00378 and\n\
00379 \n\
00380 @example\n\
00381 @group\n\
00382 tril (ones (3), 1)\n\
00383      @result{}  1  1  0\n\
00384          1  1  1\n\
00385          1  1  1\n\
00386 @end group\n\
00387 @end example\n\
00388 \n\
00389 If the option \"pack\" is given as third argument, the extracted elements\n\
00390 are not inserted into a matrix, but rather stacked column-wise one above\n\
00391 other.\n\
00392 @seealso{diag}\n\
00393 @end deftypefn")
00394 {
00395   return do_trilu ("tril", args);
00396 }
00397 
00398 DEFUN_DLD (triu, args, ,
00399   "-*- texinfo -*-\n\
00400 @deftypefn  {Function File} {} triu (@var{A})\n\
00401 @deftypefnx {Function File} {} triu (@var{A}, @var{k})\n\
00402 @deftypefnx {Function File} {} triu (@var{A}, @var{k}, @var{pack})\n\
00403 See the documentation for the @code{tril} function (@pxref{tril}).\n\
00404 @end deftypefn")
00405 {
00406   return do_trilu ("triu", args);
00407 }
00408 
00409 /*
00410 
00411 %!test
00412 %! a = [1, 2, 3; 4, 5, 6; 7, 8, 9; 10, 11, 12];
00413 %!
00414 %! l0 = [1, 0, 0; 4, 5, 0; 7, 8, 9; 10, 11, 12];
00415 %! l1 = [1, 2, 0; 4, 5, 6; 7, 8, 9; 10, 11, 12];
00416 %! l2 = [1, 2, 3; 4, 5, 6; 7, 8, 9; 10, 11, 12];
00417 %! lm1 = [0, 0, 0; 4, 0, 0; 7, 8, 0; 10, 11, 12];
00418 %! lm2 = [0, 0, 0; 0, 0, 0; 7, 0, 0; 10, 11, 0];
00419 %! lm3 = [0, 0, 0; 0, 0, 0; 0, 0, 0; 10, 0, 0];
00420 %! lm4 = [0, 0, 0; 0, 0, 0; 0, 0, 0; 0, 0, 0];
00421 %!
00422 %! assert((tril (a, -4) == lm4 && tril (a, -3) == lm3
00423 %! && tril (a, -2) == lm2 && tril (a, -1) == lm1
00424 %! && tril (a) == l0 && tril (a, 1) == l1 && tril (a, 2) == l2));
00425 
00426 %!error tril ();
00427 
00428 */
 All Classes Files Functions Variables Typedefs Enumerations Enumerator Friends Defines