Sparse-diag-op-defs.h

Go to the documentation of this file.
00001 /* -*- C++ -*-
00002 
00003 Copyright (C) 2009-2012 Jason Riedy, Jaroslav Hajek
00004 
00005 This file is part of Octave.
00006 
00007 Octave is free software; you can redistribute it and/or modify it
00008 under the terms of the GNU General Public License as published by the
00009 Free Software Foundation; either version 3 of the License, or (at your
00010 option) any later version.
00011 
00012 Octave is distributed in the hope that it will be useful, but WITHOUT
00013 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
00014 FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
00015 for more details.
00016 
00017 You should have received a copy of the GNU General Public License
00018 along with Octave; see the file COPYING.  If not, see
00019 <http://www.gnu.org/licenses/>.
00020 
00021 */
00022 
00023 #if !defined (octave_sparse_diag_op_defs_h)
00024 #define octave_sparse_diag_op_defs_h 1
00025 
00026 // Matrix multiplication
00027 
00028 template <typename RT, typename DM, typename SM>
00029 RT do_mul_dm_sm (const DM& d, const SM& a)
00030 {
00031   const octave_idx_type nr = d.rows ();
00032   const octave_idx_type nc = d.cols ();
00033 
00034   const octave_idx_type a_nr = a.rows ();
00035   const octave_idx_type a_nc = a.cols ();
00036 
00037   if (nc != a_nr)
00038     {
00039       gripe_nonconformant ("operator *", nr, nc, a_nr, a_nc);
00040       return RT ();
00041     }
00042   else
00043    {
00044      RT r (nr, a_nc, a.nnz ());
00045 
00046      octave_idx_type l = 0;
00047 
00048      for (octave_idx_type j = 0; j < a_nc; j++)
00049        {
00050          r.xcidx (j) = l;
00051          const octave_idx_type colend = a.cidx (j+1);
00052          for (octave_idx_type k = a.cidx (j); k < colend; k++)
00053            {
00054              const octave_idx_type i = a.ridx (k);
00055              if (i >= nr) break;
00056              r.xdata (l) = d.dgelem (i) * a.data (k);
00057              r.xridx (l) = i;
00058              l++;
00059            }
00060        }
00061 
00062      r.xcidx (a_nc) = l;
00063 
00064      r.maybe_compress (true);
00065      return r;
00066    }
00067 }
00068 
00069 template <typename RT, typename SM, typename DM>
00070 RT do_mul_sm_dm (const SM& a, const DM& d)
00071 {
00072   const octave_idx_type nr = d.rows ();
00073   const octave_idx_type nc = d.cols ();
00074 
00075   const octave_idx_type a_nr = a.rows ();
00076   const octave_idx_type a_nc = a.cols ();
00077 
00078   if (nr != a_nc)
00079     {
00080       gripe_nonconformant ("operator *", a_nr, a_nc, nr, nc);
00081       return RT ();
00082     }
00083   else
00084    {
00085 
00086      const octave_idx_type mnc = nc < a_nc ? nc: a_nc;
00087      RT r (a_nr, nc, a.cidx (mnc));
00088 
00089      for (octave_idx_type j = 0; j < mnc; ++j)
00090        {
00091          const typename DM::element_type s = d.dgelem (j);
00092          const octave_idx_type colend = a.cidx (j+1);
00093          r.xcidx (j) = a.cidx (j);
00094          for (octave_idx_type k = a.cidx (j); k < colend; ++k)
00095            {
00096              r.xdata (k) = s * a.data (k);
00097              r.xridx (k) = a.ridx (k);
00098            }
00099        }
00100      for (octave_idx_type j = mnc; j <= nc; ++j)
00101        r.xcidx (j) = a.cidx (mnc);
00102 
00103      r.maybe_compress (true);
00104      return r;
00105    }
00106 }
00107 
00108 // FIXME: functors such as this should be gathered somewhere
00109 template <typename T>
00110 struct identity_val
00111   : public std::unary_function <T, T>
00112 {
00113   T operator () (const T x) { return x; }
00114 };
00115 
00116 // Matrix addition
00117 
00118 template <typename RT, typename SM, typename DM, typename OpA, typename OpD>
00119 RT inner_do_add_sm_dm (const SM& a, const DM& d, OpA opa, OpD opd)
00120 {
00121   using std::min;
00122   const octave_idx_type nr = d.rows ();
00123   const octave_idx_type nc = d.cols ();
00124   const octave_idx_type n = min (nr, nc);
00125 
00126   const octave_idx_type a_nr = a.rows ();
00127   const octave_idx_type a_nc = a.cols ();
00128 
00129   const octave_idx_type nz = a.nnz ();
00130   RT r (a_nr, a_nc, nz + n);
00131   octave_idx_type k = 0;
00132 
00133   for (octave_idx_type j = 0; j < nc; ++j)
00134     {
00135       octave_quit ();
00136       const octave_idx_type colend = a.cidx (j+1);
00137       r.xcidx (j) = k;
00138       octave_idx_type k_src = a.cidx (j), k_split;
00139 
00140       for (k_split = k_src; k_split < colend; k_split++)
00141         if (a.ridx (k_split) >= j)
00142           break;
00143 
00144       for (; k_src < k_split; k_src++, k++)
00145         {
00146           r.xridx (k) = a.ridx (k_src);
00147           r.xdata (k) = opa (a.data (k_src));
00148         }
00149 
00150       if (k_src < colend && a.ridx (k_src) == j)
00151         {
00152           r.xridx (k) = j;
00153           r.xdata (k) = opa (a.data (k_src)) + opd (d.dgelem (j));
00154           k++; k_src++;
00155         }
00156       else
00157         {
00158           r.xridx (k) = j;
00159           r.xdata (k) = opd (d.dgelem (j));
00160           k++;
00161         }
00162 
00163       for (; k_src < colend; k_src++, k++)
00164         {
00165           r.xridx (k) = a.ridx (k_src);
00166           r.xdata (k) = opa (a.data (k_src));
00167         }
00168 
00169     }
00170   r.xcidx (nc) = k;
00171 
00172   r.maybe_compress (true);
00173   return r;
00174 }
00175 
00176 template <typename RT, typename DM, typename SM>
00177 RT do_commutative_add_dm_sm (const DM& d, const SM& a)
00178 {
00179   // Extra function to ensure this is only emitted once.
00180   return inner_do_add_sm_dm<RT> (a, d,
00181                                  identity_val<typename SM::element_type> (),
00182                                  identity_val<typename DM::element_type> ());
00183 }
00184 
00185 template <typename RT, typename DM, typename SM>
00186 RT do_add_dm_sm (const DM& d, const SM& a)
00187 {
00188   if (a.rows () != d.rows () || a.cols () != d.cols ())
00189     {
00190       gripe_nonconformant ("operator +", d.rows (), d.cols (), a.rows (), a.cols ());
00191       return RT ();
00192     }
00193   else
00194     return do_commutative_add_dm_sm<RT> (d, a);
00195 }
00196 
00197 template <typename RT, typename DM, typename SM>
00198 RT do_sub_dm_sm (const DM& d, const SM& a)
00199 {
00200   if (a.rows () != d.rows () || a.cols () != d.cols ())
00201     {
00202       gripe_nonconformant ("operator -", d.rows (), d.cols (), a.rows (), a.cols ());
00203       return RT ();
00204     }
00205   else
00206     return inner_do_add_sm_dm<RT> (a, d, std::negate<typename SM::element_type> (),
00207                                    identity_val<typename DM::element_type> ());
00208 }
00209 
00210 template <typename RT, typename SM, typename DM>
00211 RT do_add_sm_dm (const SM& a, const DM& d)
00212 {
00213   if (a.rows () != d.rows () || a.cols () != d.cols ())
00214     {
00215       gripe_nonconformant ("operator +", a.rows (), a.cols (), d.rows (), d.cols ());
00216       return RT ();
00217     }
00218   else
00219     return do_commutative_add_dm_sm<RT> (d, a);
00220 }
00221 
00222 template <typename RT, typename SM, typename DM>
00223 RT do_sub_sm_dm (const SM& a, const DM& d)
00224 {
00225   if (a.rows () != d.rows () || a.cols () != d.cols ())
00226     {
00227       gripe_nonconformant ("operator -", a.rows (), a.cols (), d.rows (), d.cols ());
00228       return RT ();
00229     }
00230   else
00231     return inner_do_add_sm_dm<RT> (a, d,
00232                                    identity_val<typename SM::element_type> (),
00233                                    std::negate<typename DM::element_type> ());
00234 }
00235 
00236 #endif // octave_sparse_diag_op_defs_h
 All Classes Files Functions Variables Typedefs Enumerations Enumerator Friends Defines