conv2.cc

Go to the documentation of this file.
00001 /*
00002 
00003 Copyright (C) 1999-2012 Andy Adler
00004 Copyright (C) 2010 VZLU Prague
00005 
00006 This file is part of Octave.
00007 
00008 Octave is free software; you can redistribute it and/or modify it
00009 under the terms of the GNU General Public License as published by the
00010 Free Software Foundation; either version 3 of the License, or (at your
00011 option) any later version.
00012 
00013 Octave is distributed in the hope that it will be useful, but WITHOUT
00014 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
00015 FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
00016 for more details.
00017 
00018 You should have received a copy of the GNU General Public License
00019 along with Octave; see the file COPYING.  If not, see
00020 <http://www.gnu.org/licenses/>.
00021 
00022 */
00023 
00024 #ifdef HAVE_CONFIG_H
00025 #include <config.h>
00026 #endif
00027 
00028 #include "oct-convn.h"
00029 
00030 #include "defun-dld.h"
00031 #include "error.h"
00032 #include "oct-obj.h"
00033 #include "utils.h"
00034 
00035 enum Shape { SHAPE_FULL, SHAPE_SAME, SHAPE_VALID };
00036 
00037 DEFUN_DLD (conv2, args, ,
00038   "-*- texinfo -*-\n\
00039 @deftypefn  {Loadable Function} {} conv2 (@var{A}, @var{B})\n\
00040 @deftypefnx {Loadable Function} {} conv2 (@var{v1}, @var{v2}, @var{m})\n\
00041 @deftypefnx {Loadable Function} {} conv2 (@dots{}, @var{shape})\n\
00042 Return the 2-D convolution of @var{A} and @var{B}.  The size of the result\n\
00043 is determined by the optional @var{shape} argument which takes the following\n\
00044 values\n\
00045 \n\
00046 @table @asis\n\
00047 @item @var{shape} = \"full\"\n\
00048 Return the full convolution.  (default)\n\
00049 \n\
00050 @item @var{shape} = \"same\"\n\
00051 Return the central part of the convolution with the same size as @var{A}.\n\
00052 The central part of the convolution begins at the indices\n\
00053 @code{floor ([size(@var{B})/2] + 1)}.\n\
00054 \n\
00055 @item @var{shape} = \"valid\"\n\
00056 Return only the parts which do not include zero-padded edges.\n\
00057 The size of the result is @code{max (size (A) - size (B) + 1, 0)}.\n\
00058 @end table\n\
00059 \n\
00060 When the third argument is a matrix, return the convolution of the matrix\n\
00061 @var{m} by the vector @var{v1} in the column direction and by the vector\n\
00062 @var{v2} in the row direction.\n\
00063 @seealso{conv, convn}\n\
00064 @end deftypefn")
00065 {
00066   octave_value retval;
00067   octave_value tmp;
00068   int nargin = args.length ();
00069   std::string shape = "full";   // default
00070   bool separable = false;
00071   convn_type ct;
00072 
00073   if (nargin < 2)
00074     {
00075      print_usage ();
00076      return retval;
00077     }
00078   else if (nargin == 3)
00079     {
00080       if (args(2).is_string ())
00081         shape = args(2).string_value ();
00082       else
00083         separable = true;
00084     }
00085   else if (nargin >= 4)
00086     {
00087       separable = true;
00088       shape = args(3).string_value ();
00089     }
00090 
00091   if (shape == "full")
00092     ct = convn_full;
00093   else if (shape == "same")
00094     ct = convn_same;
00095   else if (shape == "valid")
00096     ct = convn_valid;
00097   else
00098     {
00099       error ("conv2: SHAPE type not valid");
00100       print_usage ();
00101       return retval;
00102     }
00103 
00104    if (separable)
00105      {
00106       // If user requests separable, check first two params are vectors
00107 
00108        if (! (1 == args(0).rows () || 1 == args(0).columns ())
00109            || ! (1 == args(1).rows () || 1 == args(1).columns ()))
00110          {
00111            print_usage ();
00112            return retval;
00113          }
00114 
00115        if (args(0).is_single_type () || args(1).is_single_type ()
00116            || args(2).is_single_type ())
00117          {
00118            if (args(0).is_complex_type () || args(1).is_complex_type ()
00119                || args(2).is_complex_type ())
00120              {
00121                FloatComplexMatrix a (args(2).float_complex_matrix_value ());
00122                if (args(1).is_real_type () && args(2).is_real_type ())
00123                  {
00124                    FloatColumnVector v1 (args(0).float_vector_value ());
00125                    FloatRowVector v2 (args(1).float_vector_value ());
00126                    retval = convn (a, v1, v2, ct);
00127                  }
00128                else
00129                  {
00130                    FloatComplexColumnVector v1 (args(0).float_complex_vector_value ());
00131                    FloatComplexRowVector v2 (args(1).float_complex_vector_value ());
00132                    retval = convn (a, v1, v2, ct);
00133                  }
00134              }
00135            else
00136              {
00137                FloatColumnVector v1 (args(0).float_vector_value ());
00138                FloatRowVector v2 (args(1).float_vector_value ());
00139                FloatMatrix a (args(2).float_matrix_value ());
00140                retval = convn (a, v1, v2, ct);
00141              }
00142          }
00143        else
00144          {
00145            if (args(0).is_complex_type () || args(1).is_complex_type ()
00146                || args(2).is_complex_type ())
00147              {
00148                ComplexMatrix a (args(2).complex_matrix_value ());
00149                if (args(1).is_real_type () && args(2).is_real_type ())
00150                  {
00151                    ColumnVector v1 (args(0).vector_value ());
00152                    RowVector v2 (args(1).vector_value ());
00153                    retval = convn (a, v1, v2, ct);
00154                  }
00155                else
00156                  {
00157                    ComplexColumnVector v1 (args(0).complex_vector_value ());
00158                    ComplexRowVector v2 (args(1).complex_vector_value ());
00159                    retval = convn (a, v1, v2, ct);
00160                  }
00161              }
00162            else
00163              {
00164                ColumnVector v1 (args(0).vector_value ());
00165                RowVector v2 (args(1).vector_value ());
00166                Matrix a (args(2).matrix_value ());
00167                retval = convn (a, v1, v2, ct);
00168              }
00169          }
00170      } // if (separable)
00171    else
00172      {
00173        if (args(0).is_single_type () || args(1).is_single_type ())
00174          {
00175            if (args(0).is_complex_type () || args(1).is_complex_type ())
00176              {
00177                FloatComplexMatrix a (args(0).float_complex_matrix_value ());
00178                if (args(1).is_real_type ())
00179                  {
00180                    FloatMatrix b (args(1).float_matrix_value ());
00181                    retval = convn (a, b, ct);
00182                  }
00183                else
00184                  {
00185                    FloatComplexMatrix b (args(1).float_complex_matrix_value ());
00186                    retval = convn (a, b, ct);
00187                  }
00188              }
00189            else
00190              {
00191                FloatMatrix a (args(0).float_matrix_value ());
00192                FloatMatrix b (args(1).float_matrix_value ());
00193                retval = convn (a, b, ct);
00194              }
00195          }
00196        else
00197          {
00198            if (args(0).is_complex_type () || args(1).is_complex_type ())
00199              {
00200                ComplexMatrix a (args(0).complex_matrix_value ());
00201                if (args(1).is_real_type ())
00202                  {
00203                    Matrix b (args(1).matrix_value ());
00204                    retval = convn (a, b, ct);
00205                  }
00206                else
00207                  {
00208                    ComplexMatrix b (args(1).complex_matrix_value ());
00209                    retval = convn (a, b, ct);
00210                  }
00211              }
00212            else
00213              {
00214                Matrix a (args(0).matrix_value ());
00215                Matrix b (args(1).matrix_value ());
00216                retval = convn (a, b, ct);
00217              }
00218          }
00219 
00220      } // if (separable)
00221 
00222    return retval;
00223 }
00224 
00225 /*
00226 %!test
00227 %! c = [0,1,2,3;1,8,12,12;4,20,24,21;7,22,25,18];
00228 %! assert (conv2 ([0,1;1,2], [1,2,3;4,5,6;7,8,9]), c);
00229 
00230 %!test
00231 %! c = single ([0,1,2,3;1,8,12,12;4,20,24,21;7,22,25,18]);
00232 %! assert (conv2 (single ([0,1;1,2]), single ([1,2,3;4,5,6;7,8,9])), c);
00233 
00234 %!test
00235 %! c = [1,4,4;5,18,16;14,48,40;19,62,48;15,48,36];
00236 %! assert (conv2 (1:3, 1:2, [1,2;3,4;5,6]), c);
00237 
00238 %!assert (conv2 (1:3, 1:2, [1,2;3,4;5,6], "full"),
00239 %!        conv2 (1:3, 1:2, [1,2;3,4;5,6]));
00240 
00241 %% Test shapes
00242 %!shared A, B, C
00243 %! A = rand (3, 4);
00244 %! B = rand (4);
00245 %! C = conv2 (A, B);
00246 %!assert (conv2 (A,B, "full"), C)
00247 %!assert (conv2 (A,B, "same"), C(3:5,3:6))
00248 %!assert (conv2 (A,B, "valid"), zeros (0, 1))
00249 %!assert (size (conv2 (B,A, "valid")), [2 1])
00250 
00251 %!test
00252 %! B = rand (5);
00253 %! C = conv2 (A, B);
00254 %!assert (conv2 (A,B, "full"), C)
00255 %!assert (conv2 (A,B, "same"), C(3:5,3:6))
00256 %!assert (conv2 (A,B, "valid"), zeros (0, 0))
00257 %!assert (size (conv2 (B,A, "valid")), [3 2])
00258 
00259 %% Clear shared variables so they are not reported for tests below
00260 %!shared
00261 
00262 %% Test cases from Bug #34893
00263 %!assert (conv2 ([1:5;1:5], [1:2], 'same'), [4 7 10 13 10; 4 7 10 13 10])
00264 %!assert (conv2 ([1:5;1:5]', [1:2]', 'same'), [4 7 10 13 10; 4 7 10 13 10]')
00265 %!assert (conv2 ([1:5;1:5], [1:2], 'valid'), [4 7 10 13; 4 7 10 13])
00266 %!assert (conv2 ([1:5;1:5]', [1:2]', 'valid'), [4 7 10 13; 4 7 10 13]')
00267 
00268 %!test
00269 %! rand ("seed", 42);
00270 %! x = rand (100);
00271 %! y = ones (5);
00272 %! A = conv2 (x, y)(5:end-4,5:end-4);
00273 %! B = conv2 (x, y, "valid");
00274 %! assert (B, A); ## Yes, this test is for *exact* equivalence.
00275 
00276 
00277 %% Test input validation
00278 %!error conv2 ()
00279 %!error conv2 (1)
00280 %!error <SHAPE type not valid> conv2 (1,2, "NOT_A_SHAPE")
00281 %% Test alternate calling form which should be 2 vectors and a matrix
00282 %!error conv2 (ones (2), 1, 1)
00283 %!error conv2 (1, ones (2), 1)
00284 
00285 */
00286 
00287 DEFUN_DLD (convn, args, ,
00288   "-*- texinfo -*-\n\
00289 @deftypefn  {Loadable Function} {@var{C} =} convn (@var{A}, @var{B})\n\
00290 @deftypefnx {Loadable Function} {@var{C} =} convn (@var{A}, @var{B}, @var{shape})\n\
00291 Return the n-D convolution of @var{A} and @var{B}.  The size of the result\n\
00292 is determined by the optional @var{shape} argument which takes the following\n\
00293 values\n\
00294 \n\
00295 @table @asis\n\
00296 @item @var{shape} = \"full\"\n\
00297 Return the full convolution.  (default)\n\
00298 \n\
00299 @item @var{shape} = \"same\"\n\
00300 Return central part of the convolution with the same size as @var{A}.\n\
00301 The central part of the convolution begins at the indices\n\
00302 @code{floor ([size(@var{B})/2] + 1)}.\n\
00303 \n\
00304 @item @var{shape} = \"valid\"\n\
00305 Return only the parts which do not include zero-padded edges.\n\
00306 The size of the result is @code{max (size (A) - size (B) + 1, 0)}.\n\
00307 @end table\n\
00308 \n\
00309 @seealso{conv2, conv}\n\
00310 @end deftypefn")
00311 {
00312   octave_value retval;
00313   octave_value tmp;
00314   int nargin = args.length ();
00315   std::string shape = "full";   // default
00316   convn_type ct;
00317 
00318   if (nargin < 2 || nargin > 3)
00319     {
00320      print_usage ();
00321      return retval;
00322     }
00323   else if (nargin == 3)
00324     {
00325       if (args(2).is_string ())
00326         shape = args(2).string_value ();
00327     }
00328 
00329   if (shape == "full")
00330     ct = convn_full;
00331   else if (shape == "same")
00332     ct = convn_same;
00333   else if (shape == "valid")
00334     ct = convn_valid;
00335   else
00336     {
00337       error ("convn: SHAPE type not valid");
00338       print_usage ();
00339       return retval;
00340     }
00341 
00342   if (args(0).is_single_type () || args(1).is_single_type ())
00343     {
00344       if (args(0).is_complex_type () || args(1).is_complex_type ())
00345         {
00346           FloatComplexNDArray a (args(0).float_complex_array_value ());
00347           if (args(1).is_real_type ())
00348             {
00349               FloatNDArray b (args(1).float_array_value ());
00350               retval = convn (a, b, ct);
00351             }
00352           else
00353             {
00354               FloatComplexNDArray b (args(1).float_complex_array_value ());
00355               retval = convn (a, b, ct);
00356             }
00357         }
00358       else
00359         {
00360           FloatNDArray a (args(0).float_array_value ());
00361           FloatNDArray b (args(1).float_array_value ());
00362           retval = convn (a, b, ct);
00363         }
00364     }
00365   else
00366     {
00367       if (args(0).is_complex_type () || args(1).is_complex_type ())
00368         {
00369           ComplexNDArray a (args(0).complex_array_value ());
00370           if (args(1).is_real_type ())
00371             {
00372               NDArray b (args(1).array_value ());
00373               retval = convn (a, b, ct);
00374             }
00375           else
00376             {
00377               ComplexNDArray b (args(1).complex_array_value ());
00378               retval = convn (a, b, ct);
00379             }
00380         }
00381       else
00382         {
00383           NDArray a (args(0).array_value ());
00384           NDArray b (args(1).array_value ());
00385           retval = convn (a, b, ct);
00386         }
00387     }
00388 
00389    return retval;
00390 }
00391 
00392 /*
00393  FIXME: Need tests for convn in addition to conv2.
00394 */
 All Classes Files Functions Variables Typedefs Enumerations Enumerator Friends Defines