NGSolve  4.9
basiclinalg/bandmatrix.hpp
00001 #ifndef FILE_BANDMATRIX
00002 #define FILE_BANDMATRIX
00003 
00004 /****************************************************************************/
00005 /* File:   bandmatrix.hpp                                                   */
00006 /* Author: Joachim Schoeberl                                                */
00007 /* Date:   14. Aug. 2002                                                    */
00008 /****************************************************************************/
00009 
00010 namespace ngbla
00011 {
00012 
00016   template <class T = double>
00017   class FlatSymBandMatrix
00018   {
00019   protected:
00021     int n;
00023     int bw;
00025     T *data;
00026   public:
00028     typedef typename mat_traits<T>::TV_COL TV;
00029 
00031     FlatSymBandMatrix (int an, int abw, T * adata)
00032       : n(an), bw(abw), data(adata)
00033     { ; }
00034 
00036     void Mult (const FlatVector<TV> & x, FlatVector<TV> & y) const
00037     {
00038       for (int i = 0; i < n; i++)
00039         y(i) = (*this)(i,i) * x(i);
00040       for (int i = 0; i < n; i++)
00041         for (int j = max2(i-bw+1, 0); j < i; j++)
00042           {
00043             y(i) += (*this)(i,j) * x(j);
00044             y(j) += Trans((*this)(i,j)) * x(i);
00045           }
00046     }
00047 
00049     ostream & Print (ostream & ost) const
00050     {
00051       for (int i = 0; i < n; i++)
00052         {
00053           for (int j = 0; j < n; j++)
00054             if (Used (i, j))
00055               ost << setw(8) << (*this)(i,j) << " ";
00056             else if (Used (j,i))
00057               ost << setw(8) << "sym" << " ";
00058             else
00059               ost << setw(8) << 0;
00060           ost << endl;
00061         }
00062       return ost;
00063     }
00064 
00066     int Height() const { return n; }
00067 
00069     int BandWidth() const { return bw; }
00070   
00072     const T & operator() (int i, int j) const
00073     { return data[i * bw + j - i + bw-1]; }
00074 
00076     T & operator() (int i, int j) 
00077     { return data[i * bw + j - i + bw-1]; }
00078 
00080     bool Used (int i, int j) const
00081     {
00082       return (n > i && i >= j && j >= 0 && i-j < bw);
00083     }
00084   
00086     FlatSymBandMatrix & operator= (const T & val)
00087     {
00088       for (int i = 0; i < bw * n; i++)
00089         data[i] = val;
00090       return *this;
00091     }
00092 
00093 
00095     static int RequiredMem (int n, int bw)
00096     { return n*bw; }
00097   };
00098 
00099 
00100 
00102   template<typename T>
00103   inline std::ostream & operator<< (std::ostream & s, const FlatSymBandMatrix<T> & m)
00104   {
00105     m.Print (s);
00106     return s;
00107   }
00108 
00109 
00110 
00111 
00115   template <class T = double>
00116   class SymBandMatrix : public FlatSymBandMatrix<T>
00117   {
00118   public:
00119     typedef typename mat_traits<T>::TV_COL TV;
00120 
00122     SymBandMatrix (int an, int abw)
00123       : FlatSymBandMatrix<T> (an, abw, new T[an*abw])
00124     { ; }
00125 
00127     ~SymBandMatrix ()
00128     { delete [] this->data; }
00129 
00131     SymBandMatrix & operator= (const T & val)
00132     {
00133       for (int i = 0; i < this->bw * this->n; i++)
00134         this->data[i] = val;
00135       return *this;
00136     }
00137   };
00138 
00139 
00140 
00141 
00142 
00143 
00144 
00145 
00146 
00147 
00148 
00149 
00150 
00167   template <class T = double>
00168   class FlatBandCholeskyFactors
00169   {
00170   protected:
00172     int n;
00174     int bw;
00176     T * mem; 
00177   public:
00178     // typedef typename mat_traits<T>::TV_COL TV;
00179     typedef typename mat_traits<T>::TSCAL TSCAL;
00180 
00182     FlatBandCholeskyFactors (int an, int abw, T * amem)
00183     { n = an, bw = abw, mem = amem; }
00184 
00186     FlatBandCholeskyFactors ()
00187     { n = bw = 0; mem = 0; }
00188 
00190     NGS_DLL_HEADER void Factor (const FlatSymBandMatrix<T> & a);
00191 
00193     template <class TVX, class TVY>
00194     void Mult (const FlatVector<TVX> & x, FlatVector<TVY> & y) const
00195     {
00196       const TVX * hx = x.Addr(0);
00197       TVY * hy = y.Addr(0);
00198       const T * hm = &mem[0];
00199 
00200       for (int i = 0; i < n; i++)
00201         hy[i] = hx[i];
00202 
00203       int i, jj = n;
00204       for (i = 0; i < bw-1; i++)
00205         {
00206           typedef typename mat_traits<TVY>::TSCAL TTSCAL;
00207           TVY sum = TTSCAL(0.0);  
00208 
00209           for (int j = 0; j < i; j++, jj++)
00210             sum += hm[jj] * hy[j];
00211 
00212           hy[i] -= sum;
00213         }
00214 
00215       for (  ; i < n; i++)
00216         {
00217           typedef typename mat_traits<TVY>::TSCAL TTSCAL;
00218           TVY sum = TTSCAL(0.0);
00219 
00220           for (int j = i-bw+1; j < i; j++, jj++)
00221             sum += hm[jj] * hy[j];
00222 
00223           hy[i] -= sum;
00224         }
00225 
00226       for (int i = 0; i < n; i++)
00227         {
00228           TVY sum = mem[i] * hy[i];
00229           hy[i] = sum;
00230         }
00231 
00232       // jj = n + (n-1) * (bw-1) - bw*(bw-1)/2;   
00233       for (i = n-1; i >= bw-1; i--)
00234         {
00235           jj -= bw-1;
00236           TVY val = hy[i];
00237 
00238           int firstj = i-bw+1;
00239           for (int j = 0; j < bw-1; j++)
00240             hy[firstj+j] -= Trans (mem[jj+j]) * val;
00241         }
00242     
00243       for (  ; i >= 0; i--) 
00244         {
00245           jj -= i;
00246           TVY val = hy[i];
00247 
00248           for (int j = 0; j < i; j++)
00249             hy[j] -= Trans (mem[jj+j]) * val;
00250         }
00251     }
00252 
00253 
00254 
00256     ostream & Print (ostream & ost) const;
00257 
00259     int Index (int i, int j) const
00260     {
00261       if (i < bw)
00262         return n + (i * (i-1)) / 2 + j;
00263       else
00264         return n + i * (bw-2) + j - ((bw-1)*(bw-2))/2;   
00265     }
00266   
00268     const T & operator() (int i, int j) const
00269     { 
00270       if (i < bw)
00271         return mem[n + (i * (i-1)) / 2 + j];
00272       else
00273         return mem[n + i * (bw-2) + j - ((bw-1)*(bw-2))/2];   
00274     }
00275 
00277     T & operator() (int i, int j) 
00278     {
00279       if (i < bw)
00280         return mem[n + (i * (i-1)) / 2 + j];
00281       else
00282         return mem[n + i * (bw-2) + j - ((bw-1)*(bw-2))/2];   
00283     }
00284 
00286     int Size() const { return n; }
00288     int BandWidth() const { return bw; }
00290     static int RequiredMem (int n, int bw)
00291     { return n*bw - (bw * (bw-1)) / 2 + n; }
00292   };
00293   
00294 
00296   template<typename T>
00297   inline std::ostream & operator<< (std::ostream & s, const FlatBandCholeskyFactors<T> & m)
00298   {
00299     m.Print (s);
00300     return s;
00301   }
00302 
00303 
00304 
00305 
00306 
00310   template <class T = double>
00311   class BandCholeskyFactors : public FlatBandCholeskyFactors<T>
00312   {
00313   public:
00315     BandCholeskyFactors (const SymBandMatrix<T> & a)
00316       : FlatBandCholeskyFactors<T> (a.Height(), 
00317                                     a.BandWidth(),
00318                                     new T[FlatBandCholeskyFactors<T>::RequiredMem (a.Height(), a.BandWidth())])
00319     {
00320       this->Factor (a);
00321     }
00322 
00324     ~BandCholeskyFactors ()
00325     {
00326       delete [] this->mem;
00327     }
00328   };
00329 
00330 }
00331 
00332 #endif