NGSolve
4.9
|
00001 #ifndef FILE_NGS_BASEMATRIX 00002 #define FILE_NGS_BASEMATRIX 00003 00004 00005 /*********************************************************************/ 00006 /* File: basematrix.hpp */ 00007 /* Author: Joachim Schoeberl */ 00008 /* Date: 25. Mar. 2000 */ 00009 /*********************************************************************/ 00010 00011 namespace ngla 00012 { 00013 00014 00015 // sets the solver which is used for InverseMatrix 00016 enum INVERSETYPE { PARDISO, PARDISOSPD, SPARSECHOLESKY, SUPERLU, SUPERLU_DIST, MUMPS, MASTERINVERSE }; 00017 00018 00022 class NGS_DLL_HEADER BaseMatrix 00023 { 00024 protected: 00025 const ParallelDofs * paralleldofs; 00026 00027 public: 00029 BaseMatrix (); 00031 // BaseMatrix (const BaseMatrix & amat); 00032 // 00033 BaseMatrix (ParallelDofs * aparalleldofs); 00035 virtual ~BaseMatrix (); 00036 00038 virtual int VHeight() const; 00039 00041 virtual int VWidth() const; 00042 00044 int Height() const 00045 { 00046 return VHeight(); 00047 } 00048 00050 int Width() const 00051 { 00052 return VWidth(); 00053 } 00054 00056 BaseMatrix & operator= (double s) 00057 { 00058 AsVector().SetScalar(s); 00059 return *this; 00060 } 00061 00063 virtual BaseVector & AsVector(); 00065 virtual const BaseVector & AsVector() const; 00066 00067 virtual ostream & Print (ostream & ost) const; 00068 virtual void MemoryUsage (Array<MemoryUsageStruct*> & mu) const; 00069 00070 // virtual const void * Data() const; 00071 // virtual void * Data(); 00072 00074 virtual BaseMatrix * CreateMatrix () const; 00076 // virtual BaseMatrix * CreateMatrix (const Array<int> & elsperrow) const; 00078 virtual BaseVector * CreateRowVector () const; 00080 virtual BaseVector * CreateColVector () const; 00082 virtual BaseVector * CreateVector () const; 00083 00085 virtual void Mult (const BaseVector & x, BaseVector & y) const; 00087 virtual void MultAdd (double s, const BaseVector & x, BaseVector & y) const; 00089 virtual void MultAdd (Complex s, const BaseVector & x, BaseVector & y) const; 00090 00092 virtual void MultTransAdd (double s, const BaseVector & x, BaseVector & y) const; 00094 virtual void MultTransAdd (Complex s, const BaseVector & x, BaseVector & y) const; 00095 00096 00097 00098 00103 virtual void MultAdd1 (double s, const BaseVector & x, BaseVector & y, 00104 const BitArray * ainner = NULL, 00105 const Array<int> * acluster = NULL) const; 00106 00108 virtual void MultAdd2 (double s, const BaseVector & x, BaseVector & y, 00109 const BitArray * ainner = NULL, 00110 const Array<int> * acluster = NULL) const; 00111 00112 00113 void SetParallelDofs (const ParallelDofs * pardofs) { paralleldofs = pardofs; } 00114 const ParallelDofs * GetParallelDofs () const { return paralleldofs; } 00115 00116 virtual BaseMatrix * InverseMatrix (const BitArray * subset = 0) const; 00117 virtual BaseMatrix * InverseMatrix (const Array<int> * clusters) const; 00118 virtual INVERSETYPE SetInverseType ( INVERSETYPE ainversetype ) const; 00119 virtual INVERSETYPE SetInverseType ( string ainversetype ) const; 00120 virtual INVERSETYPE GetInverseType () const; 00121 00122 00123 }; 00124 00125 00126 00127 00128 00129 00131 template <typename SCAL> 00132 class NGS_DLL_HEADER S_BaseMatrix : virtual public BaseMatrix 00133 { 00134 public: 00136 S_BaseMatrix (); 00138 virtual ~S_BaseMatrix (); 00139 }; 00140 00141 // specifies the scalar type Complex. 00142 template <> 00143 class S_BaseMatrix<Complex> : virtual public BaseMatrix 00144 { 00145 public: 00147 S_BaseMatrix (); 00149 virtual ~S_BaseMatrix (); 00150 00152 virtual void MultAdd (double s, const BaseVector & x, BaseVector & y) const; 00154 virtual void MultAdd (Complex s, const BaseVector & x, BaseVector & y) const; 00155 00157 virtual void MultTransAdd (double s, const BaseVector & x, BaseVector & y) const; 00159 virtual void MultTransAdd (Complex s, const BaseVector & x, BaseVector & y) const; 00160 }; 00161 00162 00163 00164 00165 00166 00167 00168 /* *************************** Matrix * Vector ******************** */ 00169 00170 00172 class VMatVecExpr 00173 { 00174 const BaseMatrix & m; 00175 const BaseVector & x; 00176 00177 public: 00178 VMatVecExpr (const BaseMatrix & am, const BaseVector & ax) : m(am), x(ax) { ; } 00179 00180 template <class TS> 00181 void AssignTo (TS s, BaseVector & v) const 00182 { 00183 #ifdef DEBUG 00184 if (m.Height() != v.Size() || m.Width() != x.Size()) 00185 throw Exception (ToString ("matrix-vector: size does not fit\n") + 00186 "Matrix: " + ToString(m.Height()) + " x " + ToString(m.Width()) + "\n" 00187 "Vector in : " + ToString(x.Size()) + "\n" 00188 "Vector res: " + ToString(v.Size())); 00189 #endif 00190 m.Mult (x, v); 00191 v *= s; 00192 } 00193 00194 template <class TS> 00195 void AddTo (TS s, BaseVector & v) const 00196 { 00197 #ifdef DEBUG 00198 if (m.Height() != v.Size() || m.Width() != x.Size()) 00199 throw Exception ("matrix-vector MultAdd: size does not fit"); 00200 #endif 00201 m.MultAdd (s, x, v); 00202 } 00203 }; 00204 00205 00207 inline VVecExpr<VMatVecExpr> 00208 operator* (const BaseMatrix & a, const BaseVector & b) 00209 { 00210 return VMatVecExpr (a, b); 00211 } 00212 00213 00214 /* ************************** Transpose ************************* */ 00215 00219 class Transpose : public BaseMatrix 00220 { 00221 const BaseMatrix & bm; 00222 public: 00224 Transpose (const BaseMatrix & abm) : bm(abm) { ; } 00226 virtual void MultAdd (double s, const BaseVector & x, BaseVector & y) const 00227 { 00228 bm.MultTransAdd (s, x, y); 00229 } 00231 virtual void MultAdd (Complex s, const BaseVector & x, BaseVector & y) const 00232 { 00233 bm.MultTransAdd (s, x, y); 00234 } 00236 virtual void MultTransAdd (double s, const BaseVector & x, BaseVector & y) const 00237 { 00238 bm.MultAdd (s, x, y); 00239 } 00241 virtual void MultTransAdd (Complex s, const BaseVector & x, BaseVector & y) const 00242 { 00243 bm.MultAdd (s, x, y); 00244 } 00245 00246 virtual ostream & Print (ostream & ost) const 00247 { 00248 ost << "Transpose of " << endl; 00249 bm.Print(ost); 00250 return ost; 00251 } 00252 }; 00253 00254 00255 00256 class VScaleMatrix : public BaseMatrix 00257 { 00258 const BaseMatrix & bm; 00259 double scale; 00260 public: 00262 VScaleMatrix (const BaseMatrix & abm, double ascale) : bm(abm), scale(ascale) { ; } 00264 virtual void MultAdd (double s, const BaseVector & x, BaseVector & y) const 00265 { 00266 bm.MultAdd (s*scale, x, y); 00267 } 00269 virtual void MultAdd (Complex s, const BaseVector & x, BaseVector & y) const 00270 { 00271 bm.MultAdd (s*scale, x, y); 00272 } 00274 virtual void MultTransAdd (double s, const BaseVector & x, BaseVector & y) const 00275 { 00276 bm.MultTransAdd (s*scale, x, y); 00277 } 00279 virtual void MultTransAdd (Complex s, const BaseVector & x, BaseVector & y) const 00280 { 00281 bm.MultTransAdd (s*scale, x, y); 00282 } 00283 }; 00284 00285 inline VScaleMatrix operator* (double d, const BaseMatrix & m) 00286 { 00287 return VScaleMatrix (m, d); 00288 } 00289 00290 /* *********************** operator<< ********************** */ 00291 00293 inline ostream & operator<< (ostream & ost, const BaseMatrix & m) 00294 { 00295 return m.Print(ost); 00296 } 00297 00298 } 00299 00300 #endif