GNU Octave  3.8.0
A high-level interpreted language, primarily intended for numerical computations, mostly compatible with Matlab
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Pages
dot.cc
Go to the documentation of this file.
1 /*
2 
3 Copyright (C) 2009-2013 VZLU Prague
4 
5 This file is part of Octave.
6 
7 Octave is free software; you can redistribute it and/or modify it
8 under the terms of the GNU General Public License as published by the
9 Free Software Foundation; either version 3 of the License, or (at your
10 option) any later version.
11 
12 Octave is distributed in the hope that it will be useful, but WITHOUT
13 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
14 FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
15 for more details.
16 
17 You should have received a copy of the GNU General Public License
18 along with Octave; see the file COPYING. If not, see
19 <http://www.gnu.org/licenses/>.
20 
21 */
22 
23 #ifdef HAVE_CONFIG_H
24 #include <config.h>
25 #endif
26 
27 #include "f77-fcn.h"
28 #include "mx-base.h"
29 #include "error.h"
30 #include "defun.h"
31 #include "parse.h"
32 
33 extern "C"
34 {
35  F77_RET_T
36  F77_FUNC (ddot3, DDOT3) (const octave_idx_type&, const octave_idx_type&,
37  const octave_idx_type&, const double*,
38  const double*, double*);
39 
40  F77_RET_T
41  F77_FUNC (sdot3, SDOT3) (const octave_idx_type&, const octave_idx_type&,
42  const octave_idx_type&, const float*,
43  const float*, float*);
44 
45  F77_RET_T
46  F77_FUNC (zdotc3, ZDOTC3) (const octave_idx_type&, const octave_idx_type&,
47  const octave_idx_type&, const Complex*,
48  const Complex*, Complex*);
49 
50  F77_RET_T
51  F77_FUNC (cdotc3, CDOTC3) (const octave_idx_type&, const octave_idx_type&,
52  const octave_idx_type&, const FloatComplex*,
53  const FloatComplex*, FloatComplex*);
54 
55  F77_RET_T
56  F77_FUNC (dmatm3, DMATM3) (const octave_idx_type&, const octave_idx_type&,
57  const octave_idx_type&, const octave_idx_type&,
58  const double*, const double*, double*);
59 
60  F77_RET_T
61  F77_FUNC (smatm3, SMATM3) (const octave_idx_type&, const octave_idx_type&,
62  const octave_idx_type&, const octave_idx_type&,
63  const float*, const float*, float*);
64 
65  F77_RET_T
66  F77_FUNC (zmatm3, ZMATM3) (const octave_idx_type&, const octave_idx_type&,
67  const octave_idx_type&, const octave_idx_type&,
68  const Complex*, const Complex*, Complex*);
69 
70  F77_RET_T
71  F77_FUNC (cmatm3, CMATM3) (const octave_idx_type&, const octave_idx_type&,
72  const octave_idx_type&, const octave_idx_type&,
73  const FloatComplex*, const FloatComplex*,
74  FloatComplex*);
75 }
76 
77 static void
78 get_red_dims (const dim_vector& x, const dim_vector& y, int dim,
80  octave_idx_type& k)
81 {
82  int nd = x.length ();
83  assert (nd == y.length ());
84  z = dim_vector::alloc (nd);
85  m = 1, n = 1, k = 1;
86  for (int i = 0; i < nd; i++)
87  {
88  if (i < dim)
89  {
90  z(i) = x(i);
91  m *= x(i);
92  }
93  else if (i > dim)
94  {
95  z(i) = x(i);
96  n *= x(i);
97  }
98  else
99  {
100  k = x(i);
101  z(i) = 1;
102  }
103  }
104 }
105 
106 DEFUN (dot, args, ,
107  "-*- texinfo -*-\n\
108 @deftypefn {Built-in Function} {} dot (@var{x}, @var{y}, @var{dim})\n\
109 Compute the dot product of two vectors. If @var{x} and @var{y}\n\
110 are matrices, calculate the dot products along the first\n\
111 non-singleton dimension. If the optional argument @var{dim} is\n\
112 given, calculate the dot products along this dimension.\n\
113 \n\
114 This is equivalent to\n\
115 @code{sum (conj (@var{X}) .* @var{Y}, @var{dim})},\n\
116 but avoids forming a temporary array and is faster. When @var{X} and\n\
117 @var{Y} are column vectors, the result is equivalent to\n\
118 @code{@var{X}' * @var{Y}}.\n\
119 @seealso{cross, divergence}\n\
120 @end deftypefn")
121 {
122  octave_value retval;
123  int nargin = args.length ();
124 
125  if (nargin < 2 || nargin > 3)
126  {
127  print_usage ();
128  return retval;
129  }
130 
131  octave_value argx = args(0), argy = args(1);
132 
133  if (argx.is_numeric_type () && argy.is_numeric_type ())
134  {
135  dim_vector dimx = argx.dims (), dimy = argy.dims ();
136  bool match = dimx == dimy;
137  if (! match && nargin == 2
138  && dimx.is_vector () && dimy.is_vector ())
139  {
140  // Change to column vectors.
141  dimx = dimx.redim (1);
142  argx = argx.reshape (dimx);
143  dimy = dimy.redim (1);
144  argy = argy.reshape (dimy);
145  match = ! error_state;
146  }
147 
148  if (match)
149  {
150  int dim;
151  if (nargin == 2)
152  dim = dimx.first_non_singleton ();
153  else
154  dim = args(2).int_value (true) - 1;
155 
156  if (error_state)
157  ;
158  else if (dim < 0)
159  error ("dot: DIM must be a valid dimension");
160  else
161  {
162  octave_idx_type m, n, k;
163  dim_vector dimz;
164  if (argx.is_complex_type () || argy.is_complex_type ())
165  {
166  if (argx.is_single_type () || argy.is_single_type ())
167  {
169  FloatComplexNDArray y = argy.float_complex_array_value ();
170  get_red_dims (dimx, dimy, dim, dimz, m, n, k);
171  FloatComplexNDArray z(dimz);
172  if (! error_state)
173  F77_XFCN (cdotc3, CDOTC3, (m, n, k,
174  x.data (), y.data (),
175  z.fortran_vec ()));
176  retval = z;
177  }
178  else
179  {
181  ComplexNDArray y = argy.complex_array_value ();
182  get_red_dims (dimx, dimy, dim, dimz, m, n, k);
183  ComplexNDArray z(dimz);
184  if (! error_state)
185  F77_XFCN (zdotc3, ZDOTC3, (m, n, k,
186  x.data (), y.data (),
187  z.fortran_vec ()));
188  retval = z;
189  }
190  }
191  else if (argx.is_float_type () && argy.is_float_type ())
192  {
193  if (argx.is_single_type () || argy.is_single_type ())
194  {
195  FloatNDArray x = argx.float_array_value ();
196  FloatNDArray y = argy.float_array_value ();
197  get_red_dims (dimx, dimy, dim, dimz, m, n, k);
198  FloatNDArray z(dimz);
199  if (! error_state)
200  F77_XFCN (sdot3, SDOT3, (m, n, k, x.data (), y.data (),
201  z.fortran_vec ()));
202  retval = z;
203  }
204  else
205  {
206  NDArray x = argx.array_value ();
207  NDArray y = argy.array_value ();
208  get_red_dims (dimx, dimy, dim, dimz, m, n, k);
209  NDArray z(dimz);
210  if (! error_state)
211  F77_XFCN (ddot3, DDOT3, (m, n, k, x.data (), y.data (),
212  z.fortran_vec ()));
213  retval = z;
214  }
215  }
216  else
217  {
218  // Non-optimized evaluation.
219  octave_value_list tmp;
220  tmp(1) = dim + 1;
221  tmp(0) = do_binary_op (octave_value::op_el_mul, argx, argy);
222  if (! error_state)
223  {
224  tmp = feval ("sum", tmp, 1);
225  if (! tmp.empty ())
226  retval = tmp(0);
227  }
228  }
229  }
230  }
231  else
232  error ("dot: sizes of X and Y must match");
233 
234  }
235  else
236  error ("dot: X and Y must be numeric");
237 
238  return retval;
239 }
240 
241 /*
242 %!assert (dot ([1, 2], [2, 3]), 8)
243 
244 %!test
245 %! x = [2, 1; 2, 1];
246 %! y = [-0.5, 2; 0.5, -2];
247 %! assert (dot (x, y), [0 0]);
248 
249 %!test
250 %! x = [1+i, 3-i; 1-i, 3-i];
251 %! assert (dot (x, x), [4, 20]);
252 
253 %!test
254 %! x = int8 ([1 2]);
255 %! y = int8 ([2 3]);
256 %! assert (dot (x, y), 8);
257 
258 %!test
259 %! x = int8 ([1 2; 3 4]);
260 %! y = int8 ([5 6; 7 8]);
261 %! assert (dot (x, y), [26 44]);
262 %! assert (dot (x, y, 2), [17; 53]);
263 %! assert (dot (x, y, 3), [5 12; 21 32]);
264 
265 */
266 
267 DEFUN (blkmm, args, ,
268  "-*- texinfo -*-\n\
269 @deftypefn {Built-in Function} {} blkmm (@var{A}, @var{B})\n\
270 Compute products of matrix blocks. The blocks are given as\n\
271 2-dimensional subarrays of the arrays @var{A}, @var{B}.\n\
272 The size of @var{A} must have the form @code{[m,k,@dots{}]} and\n\
273 size of @var{B} must be @code{[k,n,@dots{}]}. The result is\n\
274 then of size @code{[m,n,@dots{}]} and is computed as follows:\n\
275 \n\
276 @example\n\
277 @group\n\
278 for i = 1:prod (size (@var{A})(3:end))\n\
279  @var{C}(:,:,i) = @var{A}(:,:,i) * @var{B}(:,:,i)\n\
280 endfor\n\
281 @end group\n\
282 @end example\n\
283 @end deftypefn")
284 {
285  octave_value retval;
286  int nargin = args.length ();
287 
288  if (nargin != 2)
289  {
290  print_usage ();
291  return retval;
292  }
293 
294  octave_value argx = args(0), argy = args(1);
295 
296  if (argx.is_numeric_type () && argy.is_numeric_type ())
297  {
298  const dim_vector dimx = argx.dims (), dimy = argy.dims ();
299  int nd = dimx.length ();
300  octave_idx_type m = dimx(0), k = dimx(1), n = dimy(1), np = 1;
301  bool match = dimy(0) == k && nd == dimy.length ();
302  dim_vector dimz = dim_vector::alloc (nd);
303  dimz(0) = m;
304  dimz(1) = n;
305  for (int i = 2; match && i < nd; i++)
306  {
307  match = match && dimx(i) == dimy(i);
308  dimz(i) = dimx(i);
309  np *= dimz(i);
310  }
311 
312  if (match)
313  {
314  if (argx.is_complex_type () || argy.is_complex_type ())
315  {
316  if (argx.is_single_type () || argy.is_single_type ())
317  {
319  FloatComplexNDArray y = argy.float_complex_array_value ();
320  FloatComplexNDArray z(dimz);
321  if (! error_state)
322  F77_XFCN (cmatm3, CMATM3, (m, n, k, np,
323  x.data (), y.data (),
324  z.fortran_vec ()));
325  retval = z;
326  }
327  else
328  {
330  ComplexNDArray y = argy.complex_array_value ();
331  ComplexNDArray z(dimz);
332  if (! error_state)
333  F77_XFCN (zmatm3, ZMATM3, (m, n, k, np,
334  x.data (), y.data (),
335  z.fortran_vec ()));
336  retval = z;
337  }
338  }
339  else
340  {
341  if (argx.is_single_type () || argy.is_single_type ())
342  {
343  FloatNDArray x = argx.float_array_value ();
344  FloatNDArray y = argy.float_array_value ();
345  FloatNDArray z(dimz);
346  if (! error_state)
347  F77_XFCN (smatm3, SMATM3, (m, n, k, np,
348  x.data (), y.data (),
349  z.fortran_vec ()));
350  retval = z;
351  }
352  else
353  {
354  NDArray x = argx.array_value ();
355  NDArray y = argy.array_value ();
356  NDArray z(dimz);
357  if (! error_state)
358  F77_XFCN (dmatm3, DMATM3, (m, n, k, np,
359  x.data (), y.data (),
360  z.fortran_vec ()));
361  retval = z;
362  }
363  }
364  }
365  else
366  error ("blkmm: A and B dimensions don't match: (%s) and (%s)",
367  dimx.str ().c_str (), dimy.str ().c_str ());
368 
369  }
370  else
371  error ("blkmm: A and B must be numeric");
372 
373  return retval;
374 }
375 
376 /*
377 %!test
378 %! x(:,:,1) = [1 2; 3 4];
379 %! x(:,:,2) = [1 1; 1 1];
380 %! z(:,:,1) = [7 10; 15 22];
381 %! z(:,:,2) = [2 2; 2 2];
382 %! assert (blkmm (x,x), z);
383 */