NGSolve  4.9
linalg/basematrix.hpp
00001 #ifndef FILE_NGS_BASEMATRIX
00002 #define FILE_NGS_BASEMATRIX
00003 
00004 
00005 /*********************************************************************/
00006 /* File:   basematrix.hpp                                            */
00007 /* Author: Joachim Schoeberl                                         */
00008 /* Date:   25. Mar. 2000                                             */
00009 /*********************************************************************/
00010 
00011 namespace ngla
00012 {
00013 
00014 
00015   // sets the solver which is used for InverseMatrix
00016   enum INVERSETYPE { PARDISO, PARDISOSPD, SPARSECHOLESKY, SUPERLU, SUPERLU_DIST, MUMPS, MASTERINVERSE };
00017 
00018 
00022   class NGS_DLL_HEADER BaseMatrix
00023   {
00024   protected:
00025     const ParallelDofs * paralleldofs;
00026 
00027   public:
00029     BaseMatrix ();
00031     // BaseMatrix (const BaseMatrix & amat);
00032     //
00033     BaseMatrix (ParallelDofs * aparalleldofs); 
00035     virtual ~BaseMatrix ();
00036   
00038     virtual int VHeight() const;
00039 
00041     virtual int VWidth() const;
00042 
00044     int Height() const
00045     {
00046       return VHeight();
00047     }
00048   
00050     int Width() const
00051     {
00052       return VWidth();
00053     }
00054 
00056     BaseMatrix & operator= (double s)
00057     {
00058       AsVector().SetScalar(s);
00059       return *this;
00060     }
00061 
00063     virtual BaseVector & AsVector();
00065     virtual const BaseVector & AsVector() const;
00066   
00067     virtual ostream & Print (ostream & ost) const;
00068     virtual void MemoryUsage (Array<MemoryUsageStruct*> & mu) const;
00069 
00070     // virtual const void * Data() const;
00071     // virtual void * Data();
00072 
00074     virtual BaseMatrix * CreateMatrix () const;
00076     // virtual BaseMatrix * CreateMatrix (const Array<int> & elsperrow) const;
00078     virtual BaseVector * CreateRowVector () const;
00080     virtual BaseVector * CreateColVector () const;
00082     virtual BaseVector * CreateVector () const;
00083 
00085     virtual void Mult (const BaseVector & x, BaseVector & y) const;
00087     virtual void MultAdd (double s, const BaseVector & x, BaseVector & y) const;
00089     virtual void MultAdd (Complex s, const BaseVector & x, BaseVector & y) const;
00090   
00092     virtual void MultTransAdd (double s, const BaseVector & x, BaseVector & y) const;
00094     virtual void MultTransAdd (Complex s, const BaseVector & x, BaseVector & y) const;
00095 
00096 
00097 
00098 
00103     virtual void MultAdd1 (double s, const BaseVector & x, BaseVector & y,
00104                            const BitArray * ainner = NULL,
00105                            const Array<int> * acluster = NULL) const;
00106 
00108     virtual void MultAdd2 (double s, const BaseVector & x, BaseVector & y,
00109                            const BitArray * ainner = NULL,
00110                            const Array<int> * acluster = NULL) const;
00111 
00112 
00113     void SetParallelDofs (const ParallelDofs * pardofs) { paralleldofs = pardofs; }
00114     const ParallelDofs * GetParallelDofs () const { return paralleldofs; }
00115 
00116     virtual BaseMatrix * InverseMatrix (const BitArray * subset = 0) const;
00117     virtual BaseMatrix * InverseMatrix (const Array<int> * clusters) const;
00118     virtual INVERSETYPE SetInverseType ( INVERSETYPE ainversetype ) const;
00119     virtual INVERSETYPE SetInverseType ( string ainversetype ) const;
00120     virtual INVERSETYPE  GetInverseType () const;
00121     
00122     
00123   };
00124 
00125 
00126 
00127 
00128 
00129 
00131   template <typename SCAL>
00132   class NGS_DLL_HEADER S_BaseMatrix : virtual public BaseMatrix
00133   {
00134   public:
00136     S_BaseMatrix ();
00138     virtual ~S_BaseMatrix ();
00139   };
00140 
00141   // specifies the scalar type Complex.
00142   template <>
00143   class S_BaseMatrix<Complex> : virtual public BaseMatrix
00144   {
00145   public:
00147     S_BaseMatrix ();
00149     virtual ~S_BaseMatrix ();
00150 
00152     virtual void MultAdd (double s, const BaseVector & x, BaseVector & y) const;
00154     virtual void MultAdd (Complex s, const BaseVector & x, BaseVector & y) const;
00155   
00157     virtual void MultTransAdd (double s, const BaseVector & x, BaseVector & y) const;
00159     virtual void MultTransAdd (Complex s, const BaseVector & x, BaseVector & y) const;
00160   };
00161 
00162 
00163 
00164 
00165 
00166 
00167 
00168   /* *************************** Matrix * Vector ******************** */
00169 
00170 
00172   class VMatVecExpr
00173   {
00174     const BaseMatrix & m;
00175     const BaseVector & x;
00176   
00177   public:
00178     VMatVecExpr (const BaseMatrix & am, const BaseVector & ax) : m(am), x(ax) { ; }
00179 
00180     template <class TS>
00181     void AssignTo (TS s, BaseVector & v) const
00182     { 
00183 #ifdef DEBUG
00184       if (m.Height() != v.Size() || m.Width() != x.Size())
00185         throw Exception (ToString ("matrix-vector: size does not fit\n") +
00186                          "Matrix:     " + ToString(m.Height()) + " x " + ToString(m.Width()) + "\n"
00187                          "Vector in : " + ToString(x.Size()) + "\n"
00188                          "Vector res: " + ToString(v.Size()));
00189 #endif
00190       m.Mult (x, v);
00191       v *= s;
00192     }
00193 
00194     template <class TS>
00195     void AddTo (TS s, BaseVector & v) const
00196     { 
00197 #ifdef DEBUG
00198       if (m.Height() != v.Size() || m.Width() != x.Size())
00199         throw Exception ("matrix-vector MultAdd: size does not fit");
00200 #endif
00201       m.MultAdd (s, x, v);
00202     }
00203   };
00204 
00205 
00207   inline VVecExpr<VMatVecExpr>
00208   operator* (const BaseMatrix & a, const BaseVector & b)
00209   {
00210     return VMatVecExpr (a, b);
00211   }
00212 
00213 
00214   /* ************************** Transpose ************************* */
00215 
00219   class Transpose : public BaseMatrix
00220   {
00221     const BaseMatrix & bm;
00222   public:
00224     Transpose (const BaseMatrix & abm) : bm(abm) { ; }
00226     virtual void MultAdd (double s, const BaseVector & x, BaseVector & y) const
00227     {
00228       bm.MultTransAdd (s, x, y);
00229     }
00231     virtual void MultAdd (Complex s, const BaseVector & x, BaseVector & y) const 
00232     {
00233       bm.MultTransAdd (s, x, y);
00234     }
00236     virtual void MultTransAdd (double s, const BaseVector & x, BaseVector & y) const
00237     {
00238       bm.MultAdd (s, x, y);
00239     }
00241     virtual void MultTransAdd (Complex s, const BaseVector & x, BaseVector & y) const
00242     {
00243       bm.MultAdd (s, x, y);
00244     }  
00245 
00246     virtual ostream & Print (ostream & ost) const
00247     {
00248       ost << "Transpose of " << endl;
00249       bm.Print(ost);
00250       return ost;
00251     }
00252   };
00253 
00254 
00255 
00256   class VScaleMatrix : public BaseMatrix
00257   {
00258     const BaseMatrix & bm;
00259     double scale;
00260   public:
00262     VScaleMatrix (const BaseMatrix & abm, double ascale) : bm(abm), scale(ascale) { ; }
00264     virtual void MultAdd (double s, const BaseVector & x, BaseVector & y) const
00265     {
00266       bm.MultAdd (s*scale, x, y);
00267     }
00269     virtual void MultAdd (Complex s, const BaseVector & x, BaseVector & y) const 
00270     {
00271       bm.MultAdd (s*scale, x, y);
00272     }
00274     virtual void MultTransAdd (double s, const BaseVector & x, BaseVector & y) const
00275     {
00276       bm.MultTransAdd (s*scale, x, y);
00277     }
00279     virtual void MultTransAdd (Complex s, const BaseVector & x, BaseVector & y) const
00280     {
00281       bm.MultTransAdd (s*scale, x, y);
00282     }  
00283   };
00284 
00285   inline VScaleMatrix operator* (double d, const BaseMatrix & m)
00286   {
00287     return VScaleMatrix (m, d);
00288   }
00289 
00290   /* *********************** operator<< ********************** */
00291 
00293   inline ostream & operator<< (ostream & ost, const BaseMatrix & m)
00294   {
00295     return m.Print(ost);
00296   }
00297 
00298 }
00299 
00300 #endif