GNU Octave  4.4.1
A high-level interpreted language, primarily intended for numerical computations, mostly compatible with Matlab
dot.cc
Go to the documentation of this file.
1 /*
2 
3 Copyright (C) 2009-2018 VZLU Prague
4 
5 This file is part of Octave.
6 
7 Octave is free software: you can redistribute it and/or modify it
8 under the terms of the GNU General Public License as published by
9 the Free Software Foundation, either version 3 of the License, or
10 (at your option) any later version.
11 
12 Octave is distributed in the hope that it will be useful, but
13 WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
16 
17 You should have received a copy of the GNU General Public License
18 along with Octave; see the file COPYING. If not, see
19 <https://www.gnu.org/licenses/>.
20 
21 */
22 
23 #if defined (HAVE_CONFIG_H)
24 # include "config.h"
25 #endif
26 
27 #include "lo-blas-proto.h"
28 #include "mx-base.h"
29 #include "error.h"
30 #include "defun.h"
31 #include "parse.h"
32 
33 static void
34 get_red_dims (const dim_vector& x, const dim_vector& y, int dim,
35  dim_vector& z, F77_INT& m, F77_INT& n, F77_INT& k)
36 {
37  int nd = x.ndims ();
38  assert (nd == y.ndims ());
39  z = dim_vector::alloc (nd);
40  octave_idx_type tmp_m = 1;
41  octave_idx_type tmp_n = 1;
42  octave_idx_type tmp_k = 1;
43  for (int i = 0; i < nd; i++)
44  {
45  if (i < dim)
46  {
47  z(i) = x(i);
48  tmp_m *= x(i);
49  }
50  else if (i > dim)
51  {
52  z(i) = x(i);
53  tmp_n *= x(i);
54  }
55  else
56  {
57  tmp_k = x(i);
58  z(i) = 1;
59  }
60  }
61 
62  m = octave::to_f77_int (tmp_m);
63  n = octave::to_f77_int (tmp_n);
64  k = octave::to_f77_int (tmp_k);
65 }
66 
67 DEFUN (dot, args, ,
68  doc: /* -*- texinfo -*-
69 @deftypefn {} {} dot (@var{x}, @var{y}, @var{dim})
70 Compute the dot product of two vectors.
71 
72 If @var{x} and @var{y} are matrices, calculate the dot products along the
73 first non-singleton dimension.
74 
75 If the optional argument @var{dim} is given, calculate the dot products
76 along this dimension.
77 
78 This is equivalent to
79 @code{sum (conj (@var{X}) .* @var{Y}, @var{dim})},
80 but avoids forming a temporary array and is faster. When @var{X} and
81 @var{Y} are column vectors, the result is equivalent to
82 @code{@var{X}' * @var{Y}}.
83 @seealso{cross, divergence}
84 @end deftypefn */)
85 {
86  int nargin = args.length ();
87 
89  print_usage ();
90 
92  octave_value argx = args(0);
93  octave_value argy = args(1);
94 
95  if (! argx.isnumeric () || ! argy.isnumeric ())
96  error ("dot: X and Y must be numeric");
97 
98  dim_vector dimx = argx.dims ();
99  dim_vector dimy = argy.dims ();
100  bool match = dimx == dimy;
101  if (! match && nargin == 2 && dimx.isvector () && dimy.isvector ())
102  {
103  // Change to column vectors.
104  dimx = dimx.redim (1);
105  argx = argx.reshape (dimx);
106  dimy = dimy.redim (1);
107  argy = argy.reshape (dimy);
108  match = dimx == dimy;
109  }
110 
111  if (! match)
112  error ("dot: sizes of X and Y must match");
113 
114  int dim;
115  if (nargin == 2)
116  dim = dimx.first_non_singleton ();
117  else
118  dim = args(2).int_value (true) - 1;
119 
120  if (dim < 0)
121  error ("dot: DIM must be a valid dimension");
122 
123  F77_INT m, n, k;
124  dim_vector dimz;
125  if (argx.iscomplex () || argy.iscomplex ())
126  {
127  if (argx.is_single_type () || argy.is_single_type ())
128  {
131  get_red_dims (dimx, dimy, dim, dimz, m, n, k);
132  FloatComplexNDArray z (dimz);
133 
134  F77_XFCN (cdotc3, CDOTC3, (m, n, k,
135  F77_CONST_CMPLX_ARG (x.data ()), F77_CONST_CMPLX_ARG (y.data ()),
136  F77_CMPLX_ARG (z.fortran_vec ())));
137  retval = z;
138  }
139  else
140  {
143  get_red_dims (dimx, dimy, dim, dimz, m, n, k);
144  ComplexNDArray z (dimz);
145 
146  F77_XFCN (zdotc3, ZDOTC3, (m, n, k,
149  retval = z;
150  }
151  }
152  else if (argx.isfloat () && argy.isfloat ())
153  {
154  if (argx.is_single_type () || argy.is_single_type ())
155  {
156  FloatNDArray x = argx.float_array_value ();
157  FloatNDArray y = argy.float_array_value ();
158  get_red_dims (dimx, dimy, dim, dimz, m, n, k);
159  FloatNDArray z (dimz);
160 
161  F77_XFCN (sdot3, SDOT3, (m, n, k, x.data (), y.data (),
162  z.fortran_vec ()));
163  retval = z;
164  }
165  else
166  {
167  NDArray x = argx.array_value ();
168  NDArray y = argy.array_value ();
169  get_red_dims (dimx, dimy, dim, dimz, m, n, k);
170  NDArray z (dimz);
171 
172  F77_XFCN (ddot3, DDOT3, (m, n, k, x.data (), y.data (),
173  z.fortran_vec ()));
174  retval = z;
175  }
176  }
177  else
178  {
179  // Non-optimized evaluation.
181  tmp(1) = dim + 1;
182  tmp(0) = do_binary_op (octave_value::op_el_mul, argx, argy);
183 
184  tmp = octave::feval ("sum", tmp, 1);
185  if (! tmp.empty ())
186  retval = tmp(0);
187  }
188 
189  return retval;
190 }
191 
192 /*
193 %!assert (dot ([1, 2], [2, 3]), 8)
194 
195 %!test
196 %! x = [2, 1; 2, 1];
197 %! y = [-0.5, 2; 0.5, -2];
198 %! assert (dot (x, y), [0 0]);
199 %! assert (dot (single (x), single (y)), single ([0 0]));
200 
201 %!test
202 %! x = [1+i, 3-i; 1-i, 3-i];
203 %! assert (dot (x, x), [4, 20]);
204 %! assert (dot (single (x), single (x)), single ([4, 20]));
205 
206 %!test
207 %! x = int8 ([1 2]);
208 %! y = int8 ([2 3]);
209 %! assert (dot (x, y), 8);
210 
211 %!test
212 %! x = int8 ([1 2; 3 4]);
213 %! y = int8 ([5 6; 7 8]);
214 %! assert (dot (x, y), [26 44]);
215 %! assert (dot (x, y, 2), [17; 53]);
216 %! assert (dot (x, y, 3), [5 12; 21 32]);
217 
218 ## Test input validation
219 %!error dot ()
220 %!error dot (1)
221 %!error dot (1,2,3,4)
222 %!error <X and Y must be numeric> dot ({1,2}, [3,4])
223 %!error <X and Y must be numeric> dot ([1,2], {3,4})
224 %!error <sizes of X and Y must match> dot ([1 2], [1 2 3])
225 %!error <sizes of X and Y must match> dot ([1 2]', [1 2 3]')
226 %!error <sizes of X and Y must match> dot (ones (2,2), ones (2,3))
227 %!error <DIM must be a valid dimension> dot ([1 2], [1 2], 0)
228 */
229 
230 template <typename T>
231 void
232 blkmm_internal (const T& x, const T& y, T& z,
233  F77_INT m, F77_INT n, F77_INT k, F77_INT np);
234 
235 template <>
236 void
239  F77_INT m, F77_INT n, F77_INT k, F77_INT np)
240 {
241  F77_XFCN (cmatm3, CMATM3, (m, n, k, np,
242  F77_CONST_CMPLX_ARG (x.data ()),
243  F77_CONST_CMPLX_ARG (y.data ()),
244  F77_CMPLX_ARG (z.fortran_vec ())));
245 }
246 
247 template <>
248 void
250  ComplexNDArray& z,
251  F77_INT m, F77_INT n, F77_INT k, F77_INT np)
252 {
253  F77_XFCN (zmatm3, ZMATM3, (m, n, k, np,
254  F77_CONST_DBLE_CMPLX_ARG (x.data ()),
255  F77_CONST_DBLE_CMPLX_ARG (y.data ()),
257 }
258 
259 template <>
260 void
262  F77_INT m, F77_INT n, F77_INT k, F77_INT np)
263 {
264  F77_XFCN (smatm3, SMATM3, (m, n, k, np,
265  x.data (), y.data (),
266  z.fortran_vec ()));
267 }
268 
269 template <>
270 void
271 blkmm_internal (const NDArray& x, const NDArray& y, NDArray& z,
272  F77_INT m, F77_INT n, F77_INT k, F77_INT np)
273 {
274  F77_XFCN (dmatm3, DMATM3, (m, n, k, np,
275  x.data (), y.data (),
276  z.fortran_vec ()));
277 }
278 
279 static void
280 get_blkmm_dims (const dim_vector& dimx, const dim_vector& dimy,
281  F77_INT& m, F77_INT& n, F77_INT& k, F77_INT& np,
282  dim_vector& dimz)
283 {
284  int nd = dimx.ndims ();
285 
286  m = octave::to_f77_int (dimx(0));
287  k = octave::to_f77_int (dimx(1));
288  n = octave::to_f77_int (dimy(1));
289 
290  octave_idx_type tmp_np = 1;
291 
292  bool match = dimy(0) == k && nd == dimy.ndims ();
293 
294  dimz = dim_vector::alloc (nd);
295 
296  dimz(0) = m;
297  dimz(1) = n;
298  for (int i = 2; match && i < nd; i++)
299  {
300  match = match && dimx(i) == dimy(i);
301  dimz(i) = dimx(i);
302  tmp_np *= dimz(i);
303  }
304 
305  np = octave::to_f77_int (tmp_np);
306 
307  if (! match)
308  error ("blkmm: A and B dimensions don't match: (%s) and (%s)",
309  dimx.str ().c_str (), dimy.str ().c_str ());
310 }
311 
312 template <typename T>
313 T
314 do_blkmm (const octave_value& xov, const octave_value& yov)
315 {
316  const T x = octave_value_extract<T> (xov);
317  const T y = octave_value_extract<T> (yov);
318  F77_INT m, n, k, np;
319  dim_vector dimz;
320 
321  get_blkmm_dims (x.dims (), y.dims (), m, n, k, np, dimz);
322 
323  T z (dimz);
324 
325  if (n != 0 && m != 0)
326  blkmm_internal<T> (x, y, z, m, n, k, np);
327 
328  return z;
329 }
330 
331 DEFUN (blkmm, args, ,
332  doc: /* -*- texinfo -*-
333 @deftypefn {} {} blkmm (@var{A}, @var{B})
334 Compute products of matrix blocks.
335 
336 The blocks are given as 2-dimensional subarrays of the arrays @var{A},
337 @var{B}. The size of @var{A} must have the form @code{[m,k,@dots{}]} and
338 size of @var{B} must be @code{[k,n,@dots{}]}. The result is then of size
339 @code{[m,n,@dots{}]} and is computed as follows:
340 
341 @example
342 @group
343 for i = 1:prod (size (@var{A})(3:end))
344  @var{C}(:,:,i) = @var{A}(:,:,i) * @var{B}(:,:,i)
345 endfor
346 @end group
347 @end example
348 @end deftypefn */)
349 {
350  if (args.length () != 2)
351  print_usage ();
352 
354 
355  octave_value argx = args(0);
356  octave_value argy = args(1);
357 
358  if (! argx.isnumeric () || ! argy.isnumeric ())
359  error ("blkmm: A and B must be numeric");
360 
361  if (argx.iscomplex () || argy.iscomplex ())
362  {
363  if (argx.is_single_type () || argy.is_single_type ())
364  retval = do_blkmm<FloatComplexNDArray> (argx, argy);
365  else
366  retval = do_blkmm<ComplexNDArray> (argx, argy);
367  }
368  else
369  {
370  if (argx.is_single_type () || argy.is_single_type ())
371  retval = do_blkmm<FloatNDArray> (argx, argy);
372  else
373  retval = do_blkmm<NDArray> (argx, argy);
374  }
375 
376  return retval;
377 }
378 
379 /*
380 %!test
381 %! x(:,:,1) = [1 2; 3 4];
382 %! x(:,:,2) = [1 1; 1 1];
383 %! z(:,:,1) = [7 10; 15 22];
384 %! z(:,:,2) = [2 2; 2 2];
385 %! assert (blkmm (x,x), z);
386 %! assert (blkmm (single (x), single (x)), single (z));
387 %! assert (blkmm (x, single (x)), single (z));
388 
389 %!test
390 %! x(:,:,1) = [1 2; 3 4];
391 %! x(:,:,2) = [1i 1i; 1i 1i];
392 %! z(:,:,1) = [7 10; 15 22];
393 %! z(:,:,2) = [-2 -2; -2 -2];
394 %! assert (blkmm (x,x), z);
395 %! assert (blkmm (single (x), single (x)), single (z));
396 %! assert (blkmm (x, single (x)), single (z));
397 
398 %!test <*54261>
399 %! x = ones (0, 3, 3);
400 %! y = ones (3, 5, 3);
401 %! z = blkmm (x,y);
402 %! assert (size (z), [0, 5, 3]);
403 %! x = ones (1, 3, 3);
404 %! y = ones (3, 0, 3);
405 %! z = blkmm (x,y);
406 %! assert (size (z), [1, 0, 3]);
407 
408 ## Test input validation
409 %!error blkmm ()
410 %!error blkmm (1)
411 %!error blkmm (1,2,3)
412 %!error <A and B must be numeric> blkmm ({1,2}, [3,4])
413 %!error <A and B must be numeric> blkmm ([3,4], {1,2})
414 %!error <A and B dimensions don't match> blkmm (ones (2,2), ones (3,3))
415 */
OCTINTERP_API octave_value_list feval(const std::string &name, const octave_value_list &args=octave_value_list(), int nargout=0)
std::string str(char sep='x') const
Definition: dim-vector.cc:73
static void get_red_dims(const dim_vector &x, const dim_vector &y, int dim, dim_vector &z, F77_INT &m, F77_INT &n, F77_INT &k)
Definition: dot.cc:34
subroutine cmatm3(m, n, k, np, a, b, c)
Definition: cmatm3.f:22
subroutine cdotc3(m, n, k, a, b, c)
Definition: cdotc3.f:22
OCTINTERP_API void print_usage(void)
Definition: defun.cc:54
#define F77_DBLE_CMPLX_ARG(x)
Definition: f77-fcn.h:315
for large enough k
Definition: lu.cc:617
const T * fortran_vec(void) const
Definition: Array.h:584
#define DEFUN(name, args_name, nargout_name, doc)
Macro to define a builtin function.
Definition: defun.h:53
void error(const char *fmt,...)
Definition: error.cc:578
subroutine smatm3(m, n, k, np, a, b, c)
Definition: smatm3.f:22
FloatNDArray float_array_value(bool frc_str_conv=false) const
Definition: ov.h:843
#define F77_XFCN(f, F, args)
Definition: f77-fcn.h:41
static void get_blkmm_dims(const dim_vector &dimx, const dim_vector &dimy, F77_INT &m, F77_INT &n, F77_INT &k, F77_INT &np, dim_vector &dimz)
Definition: dot.cc:280
subroutine dmatm3(m, n, k, np, a, b, c)
Definition: dmatm3.f:22
OCTINTERP_API octave_value do_binary_op(octave::type_info &ti, octave_value::binary_op op, const octave_value &a, const octave_value &b)
bool is_single_type(void) const
Definition: ov.h:651
int first_non_singleton(int def=0) const
Definition: dim-vector.h:475
dim_vector dims(void) const
Definition: ov.h:469
double tmp
Definition: data.cc:6252
octave_value retval
Definition: data.cc:6246
subroutine zmatm3(m, n, k, np, a, b, c)
Definition: zmatm3.f:22
static dim_vector alloc(int n)
Definition: dim-vector.h:264
subroutine sdot3(m, n, k, a, b, c)
Definition: sdot3.f:22
dim_vector redim(int n) const
Force certain dimensionality, preserving numel ().
Definition: dim-vector.cc:233
bool isfloat(void) const
Definition: ov.h:654
FloatComplexNDArray float_complex_array_value(bool frc_str_conv=false) const
Definition: ov.h:863
octave_value reshape(const dim_vector &dv) const
Definition: ov.h:502
#define F77_CMPLX_ARG(x)
Definition: f77-fcn.h:309
#define F77_CONST_DBLE_CMPLX_ARG(x)
Definition: f77-fcn.h:318
T do_blkmm(const octave_value &xov, const octave_value &yov)
Definition: dot.cc:314
double dot(const ColumnVector &v1, const ColumnVector &v2)
Definition: graphics.cc:5483
octave_f77_int_type F77_INT
Definition: f77-fcn.h:305
void blkmm_internal(const FloatComplexNDArray &x, const FloatComplexNDArray &y, FloatComplexNDArray &z, F77_INT m, F77_INT n, F77_INT k, F77_INT np)
Definition: dot.cc:237
bool isvector(void) const
Definition: dim-vector.h:422
the element is set to zero In other the statement xample y
Definition: data.cc:5264
args.length() nargin
Definition: file-io.cc:589
#define F77_CONST_CMPLX_ARG(x)
Definition: f77-fcn.h:312
bool iscomplex(void) const
Definition: ov.h:710
for i
Definition: data.cc:5264
octave_idx_type ndims(void) const
Number of dimensions.
Definition: dim-vector.h:295
ComplexNDArray complex_array_value(bool frc_str_conv=false) const
Definition: ov.h:859
Vector representing the dimensions (size) of an Array.
Definition: dim-vector.h:87
subroutine ddot3(m, n, k, a, b, c)
Definition: ddot3.f:22
NDArray array_value(bool frc_str_conv=false) const
Definition: ov.h:840
bool isnumeric(void) const
Definition: ov.h:723
F77_RET_T const F77_REAL const F77_REAL F77_REAL &F77_RET_T const F77_DBLE const F77_DBLE F77_DBLE &F77_RET_T const F77_DBLE F77_DBLE &F77_RET_T const F77_REAL F77_REAL &F77_RET_T const F77_DBLE * x
subroutine zdotc3(m, n, k, a, b, c)
Definition: zdotc3.f:22