kron.cc

Go to the documentation of this file.
00001 /*
00002 
00003 Copyright (C) 2002-2012 John W. Eaton
00004 
00005 This file is part of Octave.
00006 
00007 Octave is free software; you can redistribute it and/or modify it
00008 under the terms of the GNU General Public License as published by the
00009 Free Software Foundation; either version 3 of the License, or (at your
00010 option) any later version.
00011 
00012 Octave is distributed in the hope that it will be useful, but WITHOUT
00013 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
00014 FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
00015 for more details.
00016 
00017 You should have received a copy of the GNU General Public License
00018 along with Octave; see the file COPYING.  If not, see
00019 <http://www.gnu.org/licenses/>.
00020 
00021 */
00022 
00023 // Author: Paul Kienzle <pkienzle@users.sf.net>
00024 
00025 #ifdef HAVE_CONFIG_H
00026 #include <config.h>
00027 #endif
00028 
00029 #include "dMatrix.h"
00030 #include "fMatrix.h"
00031 #include "CMatrix.h"
00032 #include "fCMatrix.h"
00033 
00034 #include "dSparse.h"
00035 #include "CSparse.h"
00036 
00037 #include "dDiagMatrix.h"
00038 #include "fDiagMatrix.h"
00039 #include "CDiagMatrix.h"
00040 #include "fCDiagMatrix.h"
00041 
00042 #include "PermMatrix.h"
00043 
00044 #include "mx-inlines.cc"
00045 #include "quit.h"
00046 
00047 #include "defun-dld.h"
00048 #include "error.h"
00049 #include "oct-obj.h"
00050 
00051 template <class R, class T>
00052 static MArray<T>
00053 kron (const MArray<R>& a, const MArray<T>& b)
00054 {
00055   assert (a.ndims () == 2);
00056   assert (b.ndims () == 2);
00057 
00058   octave_idx_type nra = a.rows (), nrb = b.rows ();
00059   octave_idx_type nca = a.cols (), ncb = b.cols ();
00060 
00061   MArray<T> c (dim_vector (nra*nrb, nca*ncb));
00062   T *cv = c.fortran_vec ();
00063 
00064   for (octave_idx_type ja = 0; ja < nca; ja++)
00065     for (octave_idx_type jb = 0; jb < ncb; jb++)
00066       for (octave_idx_type ia = 0; ia < nra; ia++)
00067         {
00068           octave_quit ();
00069           mx_inline_mul (nrb, cv, a(ia, ja), b.data () + nrb*jb);
00070           cv += nrb;
00071         }
00072 
00073   return c;
00074 }
00075 
00076 template <class R, class T>
00077 static MArray<T>
00078 kron (const MDiagArray2<R>& a, const MArray<T>& b)
00079 {
00080   assert (b.ndims () == 2);
00081 
00082   octave_idx_type nra = a.rows (), nrb = b.rows (), dla = a.diag_length ();
00083   octave_idx_type nca = a.cols (), ncb = b.cols ();
00084 
00085   MArray<T> c (dim_vector (nra*nrb, nca*ncb), T());
00086 
00087   for (octave_idx_type ja = 0; ja < dla; ja++)
00088     for (octave_idx_type jb = 0; jb < ncb; jb++)
00089       {
00090         octave_quit ();
00091         mx_inline_mul (nrb, &c.xelem(ja*nrb, ja*ncb + jb), a.dgelem (ja), b.data () + nrb*jb);
00092       }
00093 
00094   return c;
00095 }
00096 
00097 template <class T>
00098 static MSparse<T>
00099 kron (const MSparse<T>& A, const MSparse<T>& B)
00100 {
00101   octave_idx_type idx = 0;
00102   MSparse<T> C (A.rows () * B.rows (), A.columns () * B.columns (),
00103                 A.nnz () * B.nnz ());
00104 
00105   C.cidx (0) = 0;
00106 
00107   for (octave_idx_type Aj = 0; Aj < A.columns (); Aj++)
00108     for (octave_idx_type Bj = 0; Bj < B.columns (); Bj++)
00109       {
00110         octave_quit ();
00111         for (octave_idx_type Ai = A.cidx (Aj); Ai < A.cidx (Aj+1); Ai++)
00112           {
00113             octave_idx_type Ci = A.ridx(Ai) * B.rows ();
00114             const T v = A.data (Ai);
00115 
00116             for (octave_idx_type Bi = B.cidx (Bj); Bi < B.cidx (Bj+1); Bi++)
00117               {
00118                 C.data (idx) = v * B.data (Bi);
00119                 C.ridx (idx++) = Ci + B.ridx (Bi);
00120               }
00121           }
00122         C.cidx (Aj * B.columns () + Bj + 1) = idx;
00123       }
00124 
00125   return C;
00126 }
00127 
00128 static PermMatrix
00129 kron (const PermMatrix& a, const PermMatrix& b)
00130 {
00131   octave_idx_type na = a.rows (), nb = b.rows ();
00132   const octave_idx_type *pa = a.data (), *pb = b.data ();
00133   PermMatrix c(na*nb); // Row permutation.
00134   octave_idx_type *pc = c.fortran_vec ();
00135 
00136   bool cola = a.is_col_perm (), colb = b.is_col_perm ();
00137   if (cola && colb)
00138     {
00139       for (octave_idx_type i = 0; i < na; i++)
00140         for (octave_idx_type j = 0; j < nb; j++)
00141           pc[pa[i]*nb+pb[j]] = i*nb+j;
00142     }
00143   else if (cola)
00144     {
00145       for (octave_idx_type i = 0; i < na; i++)
00146         for (octave_idx_type j = 0; j < nb; j++)
00147           pc[pa[i]*nb+j] = i*nb+pb[j];
00148     }
00149   else if (colb)
00150     {
00151       for (octave_idx_type i = 0; i < na; i++)
00152         for (octave_idx_type j = 0; j < nb; j++)
00153           pc[i*nb+pb[j]] = pa[i]*nb+j;
00154     }
00155   else
00156     {
00157       for (octave_idx_type i = 0; i < na; i++)
00158         for (octave_idx_type j = 0; j < nb; j++)
00159           pc[i*nb+j] = pa[i]*nb+pb[j];
00160     }
00161 
00162   return c;
00163 }
00164 
00165 template <class MTA, class MTB>
00166 octave_value
00167 do_kron (const octave_value& a, const octave_value& b)
00168 {
00169   MTA am = octave_value_extract<MTA> (a);
00170   MTB bm = octave_value_extract<MTB> (b);
00171   return octave_value (kron (am, bm));
00172 }
00173 
00174 octave_value
00175 dispatch_kron (const octave_value& a, const octave_value& b)
00176 {
00177   octave_value retval;
00178   if (a.is_perm_matrix () && b.is_perm_matrix ())
00179     retval = do_kron<PermMatrix, PermMatrix> (a, b);
00180   else if (a.is_diag_matrix ())
00181     {
00182       if (b.is_diag_matrix () && a.rows () == a.columns ()
00183           && b.rows () == b.columns ())
00184         {
00185           // We have two diagonal matrices, the product of those will be
00186           // another diagonal matrix.  To do that efficiently, extract
00187           // the diagonals as vectors and compute the product.  That
00188           // will be another vector, which we then use to construct a
00189           // diagonal matrix object.  Note that this will fail if our
00190           // digaonal matrix object is modified to allow the non-zero
00191           // values to be stored off of the principal diagonal (i.e., if
00192           // diag ([1,2], 3) is modified to return a diagonal matrix
00193           // object instead of a full matrix object).
00194 
00195           octave_value tmp = dispatch_kron (a.diag (), b.diag ());
00196           retval = tmp.diag ();
00197         }
00198       else if (a.is_single_type () || b.is_single_type ())
00199         {
00200           if (a.is_complex_type ())
00201             retval = do_kron<FloatComplexDiagMatrix, FloatComplexMatrix> (a, b);
00202           else if (b.is_complex_type ())
00203             retval = do_kron<FloatDiagMatrix, FloatComplexMatrix> (a, b);
00204           else
00205             retval = do_kron<FloatDiagMatrix, FloatMatrix> (a, b);
00206         }
00207       else
00208         {
00209           if (a.is_complex_type ())
00210             retval = do_kron<ComplexDiagMatrix, ComplexMatrix> (a, b);
00211           else if (b.is_complex_type ())
00212             retval = do_kron<DiagMatrix, ComplexMatrix> (a, b);
00213           else
00214             retval = do_kron<DiagMatrix, Matrix> (a, b);
00215         }
00216     }
00217   else if (a.is_sparse_type () || b.is_sparse_type ())
00218     {
00219       if (a.is_complex_type () || b.is_complex_type ())
00220         retval = do_kron<SparseComplexMatrix, SparseComplexMatrix> (a, b);
00221       else
00222         retval = do_kron<SparseMatrix, SparseMatrix> (a, b);
00223     }
00224   else if (a.is_single_type () || b.is_single_type ())
00225     {
00226       if (a.is_complex_type ())
00227         retval = do_kron<FloatComplexMatrix, FloatComplexMatrix> (a, b);
00228       else if (b.is_complex_type ())
00229         retval = do_kron<FloatMatrix, FloatComplexMatrix> (a, b);
00230       else
00231         retval = do_kron<FloatMatrix, FloatMatrix> (a, b);
00232     }
00233   else
00234     {
00235       if (a.is_complex_type ())
00236         retval = do_kron<ComplexMatrix, ComplexMatrix> (a, b);
00237       else if (b.is_complex_type ())
00238         retval = do_kron<Matrix, ComplexMatrix> (a, b);
00239       else
00240         retval = do_kron<Matrix, Matrix> (a, b);
00241     }
00242   return retval;
00243 }
00244 
00245 
00246 DEFUN_DLD (kron, args, , "-*- texinfo -*-\n\
00247 @deftypefn  {Loadable Function} {} kron (@var{A}, @var{B})\n\
00248 @deftypefnx {Loadable Function} {} kron (@var{A1}, @var{A2}, @dots{})\n\
00249 Form the Kronecker product of two or more matrices, defined block by \n\
00250 block as\n\
00251 \n\
00252 @example\n\
00253 x = [a(i, j) b]\n\
00254 @end example\n\
00255 \n\
00256 For example:\n\
00257 \n\
00258 @example\n\
00259 @group\n\
00260 kron (1:4, ones (3, 1))\n\
00261       @result{}  1  2  3  4\n\
00262           1  2  3  4\n\
00263           1  2  3  4\n\
00264 @end group\n\
00265 @end example\n\
00266 \n\
00267 If there are more than two input arguments @var{A1}, @var{A2}, @dots{}, \n\
00268 @var{An} the Kronecker product is computed as\n\
00269 \n\
00270 @example\n\
00271 kron (kron (@var{A1}, @var{A2}), @dots{}, @var{An})\n\
00272 @end example\n\
00273 \n\
00274 @noindent\n\
00275 Since the Kronecker product is associative, this is well-defined.\n\
00276 @end deftypefn")
00277 {
00278   octave_value retval;
00279 
00280   int nargin = args.length ();
00281 
00282   if (nargin >= 2)
00283     {
00284       octave_value a = args(0), b = args(1);
00285       retval = dispatch_kron (a, b);
00286       for (octave_idx_type i = 2; i < nargin; i++)
00287         retval = dispatch_kron (retval, args(i));
00288     }
00289   else
00290     print_usage ();
00291 
00292   return retval;
00293 }
00294 
00295 
00296 /*
00297 %!test
00298 %! x = ones(2);
00299 %! assert( kron (x, x), ones (4));
00300 
00301 %!shared x, y, z
00302 %! x =  [1, 2];
00303 %! y =  [-1, -2];
00304 %! z =  [1,  2,  3,  4; 1,  2,  3,  4; 1,  2,  3,  4];
00305 %!assert (kron (1:4, ones (3, 1)), z)
00306 %!assert (kron (x, y, z), kron (kron (x, y), z))
00307 %!assert (kron (x, y, z), kron (x, kron (y, z)))
00308 
00309 
00310 %!assert (kron (diag ([1, 2]), diag ([3, 4])), diag ([3, 4, 6, 8]))
00311 
00312 %% Test for two diag matrices.  See the comments above in
00313 %% dispatch_kron for this case.
00314 %%
00315 %!test
00316 %! expected = zeros (16, 16);
00317 %! expected (1, 11) = 3;
00318 %! expected (2, 12) = 4;
00319 %! expected (5, 15) = 6;
00320 %! expected (6, 16) = 8;
00321 %! assert (kron (diag ([1, 2], 2), diag ([3, 4], 2)), expected)
00322 */
 All Classes Files Functions Variables Typedefs Enumerations Enumerator Friends Defines