NGSolve
4.9
|
00001 #ifndef FILE_BANDMATRIX 00002 #define FILE_BANDMATRIX 00003 00004 /****************************************************************************/ 00005 /* File: bandmatrix.hpp */ 00006 /* Author: Joachim Schoeberl */ 00007 /* Date: 14. Aug. 2002 */ 00008 /****************************************************************************/ 00009 00010 namespace ngbla 00011 { 00012 00016 template <class T = double> 00017 class FlatSymBandMatrix 00018 { 00019 protected: 00021 int n; 00023 int bw; 00025 T *data; 00026 public: 00028 typedef typename mat_traits<T>::TV_COL TV; 00029 00031 FlatSymBandMatrix (int an, int abw, T * adata) 00032 : n(an), bw(abw), data(adata) 00033 { ; } 00034 00036 void Mult (const FlatVector<TV> & x, FlatVector<TV> & y) const 00037 { 00038 for (int i = 0; i < n; i++) 00039 y(i) = (*this)(i,i) * x(i); 00040 for (int i = 0; i < n; i++) 00041 for (int j = max2(i-bw+1, 0); j < i; j++) 00042 { 00043 y(i) += (*this)(i,j) * x(j); 00044 y(j) += Trans((*this)(i,j)) * x(i); 00045 } 00046 } 00047 00049 ostream & Print (ostream & ost) const 00050 { 00051 for (int i = 0; i < n; i++) 00052 { 00053 for (int j = 0; j < n; j++) 00054 if (Used (i, j)) 00055 ost << setw(8) << (*this)(i,j) << " "; 00056 else if (Used (j,i)) 00057 ost << setw(8) << "sym" << " "; 00058 else 00059 ost << setw(8) << 0; 00060 ost << endl; 00061 } 00062 return ost; 00063 } 00064 00066 int Height() const { return n; } 00067 00069 int BandWidth() const { return bw; } 00070 00072 const T & operator() (int i, int j) const 00073 { return data[i * bw + j - i + bw-1]; } 00074 00076 T & operator() (int i, int j) 00077 { return data[i * bw + j - i + bw-1]; } 00078 00080 bool Used (int i, int j) const 00081 { 00082 return (n > i && i >= j && j >= 0 && i-j < bw); 00083 } 00084 00086 FlatSymBandMatrix & operator= (const T & val) 00087 { 00088 for (int i = 0; i < bw * n; i++) 00089 data[i] = val; 00090 return *this; 00091 } 00092 00093 00095 static int RequiredMem (int n, int bw) 00096 { return n*bw; } 00097 }; 00098 00099 00100 00102 template<typename T> 00103 inline std::ostream & operator<< (std::ostream & s, const FlatSymBandMatrix<T> & m) 00104 { 00105 m.Print (s); 00106 return s; 00107 } 00108 00109 00110 00111 00115 template <class T = double> 00116 class SymBandMatrix : public FlatSymBandMatrix<T> 00117 { 00118 public: 00119 typedef typename mat_traits<T>::TV_COL TV; 00120 00122 SymBandMatrix (int an, int abw) 00123 : FlatSymBandMatrix<T> (an, abw, new T[an*abw]) 00124 { ; } 00125 00127 ~SymBandMatrix () 00128 { delete [] this->data; } 00129 00131 SymBandMatrix & operator= (const T & val) 00132 { 00133 for (int i = 0; i < this->bw * this->n; i++) 00134 this->data[i] = val; 00135 return *this; 00136 } 00137 }; 00138 00139 00140 00141 00142 00143 00144 00145 00146 00147 00148 00149 00150 00167 template <class T = double> 00168 class FlatBandCholeskyFactors 00169 { 00170 protected: 00172 int n; 00174 int bw; 00176 T * mem; 00177 public: 00178 // typedef typename mat_traits<T>::TV_COL TV; 00179 typedef typename mat_traits<T>::TSCAL TSCAL; 00180 00182 FlatBandCholeskyFactors (int an, int abw, T * amem) 00183 { n = an, bw = abw, mem = amem; } 00184 00186 FlatBandCholeskyFactors () 00187 { n = bw = 0; mem = 0; } 00188 00190 NGS_DLL_HEADER void Factor (const FlatSymBandMatrix<T> & a); 00191 00193 template <class TVX, class TVY> 00194 void Mult (const FlatVector<TVX> & x, FlatVector<TVY> & y) const 00195 { 00196 const TVX * hx = x.Addr(0); 00197 TVY * hy = y.Addr(0); 00198 const T * hm = &mem[0]; 00199 00200 for (int i = 0; i < n; i++) 00201 hy[i] = hx[i]; 00202 00203 int i, jj = n; 00204 for (i = 0; i < bw-1; i++) 00205 { 00206 typedef typename mat_traits<TVY>::TSCAL TTSCAL; 00207 TVY sum = TTSCAL(0.0); 00208 00209 for (int j = 0; j < i; j++, jj++) 00210 sum += hm[jj] * hy[j]; 00211 00212 hy[i] -= sum; 00213 } 00214 00215 for ( ; i < n; i++) 00216 { 00217 typedef typename mat_traits<TVY>::TSCAL TTSCAL; 00218 TVY sum = TTSCAL(0.0); 00219 00220 for (int j = i-bw+1; j < i; j++, jj++) 00221 sum += hm[jj] * hy[j]; 00222 00223 hy[i] -= sum; 00224 } 00225 00226 for (int i = 0; i < n; i++) 00227 { 00228 TVY sum = mem[i] * hy[i]; 00229 hy[i] = sum; 00230 } 00231 00232 // jj = n + (n-1) * (bw-1) - bw*(bw-1)/2; 00233 for (i = n-1; i >= bw-1; i--) 00234 { 00235 jj -= bw-1; 00236 TVY val = hy[i]; 00237 00238 int firstj = i-bw+1; 00239 for (int j = 0; j < bw-1; j++) 00240 hy[firstj+j] -= Trans (mem[jj+j]) * val; 00241 } 00242 00243 for ( ; i >= 0; i--) 00244 { 00245 jj -= i; 00246 TVY val = hy[i]; 00247 00248 for (int j = 0; j < i; j++) 00249 hy[j] -= Trans (mem[jj+j]) * val; 00250 } 00251 } 00252 00253 00254 00256 ostream & Print (ostream & ost) const; 00257 00259 int Index (int i, int j) const 00260 { 00261 if (i < bw) 00262 return n + (i * (i-1)) / 2 + j; 00263 else 00264 return n + i * (bw-2) + j - ((bw-1)*(bw-2))/2; 00265 } 00266 00268 const T & operator() (int i, int j) const 00269 { 00270 if (i < bw) 00271 return mem[n + (i * (i-1)) / 2 + j]; 00272 else 00273 return mem[n + i * (bw-2) + j - ((bw-1)*(bw-2))/2]; 00274 } 00275 00277 T & operator() (int i, int j) 00278 { 00279 if (i < bw) 00280 return mem[n + (i * (i-1)) / 2 + j]; 00281 else 00282 return mem[n + i * (bw-2) + j - ((bw-1)*(bw-2))/2]; 00283 } 00284 00286 int Size() const { return n; } 00288 int BandWidth() const { return bw; } 00290 static int RequiredMem (int n, int bw) 00291 { return n*bw - (bw * (bw-1)) / 2 + n; } 00292 }; 00293 00294 00296 template<typename T> 00297 inline std::ostream & operator<< (std::ostream & s, const FlatBandCholeskyFactors<T> & m) 00298 { 00299 m.Print (s); 00300 return s; 00301 } 00302 00303 00304 00305 00306 00310 template <class T = double> 00311 class BandCholeskyFactors : public FlatBandCholeskyFactors<T> 00312 { 00313 public: 00315 BandCholeskyFactors (const SymBandMatrix<T> & a) 00316 : FlatBandCholeskyFactors<T> (a.Height(), 00317 a.BandWidth(), 00318 new T[FlatBandCholeskyFactors<T>::RequiredMem (a.Height(), a.BandWidth())]) 00319 { 00320 this->Factor (a); 00321 } 00322 00324 ~BandCholeskyFactors () 00325 { 00326 delete [] this->mem; 00327 } 00328 }; 00329 00330 } 00331 00332 #endif