MSparse.cc

Go to the documentation of this file.
00001 /*
00002 
00003 Copyright (C) 2004-2012 David Bateman
00004 Copyright (C) 1998-2004 Andy Adler
00005 
00006 This file is part of Octave.
00007 
00008 Octave is free software; you can redistribute it and/or modify it
00009 under the terms of the GNU General Public License as published by the
00010 Free Software Foundation; either version 3 of the License, or (at your
00011 option) any later version.
00012 
00013 Octave is distributed in the hope that it will be useful, but WITHOUT
00014 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
00015 FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
00016 for more details.
00017 
00018 You should have received a copy of the GNU General Public License
00019 along with Octave; see the file COPYING.  If not, see
00020 <http://www.gnu.org/licenses/>.
00021 
00022 */
00023 
00024 #ifdef HAVE_CONFIG_H
00025 #include <config.h>
00026 #endif
00027 
00028 #include <functional>
00029 
00030 #include "quit.h"
00031 #include "lo-error.h"
00032 #include "MArray.h"
00033 #include "Array-util.h"
00034 
00035 #include "MSparse.h"
00036 #include "MSparse-defs.h"
00037 
00038 // sparse array with math ops.
00039 
00040 // Element by element MSparse by MSparse ops.
00041 
00042 template <class T, class OP>
00043 MSparse<T>&
00044 plus_or_minus (MSparse<T>& a, const MSparse<T>& b, OP op, const char* op_name)
00045 {
00046     MSparse<T> r;
00047 
00048     octave_idx_type a_nr = a.rows ();
00049     octave_idx_type a_nc = a.cols ();
00050 
00051     octave_idx_type b_nr = b.rows ();
00052     octave_idx_type b_nc = b.cols ();
00053 
00054     if (a_nr != b_nr || a_nc != b_nc)
00055       gripe_nonconformant (op_name , a_nr, a_nc, b_nr, b_nc);
00056     else
00057       {
00058         r = MSparse<T> (a_nr, a_nc, (a.nnz () + b.nnz ()));
00059 
00060         octave_idx_type jx = 0;
00061         for (octave_idx_type i = 0 ; i < a_nc ; i++)
00062           {
00063             octave_idx_type  ja = a.cidx(i);
00064             octave_idx_type  ja_max = a.cidx(i+1);
00065             bool ja_lt_max= ja < ja_max;
00066 
00067             octave_idx_type  jb = b.cidx(i);
00068             octave_idx_type  jb_max = b.cidx(i+1);
00069             bool jb_lt_max = jb < jb_max;
00070 
00071             while (ja_lt_max || jb_lt_max )
00072               {
00073                 octave_quit ();
00074                 if ((! jb_lt_max) ||
00075                       (ja_lt_max && (a.ridx(ja) < b.ridx(jb))))
00076                   {
00077                     r.ridx(jx) = a.ridx(ja);
00078                     r.data(jx) = op (a.data(ja), 0.);
00079                     jx++;
00080                     ja++;
00081                     ja_lt_max= ja < ja_max;
00082                   }
00083                 else if (( !ja_lt_max ) ||
00084                      (jb_lt_max && (b.ridx(jb) < a.ridx(ja)) ) )
00085                   {
00086                     r.ridx(jx) = b.ridx(jb);
00087                     r.data(jx) = op (0., b.data(jb));
00088                     jx++;
00089                     jb++;
00090                     jb_lt_max= jb < jb_max;
00091                   }
00092                 else
00093                   {
00094                      if (op (a.data(ja), b.data(jb)) != 0.)
00095                        {
00096                           r.data(jx) = op (a.data(ja), b.data(jb));
00097                           r.ridx(jx) = a.ridx(ja);
00098                           jx++;
00099                        }
00100                      ja++;
00101                      ja_lt_max= ja < ja_max;
00102                      jb++;
00103                      jb_lt_max= jb < jb_max;
00104                   }
00105               }
00106             r.cidx(i+1) = jx;
00107           }
00108 
00109         a = r.maybe_compress ();
00110       }
00111 
00112     return a;
00113 }
00114 
00115 template <typename T>
00116 MSparse<T>&
00117 operator += (MSparse<T>& a, const MSparse<T>& b)
00118 {
00119   return plus_or_minus (a, b, std::plus<T> (), "operator +=");
00120 }
00121 
00122 template <typename T>
00123 MSparse<T>&
00124 operator -= (MSparse<T>& a, const MSparse<T>& b)
00125 {
00126   return plus_or_minus (a, b, std::minus<T> (), "operator -=");
00127 }
00128 
00129 
00130 // Element by element MSparse by scalar ops.
00131 
00132 template <class T, class OP>
00133 MArray<T>
00134 plus_or_minus (const MSparse<T>& a, const T& s, OP op)
00135 {
00136   octave_idx_type nr = a.rows ();
00137   octave_idx_type nc = a.cols ();
00138 
00139   MArray<T> r (dim_vector (nr, nc), op (0.0, s));
00140 
00141   for (octave_idx_type j = 0; j < nc; j++)
00142     for (octave_idx_type i = a.cidx(j); i < a.cidx(j+1); i++)
00143       r.elem (a.ridx (i), j) = op (a.data (i), s);
00144   return r;
00145 }
00146 
00147 template <typename T>
00148 MArray<T>
00149 operator + (const MSparse<T>& a, const T& s)
00150 {
00151   return plus_or_minus (a, s, std::plus<T> ());
00152 }
00153 
00154 template <typename T>
00155 MArray<T>
00156 operator - (const MSparse<T>& a, const T& s)
00157 {
00158   return plus_or_minus (a, s, std::minus<T> ());
00159 }
00160 
00161 
00162 template <class T, class OP>
00163 MSparse<T>
00164 times_or_divide (const MSparse<T>& a, const T& s, OP op)
00165 {
00166   octave_idx_type nr = a.rows ();
00167   octave_idx_type nc = a.cols ();
00168   octave_idx_type nz = a.nnz ();
00169 
00170   MSparse<T> r (nr, nc, nz);
00171 
00172   for (octave_idx_type i = 0; i < nz; i++)
00173     {
00174       r.data(i) = op (a.data(i), s);
00175       r.ridx(i) = a.ridx(i);
00176     }
00177   for (octave_idx_type i = 0; i < nc + 1; i++)
00178     r.cidx(i) = a.cidx(i);
00179   r.maybe_compress (true);
00180   return r;
00181 }
00182 
00183 template <typename T>
00184 MSparse<T>
00185 operator * (const MSparse<T>& a, const T& s)
00186 {
00187   return times_or_divide (a, s, std::multiplies<T> ());
00188 }
00189 
00190 template <typename T>
00191 MSparse<T>
00192 operator / (const MSparse<T>& a, const T& s)
00193 {
00194   return times_or_divide (a, s, std::divides<T> ());
00195 }
00196 
00197 
00198 // Element by element scalar by MSparse ops.
00199 
00200 template <class T, class OP>
00201 MArray<T>
00202 plus_or_minus (const T& s, const MSparse<T>& a, OP op)
00203 {
00204   octave_idx_type nr = a.rows ();
00205   octave_idx_type nc = a.cols ();
00206 
00207   MArray<T> r (dim_vector (nr, nc), op (s, 0.0));
00208 
00209   for (octave_idx_type j = 0; j < nc; j++)
00210     for (octave_idx_type i = a.cidx(j); i < a.cidx(j+1); i++)
00211       r.elem (a.ridx (i), j) = op (s, a.data (i));
00212   return r;
00213 }
00214 
00215 template <typename T>
00216 MArray<T>
00217 operator + (const T& s, const MSparse<T>& a)
00218 {
00219   return plus_or_minus (s, a, std::plus<T> ());
00220 }
00221 
00222 template <typename T>
00223 MArray<T>
00224 operator - (const T& s, const MSparse<T>& a)
00225 {
00226   return plus_or_minus (s, a, std::minus<T> ());
00227 }
00228 
00229 template <class T, class OP>
00230 MSparse<T>
00231 times_or_divides (const T& s, const MSparse<T>& a, OP op)
00232 {
00233   octave_idx_type nr = a.rows ();
00234   octave_idx_type nc = a.cols ();
00235   octave_idx_type nz = a.nnz ();
00236 
00237   MSparse<T> r (nr, nc, nz);
00238 
00239   for (octave_idx_type i = 0; i < nz; i++)
00240     {
00241       r.data(i) = op (s, a.data(i));
00242       r.ridx(i) = a.ridx(i);
00243     }
00244   for (octave_idx_type i = 0; i < nc + 1; i++)
00245     r.cidx(i) = a.cidx(i);
00246   r.maybe_compress (true);
00247   return r;
00248 }
00249 
00250 template <class T>
00251 MSparse<T>
00252 operator * (const T& s, const MSparse<T>& a)
00253 {
00254   return times_or_divides (s, a, std::multiplies<T> ());
00255 }
00256 
00257 template <class T>
00258 MSparse<T>
00259 operator / (const T& s, const MSparse<T>& a)
00260 {
00261   return times_or_divides (s, a, std::divides<T> ());
00262 }
00263 
00264 
00265 // Element by element MSparse by MSparse ops.
00266 
00267 template <class T, class OP>
00268 MSparse<T>
00269 plus_or_minus (const MSparse<T>& a, const MSparse<T>& b, OP op,
00270                const char* op_name, bool negate)
00271 {
00272   MSparse<T> r;
00273 
00274   octave_idx_type a_nr = a.rows ();
00275   octave_idx_type a_nc = a.cols ();
00276 
00277   octave_idx_type b_nr = b.rows ();
00278   octave_idx_type b_nc = b.cols ();
00279 
00280   if (a_nr == 1 && a_nc == 1)
00281     {
00282       if (a.elem(0,0) == 0.)
00283         if (negate)
00284           r = -MSparse<T> (b);
00285         else
00286           r = MSparse<T> (b);
00287       else
00288         {
00289           r = MSparse<T> (b_nr, b_nc, op (a.data(0), 0.));
00290 
00291           for (octave_idx_type j = 0 ; j < b_nc ; j++)
00292             {
00293               octave_quit ();
00294               octave_idx_type idxj = j * b_nr;
00295               for (octave_idx_type i = b.cidx(j) ; i < b.cidx(j+1) ; i++)
00296                 {
00297                   octave_quit ();
00298                   r.data(idxj + b.ridx(i)) = op (a.data(0), b.data(i));
00299                 }
00300             }
00301           r.maybe_compress ();
00302         }
00303     }
00304   else if (b_nr == 1 && b_nc == 1)
00305     {
00306       if (b.elem(0,0) == 0.)
00307         r = MSparse<T> (a);
00308       else
00309         {
00310           r = MSparse<T> (a_nr, a_nc, op (0.0, b.data(0)));
00311 
00312           for (octave_idx_type j = 0 ; j < a_nc ; j++)
00313             {
00314               octave_quit ();
00315               octave_idx_type idxj = j * a_nr;
00316               for (octave_idx_type i = a.cidx(j) ; i < a.cidx(j+1) ; i++)
00317                 {
00318                   octave_quit ();
00319                   r.data(idxj + a.ridx(i)) = op (a.data(i), b.data(0));
00320                 }
00321             }
00322           r.maybe_compress ();
00323         }
00324     }
00325   else if (a_nr != b_nr || a_nc != b_nc)
00326     gripe_nonconformant (op_name, a_nr, a_nc, b_nr, b_nc);
00327   else
00328     {
00329       r = MSparse<T> (a_nr, a_nc, (a.nnz () + b.nnz ()));
00330 
00331       octave_idx_type jx = 0;
00332       r.cidx (0) = 0;
00333       for (octave_idx_type i = 0 ; i < a_nc ; i++)
00334         {
00335           octave_idx_type  ja = a.cidx(i);
00336           octave_idx_type  ja_max = a.cidx(i+1);
00337           bool ja_lt_max= ja < ja_max;
00338 
00339           octave_idx_type  jb = b.cidx(i);
00340           octave_idx_type  jb_max = b.cidx(i+1);
00341           bool jb_lt_max = jb < jb_max;
00342 
00343           while (ja_lt_max || jb_lt_max )
00344             {
00345               octave_quit ();
00346               if ((! jb_lt_max) ||
00347                   (ja_lt_max && (a.ridx(ja) < b.ridx(jb))))
00348                 {
00349                   r.ridx(jx) = a.ridx(ja);
00350                   r.data(jx) = op (a.data(ja), 0.);
00351                   jx++;
00352                   ja++;
00353                   ja_lt_max= ja < ja_max;
00354                 }
00355               else if (( !ja_lt_max ) ||
00356                        (jb_lt_max && (b.ridx(jb) < a.ridx(ja)) ) )
00357                 {
00358                   r.ridx(jx) = b.ridx(jb);
00359                   r.data(jx) = op (0.,  b.data(jb));
00360                   jx++;
00361                   jb++;
00362                   jb_lt_max= jb < jb_max;
00363                 }
00364               else
00365                 {
00366                   if (op (a.data(ja), b.data(jb)) != 0.)
00367                     {
00368                       r.data(jx) = op (a.data(ja), b.data(jb));
00369                       r.ridx(jx) = a.ridx(ja);
00370                       jx++;
00371                     }
00372                   ja++;
00373                   ja_lt_max= ja < ja_max;
00374                   jb++;
00375                   jb_lt_max= jb < jb_max;
00376                 }
00377             }
00378           r.cidx(i+1) = jx;
00379         }
00380 
00381       r.maybe_compress ();
00382     }
00383 
00384   return r;
00385 }
00386 
00387 template <class T>
00388 MSparse<T>
00389 operator+ (const MSparse<T>& a, const MSparse<T>& b)
00390 {
00391   return plus_or_minus (a, b, std::plus<T> (), "operator +", false);
00392 }
00393 
00394 template <class T>
00395 MSparse<T>
00396 operator- (const MSparse<T>& a, const MSparse<T>& b)
00397 {
00398   return plus_or_minus (a, b, std::minus<T> (), "operator -", true);
00399 }
00400 
00401 template <class T>
00402 MSparse<T>
00403 product (const MSparse<T>& a, const MSparse<T>& b)
00404 {
00405   MSparse<T> r;
00406 
00407   octave_idx_type a_nr = a.rows ();
00408   octave_idx_type a_nc = a.cols ();
00409 
00410   octave_idx_type b_nr = b.rows ();
00411   octave_idx_type b_nc = b.cols ();
00412 
00413   if (a_nr == 1 && a_nc == 1)
00414     {
00415       if (a.elem(0,0) == 0.)
00416         r = MSparse<T> (b_nr, b_nc);
00417       else
00418         {
00419           r = MSparse<T> (b);
00420           octave_idx_type b_nnz = b.nnz();
00421 
00422           for (octave_idx_type i = 0 ; i < b_nnz ; i++)
00423             {
00424               octave_quit ();
00425               r.data (i) = a.data(0) * r.data(i);
00426             }
00427           r.maybe_compress ();
00428         }
00429     }
00430   else if (b_nr == 1 && b_nc == 1)
00431     {
00432       if (b.elem(0,0) == 0.)
00433         r = MSparse<T> (a_nr, a_nc);
00434       else
00435         {
00436           r = MSparse<T> (a);
00437           octave_idx_type a_nnz = a.nnz();
00438 
00439           for (octave_idx_type i = 0 ; i < a_nnz ; i++)
00440             {
00441               octave_quit ();
00442               r.data (i) = r.data(i) * b.data(0);
00443             }
00444           r.maybe_compress ();
00445         }
00446     }
00447   else if (a_nr != b_nr || a_nc != b_nc)
00448     gripe_nonconformant ("product", a_nr, a_nc, b_nr, b_nc);
00449   else
00450     {
00451       r = MSparse<T> (a_nr, a_nc, (a.nnz () > b.nnz () ? a.nnz () : b.nnz ()));
00452 
00453       octave_idx_type jx = 0;
00454       r.cidx (0) = 0;
00455       for (octave_idx_type i = 0 ; i < a_nc ; i++)
00456         {
00457           octave_idx_type  ja = a.cidx(i);
00458           octave_idx_type  ja_max = a.cidx(i+1);
00459           bool ja_lt_max= ja < ja_max;
00460 
00461           octave_idx_type  jb = b.cidx(i);
00462           octave_idx_type  jb_max = b.cidx(i+1);
00463           bool jb_lt_max = jb < jb_max;
00464 
00465           while (ja_lt_max || jb_lt_max )
00466             {
00467               octave_quit ();
00468               if ((! jb_lt_max) ||
00469                   (ja_lt_max && (a.ridx(ja) < b.ridx(jb))))
00470                 {
00471                   ja++; ja_lt_max= ja < ja_max;
00472                 }
00473               else if (( !ja_lt_max ) ||
00474                        (jb_lt_max && (b.ridx(jb) < a.ridx(ja)) ) )
00475                 {
00476                   jb++; jb_lt_max= jb < jb_max;
00477                 }
00478               else
00479                 {
00480                   if ((a.data(ja) * b.data(jb)) != 0.)
00481                     {
00482                       r.data(jx) = a.data(ja) * b.data(jb);
00483                       r.ridx(jx) = a.ridx(ja);
00484                       jx++;
00485                     }
00486                   ja++; ja_lt_max= ja < ja_max;
00487                   jb++; jb_lt_max= jb < jb_max;
00488                 }
00489             }
00490           r.cidx(i+1) = jx;
00491         }
00492 
00493       r.maybe_compress ();
00494     }
00495 
00496   return r;
00497 }
00498 
00499 template <class T>
00500 MSparse<T>
00501 quotient (const MSparse<T>& a, const MSparse<T>& b)
00502 {
00503   MSparse<T> r;
00504   T Zero = T ();
00505 
00506   octave_idx_type a_nr = a.rows ();
00507   octave_idx_type a_nc = a.cols ();
00508 
00509   octave_idx_type b_nr = b.rows ();
00510   octave_idx_type b_nc = b.cols ();
00511 
00512   if (a_nr == 1 && a_nc == 1)
00513     {
00514       T val = a.elem (0,0);
00515       T fill = val / T();
00516       if (fill == T())
00517         {
00518           octave_idx_type b_nnz = b.nnz();
00519           r = MSparse<T> (b);
00520           for (octave_idx_type i = 0 ; i < b_nnz ; i++)
00521             r.data (i) = val / r.data(i);
00522           r.maybe_compress ();
00523         }
00524       else
00525         {
00526           r = MSparse<T> (b_nr, b_nc, fill);
00527           for (octave_idx_type j = 0 ; j < b_nc ; j++)
00528             {
00529               octave_quit ();
00530               octave_idx_type idxj = j * b_nr;
00531               for (octave_idx_type i = b.cidx(j) ; i < b.cidx(j+1) ; i++)
00532                 {
00533                   octave_quit ();
00534                   r.data(idxj + b.ridx(i)) = val / b.data(i);
00535                 }
00536             }
00537           r.maybe_compress ();
00538         }
00539     }
00540   else if (b_nr == 1 && b_nc == 1)
00541     {
00542       T val = b.elem (0,0);
00543       T fill = T() / val;
00544       if (fill == T())
00545         {
00546           octave_idx_type a_nnz = a.nnz();
00547           r = MSparse<T> (a);
00548           for (octave_idx_type i = 0 ; i < a_nnz ; i++)
00549             r.data (i) = r.data(i) / val;
00550           r.maybe_compress ();
00551         }
00552       else
00553         {
00554           r = MSparse<T> (a_nr, a_nc, fill);
00555           for (octave_idx_type j = 0 ; j < a_nc ; j++)
00556             {
00557               octave_quit ();
00558               octave_idx_type idxj = j * a_nr;
00559               for (octave_idx_type i = a.cidx(j) ; i < a.cidx(j+1) ; i++)
00560                 {
00561                   octave_quit ();
00562                   r.data(idxj + a.ridx(i)) = a.data(i) / val;
00563                 }
00564             }
00565           r.maybe_compress ();
00566         }
00567     }
00568   else if (a_nr != b_nr || a_nc != b_nc)
00569     gripe_nonconformant ("quotient", a_nr, a_nc, b_nr, b_nc);
00570   else
00571     {
00572       r = MSparse<T>( a_nr, a_nc, (Zero / Zero));
00573 
00574       for (octave_idx_type i = 0 ; i < a_nc ; i++)
00575         {
00576           octave_idx_type  ja = a.cidx(i);
00577           octave_idx_type  ja_max = a.cidx(i+1);
00578           bool ja_lt_max= ja < ja_max;
00579 
00580           octave_idx_type  jb = b.cidx(i);
00581           octave_idx_type  jb_max = b.cidx(i+1);
00582           bool jb_lt_max = jb < jb_max;
00583 
00584           while (ja_lt_max || jb_lt_max )
00585             {
00586               octave_quit ();
00587               if ((! jb_lt_max) ||
00588                   (ja_lt_max && (a.ridx(ja) < b.ridx(jb))))
00589                 {
00590                   r.elem (a.ridx(ja),i) = a.data(ja) / Zero;
00591                   ja++; ja_lt_max= ja < ja_max;
00592                 }
00593               else if (( !ja_lt_max ) ||
00594                        (jb_lt_max && (b.ridx(jb) < a.ridx(ja)) ) )
00595                 {
00596                   r.elem (b.ridx(jb),i) = Zero / b.data(jb);
00597                   jb++; jb_lt_max= jb < jb_max;
00598                 }
00599               else
00600                 {
00601                   r.elem (a.ridx(ja),i) = a.data(ja) / b.data(jb);
00602                   ja++; ja_lt_max= ja < ja_max;
00603                   jb++; jb_lt_max= jb < jb_max;
00604                 }
00605             }
00606         }
00607 
00608       r.maybe_compress (true);
00609     }
00610 
00611   return r;
00612 }
00613 
00614 
00615 
00616 // Unary MSparse ops.
00617 
00618 template <class T>
00619 MSparse<T>
00620 operator + (const MSparse<T>& a)
00621 {
00622   return a;
00623 }
00624 
00625 template <class T>
00626 MSparse<T>
00627 operator - (const MSparse<T>& a)
00628 {
00629   MSparse<T> retval (a);
00630   octave_idx_type nz = a.nnz ();
00631   for (octave_idx_type i = 0; i < nz; i++)
00632     retval.data(i) = - retval.data(i);
00633   return retval;
00634 }
 All Classes Files Functions Variables Typedefs Enumerations Enumerator Friends Defines