GNU Octave  4.0.0
A high-level interpreted language, primarily intended for numerical computations, mostly compatible with Matlab
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Pages
dot.cc
Go to the documentation of this file.
1 /*
2 
3 Copyright (C) 2009-2015 VZLU Prague
4 
5 This file is part of Octave.
6 
7 Octave is free software; you can redistribute it and/or modify it
8 under the terms of the GNU General Public License as published by the
9 Free Software Foundation; either version 3 of the License, or (at your
10 option) any later version.
11 
12 Octave is distributed in the hope that it will be useful, but WITHOUT
13 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
14 FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
15 for more details.
16 
17 You should have received a copy of the GNU General Public License
18 along with Octave; see the file COPYING. If not, see
19 <http://www.gnu.org/licenses/>.
20 
21 */
22 
23 #ifdef HAVE_CONFIG_H
24 #include <config.h>
25 #endif
26 
27 #include "f77-fcn.h"
28 #include "mx-base.h"
29 #include "error.h"
30 #include "defun.h"
31 #include "parse.h"
32 
33 extern "C"
34 {
35  F77_RET_T
36  F77_FUNC (ddot3, DDOT3) (const octave_idx_type&, const octave_idx_type&,
37  const octave_idx_type&, const double*,
38  const double*, double*);
39 
40  F77_RET_T
41  F77_FUNC (sdot3, SDOT3) (const octave_idx_type&, const octave_idx_type&,
42  const octave_idx_type&, const float*,
43  const float*, float*);
44 
45  F77_RET_T
46  F77_FUNC (zdotc3, ZDOTC3) (const octave_idx_type&, const octave_idx_type&,
47  const octave_idx_type&, const Complex*,
48  const Complex*, Complex*);
49 
50  F77_RET_T
51  F77_FUNC (cdotc3, CDOTC3) (const octave_idx_type&, const octave_idx_type&,
52  const octave_idx_type&, const FloatComplex*,
53  const FloatComplex*, FloatComplex*);
54 
55  F77_RET_T
56  F77_FUNC (dmatm3, DMATM3) (const octave_idx_type&, const octave_idx_type&,
57  const octave_idx_type&, const octave_idx_type&,
58  const double*, const double*, double*);
59 
60  F77_RET_T
61  F77_FUNC (smatm3, SMATM3) (const octave_idx_type&, const octave_idx_type&,
62  const octave_idx_type&, const octave_idx_type&,
63  const float*, const float*, float*);
64 
65  F77_RET_T
66  F77_FUNC (zmatm3, ZMATM3) (const octave_idx_type&, const octave_idx_type&,
67  const octave_idx_type&, const octave_idx_type&,
68  const Complex*, const Complex*, Complex*);
69 
70  F77_RET_T
71  F77_FUNC (cmatm3, CMATM3) (const octave_idx_type&, const octave_idx_type&,
72  const octave_idx_type&, const octave_idx_type&,
73  const FloatComplex*, const FloatComplex*,
74  FloatComplex*);
75 }
76 
77 static void
78 get_red_dims (const dim_vector& x, const dim_vector& y, int dim,
80  octave_idx_type& k)
81 {
82  int nd = x.length ();
83  assert (nd == y.length ());
84  z = dim_vector::alloc (nd);
85  m = 1, n = 1, k = 1;
86  for (int i = 0; i < nd; i++)
87  {
88  if (i < dim)
89  {
90  z(i) = x(i);
91  m *= x(i);
92  }
93  else if (i > dim)
94  {
95  z(i) = x(i);
96  n *= x(i);
97  }
98  else
99  {
100  k = x(i);
101  z(i) = 1;
102  }
103  }
104 }
105 
106 DEFUN (dot, args, ,
107  "-*- texinfo -*-\n\
108 @deftypefn {Built-in Function} {} dot (@var{x}, @var{y}, @var{dim})\n\
109 Compute the dot product of two vectors.\n\
110 \n\
111 If @var{x} and @var{y} are matrices, calculate the dot products along the\n\
112 first non-singleton dimension.\n\
113 \n\
114 If the optional argument @var{dim} is given, calculate the dot products\n\
115 along this dimension.\n\
116 \n\
117 This is equivalent to\n\
118 @code{sum (conj (@var{X}) .* @var{Y}, @var{dim})},\n\
119 but avoids forming a temporary array and is faster. When @var{X} and\n\
120 @var{Y} are column vectors, the result is equivalent to\n\
121 @code{@var{X}' * @var{Y}}.\n\
122 @seealso{cross, divergence}\n\
123 @end deftypefn")
124 {
125  octave_value retval;
126  int nargin = args.length ();
127 
128  if (nargin < 2 || nargin > 3)
129  {
130  print_usage ();
131  return retval;
132  }
133 
134  octave_value argx = args(0);
135  octave_value argy = args(1);
136 
137  if (argx.is_numeric_type () && argy.is_numeric_type ())
138  {
139  dim_vector dimx = argx.dims ();
140  dim_vector dimy = argy.dims ();
141  bool match = dimx == dimy;
142  if (! match && nargin == 2
143  && dimx.is_vector () && dimy.is_vector ())
144  {
145  // Change to column vectors.
146  dimx = dimx.redim (1);
147  argx = argx.reshape (dimx);
148  dimy = dimy.redim (1);
149  argy = argy.reshape (dimy);
150  match = ! error_state && (dimx == dimy);
151  }
152 
153  if (match)
154  {
155  int dim;
156  if (nargin == 2)
157  dim = dimx.first_non_singleton ();
158  else
159  dim = args(2).int_value (true) - 1;
160 
161  if (error_state)
162  ;
163  else if (dim < 0)
164  error ("dot: DIM must be a valid dimension");
165  else
166  {
167  octave_idx_type m, n, k;
168  dim_vector dimz;
169  if (argx.is_complex_type () || argy.is_complex_type ())
170  {
171  if (argx.is_single_type () || argy.is_single_type ())
172  {
175  get_red_dims (dimx, dimy, dim, dimz, m, n, k);
176  FloatComplexNDArray z(dimz);
177  if (! error_state)
178  F77_XFCN (cdotc3, CDOTC3, (m, n, k,
179  x.data (), y.data (),
180  z.fortran_vec ()));
181  retval = z;
182  }
183  else
184  {
187  get_red_dims (dimx, dimy, dim, dimz, m, n, k);
188  ComplexNDArray z(dimz);
189  if (! error_state)
190  F77_XFCN (zdotc3, ZDOTC3, (m, n, k,
191  x.data (), y.data (),
192  z.fortran_vec ()));
193  retval = z;
194  }
195  }
196  else if (argx.is_float_type () && argy.is_float_type ())
197  {
198  if (argx.is_single_type () || argy.is_single_type ())
199  {
200  FloatNDArray x = argx.float_array_value ();
201  FloatNDArray y = argy.float_array_value ();
202  get_red_dims (dimx, dimy, dim, dimz, m, n, k);
203  FloatNDArray z(dimz);
204  if (! error_state)
205  F77_XFCN (sdot3, SDOT3, (m, n, k, x.data (), y.data (),
206  z.fortran_vec ()));
207  retval = z;
208  }
209  else
210  {
211  NDArray x = argx.array_value ();
212  NDArray y = argy.array_value ();
213  get_red_dims (dimx, dimy, dim, dimz, m, n, k);
214  NDArray z(dimz);
215  if (! error_state)
216  F77_XFCN (ddot3, DDOT3, (m, n, k, x.data (), y.data (),
217  z.fortran_vec ()));
218  retval = z;
219  }
220  }
221  else
222  {
223  // Non-optimized evaluation.
224  octave_value_list tmp;
225  tmp(1) = dim + 1;
226  tmp(0) = do_binary_op (octave_value::op_el_mul, argx, argy);
227  if (! error_state)
228  {
229  tmp = feval ("sum", tmp, 1);
230  if (! tmp.empty ())
231  retval = tmp(0);
232  }
233  }
234  }
235  }
236  else
237  error ("dot: sizes of X and Y must match");
238 
239  }
240  else
241  error ("dot: X and Y must be numeric");
242 
243  return retval;
244 }
245 
246 /*
247 %!assert (dot ([1, 2], [2, 3]), 8)
248 
249 %!test
250 %! x = [2, 1; 2, 1];
251 %! y = [-0.5, 2; 0.5, -2];
252 %! assert (dot (x, y), [0 0]);
253 %! assert (dot (single (x), single (y)), single ([0 0]));
254 
255 %!test
256 %! x = [1+i, 3-i; 1-i, 3-i];
257 %! assert (dot (x, x), [4, 20]);
258 %! assert (dot (single (x), single (x)), single ([4, 20]));
259 
260 %!test
261 %! x = int8 ([1 2]);
262 %! y = int8 ([2 3]);
263 %! assert (dot (x, y), 8);
264 
265 %!test
266 %! x = int8 ([1 2; 3 4]);
267 %! y = int8 ([5 6; 7 8]);
268 %! assert (dot (x, y), [26 44]);
269 %! assert (dot (x, y, 2), [17; 53]);
270 %! assert (dot (x, y, 3), [5 12; 21 32]);
271 
272 %% Test input validation
273 %!error dot ()
274 %!error dot (1)
275 %!error dot (1,2,3,4)
276 %!error <X and Y must be numeric> dot ({1,2}, [3,4])
277 %!error <X and Y must be numeric> dot ([1,2], {3,4})
278 %!error <sizes of X and Y must match> dot ([1 2], [1 2 3])
279 %!error <sizes of X and Y must match> dot ([1 2]', [1 2 3]')
280 %!error <sizes of X and Y must match> dot (ones (2,2), ones (2,3))
281 %!error <DIM must be a valid dimension> dot ([1 2], [1 2], 0)
282 */
283 
284 DEFUN (blkmm, args, ,
285  "-*- texinfo -*-\n\
286 @deftypefn {Built-in Function} {} blkmm (@var{A}, @var{B})\n\
287 Compute products of matrix blocks.\n\
288 \n\
289 The blocks are given as 2-dimensional subarrays of the arrays @var{A},\n\
290 @var{B}. The size of @var{A} must have the form @code{[m,k,@dots{}]} and\n\
291 size of @var{B} must be @code{[k,n,@dots{}]}. The result is then of size\n\
292 @code{[m,n,@dots{}]} and is computed as follows:\n\
293 \n\
294 @example\n\
295 @group\n\
296 for i = 1:prod (size (@var{A})(3:end))\n\
297  @var{C}(:,:,i) = @var{A}(:,:,i) * @var{B}(:,:,i)\n\
298 endfor\n\
299 @end group\n\
300 @end example\n\
301 @end deftypefn")
302 {
303  octave_value retval;
304  int nargin = args.length ();
305 
306  if (nargin != 2)
307  {
308  print_usage ();
309  return retval;
310  }
311 
312  octave_value argx = args(0);
313  octave_value argy = args(1);
314 
315  if (argx.is_numeric_type () && argy.is_numeric_type ())
316  {
317  const dim_vector dimx = argx.dims ();
318  const dim_vector dimy = argy.dims ();
319  int nd = dimx.length ();
320  octave_idx_type m = dimx(0);
321  octave_idx_type k = dimx(1);
322  octave_idx_type n = dimy(1);
323  octave_idx_type np = 1;
324  bool match = dimy(0) == k && nd == dimy.length ();
325  dim_vector dimz = dim_vector::alloc (nd);
326  dimz(0) = m;
327  dimz(1) = n;
328  for (int i = 2; match && i < nd; i++)
329  {
330  match = match && dimx(i) == dimy(i);
331  dimz(i) = dimx(i);
332  np *= dimz(i);
333  }
334 
335  if (match)
336  {
337  if (argx.is_complex_type () || argy.is_complex_type ())
338  {
339  if (argx.is_single_type () || argy.is_single_type ())
340  {
343  FloatComplexNDArray z(dimz);
344  if (! error_state)
345  F77_XFCN (cmatm3, CMATM3, (m, n, k, np,
346  x.data (), y.data (),
347  z.fortran_vec ()));
348  retval = z;
349  }
350  else
351  {
354  ComplexNDArray z(dimz);
355  if (! error_state)
356  F77_XFCN (zmatm3, ZMATM3, (m, n, k, np,
357  x.data (), y.data (),
358  z.fortran_vec ()));
359  retval = z;
360  }
361  }
362  else
363  {
364  if (argx.is_single_type () || argy.is_single_type ())
365  {
366  FloatNDArray x = argx.float_array_value ();
367  FloatNDArray y = argy.float_array_value ();
368  FloatNDArray z(dimz);
369  if (! error_state)
370  F77_XFCN (smatm3, SMATM3, (m, n, k, np,
371  x.data (), y.data (),
372  z.fortran_vec ()));
373  retval = z;
374  }
375  else
376  {
377  NDArray x = argx.array_value ();
378  NDArray y = argy.array_value ();
379  NDArray z(dimz);
380  if (! error_state)
381  F77_XFCN (dmatm3, DMATM3, (m, n, k, np,
382  x.data (), y.data (),
383  z.fortran_vec ()));
384  retval = z;
385  }
386  }
387  }
388  else
389  error ("blkmm: A and B dimensions don't match: (%s) and (%s)",
390  dimx.str ().c_str (), dimy.str ().c_str ());
391 
392  }
393  else
394  error ("blkmm: A and B must be numeric");
395 
396  return retval;
397 }
398 
399 /*
400 %!test
401 %! x(:,:,1) = [1 2; 3 4];
402 %! x(:,:,2) = [1 1; 1 1];
403 %! z(:,:,1) = [7 10; 15 22];
404 %! z(:,:,2) = [2 2; 2 2];
405 %! assert (blkmm (x,x), z);
406 %! assert (blkmm (single (x), single (x)), single (z));
407 %! assert (blkmm (x, single (x)), single (z));
408 
409 %!test
410 %! x(:,:,1) = [1 2; 3 4];
411 %! x(:,:,2) = [1i 1i; 1i 1i];
412 %! z(:,:,1) = [7 10; 15 22];
413 %! z(:,:,2) = [-2 -2; -2 -2];
414 %! assert (blkmm (x,x), z);
415 %! assert (blkmm (single (x), single (x)), single (z));
416 %! assert (blkmm (x, single (x)), single (z));
417 
418 %% Test input validation
419 %!error blkmm ()
420 %!error blkmm (1)
421 %!error blkmm (1,2,3)
422 %!error <A and B dimensions don't match> blkmm (ones (2,2), ones (3,3))
423 %!error <A and B must be numeric> blkmm ({1,2}, [3,4])
424 %!error <A and B must be numeric> blkmm ([3,4], {1,2})
425 */
426 
ComplexNDArray complex_array_value(bool frc_str_conv=false) const
Definition: ov.h:798
std::string str(char sep= 'x') const
Definition: dim-vector.cc:63
subroutine cmatm3(m, n, k, np, a, b, c)
Definition: cmatm3.f:21
octave_value reshape(const dim_vector &dv) const
Definition: ov.h:498
bool is_vector(void) const
Definition: dim-vector.h:430
subroutine cdotc3(m, n, k, a, b, c)
Definition: cdotc3.f:21
OCTINTERP_API void print_usage(void)
Definition: defun.cc:51
bool is_numeric_type(void) const
Definition: ov.h:663
#define DEFUN(name, args_name, nargout_name, doc)
Definition: defun.h:44
void error(const char *fmt,...)
Definition: error.cc:476
subroutine smatm3(m, n, k, np, a, b, c)
Definition: smatm3.f:21
octave_value_list feval(const std::string &name, const octave_value_list &args, int nargout)
Definition: oct-parse.cc:8625
F77_RET_T const octave_idx_type const octave_idx_type const double const double double *F77_RET_T const octave_idx_type const octave_idx_type const float const float float *F77_RET_T const octave_idx_type const octave_idx_type const Complex const Complex Complex *F77_RET_T const octave_idx_type const octave_idx_type const FloatComplex const FloatComplex FloatComplex *F77_RET_T const octave_idx_type const octave_idx_type const octave_idx_type const double const double double *F77_RET_T const octave_idx_type const octave_idx_type const octave_idx_type const float const float float *F77_RET_T const octave_idx_type const octave_idx_type const octave_idx_type const Complex const Complex Complex *F77_RET_T const octave_idx_type const octave_idx_type const octave_idx_type const FloatComplex const FloatComplex FloatComplex *static void get_red_dims(const dim_vector &x, const dim_vector &y, int dim, dim_vector &z, octave_idx_type &m, octave_idx_type &n, octave_idx_type &k)
Definition: dot.cc:78
F77_RET_T F77_FUNC(ddot3, DDOT3)(const octave_idx_type &
#define F77_XFCN(f, F, args)
Definition: f77-fcn.h:51
int first_non_singleton(int def=0) const
Definition: dim-vector.h:435
bool is_float_type(void) const
Definition: ov.h:614
FloatNDArray float_array_value(bool frc_str_conv=false) const
Definition: ov.h:782
subroutine dmatm3(m, n, k, np, a, b, c)
Definition: dmatm3.f:21
FloatComplexNDArray float_complex_array_value(bool frc_str_conv=false) const
Definition: ov.h:802
const T * data(void) const
Definition: Array.h:479
int error_state
Definition: error.cc:101
bool is_complex_type(void) const
Definition: ov.h:654
dim_vector redim(int n) const
Definition: dim-vector.cc:266
octave_idx_type length(void) const
Definition: ov.cc:1525
#define F77_RET_T
Definition: f77-fcn.h:264
subroutine zmatm3(m, n, k, np, a, b, c)
Definition: zmatm3.f:21
static dim_vector alloc(int n)
Definition: dim-vector.h:256
dim_vector dims(void) const
Definition: ov.h:470
subroutine sdot3(m, n, k, a, b, c)
Definition: sdot3.f:21
bool empty(void) const
Definition: oct-obj.h:91
NDArray array_value(bool frc_str_conv=false) const
Definition: ov.h:779
double dot(const ColumnVector &v1, const ColumnVector &v2)
Definition: graphics.cc:5322
std::complex< float > FloatComplex
Definition: oct-cmplx.h:30
static bool match(const std::string &filename_arg, const std::string &path_elt_arg)
Definition: kpse.cc:1738
std::complex< double > Complex
Definition: oct-cmplx.h:29
const T * fortran_vec(void) const
Definition: Array.h:481
bool is_single_type(void) const
Definition: ov.h:611
subroutine ddot3(m, n, k, a, b, c)
Definition: ddot3.f:21
int length(void) const
Definition: dim-vector.h:281
F77_RET_T const double * x
subroutine zdotc3(m, n, k, a, b, c)
Definition: zdotc3.f:21
octave_value do_binary_op(octave_value::binary_op op, const octave_value &v1, const octave_value &v2)
Definition: ov.cc:1978