GNU Octave  4.2.1
A high-level interpreted language, primarily intended for numerical computations, mostly compatible with Matlab
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Pages
dot.cc
Go to the documentation of this file.
1 /*
2 
3 Copyright (C) 2009-2017 VZLU Prague
4 
5 This file is part of Octave.
6 
7 Octave is free software; you can redistribute it and/or modify it
8 under the terms of the GNU General Public License as published by the
9 Free Software Foundation; either version 3 of the License, or (at your
10 option) any later version.
11 
12 Octave is distributed in the hope that it will be useful, but WITHOUT
13 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
14 FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
15 for more details.
16 
17 You should have received a copy of the GNU General Public License
18 along with Octave; see the file COPYING. If not, see
19 <http://www.gnu.org/licenses/>.
20 
21 */
22 
23 #if defined (HAVE_CONFIG_H)
24 # include "config.h"
25 #endif
26 
27 #include "lo-blas-proto.h"
28 #include "mx-base.h"
29 #include "error.h"
30 #include "defun.h"
31 #include "parse.h"
32 
33 static void
34 get_red_dims (const dim_vector& x, const dim_vector& y, int dim,
37 {
38  int nd = x.ndims ();
39  assert (nd == y.ndims ());
40  z = dim_vector::alloc (nd);
41  m = 1, n = 1, k = 1;
42  for (int i = 0; i < nd; i++)
43  {
44  if (i < dim)
45  {
46  z(i) = x(i);
47  m *= x(i);
48  }
49  else if (i > dim)
50  {
51  z(i) = x(i);
52  n *= x(i);
53  }
54  else
55  {
56  k = x(i);
57  z(i) = 1;
58  }
59  }
60 }
61 
62 DEFUN (dot, args, ,
63  doc: /* -*- texinfo -*-
64 @deftypefn {} {} dot (@var{x}, @var{y}, @var{dim})
65 Compute the dot product of two vectors.
66 
67 If @var{x} and @var{y} are matrices, calculate the dot products along the
68 first non-singleton dimension.
69 
70 If the optional argument @var{dim} is given, calculate the dot products
71 along this dimension.
72 
73 This is equivalent to
74 @code{sum (conj (@var{X}) .* @var{Y}, @var{dim})},
75 but avoids forming a temporary array and is faster. When @var{X} and
76 @var{Y} are column vectors, the result is equivalent to
77 @code{@var{X}' * @var{Y}}.
78 @seealso{cross, divergence}
79 @end deftypefn */)
80 {
81  int nargin = args.length ();
82 
83  if (nargin < 2 || nargin > 3)
84  print_usage ();
85 
87  octave_value argx = args(0);
88  octave_value argy = args(1);
89 
90  if (! argx.is_numeric_type () || ! argy.is_numeric_type ())
91  error ("dot: X and Y must be numeric");
92 
93  dim_vector dimx = argx.dims ();
94  dim_vector dimy = argy.dims ();
95  bool match = dimx == dimy;
96  if (! match && nargin == 2 && dimx.is_vector () && dimy.is_vector ())
97  {
98  // Change to column vectors.
99  dimx = dimx.redim (1);
100  argx = argx.reshape (dimx);
101  dimy = dimy.redim (1);
102  argy = argy.reshape (dimy);
103  match = dimx == dimy;
104  }
105 
106  if (! match)
107  error ("dot: sizes of X and Y must match");
108 
109  int dim;
110  if (nargin == 2)
111  dim = dimx.first_non_singleton ();
112  else
113  dim = args(2).int_value (true) - 1;
114 
115  if (dim < 0)
116  error ("dot: DIM must be a valid dimension");
117 
118  octave_idx_type m, n, k;
119  dim_vector dimz;
120  if (argx.is_complex_type () || argy.is_complex_type ())
121  {
122  if (argx.is_single_type () || argy.is_single_type ())
123  {
126  get_red_dims (dimx, dimy, dim, dimz, m, n, k);
127  FloatComplexNDArray z (dimz);
128 
129  F77_XFCN (cdotc3, CDOTC3, (m, n, k,
131  F77_CMPLX_ARG (z.fortran_vec ())));
132  retval = z;
133  }
134  else
135  {
138  get_red_dims (dimx, dimy, dim, dimz, m, n, k);
139  ComplexNDArray z (dimz);
140 
141  F77_XFCN (zdotc3, ZDOTC3, (m, n, k,
144  retval = z;
145  }
146  }
147  else if (argx.is_float_type () && argy.is_float_type ())
148  {
149  if (argx.is_single_type () || argy.is_single_type ())
150  {
151  FloatNDArray x = argx.float_array_value ();
152  FloatNDArray y = argy.float_array_value ();
153  get_red_dims (dimx, dimy, dim, dimz, m, n, k);
154  FloatNDArray z (dimz);
155 
156  F77_XFCN (sdot3, SDOT3, (m, n, k, x.data (), y.data (),
157  z.fortran_vec ()));
158  retval = z;
159  }
160  else
161  {
162  NDArray x = argx.array_value ();
163  NDArray y = argy.array_value ();
164  get_red_dims (dimx, dimy, dim, dimz, m, n, k);
165  NDArray z (dimz);
166 
167  F77_XFCN (ddot3, DDOT3, (m, n, k, x.data (), y.data (),
168  z.fortran_vec ()));
169  retval = z;
170  }
171  }
172  else
173  {
174  // Non-optimized evaluation.
176  tmp(1) = dim + 1;
177  tmp(0) = do_binary_op (octave_value::op_el_mul, argx, argy);
178 
179  tmp = feval ("sum", tmp, 1);
180  if (! tmp.empty ())
181  retval = tmp(0);
182  }
183 
184  return retval;
185 }
186 
187 /*
188 %!assert (dot ([1, 2], [2, 3]), 8)
189 
190 %!test
191 %! x = [2, 1; 2, 1];
192 %! y = [-0.5, 2; 0.5, -2];
193 %! assert (dot (x, y), [0 0]);
194 %! assert (dot (single (x), single (y)), single ([0 0]));
195 
196 %!test
197 %! x = [1+i, 3-i; 1-i, 3-i];
198 %! assert (dot (x, x), [4, 20]);
199 %! assert (dot (single (x), single (x)), single ([4, 20]));
200 
201 %!test
202 %! x = int8 ([1 2]);
203 %! y = int8 ([2 3]);
204 %! assert (dot (x, y), 8);
205 
206 %!test
207 %! x = int8 ([1 2; 3 4]);
208 %! y = int8 ([5 6; 7 8]);
209 %! assert (dot (x, y), [26 44]);
210 %! assert (dot (x, y, 2), [17; 53]);
211 %! assert (dot (x, y, 3), [5 12; 21 32]);
212 
213 %% Test input validation
214 %!error dot ()
215 %!error dot (1)
216 %!error dot (1,2,3,4)
217 %!error <X and Y must be numeric> dot ({1,2}, [3,4])
218 %!error <X and Y must be numeric> dot ([1,2], {3,4})
219 %!error <sizes of X and Y must match> dot ([1 2], [1 2 3])
220 %!error <sizes of X and Y must match> dot ([1 2]', [1 2 3]')
221 %!error <sizes of X and Y must match> dot (ones (2,2), ones (2,3))
222 %!error <DIM must be a valid dimension> dot ([1 2], [1 2], 0)
223 */
224 
225 DEFUN (blkmm, args, ,
226  doc: /* -*- texinfo -*-
227 @deftypefn {} {} blkmm (@var{A}, @var{B})
228 Compute products of matrix blocks.
229 
230 The blocks are given as 2-dimensional subarrays of the arrays @var{A},
231 @var{B}. The size of @var{A} must have the form @code{[m,k,@dots{}]} and
232 size of @var{B} must be @code{[k,n,@dots{}]}. The result is then of size
233 @code{[m,n,@dots{}]} and is computed as follows:
234 
235 @example
236 @group
237 for i = 1:prod (size (@var{A})(3:end))
238  @var{C}(:,:,i) = @var{A}(:,:,i) * @var{B}(:,:,i)
239 endfor
240 @end group
241 @end example
242 @end deftypefn */)
243 {
244  if (args.length () != 2)
245  print_usage ();
246 
248 
249  octave_value argx = args(0);
250  octave_value argy = args(1);
251 
252  if (! argx.is_numeric_type () || ! argy.is_numeric_type ())
253  error ("blkmm: A and B must be numeric");
254 
255  const dim_vector dimx = argx.dims ();
256  const dim_vector dimy = argy.dims ();
257  int nd = dimx.ndims ();
258  octave_idx_type m = dimx(0);
259  octave_idx_type k = dimx(1);
260  octave_idx_type n = dimy(1);
261  octave_idx_type np = 1;
262  bool match = dimy(0) == k && nd == dimy.ndims ();
263  dim_vector dimz = dim_vector::alloc (nd);
264  dimz(0) = m;
265  dimz(1) = n;
266  for (int i = 2; match && i < nd; i++)
267  {
268  match = match && dimx(i) == dimy(i);
269  dimz(i) = dimx(i);
270  np *= dimz(i);
271  }
272 
273  if (! match)
274  error ("blkmm: A and B dimensions don't match: (%s) and (%s)",
275  dimx.str ().c_str (), dimy.str ().c_str ());
276 
277  if (argx.is_complex_type () || argy.is_complex_type ())
278  {
279  if (argx.is_single_type () || argy.is_single_type ())
280  {
283  FloatComplexNDArray z (dimz);
284 
285  F77_XFCN (cmatm3, CMATM3, (m, n, k, np,
287  F77_CMPLX_ARG (z.fortran_vec ())));
288  retval = z;
289  }
290  else
291  {
294  ComplexNDArray z (dimz);
295 
296  F77_XFCN (zmatm3, ZMATM3, (m, n, k, np,
298  F77_DBLE_CMPLX_ARG (z.fortran_vec ())));
299  retval = z;
300  }
301  }
302  else
303  {
304  if (argx.is_single_type () || argy.is_single_type ())
305  {
306  FloatNDArray x = argx.float_array_value ();
307  FloatNDArray y = argy.float_array_value ();
308  FloatNDArray z (dimz);
309 
310  F77_XFCN (smatm3, SMATM3, (m, n, k, np,
311  x.data (), y.data (),
312  z.fortran_vec ()));
313  retval = z;
314  }
315  else
316  {
317  NDArray x = argx.array_value ();
318  NDArray y = argy.array_value ();
319  NDArray z (dimz);
320 
321  F77_XFCN (dmatm3, DMATM3, (m, n, k, np,
322  x.data (), y.data (),
323  z.fortran_vec ()));
324  retval = z;
325  }
326  }
327 
328  return retval;
329 }
330 
331 /*
332 %!test
333 %! x(:,:,1) = [1 2; 3 4];
334 %! x(:,:,2) = [1 1; 1 1];
335 %! z(:,:,1) = [7 10; 15 22];
336 %! z(:,:,2) = [2 2; 2 2];
337 %! assert (blkmm (x,x), z);
338 %! assert (blkmm (single (x), single (x)), single (z));
339 %! assert (blkmm (x, single (x)), single (z));
340 
341 %!test
342 %! x(:,:,1) = [1 2; 3 4];
343 %! x(:,:,2) = [1i 1i; 1i 1i];
344 %! z(:,:,1) = [7 10; 15 22];
345 %! z(:,:,2) = [-2 -2; -2 -2];
346 %! assert (blkmm (x,x), z);
347 %! assert (blkmm (single (x), single (x)), single (z));
348 %! assert (blkmm (x, single (x)), single (z));
349 
350 %% Test input validation
351 %!error blkmm ()
352 %!error blkmm (1)
353 %!error blkmm (1,2,3)
354 %!error <A and B dimensions don't match> blkmm (ones (2,2), ones (3,3))
355 %!error <A and B must be numeric> blkmm ({1,2}, [3,4])
356 %!error <A and B must be numeric> blkmm ([3,4], {1,2})
357 */
ComplexNDArray complex_array_value(bool frc_str_conv=false) const
Definition: ov.h:812
std::string str(char sep= 'x') const
Definition: dim-vector.cc:73
subroutine cmatm3(m, n, k, np, a, b, c)
Definition: cmatm3.f:21
octave_value reshape(const dim_vector &dv) const
Definition: ov.h:515
bool is_vector(void) const
Definition: dim-vector.h:458
subroutine cdotc3(m, n, k, a, b, c)
Definition: cdotc3.f:21
OCTINTERP_API void print_usage(void)
Definition: defun.cc:52
#define F77_DBLE_CMPLX_ARG(x)
Definition: f77-fcn.h:345
bool is_numeric_type(void) const
Definition: ov.h:679
for large enough k
Definition: lu.cc:606
#define DEFUN(name, args_name, nargout_name, doc)
Definition: defun.h:46
void error(const char *fmt,...)
Definition: error.cc:570
subroutine smatm3(m, n, k, np, a, b, c)
Definition: smatm3.f:21
#define F77_XFCN(f, F, args)
Definition: f77-fcn.h:52
int first_non_singleton(int def=0) const
Definition: dim-vector.h:463
bool is_float_type(void) const
Definition: ov.h:630
JNIEnv void * args
Definition: ov-java.cc:67
FloatNDArray float_array_value(bool frc_str_conv=false) const
Definition: ov.h:796
subroutine dmatm3(m, n, k, np, a, b, c)
Definition: dmatm3.f:21
nd deftypefn *octave_map m
Definition: ov-struct.cc:2058
int nargin
Definition: graphics.cc:10115
FloatComplexNDArray float_complex_array_value(bool frc_str_conv=false) const
Definition: ov.h:816
const T * data(void) const
Definition: Array.h:582
bool is_complex_type(void) const
Definition: ov.h:670
double tmp
Definition: data.cc:6300
octave_value retval
Definition: data.cc:6294
dim_vector redim(int n) const
Definition: dim-vector.cc:275
subroutine zmatm3(m, n, k, np, a, b, c)
Definition: zmatm3.f:21
static dim_vector alloc(int n)
Definition: dim-vector.h:270
dim_vector dims(void) const
Definition: ov.h:486
subroutine sdot3(m, n, k, a, b, c)
Definition: sdot3.f:21
static void get_red_dims(const dim_vector &x, const dim_vector &y, int dim, dim_vector &z, octave_idx_type &m, octave_idx_type &n, octave_idx_type &k)
Definition: dot.cc:34
feval(ar{f}, 1) esult
Definition: oct-parse.cc:8829
#define F77_CMPLX_ARG(x)
Definition: f77-fcn.h:339
bool empty(void) const
Definition: ovl.h:98
#define F77_CONST_DBLE_CMPLX_ARG(x)
Definition: f77-fcn.h:348
NDArray array_value(bool frc_str_conv=false) const
Definition: ov.h:793
double dot(const ColumnVector &v1, const ColumnVector &v2)
Definition: graphics.cc:5191
=val(i)}if ode{val(i)}occurs in table i
Definition: lookup.cc:239
octave_idx_type ndims(void) const
Number of dimensions.
Definition: dim-vector.h:301
the element is set to zero In other the statement xample y
Definition: data.cc:5342
#define F77_CONST_CMPLX_ARG(x)
Definition: f77-fcn.h:342
const T * fortran_vec(void) const
Definition: Array.h:584
bool is_single_type(void) const
Definition: ov.h:627
Vector representing the dimensions (size) of an Array.
Definition: dim-vector.h:87
subroutine ddot3(m, n, k, a, b, c)
Definition: ddot3.f:21
F77_RET_T F77_REAL &F77_RET_T F77_DBLE &F77_RET_T F77_REAL &F77_RET_T F77_DBLE &F77_RET_T F77_REAL &F77_RET_T F77_DBLE &F77_RET_T const F77_REAL const F77_REAL F77_REAL &F77_RET_T const F77_DBLE const F77_DBLE F77_DBLE &F77_RET_T F77_REAL &F77_RET_T F77_DBLE &F77_RET_T F77_DBLE &F77_RET_T F77_REAL &F77_RET_T F77_REAL &F77_RET_T F77_DBLE &F77_RET_T const F77_DBLE F77_DBLE &F77_RET_T const F77_REAL F77_REAL &F77_RET_T F77_REAL F77_REAL &F77_RET_T F77_DBLE F77_DBLE &F77_RET_T const F77_DBLE * x
subroutine zdotc3(m, n, k, a, b, c)
Definition: zdotc3.f:21
octave_value do_binary_op(octave_value::binary_op op, const octave_value &v1, const octave_value &v2)
Definition: ov.cc:2214