GNU Octave  4.2.1
A high-level interpreted language, primarily intended for numerical computations, mostly compatible with Matlab
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Pages
kron.cc
Go to the documentation of this file.
1 /*
2 
3 Copyright (C) 2002-2017 John W. Eaton
4 
5 This file is part of Octave.
6 
7 Octave is free software; you can redistribute it and/or modify it
8 under the terms of the GNU General Public License as published by the
9 Free Software Foundation; either version 3 of the License, or (at your
10 option) any later version.
11 
12 Octave is distributed in the hope that it will be useful, but WITHOUT
13 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
14 FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
15 for more details.
16 
17 You should have received a copy of the GNU General Public License
18 along with Octave; see the file COPYING. If not, see
19 <http://www.gnu.org/licenses/>.
20 
21 */
22 
23 // Author: Paul Kienzle <pkienzle@users.sf.net>
24 
25 #if defined (HAVE_CONFIG_H)
26 # include "config.h"
27 #endif
28 
29 #include "dMatrix.h"
30 #include "fMatrix.h"
31 #include "CMatrix.h"
32 #include "fCMatrix.h"
33 
34 #include "dSparse.h"
35 #include "CSparse.h"
36 
37 #include "dDiagMatrix.h"
38 #include "fDiagMatrix.h"
39 #include "CDiagMatrix.h"
40 #include "fCDiagMatrix.h"
41 
42 #include "PermMatrix.h"
43 
44 #include "mx-inlines.cc"
45 #include "quit.h"
46 
47 #include "defun.h"
48 #include "error.h"
49 #include "ovl.h"
50 
51 template <typename R, typename T>
52 static MArray<T>
53 kron (const MArray<R>& a, const MArray<T>& b)
54 {
55  assert (a.ndims () == 2);
56  assert (b.ndims () == 2);
57 
58  octave_idx_type nra = a.rows ();
59  octave_idx_type nrb = b.rows ();
60  octave_idx_type nca = a.cols ();
61  octave_idx_type ncb = b.cols ();
62 
63  MArray<T> c (dim_vector (nra*nrb, nca*ncb));
64  T *cv = c.fortran_vec ();
65 
66  for (octave_idx_type ja = 0; ja < nca; ja++)
67  for (octave_idx_type jb = 0; jb < ncb; jb++)
68  for (octave_idx_type ia = 0; ia < nra; ia++)
69  {
70  octave_quit ();
71  mx_inline_mul (nrb, cv, a(ia, ja), b.data () + nrb*jb);
72  cv += nrb;
73  }
74 
75  return c;
76 }
77 
78 template <typename R, typename T>
79 static MArray<T>
80 kron (const MDiagArray2<R>& a, const MArray<T>& b)
81 {
82  assert (b.ndims () == 2);
83 
84  octave_idx_type nra = a.rows ();
85  octave_idx_type nrb = b.rows ();
86  octave_idx_type dla = a.diag_length ();
87  octave_idx_type nca = a.cols ();
88  octave_idx_type ncb = b.cols ();
89 
90  MArray<T> c (dim_vector (nra*nrb, nca*ncb), T ());
91 
92  for (octave_idx_type ja = 0; ja < dla; ja++)
93  for (octave_idx_type jb = 0; jb < ncb; jb++)
94  {
95  octave_quit ();
96  mx_inline_mul (nrb, &c.xelem (ja*nrb, ja*ncb + jb), a.dgelem (ja),
97  b.data () + nrb*jb);
98  }
99 
100  return c;
101 }
102 
103 template <typename T>
104 static MSparse<T>
105 kron (const MSparse<T>& A, const MSparse<T>& B)
106 {
107  octave_idx_type idx = 0;
108  MSparse<T> C (A.rows () * B.rows (), A.columns () * B.columns (),
109  A.nnz () * B.nnz ());
110 
111  C.cidx (0) = 0;
112 
113  for (octave_idx_type Aj = 0; Aj < A.columns (); Aj++)
114  for (octave_idx_type Bj = 0; Bj < B.columns (); Bj++)
115  {
116  octave_quit ();
117  for (octave_idx_type Ai = A.cidx (Aj); Ai < A.cidx (Aj+1); Ai++)
118  {
119  octave_idx_type Ci = A.ridx (Ai) * B.rows ();
120  const T v = A.data (Ai);
121 
122  for (octave_idx_type Bi = B.cidx (Bj); Bi < B.cidx (Bj+1); Bi++)
123  {
124  C.data (idx) = v * B.data (Bi);
125  C.ridx (idx++) = Ci + B.ridx (Bi);
126  }
127  }
128  C.cidx (Aj * B.columns () + Bj + 1) = idx;
129  }
130 
131  return C;
132 }
133 
134 static PermMatrix
135 kron (const PermMatrix& a, const PermMatrix& b)
136 {
137  octave_idx_type na = a.rows ();
138  octave_idx_type nb = b.rows ();
139  const Array<octave_idx_type>& pa = a.col_perm_vec ();
140  const Array<octave_idx_type>& pb = b.col_perm_vec ();
141  Array<octave_idx_type> res_perm (dim_vector (na * nb, 1));
142  octave_idx_type rescol = 0;
143  for (octave_idx_type i = 0; i < na; i++)
144  {
145  octave_idx_type a_add = pa(i) * nb;
146  for (octave_idx_type j = 0; j < nb; j++)
147  res_perm.xelem (rescol++) = a_add + pb(j);
148  }
149 
150  return PermMatrix (res_perm, true);
151 }
152 
153 template <typename MTA, typename MTB>
156 {
157  MTA am = octave_value_extract<MTA> (a);
158  MTB bm = octave_value_extract<MTB> (b);
159 
160  return octave_value (kron (am, bm));
161 }
162 
165 {
167  if (a.is_perm_matrix () && b.is_perm_matrix ())
168  retval = do_kron<PermMatrix, PermMatrix> (a, b);
169  else if (a.is_sparse_type () || b.is_sparse_type ())
170  {
171  if (a.is_complex_type () || b.is_complex_type ())
172  retval = do_kron<SparseComplexMatrix, SparseComplexMatrix> (a, b);
173  else
174  retval = do_kron<SparseMatrix, SparseMatrix> (a, b);
175  }
176  else if (a.is_diag_matrix ())
177  {
178  if (b.is_diag_matrix () && a.rows () == a.columns ()
179  && b.rows () == b.columns ())
180  {
181  // We have two diagonal matrices, the product of those will be
182  // another diagonal matrix. To do that efficiently, extract
183  // the diagonals as vectors and compute the product. That
184  // will be another vector, which we then use to construct a
185  // diagonal matrix object. Note that this will fail if our
186  // digaonal matrix object is modified to allow the nonzero
187  // values to be stored off of the principal diagonal (i.e., if
188  // diag ([1,2], 3) is modified to return a diagonal matrix
189  // object instead of a full matrix object).
190 
191  octave_value tmp = dispatch_kron (a.diag (), b.diag ());
192  retval = tmp.diag ();
193  }
194  else if (a.is_single_type () || b.is_single_type ())
195  {
196  if (a.is_complex_type ())
197  retval = do_kron<FloatComplexDiagMatrix, FloatComplexMatrix> (a, b);
198  else if (b.is_complex_type ())
199  retval = do_kron<FloatDiagMatrix, FloatComplexMatrix> (a, b);
200  else
201  retval = do_kron<FloatDiagMatrix, FloatMatrix> (a, b);
202  }
203  else
204  {
205  if (a.is_complex_type ())
206  retval = do_kron<ComplexDiagMatrix, ComplexMatrix> (a, b);
207  else if (b.is_complex_type ())
208  retval = do_kron<DiagMatrix, ComplexMatrix> (a, b);
209  else
210  retval = do_kron<DiagMatrix, Matrix> (a, b);
211  }
212  }
213  else if (a.is_single_type () || b.is_single_type ())
214  {
215  if (a.is_complex_type ())
216  retval = do_kron<FloatComplexMatrix, FloatComplexMatrix> (a, b);
217  else if (b.is_complex_type ())
218  retval = do_kron<FloatMatrix, FloatComplexMatrix> (a, b);
219  else
220  retval = do_kron<FloatMatrix, FloatMatrix> (a, b);
221  }
222  else
223  {
224  if (a.is_complex_type ())
225  retval = do_kron<ComplexMatrix, ComplexMatrix> (a, b);
226  else if (b.is_complex_type ())
227  retval = do_kron<Matrix, ComplexMatrix> (a, b);
228  else
229  retval = do_kron<Matrix, Matrix> (a, b);
230  }
231  return retval;
232 }
233 
234 
235 DEFUN (kron, args, ,
236  doc: /* -*- texinfo -*-
237 @deftypefn {} {} kron (@var{A}, @var{B})
238 @deftypefnx {} {} kron (@var{A1}, @var{A2}, @dots{})
239 Form the Kronecker product of two or more matrices.
240 
241 This is defined block by block as
242 
243 @example
244 x = [ a(i,j)*b ]
245 @end example
246 
247 For example:
248 
249 @example
250 @group
251 kron (1:4, ones (3, 1))
252  @result{} 1 2 3 4
253  1 2 3 4
254  1 2 3 4
255 @end group
256 @end example
257 
258 If there are more than two input arguments @var{A1}, @var{A2}, @dots{},
259 @var{An} the Kronecker product is computed as
260 
261 @example
262 kron (kron (@var{A1}, @var{A2}), @dots{}, @var{An})
263 @end example
264 
265 @noindent
266 Since the Kronecker product is associative, this is well-defined.
267 @end deftypefn */)
268 {
269  int nargin = args.length ();
270 
271  if (nargin < 2)
272  print_usage ();
273 
275 
278 
279  retval = dispatch_kron (a, b);
280 
281  for (octave_idx_type i = 2; i < nargin; i++)
282  retval = dispatch_kron (retval, args(i));
283 
284  return retval;
285 }
286 
287 /*
288 %!test
289 %! x = ones (2);
290 %! assert (kron (x, x), ones (4));
291 
292 %!shared x, y, z, p1, p2, d1, d2
293 %! x = [1, 2];
294 %! y = [-1, -2];
295 %! z = [1, 2, 3, 4; 1, 2, 3, 4; 1, 2, 3, 4];
296 %! p1 = eye (3)([2, 3, 1], :); ## Permutation matrix
297 %! p2 = [0 1 0; 0 0 1; 1 0 0]; ## Non-permutation equivalent
298 %! d1 = diag ([1 2 3]); ## Diag type matrix
299 %! d2 = [1 0 0; 0 2 0; 0 0 3]; ## Non-diag equivalent
300 %!assert (kron (1:4, ones (3, 1)), z)
301 %!assert (kron (single (1:4), ones (3, 1)), single (z))
302 %!assert (kron (sparse (1:4), ones (3, 1)), sparse (z))
303 %!assert (kron (complex (1:4), ones (3, 1)), z)
304 %!assert (kron (complex (single(1:4)), ones (3, 1)), single(z))
305 %!assert (kron (x, y, z), kron (kron (x, y), z))
306 %!assert (kron (x, y, z), kron (x, kron (y, z)))
307 %!assert (kron (p1, p1), kron (p2, p2))
308 %!assert (kron (p1, p2), kron (p2, p1))
309 %!assert (kron (d1, d1), kron (d2, d2))
310 %!assert (kron (d1, d2), kron (d2, d1))
311 
312 %!assert (kron (diag ([1, 2]), diag ([3, 4])), diag ([3, 4, 6, 8]))
313 
314 %% Test for two diag matrices. See the comments above in
315 %% dispatch_kron for this case.
316 %%
317 %!test
318 %! expected = zeros (16, 16);
319 %! expected (1, 11) = 3;
320 %! expected (2, 12) = 4;
321 %! expected (5, 15) = 6;
322 %! expected (6, 16) = 8;
323 %! assert (kron (diag ([1, 2], 2), diag ([3, 4], 2)), expected);
324 */
octave_value dispatch_kron(const octave_value &a, const octave_value &b)
Definition: kron.cc:164
T * data(void)
Definition: Sparse.h:521
octave_idx_type rows(void) const
Definition: Sparse.h:271
#define C(a, b)
Definition: Faddeeva.cc:246
octave_idx_type rows(void) const
Definition: ov.h:489
octave_idx_type rows(void) const
Definition: PermMatrix.h:59
int ndims(void) const
Definition: Array.h:590
OCTINTERP_API void print_usage(void)
Definition: defun.cc:52
octave_value diag(octave_idx_type k=0) const
Definition: ov.h:1311
#define DEFUN(name, args_name, nargout_name, doc)
Definition: defun.h:46
bool is_perm_matrix(void) const
Definition: ov.h:575
octave_idx_type rows(void) const
Definition: DiagArray2.h:88
octave_idx_type * cidx(void)
Definition: Sparse.h:543
octave_idx_type columns(void) const
Definition: Sparse.h:273
Template for N-dimensional array classes with like-type math operators.
Definition: MArray.h:32
octave_value b
Definition: kron.cc:277
octave_idx_type rows(void) const
Definition: Array.h:401
void mx_inline_mul(size_t n, R *r, const X *x, const Y *y)
Definition: mx-inlines.cc:110
octave_idx_type nnz(void) const
Actual number of nonzero terms.
Definition: Sparse.h:253
JNIEnv void * args
Definition: ov-java.cc:67
T dgelem(octave_idx_type i) const
Definition: DiagArray2.h:121
F77_RET_T const F77_INT F77_CMPLX const F77_INT F77_CMPLX * B
octave_idx_type columns(void) const
Definition: ov.h:491
bool is_sparse_type(void) const
Definition: ov.h:682
const Array< octave_idx_type > & col_perm_vec(void) const
Definition: PermMatrix.h:79
int nargin
Definition: graphics.cc:10115
const T * data(void) const
Definition: Array.h:582
bool is_complex_type(void) const
Definition: ov.h:670
double tmp
Definition: data.cc:6300
Template for two dimensional diagonal array with math operators.
Definition: MDiagArray2.h:33
the sparsity preserving column transformation such that that defines the pivoting threshold can be given in which case it defines the c
Definition: lu.cc:138
static MArray< T > kron(const MArray< R > &a, const MArray< T > &b)
Definition: kron.cc:53
T & xelem(octave_idx_type n)
Definition: Array.h:455
octave_idx_type cols(void) const
Definition: DiagArray2.h:89
octave_idx_type * ridx(void)
Definition: Sparse.h:530
octave_value a
Definition: kron.cc:276
octave_value do_kron(const octave_value &a, const octave_value &b)
Definition: kron.cc:155
octave_value retval
Definition: kron.cc:274
=val(i)}if ode{val(i)}occurs in table i
Definition: lookup.cc:239
octave_idx_type diag_length(void) const
Definition: DiagArray2.h:92
const T * fortran_vec(void) const
Definition: Array.h:584
bool is_single_type(void) const
Definition: ov.h:627
octave_idx_type cols(void) const
Definition: Array.h:409
Vector representing the dimensions (size) of an Array.
Definition: dim-vector.h:87
bool is_diag_matrix(void) const
Definition: ov.h:572
return octave_value(v1.char_array_value().concat(v2.char_array_value(), ra_idx),((a1.is_sq_string()||a2.is_sq_string())? '\'': '"'))
F77_RET_T const F77_INT F77_CMPLX * A