GNU Octave  3.8.0
A high-level interpreted language, primarily intended for numerical computations, mostly compatible with Matlab
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Pages
tril.cc
Go to the documentation of this file.
1 /*
2 
3 Copyright (C) 2004-2013 David Bateman
4 Copyright (C) 2009 VZLU Prague
5 
6 This file is part of Octave.
7 
8 Octave is free software; you can redistribute it and/or modify it
9 under the terms of the GNU General Public License as published by the
10 Free Software Foundation; either version 3 of the License, or (at your
11 option) any later version.
12 
13 Octave is distributed in the hope that it will be useful, but WITHOUT
14 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
15 FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
16 for more details.
17 
18 You should have received a copy of the GNU General Public License
19 along with Octave; see the file COPYING. If not, see
20 <http://www.gnu.org/licenses/>.
21 
22 */
23 
24 #ifdef HAVE_CONFIG_H
25 #include <config.h>
26 #endif
27 
28 #include <algorithm>
29 #include "Array.h"
30 #include "Sparse.h"
31 #include "mx-base.h"
32 
33 #include "ov.h"
34 #include "Cell.h"
35 
36 #include "defun.h"
37 #include "error.h"
38 #include "oct-obj.h"
39 
40 // The bulk of the work.
41 template <class T>
42 static Array<T>
43 do_tril (const Array<T>& a, octave_idx_type k, bool pack)
44 {
45  octave_idx_type nr = a.rows (), nc = a.columns ();
46  const T *avec = a.fortran_vec ();
47  octave_idx_type zero = 0;
48 
49  if (pack)
50  {
51  octave_idx_type j1 = std::min (std::max (zero, k), nc);
52  octave_idx_type j2 = std::min (std::max (zero, nr + k), nc);
53  octave_idx_type n = j1 * nr + ((j2 - j1) * (nr-(j1-k) + nr-(j2-1-k))) / 2;
54  Array<T> r (dim_vector (n, 1));
55  T *rvec = r.fortran_vec ();
56  for (octave_idx_type j = 0; j < nc; j++)
57  {
58  octave_idx_type ii = std::min (std::max (zero, j - k), nr);
59  rvec = std::copy (avec + ii, avec + nr, rvec);
60  avec += nr;
61  }
62 
63  return r;
64  }
65  else
66  {
67  Array<T> r (a.dims ());
68  T *rvec = r.fortran_vec ();
69  for (octave_idx_type j = 0; j < nc; j++)
70  {
71  octave_idx_type ii = std::min (std::max (zero, j - k), nr);
72  std::fill (rvec, rvec + ii, T ());
73  std::copy (avec + ii, avec + nr, rvec + ii);
74  avec += nr;
75  rvec += nr;
76  }
77 
78  return r;
79  }
80 }
81 
82 template <class T>
83 static Array<T>
84 do_triu (const Array<T>& a, octave_idx_type k, bool pack)
85 {
86  octave_idx_type nr = a.rows (), nc = a.columns ();
87  const T *avec = a.fortran_vec ();
88  octave_idx_type zero = 0;
89 
90  if (pack)
91  {
92  octave_idx_type j1 = std::min (std::max (zero, k), nc);
93  octave_idx_type j2 = std::min (std::max (zero, nr + k), nc);
95  = ((j2 - j1) * ((j1+1-k) + (j2-k))) / 2 + (nc - j2) * nr;
96  Array<T> r (dim_vector (n, 1));
97  T *rvec = r.fortran_vec ();
98  for (octave_idx_type j = 0; j < nc; j++)
99  {
100  octave_idx_type ii = std::min (std::max (zero, j + 1 - k), nr);
101  rvec = std::copy (avec, avec + ii, rvec);
102  avec += nr;
103  }
104 
105  return r;
106  }
107  else
108  {
109  NoAlias<Array<T> > r (a.dims ());
110  T *rvec = r.fortran_vec ();
111  for (octave_idx_type j = 0; j < nc; j++)
112  {
113  octave_idx_type ii = std::min (std::max (zero, j + 1 - k), nr);
114  std::copy (avec, avec + ii, rvec);
115  std::fill (rvec + ii, rvec + nr, T ());
116  avec += nr;
117  rvec += nr;
118  }
119 
120  return r;
121  }
122 }
123 
124 // These two are by David Bateman.
125 // FIXME: optimizations possible. "pack" support missing.
126 
127 template <class T>
128 static Sparse<T>
129 do_tril (const Sparse<T>& a, octave_idx_type k, bool pack)
130 {
131  if (pack) // FIXME
132  {
133  error ("tril: \"pack\" not implemented for sparse matrices");
134  return Sparse<T> ();
135  }
136 
137  Sparse<T> m = a;
138  octave_idx_type nc = m.cols ();
139 
140  for (octave_idx_type j = 0; j < nc; j++)
141  for (octave_idx_type i = m.cidx (j); i < m.cidx (j+1); i++)
142  if (m.ridx (i) < j-k)
143  m.data(i) = 0.;
144 
145  m.maybe_compress (true);
146  return m;
147 }
148 
149 template <class T>
150 static Sparse<T>
151 do_triu (const Sparse<T>& a, octave_idx_type k, bool pack)
152 {
153  if (pack) // FIXME
154  {
155  error ("triu: \"pack\" not implemented for sparse matrices");
156  return Sparse<T> ();
157  }
158 
159  Sparse<T> m = a;
160  octave_idx_type nc = m.cols ();
161 
162  for (octave_idx_type j = 0; j < nc; j++)
163  for (octave_idx_type i = m.cidx (j); i < m.cidx (j+1); i++)
164  if (m.ridx (i) > j-k)
165  m.data(i) = 0.;
166 
167  m.maybe_compress (true);
168  return m;
169 }
170 
171 // Convenience dispatchers.
172 template <class T>
173 static Array<T>
174 do_trilu (const Array<T>& a, octave_idx_type k, bool lower, bool pack)
175 {
176  return lower ? do_tril (a, k, pack) : do_triu (a, k, pack);
177 }
178 
179 template <class T>
180 static Sparse<T>
181 do_trilu (const Sparse<T>& a, octave_idx_type k, bool lower, bool pack)
182 {
183  return lower ? do_tril (a, k, pack) : do_triu (a, k, pack);
184 }
185 
186 static octave_value
187 do_trilu (const std::string& name,
188  const octave_value_list& args)
189 {
190  bool lower = name == "tril";
191 
192  octave_value retval;
193  int nargin = args.length ();
194  octave_idx_type k = 0;
195  bool pack = false;
196  if (nargin >= 2 && args(nargin-1).is_string ())
197  {
198  pack = args(nargin-1).string_value () == "pack";
199  nargin--;
200  }
201 
202  if (nargin == 2)
203  {
204  k = args(1).int_value (true);
205 
206  if (error_state)
207  return retval;
208  }
209 
210  if (nargin < 1 || nargin > 2)
211  print_usage ();
212  else
213  {
214  octave_value arg = args (0);
215 
216  dim_vector dims = arg.dims ();
217  if (dims.length () != 2)
218  error ("%s: need a 2-D matrix", name.c_str ());
219  else if (k < -dims (0) || k > dims(1))
220  error ("%s: requested diagonal out of range", name.c_str ());
221  else
222  {
223  switch (arg.builtin_type ())
224  {
225  case btyp_double:
226  if (arg.is_sparse_type ())
227  retval = do_trilu (arg.sparse_matrix_value (), k, lower, pack);
228  else
229  retval = do_trilu (arg.array_value (), k, lower, pack);
230  break;
231  case btyp_complex:
232  if (arg.is_sparse_type ())
233  retval = do_trilu (arg.sparse_complex_matrix_value (), k, lower,
234  pack);
235  else
236  retval = do_trilu (arg.complex_array_value (), k, lower, pack);
237  break;
238  case btyp_bool:
239  if (arg.is_sparse_type ())
240  retval = do_trilu (arg.sparse_bool_matrix_value (), k, lower,
241  pack);
242  else
243  retval = do_trilu (arg.bool_array_value (), k, lower, pack);
244  break;
245 #define ARRAYCASE(TYP) \
246  case btyp_ ## TYP: \
247  retval = do_trilu (arg.TYP ## _array_value (), k, lower, pack); \
248  break
249  ARRAYCASE (float);
250  ARRAYCASE (float_complex);
251  ARRAYCASE (int8);
252  ARRAYCASE (int16);
253  ARRAYCASE (int32);
254  ARRAYCASE (int64);
255  ARRAYCASE (uint8);
256  ARRAYCASE (uint16);
257  ARRAYCASE (uint32);
258  ARRAYCASE (uint64);
259  ARRAYCASE (char);
260 #undef ARRAYCASE
261  default:
262  {
263  // Generic code that works on octave-values, that is slow
264  // but will also work on arbitrary user types
265 
266  if (pack) // FIXME
267  {
268  error ("%s: \"pack\" not implemented for class %s",
269  name.c_str (), arg.class_name ().c_str ());
270  return octave_value ();
271  }
272 
273  octave_value tmp = arg;
274  if (arg.numel () == 0)
275  return arg;
276 
277  octave_idx_type nr = dims(0), nc = dims (1);
278 
279  // The sole purpose of the below is to force the correct
280  // matrix size. This would not be necessary if the
281  // octave_value resize function allowed a fill_value.
282  // It also allows odd attributes in some user types
283  // to be handled. With a fill_value ot should be replaced
284  // with
285  //
286  // octave_value_list ov_idx;
287  // tmp = tmp.resize(dim_vector (0,0)).resize (dims, fill_value);
288 
289  octave_value_list ov_idx;
290  std::list<octave_value_list> idx_tmp;
291  ov_idx(1) = static_cast<double> (nc+1);
292  ov_idx(0) = Range (1, nr);
293  idx_tmp.push_back (ov_idx);
294  ov_idx(1) = static_cast<double> (nc);
295  tmp = tmp.resize (dim_vector (0,0));
296  tmp = tmp.subsasgn ("(",idx_tmp, arg.do_index_op (ov_idx));
297  tmp = tmp.resize (dims);
298 
299  if (lower)
300  {
301  octave_idx_type st = nc < nr + k ? nc : nr + k;
302 
303  for (octave_idx_type j = 1; j <= st; j++)
304  {
305  octave_idx_type nr_limit = 1 > j - k ? 1 : j - k;
306  ov_idx(1) = static_cast<double> (j);
307  ov_idx(0) = Range (nr_limit, nr);
308  std::list<octave_value_list> idx;
309  idx.push_back (ov_idx);
310 
311  tmp = tmp.subsasgn ("(", idx, arg.do_index_op (ov_idx));
312 
313  if (error_state)
314  return retval;
315  }
316  }
317  else
318  {
319  octave_idx_type st = k + 1 > 1 ? k + 1 : 1;
320 
321  for (octave_idx_type j = st; j <= nc; j++)
322  {
323  octave_idx_type nr_limit = nr < j - k ? nr : j - k;
324  ov_idx(1) = static_cast<double> (j);
325  ov_idx(0) = Range (1, nr_limit);
326  std::list<octave_value_list> idx;
327  idx.push_back (ov_idx);
328 
329  tmp = tmp.subsasgn ("(", idx, arg.do_index_op (ov_idx));
330 
331  if (error_state)
332  return retval;
333  }
334  }
335 
336  retval = tmp;
337  }
338  }
339  }
340  }
341 
342  return retval;
343 }
344 
345 DEFUN (tril, args, ,
346  "-*- texinfo -*-\n\
347 @deftypefn {Function File} {} tril (@var{A})\n\
348 @deftypefnx {Function File} {} tril (@var{A}, @var{k})\n\
349 @deftypefnx {Function File} {} tril (@var{A}, @var{k}, @var{pack})\n\
350 @deftypefnx {Function File} {} triu (@var{A})\n\
351 @deftypefnx {Function File} {} triu (@var{A}, @var{k})\n\
352 @deftypefnx {Function File} {} triu (@var{A}, @var{k}, @var{pack})\n\
353 Return a new matrix formed by extracting the lower (@code{tril})\n\
354 or upper (@code{triu}) triangular part of the matrix @var{A}, and\n\
355 setting all other elements to zero. The second argument is optional,\n\
356 and specifies how many diagonals above or below the main diagonal should\n\
357 also be set to zero.\n\
358 \n\
359 The default value of @var{k} is zero, so that @code{triu} and\n\
360 @code{tril} normally include the main diagonal as part of the result.\n\
361 \n\
362 If the value of @var{k} is nonzero integer, the selection of elements\n\
363 starts at an offset of @var{k} diagonals above or below the main\n\
364 diagonal; above for positive @var{k} and below for negative @var{k}.\n\
365 \n\
366 The absolute value of @var{k} must not be greater than the number of\n\
367 sub-diagonals or super-diagonals.\n\
368 \n\
369 For example:\n\
370 \n\
371 @example\n\
372 @group\n\
373 tril (ones (3), -1)\n\
374  @result{} 0 0 0\n\
375  1 0 0\n\
376  1 1 0\n\
377 @end group\n\
378 @end example\n\
379 \n\
380 @noindent\n\
381 and\n\
382 \n\
383 @example\n\
384 @group\n\
385 tril (ones (3), 1)\n\
386  @result{} 1 1 0\n\
387  1 1 1\n\
388  1 1 1\n\
389 @end group\n\
390 @end example\n\
391 \n\
392 If the option @qcode{\"pack\"} is given as third argument, the extracted\n\
393 elements are not inserted into a matrix, but rather stacked column-wise one\n\
394 above other.\n\
395 @seealso{diag}\n\
396 @end deftypefn")
397 {
398  return do_trilu ("tril", args);
399 }
400 
401 DEFUN (triu, args, ,
402  "-*- texinfo -*-\n\
403 @deftypefn {Function File} {} triu (@var{A})\n\
404 @deftypefnx {Function File} {} triu (@var{A}, @var{k})\n\
405 @deftypefnx {Function File} {} triu (@var{A}, @var{k}, @var{pack})\n\
406 See the documentation for the @code{tril} function (@pxref{tril}).\n\
407 @end deftypefn")
408 {
409  return do_trilu ("triu", args);
410 }
411 
412 /*
413 %!test
414 %! a = [1, 2, 3; 4, 5, 6; 7, 8, 9; 10, 11, 12];
415 %!
416 %! l0 = [1, 0, 0; 4, 5, 0; 7, 8, 9; 10, 11, 12];
417 %! l1 = [1, 2, 0; 4, 5, 6; 7, 8, 9; 10, 11, 12];
418 %! l2 = [1, 2, 3; 4, 5, 6; 7, 8, 9; 10, 11, 12];
419 %! lm1 = [0, 0, 0; 4, 0, 0; 7, 8, 0; 10, 11, 12];
420 %! lm2 = [0, 0, 0; 0, 0, 0; 7, 0, 0; 10, 11, 0];
421 %! lm3 = [0, 0, 0; 0, 0, 0; 0, 0, 0; 10, 0, 0];
422 %! lm4 = [0, 0, 0; 0, 0, 0; 0, 0, 0; 0, 0, 0];
423 %!
424 %! assert (tril (a, -4), lm4);
425 %! assert (tril (a, -3), lm3);
426 %! assert (tril (a, -2), lm2);
427 %! assert (tril (a, -1), lm1);
428 %! assert (tril (a), l0);
429 %! assert (tril (a, 1), l1);
430 %! assert (tril (a, 2), l2);
431 
432 %!error tril ()
433 */