AFEPack
MPI.h
浏览该文件的文档。
00001 
00046 #ifndef __MPI_h__
00047 #define __MPI_h__
00048 
00049 #include <string>
00050 #include <vector>
00051 #include <list>
00052 #include <map>
00053 #include <mpi.h>
00054 
00055 #include <AFEPack/DerefIterator.h>
00056 #include <AFEPack/BinaryBuffer.h>
00057 
00058 namespace MPI {
00059 
00063   int get_comm_tag(MPI_Comm comm);
00064   int get_comm_int(MPI_Comm comm);
00065 
00074   template <class DataIterator, class TargetIterator>
00075     void sendrecv_data(MPI_Comm comm, 
00076                        int n,
00077                        DataIterator start_send_data,
00078                        DataIterator start_recv_data,
00079                        TargetIterator start_target) {
00080     int tag = get_comm_tag(comm);
00081     if (n == 0) return;
00082 
00083     MPI_Request request[2*n];
00084     MPI_Status status[2*n];
00085 
00087     int send_data_size[n], recv_data_size[n];
00088     DataIterator the_send_data = start_send_data;
00089     for (int i = 0;i < n;++ i, ++ the_send_data) {
00090       send_data_size[i] = the_send_data->size();
00091     }
00092 
00093     TargetIterator the_target = start_target;
00094     for (int i = 0;i < n;++ i, ++ the_target) {
00095       MPI_Isend(&send_data_size[i], 1, MPI_INT,
00096                 *the_target, tag, comm, &request[i]);
00097       MPI_Irecv(&recv_data_size[i], 1, MPI_INT,
00098                 *the_target, tag, comm, &request[i + n]);
00099     }
00100     MPI_Waitall(2*n, request, status);
00101 
00103     DataIterator the_recv_data = start_recv_data;
00104     for (int i = 0;i < n;++ i, ++ the_recv_data) {
00105       the_recv_data->resize(recv_data_size[i]);
00106     }
00107 
00109     int n_request = 0;
00110     the_target = start_target;
00111     the_send_data = start_send_data;
00112     the_recv_data = start_recv_data;
00113     for (int i = 0;i < n;++ i) {
00114       if (send_data_size[i] > 0) {
00115         MPI_Isend(the_send_data->start_address(), send_data_size[i], MPI_CHAR,
00116                   *the_target, tag, comm, &request[n_request ++]);
00117       }
00118       if (recv_data_size[i] > 0) {
00119         MPI_Irecv(the_recv_data->start_address(), recv_data_size[i], MPI_CHAR,
00120                   *the_target, tag, comm, &request[n_request ++]);
00121       }
00122       ++ the_target, ++ the_send_data, ++ the_recv_data;
00123     }
00124     if (n_request > 0) MPI_Waitall(n_request, request, status);
00125   }
00126 
00130   template <class T>
00131     struct Remote_pointer {
00132       int type; 
00133       T * ptr;  
00134     Remote_pointer() : type(0), ptr(NULL) {}
00135     Remote_pointer(int _type, T * _ptr) :
00136       type(_type), ptr(_ptr) {}
00137     Remote_pointer(const Remote_pointer<T>& rp) :
00138       type(rp.type), ptr(rp.ptr) {}
00139       Remote_pointer<T>& operator=(const Remote_pointer<T>& rp) {
00140         type = rp.type;
00141         ptr = rp.ptr;
00142       }
00143     };
00144 
00145   namespace Shared_type_filter {
00146     struct all {
00147       bool operator()(int type) const {
00148         return true;
00149       }
00150     };
00151     template <int D0, int D1>
00152       struct between {
00153         bool operator()(int type) const {
00154           return (type >= D0)&&(type < D1);
00155         }
00156       };
00157     template <int D>
00158       struct only {
00159         bool operator()(int type) const {
00160           return (type == D);
00161         }
00162       };
00163     template <int D>
00164       struct except {
00165         bool operator()(int type) const {
00166           return (type != D);
00167         }
00168       };
00169     template <int D>
00170       struct greater_than {
00171         bool operator()(int type) const {
00172           return (type > D);
00173         }
00174       };
00175     template <int D>
00176       struct less_than {
00177         bool operator()(int type) const {
00178           return (type < D);
00179         }
00180       };
00181     template <class FILTER>
00182       struct negate {
00183         FILTER filter;
00184 
00185         negate() {}
00186         negate(const FILTER& _filter) : filter(_filter) {}
00187 
00188         bool operator()(int type) const {
00189           if (filter(type)) return false;
00190           else return true;
00191         }
00192       };
00193   }
00194 
00203   template <class T>
00204     struct Shared_object : public std::multimap<int,Remote_pointer<T> > {
00205     typedef Remote_pointer<T> pointer_t;
00206     typedef std::pair<int,pointer_t> pair_t;
00207     typedef std::multimap<int,pointer_t> _Base;
00208 
00209     T * _ptr; 
00210 
00211     Shared_object() {}
00212     Shared_object(T& t) : _ptr(&t) {}
00213     bool add_clone(int rank, T* ptr) {
00214       return this->add_clone(rank, pointer_t(0, ptr));
00215     }
00216     bool add_clone(int rank, int type, T* ptr) {
00217       return this->add_clone(rank, pointer_t(type, ptr));
00218     }
00219     bool add_clone(int rank, const pointer_t& ptr) {
00220       bool result = false;
00221       if (! is_duplicate_entry(rank, ptr)) {
00222         this->insert(pair_t(rank, ptr));
00223         result = true;
00224       } 
00225       return result;
00226     }
00227 
00228     T *& local_pointer() { return _ptr; }
00229     T * local_pointer() const { return _ptr; }
00230     T& local_object() const { return *_ptr; }
00231 
00235     bool is_duplicate_entry(int rank, 
00236                             const pointer_t& ptr) const {
00237       typedef typename _Base::const_iterator it_t;
00238       std::pair<it_t,it_t> range = _Base::equal_range(rank);
00239       it_t the_ptr = range.first, end_ptr = range.second;
00240       for (;the_ptr != end_ptr;++ the_ptr) {
00241         if (the_ptr->second.ptr == ptr.ptr) {
00242           return true;
00243         }
00244       }
00245       return false;
00246     }
00247 
00249 
00258     int primary_rank(int rank) const {
00259       return std::min(_Base::begin()->first, rank);
00260     }
00261 
00267     bool is_on_primary_rank(int rank) const {
00268       return (_Base::begin()->first >= rank);
00269     }
00271 
00280     bool is_primary_object(int rank) const {
00281       int first_rank = _Base::begin()->first;
00282       bool result = true;
00283       if (first_rank < rank) {
00284         result = false; 
00285       } else if (first_rank == rank) { 
00289         typedef typename _Base::const_iterator it_t;
00290         it_t the_ptr = _Base::begin();
00291         it_t end_ptr = _Base::upper_bound(rank);
00292         for (;the_ptr != end_ptr;++ the_ptr) {
00293           assert (the_ptr->second.ptr != _ptr);
00294           if (the_ptr->second.ptr < _ptr) {
00295             result = false;
00296             break;
00297           }
00298         }
00299       } else { 
00300         result = true; 
00301       }
00302       return result;
00303     }
00304   };
00305 
00310   template <class T, 
00311     template <class C, typename ALLOC = std::allocator<C> > class CNT = std::list>
00312     struct Shared_list : public CNT<Shared_object<T> > {};
00313 
00318   template <class T,
00319     template <class C, typename ALLOC = std::allocator<C> > class CNT = std::list>
00320     struct Shared_ptr_list : public CNT<Shared_object<T> *> {
00321     typedef CNT<Shared_object<T> *> base_t;
00322     typedef _Deref_iterator<typename base_t::iterator, Shared_object<T> > iterator;
00323     typedef _Deref_iterator<typename base_t::const_iterator, const Shared_object<T> > const_iterator;
00324     iterator begin() { return base_t::begin(); }
00325     iterator end() { return base_t::end(); }
00326     const_iterator begin() const { return base_t::begin(); }
00327     const_iterator end() const { return base_t::end(); }
00328     typename base_t::iterator begin_ptr() { return base_t::begin(); }
00329     typename base_t::iterator end_ptr() { return base_t::end(); }
00330     typename base_t::const_iterator begin_ptr() const { return base_t::begin(); }
00331     typename base_t::const_iterator end_ptr() const { return base_t::end(); }
00332   };
00333 
00340   template <class T, class SHARED_TYPE_FILTER=Shared_type_filter::all>
00341     struct Transmit_map : 
00342     public std::map<int, std::pair<int, std::list<std::pair<T*, T*> > > > {
00343     typedef std::list<std::pair<T*, T*> > value_t;
00344     typedef std::pair<int, value_t> pair_t;
00345     typedef std::map<int, pair_t> _Base;
00346     typedef SHARED_TYPE_FILTER type_filter_t;
00347 
00348     type_filter_t type_filter;
00349 
00354     template <class CONTAINER>
00355       void build(const CONTAINER& shlist) {
00356       _Base::clear();
00357 
00358       typename CONTAINER::const_iterator 
00359         the_obj = shlist.begin(),
00360         end_obj = shlist.end();
00361       for (;the_obj != end_obj;++ the_obj) {
00362         this->add_object(*the_obj);
00363       }
00364     }
00365 
00366     template <class CONTAINER>
00367       void build(const CONTAINER& shlist,
00368                  bool (*filter)(T *)) {
00369       _Base::clear();
00370 
00371       typename CONTAINER::const_iterator 
00372         the_obj = shlist.begin(),
00373         end_obj = shlist.end();
00374       for (;the_obj != end_obj;++ the_obj) {
00375         this->add_object(*the_obj, 
00376                          (*filter)(the_obj->local_pointer()));
00377       }
00378     }
00379 
00380     template <class CONTAINER, class DATA_PACKER>
00381       void build(const CONTAINER& shlist,
00382                  DATA_PACKER& data_packer,
00383                  bool (DATA_PACKER::*filter)(T *)) {
00384       _Base::clear();
00385 
00386       typename CONTAINER::const_iterator 
00387         the_obj = shlist.begin(),
00388         end_obj = shlist.end();
00389       for (;the_obj != end_obj;++ the_obj) {
00390         this->add_object(*the_obj, 
00391                          (data_packer.*filter)(the_obj->local_pointer()));
00392       }
00393     }
00394 
00395     template <class CONTAINER, class DATA_PACKER>
00396       void build(const CONTAINER& shlist,
00397                  const DATA_PACKER& data_packer,
00398                  bool (DATA_PACKER::*filter)(T *) const) {
00399       _Base::clear();
00400 
00401       typename CONTAINER::const_iterator 
00402         the_obj = shlist.begin(),
00403         end_obj = shlist.end();
00404       for (;the_obj != end_obj;++ the_obj) {
00405         this->add_object(*the_obj, 
00406                          (data_packer.*filter)(the_obj->local_pointer()));
00407       }
00408     }
00409 
00414     template <class ITERATOR>
00415       void build(ITERATOR& begin, ITERATOR& end) {
00416       _Base::clear();
00417 
00418       ITERATOR the_obj(begin);
00419       for (;the_obj != end;++ the_obj) {
00420         this->add_object(*the_obj);
00421       }
00422     }
00423 
00427     void add_object(const Shared_object<T>& obj, 
00428                     bool is_add_entry = true) {
00429       typename Shared_object<T>::const_iterator
00430         the_ptr = obj.begin(),
00431         end_ptr = obj.end();
00432       if (! is_add_entry) {
00436         for (;the_ptr != end_ptr;++ the_ptr) {
00437           int rank = the_ptr->first;
00438           if (this->find(rank) == this->end()) {
00439             (*this)[rank] = pair_t(0, value_t());
00440           }
00441         }
00442       } else {
00443         T * obj_ptr = obj.local_pointer();
00444         for (;the_ptr != end_ptr;++ the_ptr) {
00445           const int& rank = the_ptr->first;
00446           if (this->find(rank) == this->end()) {
00447             (*this)[rank] = pair_t(0, value_t());
00448           }
00449 
00450           if (type_filter(the_ptr->second.type)) {
00451             pair_t& pair = (*this)[rank];
00452             pair.first += 1;
00453             pair.second.push_back(std::pair<T*,T*>(obj_ptr, 
00454                                                    the_ptr->second.ptr));
00455           }
00456         }
00457       }
00458     }
00459   };
00460 
00481   template <class T, class DATA_PACKER, class SHARED_TYPE_FILTER>
00482     void sync_data(MPI_Comm comm,
00483                    Transmit_map<T,SHARED_TYPE_FILTER>& map,
00484                    DATA_PACKER& data_packer,
00485                    void (DATA_PACKER::*pack)(T *,int,AFEPack::ostream<>&),
00486                    void (DATA_PACKER::*unpack)(T *,int,AFEPack::istream<>&)) {
00487     typedef Transmit_map<T,SHARED_TYPE_FILTER> map_t;
00488     typedef typename map_t::value_t value_t;
00489 
00490     std::list<int> target_list;
00491     std::list<BinaryBuffer<> > data_buffer_in, data_buffer_out;
00492 
00493     int n = 0;
00494     typename map_t::iterator
00495       the_pair = map.begin(),
00496       end_pair = map.end();
00497     for (;the_pair != end_pair;++ the_pair, ++ n) {
00498       int rank = the_pair->first;
00499       target_list.push_back(rank);
00500 
00501       data_buffer_in.push_back(BinaryBuffer<>());
00502       data_buffer_out.push_back(BinaryBuffer<>());
00503 
00504       AFEPack::ostream<> os(data_buffer_out.back());
00505       int n_item = the_pair->second.first;
00506       if (n_item == 0) continue;
00507 
00508       value_t& lst = the_pair->second.second;
00509       os << n_item; 
00510       typename value_t::iterator
00511         the_ptr = lst.begin(), end_ptr = lst.end();
00512       for (;the_ptr != end_ptr;++ the_ptr) {
00513         T *& local_obj = the_ptr->first;
00514         T *& remote_obj = the_ptr->second;
00515         os << remote_obj;
00516         (data_packer.*pack)(local_obj, rank, os);
00517       }
00518     }
00519 
00520     sendrecv_data(comm, n, data_buffer_out.begin(), data_buffer_in.begin(),
00521                   target_list.begin());
00522 
00523     typename std::list<BinaryBuffer<> >::iterator 
00524       the_buf = data_buffer_in.begin();
00525     the_pair = map.begin();
00526     for (;the_pair != end_pair;++ the_pair, ++ the_buf) {
00527       if (the_buf->size() == 0) continue;
00528 
00529       int rank = the_pair->first;
00530       AFEPack::istream<> is(*the_buf);
00531       int n_item;
00532       T * local_obj;
00533       is >> n_item; 
00534       for (int i = 0;i < n_item;++ i) {
00535         is >> local_obj;
00536         (data_packer.*unpack)(local_obj, rank, is);
00537       }
00538     }
00539   }
00540 
00551   template <class T, class SHARED_LIST, class DATA_PACKER>
00552     void sync_data(MPI_Comm comm,
00553                    SHARED_LIST& shlist,
00554                    DATA_PACKER& data_packer,
00555                    void (DATA_PACKER::*pack)(T *,int,AFEPack::ostream<>&),
00556                    void (DATA_PACKER::*unpack)(T *,int,AFEPack::istream<>&)) {
00557     Transmit_map<T> map;
00558     map.build(shlist);
00559     sync_data(comm, map, data_packer, pack, unpack);
00560   }
00561 
00562   template <class T, class SHARED_LIST, class DATA_PACKER, class SHARED_TYPE_FILTER>
00563     void sync_data(MPI_Comm comm,
00564                    SHARED_LIST& shlist,
00565                    DATA_PACKER& data_packer,
00566                    void (DATA_PACKER::*pack)(T *,int,AFEPack::ostream<>&),
00567                    void (DATA_PACKER::*unpack)(T *,int,AFEPack::istream<>&),
00568                    const SHARED_TYPE_FILTER& stf) {
00569     Transmit_map<T,SHARED_TYPE_FILTER> map;
00570     map.build(shlist);
00571     sync_data(comm, map, data_packer, pack, unpack);
00572   }
00573 
00577   template <class T, class SHARED_LIST, class DATA_PACKER>
00578     void sync_data(MPI_Comm comm,
00579                    SHARED_LIST& shlist,
00580                    DATA_PACKER& data_packer,
00581                    void (DATA_PACKER::*pack)(T *,int,AFEPack::ostream<>&),
00582                    void (DATA_PACKER::*unpack)(T *,int,AFEPack::istream<>&),
00583                    bool (DATA_PACKER::*filter)(T *)) {
00584     Transmit_map<T> map;
00585     map.build(shlist, data_packer, filter);
00586     sync_data(comm, map, data_packer, pack, unpack);
00587   }
00588 
00589   template <class T, class SHARED_LIST, class DATA_PACKER, class SHARED_TYPE_FILTER>
00590     void sync_data(MPI_Comm comm,
00591                    SHARED_LIST& shlist,
00592                    DATA_PACKER& data_packer,
00593                    void (DATA_PACKER::*pack)(T *,int,AFEPack::ostream<>&),
00594                    void (DATA_PACKER::*unpack)(T *,int,AFEPack::istream<>&),
00595                    bool (DATA_PACKER::*filter)(T *),
00596                    const SHARED_TYPE_FILTER& stf) {
00597     Transmit_map<T,SHARED_TYPE_FILTER> map;
00598     map.build(shlist, data_packer, filter);
00599     sync_data(comm, map, data_packer, pack, unpack);
00600   }
00601 
00605   template <class T, class SHARED_LIST, class DATA_PACKER>
00606     void sync_data(MPI_Comm comm,
00607                    SHARED_LIST& shlist,
00608                    DATA_PACKER& data_packer,
00609                    void (DATA_PACKER::*pack)(T *,int,AFEPack::ostream<>&),
00610                    void (DATA_PACKER::*unpack)(T *,int,AFEPack::istream<>&),
00611                    bool (DATA_PACKER::*filter)(T *) const) {
00612     Transmit_map<T> map;
00613     map.build(shlist, data_packer, filter);
00614     sync_data(comm, map, data_packer, pack, unpack);
00615   }
00616 
00617   template <class T, class SHARED_LIST, class DATA_PACKER, class SHARED_TYPE_FILTER>
00618     void sync_data(MPI_Comm comm,
00619                    SHARED_LIST& shlist,
00620                    DATA_PACKER& data_packer,
00621                    void (DATA_PACKER::*pack)(T *,int,AFEPack::ostream<>&),
00622                    void (DATA_PACKER::*unpack)(T *,int,AFEPack::istream<>&),
00623                    bool (DATA_PACKER::*filter)(T *) const,
00624                    const SHARED_TYPE_FILTER& stf) {
00625     Transmit_map<T,SHARED_TYPE_FILTER> map;
00626     map.build(shlist, data_packer, filter);
00627     sync_data(comm, map, data_packer, pack, unpack);
00628   }
00629 
00630 } // namespace MPI
00631 
00632 #endif // __MPI_h__
00633