oct-convn.cc

Go to the documentation of this file.
00001 /*
00002 
00003 Copyright (C) 2010-2012 VZLU Prague
00004 
00005 This file is part of Octave.
00006 
00007 Octave is free software; you can redistribute it and/or modify it
00008 under the terms of the GNU General Public License as published by the
00009 Free Software Foundation; either version 3 of the License, or (at your
00010 option) any later version.
00011 
00012 Octave is distributed in the hope that it will be useful, but WITHOUT
00013 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
00014 FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
00015 for more details.
00016 
00017 You should have received a copy of the GNU General Public License
00018 along with Octave; see the file COPYING.  If not, see
00019 <http://www.gnu.org/licenses/>.
00020 
00021 */
00022 
00023 #ifdef HAVE_CONFIG_H
00024 #include <config.h>
00025 #endif
00026 
00027 #include <iostream>
00028 #include <algorithm>
00029 
00030 #include "f77-fcn.h"
00031 
00032 #include "oct-convn.h"
00033 #include "oct-locbuf.h"
00034 
00035 // 2d convolution with a matrix kernel.
00036 template <class T, class R>
00037 static void
00038 convolve_2d (const T *a, octave_idx_type ma, octave_idx_type na,
00039              const R *b, octave_idx_type mb, octave_idx_type nb,
00040              T *c, bool inner);
00041 
00042 // Forward instances to our Fortran implementations.
00043 #define FORWARD_IMPL(T,R,f,F) \
00044 extern "C" \
00045 F77_RET_T \
00046 F77_FUNC (f##conv2o, F##CONV2O) (const octave_idx_type&, \
00047                                  const octave_idx_type&, \
00048                                  const T*, const octave_idx_type&, \
00049                                  const octave_idx_type&, const R*, T *); \
00050 \
00051 extern "C" \
00052 F77_RET_T \
00053 F77_FUNC (f##conv2i, F##CONV2I) (const octave_idx_type&, \
00054                                  const octave_idx_type&, \
00055                                  const T*, const octave_idx_type&, \
00056                                  const octave_idx_type&, const R*, T *); \
00057 \
00058 template <> void \
00059 convolve_2d<T, R> (const T *a, octave_idx_type ma, octave_idx_type na, \
00060                    const R *b, octave_idx_type mb, octave_idx_type nb, \
00061                    T *c, bool inner) \
00062 { \
00063   if (inner) \
00064     F77_XFCN (f##conv2i, F##CONV2I, (ma, na, a, mb, nb, b, c)); \
00065   else \
00066     F77_XFCN (f##conv2o, F##CONV2O, (ma, na, a, mb, nb, b, c)); \
00067 }
00068 
00069 FORWARD_IMPL (double, double, d, D)
00070 FORWARD_IMPL (float, float, s, S)
00071 FORWARD_IMPL (Complex, Complex, z, Z)
00072 FORWARD_IMPL (FloatComplex, FloatComplex, c, C)
00073 FORWARD_IMPL (Complex, double, zd, ZD)
00074 FORWARD_IMPL (FloatComplex, float, cs, CS)
00075 
00076 template <class T, class R>
00077 void convolve_nd (const T *a, const dim_vector& ad, const dim_vector& acd,
00078                   const R *b, const dim_vector& bd, const dim_vector& bcd,
00079                   T *c, const dim_vector& ccd, int nd, bool inner)
00080 {
00081   if (nd == 2)
00082     convolve_2d<T, R> (a, ad(0), ad(1), b, bd(0), bd(1), c, inner);
00083   else
00084     {
00085       octave_idx_type ma = acd(nd-2), na = ad(nd-1), mb = bcd(nd-2), nb = bd(nd-1);
00086       octave_idx_type ldc = ccd(nd-2);
00087       if (inner)
00088         {
00089           for (octave_idx_type ja = 0; ja < na - nb + 1; ja++)
00090             for (octave_idx_type jb = 0; jb < nb; jb++)
00091               convolve_nd<T, R> (a + ma*(ja + jb), ad, acd, b + mb*jb, bd, bcd,
00092                                  c + ldc*ja, ccd, nd-1, inner);
00093         }
00094       else
00095         {
00096           for (octave_idx_type ja = 0; ja < na; ja++)
00097             for (octave_idx_type jb = 0; jb < nb; jb++)
00098               convolve_nd<T, R> (a + ma*ja, ad, acd, b + mb*jb, bd, bcd,
00099                                  c + ldc*(ja+jb), ccd, nd-1, inner);
00100         }
00101     }
00102 }
00103 
00104 // Arbitrary convolutor.
00105 // The 2nd array is assumed to be the smaller one.
00106 template <class T, class R>
00107 static MArray<T>
00108 convolve (const MArray<T>& a, const MArray<R>& b,
00109           convn_type ct)
00110 {
00111   if (a.is_empty () || b.is_empty ())
00112     return MArray<T> ();
00113 
00114   int nd = std::max (a.ndims (), b.ndims ());
00115   const dim_vector adims = a.dims ().redim (nd), bdims = b.dims ().redim (nd);
00116   dim_vector cdims = dim_vector::alloc (nd);
00117 
00118   for (int i = 0; i < nd; i++)
00119     {
00120       if (ct == convn_valid)
00121         cdims(i) = std::max (adims(i) - bdims(i) + 1,
00122                              static_cast<octave_idx_type> (0));
00123       else
00124         cdims(i) = std::max (adims(i) + bdims(i) - 1,
00125                              static_cast<octave_idx_type> (0));
00126     }
00127 
00128   MArray<T> c (cdims, T());
00129 
00130   convolve_nd<T, R> (a.fortran_vec (), adims, adims.cumulative (),
00131                      b.fortran_vec (), bdims, bdims.cumulative (),
00132                      c.fortran_vec (), cdims.cumulative (), nd, ct == convn_valid);
00133 
00134   if (ct == convn_same)
00135     {
00136       // Pick the relevant part.
00137       Array<idx_vector> sidx (dim_vector (nd, 1));
00138 
00139       for (int i = 0; i < nd; i++)
00140         sidx(i) = idx_vector::make_range (bdims(i)/2, 1, adims(i));
00141       c = c.index (sidx);
00142     }
00143 
00144   return c;
00145 }
00146 
00147 #define CONV_DEFS(TPREF, RPREF) \
00148 TPREF ## NDArray \
00149 convn (const TPREF ## NDArray& a, const RPREF ## NDArray& b, convn_type ct) \
00150 { \
00151   return convolve (a, b, ct); \
00152 } \
00153 TPREF ## Matrix \
00154 convn (const TPREF ## Matrix& a, const RPREF ## Matrix& b, convn_type ct) \
00155 { \
00156   return convolve (a, b, ct); \
00157 } \
00158 TPREF ## Matrix \
00159 convn (const TPREF ## Matrix& a, const RPREF ## ColumnVector& c, \
00160        const RPREF ## RowVector& r, convn_type ct) \
00161 { \
00162   return convolve (a, c * r, ct); \
00163 }
00164 
00165 CONV_DEFS ( , )
00166 CONV_DEFS (Complex, )
00167 CONV_DEFS (Complex, Complex)
00168 CONV_DEFS (Float, Float)
00169 CONV_DEFS (FloatComplex, Float)
00170 CONV_DEFS (FloatComplex, FloatComplex)
 All Classes Files Functions Variables Typedefs Enumerations Enumerator Friends Defines