dot.cc

Go to the documentation of this file.
00001 /*
00002 
00003 Copyright (C) 2009-2012 VZLU Prague
00004 
00005 This file is part of Octave.
00006 
00007 Octave is free software; you can redistribute it and/or modify it
00008 under the terms of the GNU General Public License as published by the
00009 Free Software Foundation; either version 3 of the License, or (at your
00010 option) any later version.
00011 
00012 Octave is distributed in the hope that it will be useful, but WITHOUT
00013 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
00014 FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
00015 for more details.
00016 
00017 You should have received a copy of the GNU General Public License
00018 along with Octave; see the file COPYING.  If not, see
00019 <http://www.gnu.org/licenses/>.
00020 
00021 */
00022 
00023 #ifdef HAVE_CONFIG_H
00024 #include <config.h>
00025 #endif
00026 
00027 #include "f77-fcn.h"
00028 #include "mx-base.h"
00029 #include "error.h"
00030 #include "defun-dld.h"
00031 #include "parse.h"
00032 
00033 extern "C"
00034 {
00035   F77_RET_T
00036   F77_FUNC (ddot3, DDOT3) (const octave_idx_type&, const octave_idx_type&,
00037                            const octave_idx_type&, const double*,
00038                            const double*, double*);
00039 
00040   F77_RET_T
00041   F77_FUNC (sdot3, SDOT3) (const octave_idx_type&, const octave_idx_type&,
00042                            const octave_idx_type&, const float*,
00043                            const float*, float*);
00044 
00045   F77_RET_T
00046   F77_FUNC (zdotc3, ZDOTC3) (const octave_idx_type&, const octave_idx_type&,
00047                              const octave_idx_type&, const Complex*,
00048                              const Complex*, Complex*);
00049 
00050   F77_RET_T
00051   F77_FUNC (cdotc3, CDOTC3) (const octave_idx_type&, const octave_idx_type&,
00052                              const octave_idx_type&, const FloatComplex*,
00053                              const FloatComplex*, FloatComplex*);
00054 
00055   F77_RET_T
00056   F77_FUNC (dmatm3, DMATM3) (const octave_idx_type&, const octave_idx_type&,
00057                              const octave_idx_type&, const octave_idx_type&,
00058                              const double*, const double*, double*);
00059 
00060   F77_RET_T
00061   F77_FUNC (smatm3, SMATM3) (const octave_idx_type&, const octave_idx_type&,
00062                              const octave_idx_type&, const octave_idx_type&,
00063                              const float*, const float*, float*);
00064 
00065   F77_RET_T
00066   F77_FUNC (zmatm3, ZMATM3) (const octave_idx_type&, const octave_idx_type&,
00067                              const octave_idx_type&, const octave_idx_type&,
00068                              const Complex*, const Complex*, Complex*);
00069 
00070   F77_RET_T
00071   F77_FUNC (cmatm3, CMATM3) (const octave_idx_type&, const octave_idx_type&,
00072                              const octave_idx_type&, const octave_idx_type&,
00073                              const FloatComplex*, const FloatComplex*,
00074                              FloatComplex*);
00075 }
00076 
00077 static void
00078 get_red_dims (const dim_vector& x, const dim_vector& y, int dim,
00079               dim_vector& z, octave_idx_type& m, octave_idx_type& n,
00080               octave_idx_type& k)
00081 {
00082   int nd = x.length ();
00083   assert (nd == y.length ());
00084   z = dim_vector::alloc (nd);
00085   m = 1, n = 1, k = 1;
00086   for (int i = 0; i < nd; i++)
00087     {
00088       if (i < dim)
00089         {
00090           z(i) = x(i);
00091           m *= x(i);
00092         }
00093       else if (i > dim)
00094         {
00095           z(i) = x(i);
00096           n *= x(i);
00097         }
00098       else
00099         {
00100           k = x(i);
00101           z(i) = 1;
00102         }
00103     }
00104 }
00105 
00106 DEFUN_DLD (dot, args, ,
00107   "-*- texinfo -*-\n\
00108 @deftypefn {Loadable Function} {} dot (@var{x}, @var{y}, @var{dim})\n\
00109 Compute the dot product of two vectors.  If @var{x} and @var{y}\n\
00110 are matrices, calculate the dot products along the first\n\
00111 non-singleton dimension.  If the optional argument @var{dim} is\n\
00112 given, calculate the dot products along this dimension.\n\
00113 \n\
00114 This is equivalent to\n\
00115 @code{sum (conj (@var{X}) .* @var{Y}, @var{dim})},\n\
00116 but avoids forming a temporary array and is faster.  When @var{X} and\n\
00117 @var{Y} are column vectors, the result is equivalent to\n\
00118 @code{@var{X}' * @var{Y}}.\n\
00119 @seealso{cross, divergence}\n\
00120 @end deftypefn")
00121 {
00122   octave_value retval;
00123   int nargin = args.length ();
00124 
00125   if (nargin < 2 || nargin > 3)
00126     {
00127       print_usage ();
00128       return retval;
00129     }
00130 
00131   octave_value argx = args(0), argy = args(1);
00132 
00133   if (argx.is_numeric_type () && argy.is_numeric_type ())
00134     {
00135       dim_vector dimx = argx.dims (), dimy = argy.dims ();
00136       bool match = dimx == dimy;
00137       if (! match && nargin == 2
00138           && dimx.is_vector () && dimy.is_vector ())
00139         {
00140           // Change to column vectors.
00141           dimx = dimx.redim (1);
00142           argx = argx.reshape (dimx);
00143           dimy = dimy.redim (1);
00144           argy = argy.reshape (dimy);
00145           match = ! error_state;
00146         }
00147 
00148       if (match)
00149         {
00150           int dim;
00151           if (nargin == 2)
00152             dim = dimx.first_non_singleton ();
00153           else
00154             dim = args(2).int_value (true) - 1;
00155 
00156           if (error_state)
00157             ;
00158           else if (dim < 0)
00159             error ("dot: DIM must be a valid dimension");
00160           else
00161             {
00162               octave_idx_type m, n, k;
00163               dim_vector dimz;
00164               if (argx.is_complex_type () || argy.is_complex_type ())
00165                 {
00166                   if (argx.is_single_type () || argy.is_single_type ())
00167                     {
00168                       FloatComplexNDArray x = argx.float_complex_array_value ();
00169                       FloatComplexNDArray y = argy.float_complex_array_value ();
00170                       get_red_dims (dimx, dimy, dim, dimz, m, n, k);
00171                       FloatComplexNDArray z(dimz);
00172                       if (! error_state)
00173                         F77_XFCN (cdotc3, CDOTC3, (m, n, k, x.data (), y.data (),
00174                                                    z.fortran_vec ()));
00175                       retval = z;
00176                     }
00177                   else
00178                     {
00179                       ComplexNDArray x = argx.complex_array_value ();
00180                       ComplexNDArray y = argy.complex_array_value ();
00181                       get_red_dims (dimx, dimy, dim, dimz, m, n, k);
00182                       ComplexNDArray z(dimz);
00183                       if (! error_state)
00184                         F77_XFCN (zdotc3, ZDOTC3, (m, n, k, x.data (), y.data (),
00185                                                    z.fortran_vec ()));
00186                       retval = z;
00187                     }
00188                 }
00189               else if (argx.is_float_type () && argy.is_float_type ())
00190                 {
00191                   if (argx.is_single_type () || argy.is_single_type ())
00192                     {
00193                       FloatNDArray x = argx.float_array_value ();
00194                       FloatNDArray y = argy.float_array_value ();
00195                       get_red_dims (dimx, dimy, dim, dimz, m, n, k);
00196                       FloatNDArray z(dimz);
00197                       if (! error_state)
00198                         F77_XFCN (sdot3, SDOT3, (m, n, k, x.data (), y.data (),
00199                                                  z.fortran_vec ()));
00200                       retval = z;
00201                     }
00202                   else
00203                     {
00204                       NDArray x = argx.array_value ();
00205                       NDArray y = argy.array_value ();
00206                       get_red_dims (dimx, dimy, dim, dimz, m, n, k);
00207                       NDArray z(dimz);
00208                       if (! error_state)
00209                         F77_XFCN (ddot3, DDOT3, (m, n, k, x.data (), y.data (),
00210                                                  z.fortran_vec ()));
00211                       retval = z;
00212                     }
00213                 }
00214               else
00215                 {
00216                   // Non-optimized evaluation.
00217                   octave_value_list tmp;
00218                   tmp(1) = args(2);
00219                   tmp(0) = do_binary_op (octave_value::op_el_mul, argx, argy);
00220                   if (! error_state)
00221                     {
00222                       tmp = feval ("sum", tmp, 1);
00223                       if (! tmp.empty ())
00224                         retval = tmp(0);
00225                     }
00226                 }
00227             }
00228         }
00229       else
00230         error ("dot: sizes of X and Y must match");
00231 
00232     }
00233   else
00234     error ("dot: X and Y must be numeric");
00235 
00236   return retval;
00237 }
00238 
00239 /*
00240 
00241 %! assert(dot ([1, 2], [2, 3]), 11);
00242 
00243 %!test
00244 %! x = [2, 1; 2, 1];
00245 %! y = [-0.5, 2; 0.5, -2];
00246 %! assert(dot (x, y), [0 0]);
00247 
00248 %!test
00249 %! x = [ 1+i, 3-i; 1-i, 3-i];
00250 %! assert(dot (x, x), [4, 20]);
00251 
00252 */
00253 
00254 DEFUN_DLD (blkmm, args, ,
00255   "-*- texinfo -*-\n\
00256 @deftypefn {Loadable Function} {} blkmm (@var{A}, @var{B})\n\
00257 Compute products of matrix blocks.  The blocks are given as\n\
00258 2-dimensional subarrays of the arrays @var{A}, @var{B}.\n\
00259 The size of @var{A} must have the form @code{[m,k,@dots{}]} and\n\
00260 size of @var{B} must be @code{[k,n,@dots{}]}.  The result is\n\
00261 then of size @code{[m,n,@dots{}]} and is computed as follows:\n\
00262 \n\
00263 @example\n\
00264 @group\n\
00265   for i = 1:prod (size (@var{A})(3:end))\n\
00266     @var{C}(:,:,i) = @var{A}(:,:,i) * @var{B}(:,:,i)\n\
00267   endfor\n\
00268 @end group\n\
00269 @end example\n\
00270 @end deftypefn")
00271 {
00272   octave_value retval;
00273   int nargin = args.length ();
00274 
00275   if (nargin != 2)
00276     {
00277       print_usage ();
00278       return retval;
00279     }
00280 
00281   octave_value argx = args(0), argy = args(1);
00282 
00283   if (argx.is_numeric_type () && argy.is_numeric_type ())
00284     {
00285       const dim_vector dimx = argx.dims (), dimy = argy.dims ();
00286       int nd = dimx.length ();
00287       octave_idx_type m = dimx(0), k = dimx(1), n = dimy(1), np = 1;
00288       bool match = dimy(0) == k && nd == dimy.length ();
00289       dim_vector dimz = dim_vector::alloc (nd);
00290       dimz(0) = m;
00291       dimz(1) = n;
00292       for (int i = 2; match && i < nd; i++)
00293         {
00294           match = match && dimx(i) == dimy(i);
00295           dimz(i) = dimx(i);
00296           np *= dimz(i);
00297         }
00298 
00299       if (match)
00300         {
00301           if (argx.is_complex_type () || argy.is_complex_type ())
00302             {
00303               if (argx.is_single_type () || argy.is_single_type ())
00304                 {
00305                   FloatComplexNDArray x = argx.float_complex_array_value ();
00306                   FloatComplexNDArray y = argy.float_complex_array_value ();
00307                   FloatComplexNDArray z(dimz);
00308                   if (! error_state)
00309                     F77_XFCN (cmatm3, CMATM3, (m, n, k, np, x.data (), y.data (),
00310                                                z.fortran_vec ()));
00311                   retval = z;
00312                 }
00313               else
00314                 {
00315                   ComplexNDArray x = argx.complex_array_value ();
00316                   ComplexNDArray y = argy.complex_array_value ();
00317                   ComplexNDArray z(dimz);
00318                   if (! error_state)
00319                     F77_XFCN (zmatm3, ZMATM3, (m, n, k, np, x.data (), y.data (),
00320                                                z.fortran_vec ()));
00321                   retval = z;
00322                 }
00323             }
00324           else
00325             {
00326               if (argx.is_single_type () || argy.is_single_type ())
00327                 {
00328                   FloatNDArray x = argx.float_array_value ();
00329                   FloatNDArray y = argy.float_array_value ();
00330                   FloatNDArray z(dimz);
00331                   if (! error_state)
00332                     F77_XFCN (smatm3, SMATM3, (m, n, k, np, x.data (), y.data (),
00333                                                z.fortran_vec ()));
00334                   retval = z;
00335                 }
00336               else
00337                 {
00338                   NDArray x = argx.array_value ();
00339                   NDArray y = argy.array_value ();
00340                   NDArray z(dimz);
00341                   if (! error_state)
00342                     F77_XFCN (dmatm3, DMATM3, (m, n, k, np, x.data (), y.data (),
00343                                                z.fortran_vec ()));
00344                   retval = z;
00345                 }
00346             }
00347         }
00348       else
00349         error ("blkmm: A and B dimensions don't match: (%s) and (%s)",
00350                dimx.str ().c_str (), dimy.str ().c_str ());
00351 
00352     }
00353   else
00354     error ("blkmm: A and B must be numeric");
00355 
00356   return retval;
00357 }
00358 
00359 /*
00360 
00361 %!test
00362 %! x(:,:,1) = [1 2; 3 4];
00363 %! x(:,:,2) = [1 1; 1 1];
00364 %! z(:,:,1) = [7 10; 15 22];
00365 %! z(:,:,2) = [2 2; 2 2];
00366 %! assert(blkmm (x,x),z);
00367 
00368 */
 All Classes Files Functions Variables Typedefs Enumerations Enumerator Friends Defines