1 #ifndef _PEXSI_TREE_HPP_
2 #define _PEXSI_TREE_HPP_
16 #define FTREE_LIMIT 16
24 extern std::map< MPI_Comm , std::vector<int> > commGlobRanks;
33 MPI_Request recvRequest_;
34 NumVec<char> myRecvBuffer_;
36 NumVec<MPI_Request> myRequests_;
37 NumVec<MPI_Status> myStatuses_;
53 #ifdef COMM_PROFILE_BCAST
59 void SetGlobalComm(
const MPI_Comm & pGComm){
60 if(commGlobRanks.count(comm_)==0){
61 MPI_Group group2 = MPI_GROUP_NULL;
62 MPI_Comm_group(pGComm, &group2);
63 MPI_Group group1 = MPI_GROUP_NULL;
64 MPI_Comm_group(comm_, &group1);
67 MPI_Comm_size(comm_,&size);
68 vector<int> globRanks(size);
69 vector<int> Lranks(size);
70 for(
int i = 0; i<size;++i){Lranks[i]=i;}
71 MPI_Group_translate_ranks(group1, size, &Lranks[0],group2, &globRanks[0]);
72 commGlobRanks[comm_] = globRanks;
74 myGRoot_ = commGlobRanks[comm_][myRoot_];
75 myGRank_ = commGlobRanks[comm_][myRank_];
89 virtual void buildTree(Int * ranks, Int rank_cnt)=0;
98 comm_ = MPI_COMM_WORLD;
107 recvRequest_ = MPI_REQUEST_NULL;
114 TreeBcast2(
const MPI_Comm & pComm, Int * ranks, Int rank_cnt,Int msgSize):TreeBcast2(){
116 MPI_Comm_rank(comm_,&myRank_);
126 virtual TreeBcast2 * clone()
const = 0;
128 TreeBcast2(
const TreeBcast2 & Tree){
132 virtual void Copy(
const TreeBcast2 & Tree){
133 this->comm_ = Tree.comm_;
134 this->myRank_ = Tree.myRank_;
135 this->myRoot_ = Tree.myRoot_;
136 this->msgSize_ = Tree.msgSize_;
138 this->numRecv_ = Tree.numRecv_;
139 this->tag_= Tree.tag_;
140 this->mainRoot_= Tree.mainRoot_;
141 this->isReady_ = Tree.isReady_;
142 this->myDests_ = Tree.myDests_;
145 this->recvRequest_ = Tree.recvRequest_;
146 this->myRecvBuffer_ = Tree.myRecvBuffer_;
147 this->myRequests_ = Tree.myRequests_;
148 this->myStatuses_ = Tree.myStatuses_;
149 this->myData_ = Tree.myData_;
150 if(Tree.myData_==(T*)&Tree.myRecvBuffer_[0]){
151 this->myData_=(T*)&this->myRecvBuffer_[0];
157 this->fwded_= Tree.fwded_;
159 this->done_= Tree.done_;
167 recvRequest_ = MPI_REQUEST_NULL;
176 virtual ~TreeBcast2(){
181 static TreeBcast2<T> * Create(
const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize,
double rseed);
183 virtual inline Int GetNumRecvMsg(){
return numRecv_;}
184 virtual inline Int GetNumMsgToSend(){
return GetDestCount();}
185 inline void SetDataReady(
bool rdy){
189 inline void SetTag(Int tag){ tag_ = tag;}
190 inline int GetTag(){
return tag_;}
192 bool IsDone(){
return done_;}
193 bool IsDataReady(){
return isReady_;}
194 bool IsDataReceived(){
return numRecv_==1;}
196 Int * GetDests(){
return &myDests_[0];}
197 Int GetDest(Int i){
return myDests_[i];}
198 Int GetDestCount(){
return myDests_.size();}
199 Int GetRoot(){
return myRoot_;}
201 bool IsRoot(){
return myRoot_==myRank_;}
202 Int GetMsgSize(){
return msgSize_;}
204 void ForwardMessage( ){
205 if(myRequests_.m()!=GetDestCount()){
206 myRequests_.Resize(GetDestCount());
207 SetValue(myRequests_,MPI_REQUEST_NULL);
209 for( Int idxRecv = 0; idxRecv < myDests_.size(); ++idxRecv ){
210 Int iProc = myDests_[idxRecv];
212 MPI_Isend( myData_, msgSize_, MPI_BYTE,
213 iProc, tag_,comm_, &myRequests_[idxRecv] );
215 #if ( _DEBUGlevel_ >= 1 ) || defined(BCAST_VERBOSE)
216 statusOFS<<myRank_<<
" FWD to "<<iProc<<
" on tag "<<tag_<<std::endl;
218 #ifdef COMM_PROFILE_BCAST
223 PROFILE_COMM(myGRank_,commGlobRanks[comm_][iProc],tag_,msgSize_);
229 void CleanupBuffers(){
232 myRecvBuffer_.Clear();
236 void SetLocalBuffer(T * locBuffer){
237 if(myData_!=NULL && myData_!=locBuffer){
239 CopyLocalBuffer(locBuffer);
242 myRecvBuffer_.Clear();
250 virtual bool Progress(){
257 if (myRank_==myRoot_){
260 #if ( _DEBUGlevel_ >= 1 ) || defined(BCAST_VERBOSE)
261 statusOFS<<myRank_<<
" FORWARDING on tag "<<tag_<<std::endl;
267 if(myStatuses_.m()!=GetDestCount()){
268 myStatuses_.Resize(GetDestCount());
269 recvRequest_ = MPI_REQUEST_NULL;
273 int reqCnt = GetDestCount();
275 assert(reqCnt == myRequests_.m());
276 MPI_Testall(reqCnt,&myRequests_[0],&flag,&myStatuses_[0]);
286 bool received = (numRecv_==1);
289 if(recvRequest_ == MPI_REQUEST_NULL ){
290 #if ( _DEBUGlevel_ >= 1 ) || defined(BCAST_VERBOSE)
291 statusOFS<<myRank_<<
" POSTING RECV on tag "<<tag_<<std::endl;
298 if(myStatuses_.m()!=GetDestCount()){
299 myStatuses_.Resize(GetDestCount());
300 recvRequest_ = MPI_REQUEST_NULL;
302 #if ( _DEBUGlevel_ >= 1 ) || defined(BCAST_VERBOSE)
303 statusOFS<<myRank_<<
" TESTING RECV on tag "<<tag_<<std::endl;
308 int test = MPI_Test(&recvRequest_,&flag,&stat);
309 assert(test==MPI_SUCCESS);
316 #if ( _DEBUGlevel_ >= 1 ) || defined(BCAST_VERBOSE)
317 statusOFS<<myRank_<<
" FORWARDING on tag "<<tag_<<std::endl;
328 int reqCnt = GetDestCount();
330 assert(reqCnt == myRequests_.m());
331 MPI_Testall(reqCnt,&myRequests_[0],&flag,&myStatuses_[0]);
344 #if ( _DEBUGlevel_ >= 1 ) || defined(BCAST_VERBOSE)
345 statusOFS<<myRank_<<
" EVERYTHING COMPLETED on tag "<<tag_<<std::endl;
361 T * GetLocalBuffer(){
365 virtual void PostRecv()
367 if(this->numRecv_<1 && this->recvRequest_==MPI_REQUEST_NULL && myRank_!=myRoot_){
370 myRecvBuffer_.Resize(msgSize_);
371 myData_ = (T*)&myRecvBuffer_[0];
373 MPI_Irecv( (
char*)this->myData_, this->msgSize_, MPI_BYTE,
374 this->myRoot_, this->tag_,this->comm_, &this->recvRequest_ );
380 void CopyLocalBuffer(T* destBuffer){
381 std::copy((
char*)myData_,(
char*)myData_+GetMsgSize(),(
char*)destBuffer);
390 template<
typename T>
391 class FTreeBcast2:
public TreeBcast2<T>{
393 virtual void buildTree(Int * ranks, Int rank_cnt){
396 Int idxEnd = rank_cnt;
400 this->myRoot_ = ranks[0];
402 if(this->myRank_==this->myRoot_){
403 this->myDests_.insert(this->myDests_.begin(),&ranks[1],&ranks[0]+rank_cnt);
406 #if (defined(BCAST_VERBOSE))
407 statusOFS<<
"My root is "<<this->myRoot_<<std::endl;
408 statusOFS<<
"My dests are ";
409 for(
int i =0;i<this->myDests_.size();++i){statusOFS<<this->myDests_[i]<<
" ";}
410 statusOFS<<std::endl;
417 FTreeBcast2(
const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize):TreeBcast2<T>(pComm,ranks,rank_cnt,msgSize){
419 buildTree(ranks,rank_cnt);
423 virtual FTreeBcast2 * clone()
const{
424 FTreeBcast2 * out =
new FTreeBcast2(*
this);
429 template<
typename T>
430 class BTreeBcast2:
public TreeBcast2<T>{
455 virtual void buildTree(Int * ranks, Int rank_cnt){
458 Int idxEnd = rank_cnt;
462 Int prevRoot = ranks[0];
463 while(idxStart<idxEnd){
464 Int curRoot = ranks[idxStart];
465 Int listSize = idxEnd - idxStart;
468 if(curRoot == this->myRank_){
469 this->myRoot_ = prevRoot;
474 Int halfList = floor(ceil(
double(listSize) / 2.0));
475 Int idxStartL = idxStart+1;
476 Int idxStartH = idxStart+halfList;
478 if(curRoot == this->myRank_){
479 if ((idxEnd - idxStartH) > 0 && (idxStartH - idxStartL)>0){
480 Int childL = ranks[idxStartL];
481 Int childR = ranks[idxStartH];
483 this->myDests_.push_back(childL);
484 this->myDests_.push_back(childR);
486 else if ((idxEnd - idxStartH) > 0){
487 Int childR = ranks[idxStartH];
488 this->myDests_.push_back(childR);
491 Int childL = ranks[idxStartL];
492 this->myDests_.push_back(childL);
494 this->myRoot_ = prevRoot;
498 if( this->myRank_ < ranks[idxStartH]){
499 idxStart = idxStartL;
503 idxStart = idxStartH;
510 #if (defined(BCAST_VERBOSE))
511 statusOFS<<
"My root is "<<myRoot_<<std::endl;
512 statusOFS<<
"My dests are ";
513 for(
int i =0;i<this->myDests_.size();++i){statusOFS<<this->myDests_[i]<<
" ";}
514 statusOFS<<std::endl;
521 BTreeBcast2(
const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize):TreeBcast2<T>(pComm,ranks,rank_cnt,msgSize){
523 buildTree(ranks,rank_cnt);
526 virtual BTreeBcast2<T> * clone()
const{
527 BTreeBcast2<T> * out =
new BTreeBcast2<T>(*this);
535 template<
typename T>
536 class ModBTreeBcast2:
public TreeBcast2<T>{
540 virtual void buildTree(Int * ranks, Int rank_cnt){
543 Int idxEnd = rank_cnt;
550 Int new_idx = (Int)rseed_ % (rank_cnt - 1) + 1;
556 Int * new_start = &ranks[new_idx];
563 std::rotate(&ranks[1], new_start, &ranks[0]+rank_cnt);
568 Int prevRoot = ranks[0];
569 while(idxStart<idxEnd){
570 Int curRoot = ranks[idxStart];
571 Int listSize = idxEnd - idxStart;
574 if(curRoot == this->myRank_){
575 this->myRoot_ = prevRoot;
580 Int halfList = floor(ceil(
double(listSize) / 2.0));
581 Int idxStartL = idxStart+1;
582 Int idxStartH = idxStart+halfList;
584 if(curRoot == this->myRank_){
585 if ((idxEnd - idxStartH) > 0 && (idxStartH - idxStartL)>0){
586 Int childL = ranks[idxStartL];
587 Int childR = ranks[idxStartH];
589 this->myDests_.push_back(childL);
590 this->myDests_.push_back(childR);
592 else if ((idxEnd - idxStartH) > 0){
593 Int childR = ranks[idxStartH];
594 this->myDests_.push_back(childR);
597 Int childL = ranks[idxStartL];
598 this->myDests_.push_back(childL);
600 this->myRoot_ = prevRoot;
606 TIMER_START(FIND_RANK);
607 Int * pos = std::find(&ranks[idxStartL], &ranks[idxStartH], this->myRank_);
608 TIMER_STOP(FIND_RANK);
609 if( pos != &ranks[idxStartH]){
610 idxStart = idxStartL;
614 idxStart = idxStartH;
621 #if (defined(REDUCE_VERBOSE))
622 statusOFS<<
"My root is "<<myRoot_<<std::endl;
623 statusOFS<<
"My dests are ";
624 for(
int i =0;i<this->myDests_.size();++i){statusOFS<<this->myDests_[i]<<
" ";}
625 statusOFS<<std::endl;
632 ModBTreeBcast2(
const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize,
double rseed):TreeBcast2<T>(pComm,ranks,rank_cnt,msgSize){
635 buildTree(ranks,rank_cnt);
662 virtual ModBTreeBcast2 * clone()
const{
663 ModBTreeBcast2 * out =
new ModBTreeBcast2(*
this);
709 vector<Int> myDests_;
718 #if defined(COMM_PROFILE_BCAST) || defined(COMM_PROFILE)
724 void SetGlobalComm(
const MPI_Comm & pGComm){
725 if(commGlobRanks.count(comm_)==0){
726 MPI_Group group2 = MPI_GROUP_NULL;
727 MPI_Comm_group(pGComm, &group2);
728 MPI_Group group1 = MPI_GROUP_NULL;
729 MPI_Comm_group(comm_, &group1);
732 MPI_Comm_size(comm_,&size);
733 vector<int> globRanks(size);
734 vector<int> Lranks(size);
735 for(
int i = 0; i<size;++i){Lranks[i]=i;}
736 MPI_Group_translate_ranks(group1, size, &Lranks[0],group2, &globRanks[0]);
737 commGlobRanks[comm_] = globRanks;
739 myGRoot_ = commGlobRanks[comm_][myRoot_];
740 myGRank_ = commGlobRanks[comm_][myRank_];
752 virtual void buildTree(Int * ranks, Int rank_cnt)=0;
755 comm_ = MPI_COMM_WORLD;
766 TreeBcast(
const MPI_Comm & pComm, Int * ranks, Int rank_cnt,Int msgSize){
768 MPI_Comm_rank(comm_,&myRank_);
782 virtual void Copy(
const TreeBcast & Tree){
784 myRank_ = Tree.myRank_;
785 myRoot_ = Tree.myRoot_;
786 msgSize_ = Tree.msgSize_;
789 mainRoot_= Tree.mainRoot_;
790 myDests_ = Tree.myDests_;
803 this->isReady_ =
false;
808 static TreeBcast * Create(
const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize,
double rseed);
810 virtual inline Int GetNumRecvMsg(){
return numRecv_;}
811 virtual inline Int GetNumMsgToRecv(){
return 1;}
812 inline void SetDataReady(
bool rdy){ isReady_=rdy; }
813 inline void SetTag(Int tag){ tag_ = tag;}
814 inline int GetTag(){
return tag_;}
817 Int * GetDests(){
return &myDests_[0];}
818 Int GetDest(Int i){
return myDests_[i];}
819 Int GetDestCount(){
return myDests_.size();}
820 Int GetRoot(){
return myRoot_;}
821 Int GetMsgSize(){
return msgSize_;}
823 void ForwardMessage(
char * data,
size_t size,
int tag, MPI_Request * requests ){
825 for( Int idxRecv = 0; idxRecv < myDests_.size(); ++idxRecv ){
826 Int iProc = myDests_[idxRecv];
828 MPI_Isend( data, size, MPI_BYTE,
829 iProc, tag,comm_, &requests[2*iProc+1] );
831 #if defined(COMM_PROFILE_BCAST) || defined(COMM_PROFILE)
835 PROFILE_COMM(myGRank_,commGlobRanks[comm_][iProc],tag_,msgSize_);
845 virtual void buildTree(Int * ranks, Int rank_cnt){
848 Int idxEnd = rank_cnt;
854 if(myRank_==myRoot_){
855 myDests_.insert(myDests_.begin(),&ranks[1],&ranks[0]+rank_cnt);
858 #if (defined(BCAST_VERBOSE))
859 statusOFS<<
"My root is "<<myRoot_<<std::endl;
860 statusOFS<<
"My dests are ";
861 for(
int i =0;i<myDests_.size();++i){statusOFS<<myDests_[i]<<
" ";}
862 statusOFS<<std::endl;
869 FTreeBcast(
const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize):
TreeBcast(pComm,ranks,rank_cnt,msgSize){
871 buildTree(ranks,rank_cnt);
908 virtual void buildTree(Int * ranks, Int rank_cnt){
911 Int idxEnd = rank_cnt;
915 Int prevRoot = ranks[0];
916 while(idxStart<idxEnd){
917 Int curRoot = ranks[idxStart];
918 Int listSize = idxEnd - idxStart;
921 if(curRoot == myRank_){
927 Int halfList = floor(ceil(
double(listSize) / 2.0));
928 Int idxStartL = idxStart+1;
929 Int idxStartH = idxStart+halfList;
931 if(curRoot == myRank_){
932 if ((idxEnd - idxStartH) > 0 && (idxStartH - idxStartL)>0){
933 Int childL = ranks[idxStartL];
934 Int childR = ranks[idxStartH];
936 myDests_.push_back(childL);
937 myDests_.push_back(childR);
939 else if ((idxEnd - idxStartH) > 0){
940 Int childR = ranks[idxStartH];
941 myDests_.push_back(childR);
944 Int childL = ranks[idxStartL];
945 myDests_.push_back(childL);
951 if( myRank_ < ranks[idxStartH]){
952 idxStart = idxStartL;
956 idxStart = idxStartH;
963 #if (defined(BCAST_VERBOSE))
964 statusOFS<<
"My root is "<<myRoot_<<std::endl;
965 statusOFS<<
"My dests are ";
966 for(
int i =0;i<myDests_.size();++i){statusOFS<<myDests_[i]<<
" ";}
967 statusOFS<<std::endl;
974 BTreeBcast(
const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize):
TreeBcast(pComm,ranks,rank_cnt,msgSize){
976 buildTree(ranks,rank_cnt);
992 virtual void buildTree(Int * ranks, Int rank_cnt){
995 Int idxEnd = rank_cnt;
1003 Int new_idx = (int)((rank_cnt - 0) * ( (double)this->rseed_ / (
double)RAND_MAX ) + 0);
1008 Int * new_start = &ranks[new_idx];
1015 std::rotate(&ranks[1], new_start, &ranks[0]+rank_cnt);
1020 Int prevRoot = ranks[0];
1021 while(idxStart<idxEnd){
1022 Int curRoot = ranks[idxStart];
1023 Int listSize = idxEnd - idxStart;
1026 if(curRoot == myRank_){
1032 Int halfList = floor(ceil(
double(listSize) / 2.0));
1033 Int idxStartL = idxStart+1;
1034 Int idxStartH = idxStart+halfList;
1036 if(curRoot == myRank_){
1037 if ((idxEnd - idxStartH) > 0 && (idxStartH - idxStartL)>0){
1038 Int childL = ranks[idxStartL];
1039 Int childR = ranks[idxStartH];
1041 myDests_.push_back(childL);
1042 myDests_.push_back(childR);
1044 else if ((idxEnd - idxStartH) > 0){
1045 Int childR = ranks[idxStartH];
1046 myDests_.push_back(childR);
1049 Int childL = ranks[idxStartL];
1050 myDests_.push_back(childL);
1058 TIMER_START(FIND_RANK);
1059 Int * pos = std::find(&ranks[idxStartL], &ranks[idxStartH], myRank_);
1060 TIMER_STOP(FIND_RANK);
1061 if( pos != &ranks[idxStartH]){
1062 idxStart = idxStartL;
1066 idxStart = idxStartH;
1073 #if (defined(REDUCE_VERBOSE))
1074 statusOFS<<
"My root is "<<myRoot_<<std::endl;
1075 statusOFS<<
"My dests are ";
1076 for(
int i =0;i<myDests_.size();++i){statusOFS<<myDests_[i]<<
" ";}
1077 statusOFS<<std::endl;
1084 ModBTreeBcast(
const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize,
double rseed):
TreeBcast(pComm,ranks,rank_cnt,msgSize){
1087 buildTree(ranks,rank_cnt);
1114 rseed_ = Tree.rseed_;
1128 virtual void buildTree(Int * ranks, Int rank_cnt){
1131 Int idxEnd = rank_cnt;
1135 for(
int i =0;i<rank_cnt;++i){statusOFS<<ranks[i]<<
" ";} statusOFS<<std::endl;
1137 std::random_shuffle(&ranks[1],&ranks[0]+rank_cnt);
1138 for(
int i =0;i<rank_cnt;++i){statusOFS<<ranks[i]<<
" ";} statusOFS<<std::endl;
1142 Int prevRoot = ranks[0];
1143 while(idxStart<idxEnd){
1144 Int curRoot = ranks[idxStart];
1145 Int listSize = idxEnd - idxStart;
1148 if(curRoot == myRank_){
1154 Int halfList = floor(ceil(
double(listSize) / 2.0));
1155 Int idxStartL = idxStart+1;
1156 Int idxStartH = idxStart+halfList;
1158 if(curRoot == myRank_){
1159 if ((idxEnd - idxStartH) > 0 && (idxStartH - idxStartL)>0){
1160 Int childL = ranks[idxStartL];
1161 Int childR = ranks[idxStartH];
1163 myDests_.push_back(childL);
1164 myDests_.push_back(childR);
1166 else if ((idxEnd - idxStartH) > 0){
1167 Int childR = ranks[idxStartH];
1168 myDests_.push_back(childR);
1171 Int childL = ranks[idxStartL];
1172 myDests_.push_back(childL);
1180 Int * pos = std::find(&ranks[idxStartL], &ranks[idxStartH], myRank_);
1181 if( pos != &ranks[idxStartH]){
1182 idxStart = idxStartL;
1186 idxStart = idxStartH;
1193 #if (defined(REDUCE_VERBOSE))
1194 statusOFS<<
"My root is "<<myRoot_<<std::endl;
1195 statusOFS<<
"My dests are ";
1196 for(
int i =0;i<myDests_.size();++i){statusOFS<<myDests_[i]<<
" ";}
1197 statusOFS<<std::endl;
1204 RandBTreeBcast(
const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize):
TreeBcast(pComm,ranks,rank_cnt,msgSize){
1206 buildTree(ranks,rank_cnt);
1221 virtual void buildTree(Int * ranks, Int rank_cnt){
1222 Int numLevel = floor(log2(rank_cnt));
1224 for(Int level=0;level<numLevel;++level){
1225 numRoots = std::min( rank_cnt, numRoots + (Int)pow(2.0,level));
1226 Int numNextRoots = std::min(rank_cnt,numRoots + (Int)pow(2.0,(level+1)));
1227 Int numReceivers = numNextRoots - numRoots;
1228 for(Int ip = 0; ip<numRoots;++ip){
1230 for(Int ir = ip; ir<numReceivers;ir+=numRoots){
1231 Int r = ranks[numRoots+ir];
1237 myDests_.push_back(r);
1243 #if (defined(BCAST_VERBOSE))
1244 statusOFS<<
"My root is "<<myRoot_<<std::endl;
1245 statusOFS<<
"My dests are ";
1246 for(
int i =0;i<myDests_.size();++i){statusOFS<<myDests_[i]<<
" ";}
1247 statusOFS<<std::endl;
1254 PalmTreeBcast(
const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize):
TreeBcast(pComm,ranks,rank_cnt,msgSize){
1256 buildTree(ranks,rank_cnt);
1267 template<
typename T>
1271 MPI_Request sendRequest_;
1285 TreeReduce(
const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize):
TreeBcast(pComm,ranks,rank_cnt,msgSize){
1287 sendRequest_ = MPI_REQUEST_NULL;
1315 this->myData_ = NULL;
1316 this->sendRequest_ = MPI_REQUEST_NULL;
1317 this->fwded_=
false;
1319 this->isAllocated_= Tree.isAllocated_;
1320 this->numRecvPosted_= 0;
1331 bool IsAllocated(){
return isAllocated_;}
1338 static TreeReduce<T> * Create(
const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize,
double rseed);
1340 virtual inline Int GetNumMsgToRecv(){
return GetDestCount();}
1342 virtual void AllocRecvBuffers(){
1343 remoteData_.Resize(GetDestCount());
1350 myRecvBuffers_.Resize(GetDestCount()*msgSize_);
1353 for( Int idxRecv = 0; idxRecv < GetDestCount(); ++idxRecv ){
1354 remoteData_[idxRecv] = (T*)&(myRecvBuffers_[idxRecv*msgSize_]);
1359 myRequests_.Resize(GetDestCount());
1360 SetValue(myRequests_,MPI_REQUEST_NULL);
1361 myStatuses_.Resize(GetDestCount());
1362 recvIdx_.Resize(GetDestCount());
1364 sendRequest_ = MPI_REQUEST_NULL;
1366 isAllocated_ =
true;
1369 void CleanupBuffers(){
1370 myLocalBuffer_.Clear();
1377 remoteData_.Clear();
1385 myRequests_.Clear();
1386 myStatuses_.Clear();
1404 sendRequest_ = MPI_REQUEST_NULL;
1414 void SetLocalBuffer(T * locBuffer){
1415 if(myData_!=NULL && myData_!=locBuffer){
1419 blas::Axpy(msgSize_/
sizeof(T), ONE<T>(), myData_, 1, locBuffer, 1 );
1420 myLocalBuffer_.Clear();
1423 myData_ = locBuffer;
1426 inline bool AccumulationDone(){
1427 if(myRank_==myRoot_ && isAllocated_){
1430 return isReady_ && (numRecv_ == GetDestCount());
1434 inline bool IsDone(){
1435 if(myRank_==myRoot_ && isAllocated_){
1439 bool retVal = AccumulationDone();
1440 if(myRoot_ != myRank_ && !fwded_){
1444 if (retVal && myRoot_ != myRank_ && fwded_){
1447 MPI_Test(&sendRequest_,&flag,MPI_STATUS_IGNORE);
1455 virtual bool Progress(){
1464 if(myRank_==myRoot_ && isAllocated_){
1472 bool retVal = AccumulationDone();
1473 if(isReady_ && !retVal){
1479 int reqCnt = GetDestCount();
1480 assert(reqCnt == myRequests_.m());
1481 MPI_Testsome(reqCnt,&myRequests_[0],&recvCount,&recvIdx_[0],&myStatuses_[0]);
1483 for(Int i = 0;i<recvCount;++i ){
1484 Int idx = recvIdx_[i];
1486 if(idx!=MPI_UNDEFINED){
1489 MPI_Get_count(&myStatuses_[i], MPI_BYTE, &size);
1492 #if ( _DEBUGlevel_ >= 1 ) || defined(REDUCE_VERBOSE)
1493 statusOFS<<myRank_<<
" RECVD from "<<myStatuses_[i].MPI_SOURCE<<
" on tag "<<tag_<<std::endl;
1499 myLocalBuffer_.Resize(msgSize_);
1501 myData_ = (T*)&myLocalBuffer_[0];
1502 Int nelem = +msgSize_/
sizeof(T);
1503 std::fill(myData_,myData_+nelem,ZERO<T>());
1516 else if (isReady_ && sendRequest_ == MPI_REQUEST_NULL && myRoot_ != myRank_ && !fwded_){
1518 myRecvBuffers_.Clear();
1519 myRequests_.Clear();
1520 myStatuses_.Clear();
1533 myRecvBuffers_.Clear();
1534 myRequests_.Clear();
1535 myStatuses_.Clear();
1552 T * GetLocalBuffer(){
1558 void CopyLocalBuffer(T* destBuffer){
1559 std::copy((
char*)myData_,(
char*)myData_+GetMsgSize(),(
char*)destBuffer);
1563 virtual void PostFirstRecv()
1565 if(this->GetDestCount()>this->numRecvPosted_){
1566 for( Int idxRecv = 0; idxRecv < myDests_.size(); ++idxRecv ){
1567 Int iProc = myDests_[idxRecv];
1569 MPI_Irecv( (
char*)remoteData_[idxRecv], msgSize_, MPI_BYTE,
1570 iProc, tag_,comm_, &myRequests_[idxRecv] );
1571 this->numRecvPosted_++;
1580 virtual void Reduce( Int idxRecv, Int idReq){
1582 blas::Axpy(msgSize_/
sizeof(T), ONE<T>(), remoteData_[idxRecv], 1, myData_, 1 );
1587 Int iProc = myRoot_;
1590 MPI_Isend( NULL, 0, MPI_BYTE,
1591 iProc, tag_,comm_, &sendRequest_ );
1593 PROFILE_COMM(myGRank_,myGRoot_,tag_,0);
1597 MPI_Isend( (
char*)myData_, msgSize_, MPI_BYTE,
1598 iProc, tag_,comm_, &sendRequest_ );
1600 PROFILE_COMM(myGRank_,myGRoot_,tag_,msgSize_);
1604 #if ( _DEBUGlevel_ >= 1 ) || defined(REDUCE_VERBOSE)
1605 statusOFS<<myRank_<<
" FWD to "<<iProc<<
" on tag "<<tag_<<std::endl;
1615 template<
typename T>
1618 virtual void buildTree(Int * ranks, Int rank_cnt){
1621 Int idxEnd = rank_cnt;
1625 this->myRoot_ = ranks[0];
1627 if(this->myRank_==this->myRoot_){
1628 this->myDests_.insert(this->myDests_.begin(),&ranks[1],&ranks[0]+rank_cnt);
1631 #if (defined(REDUCE_VERBOSE))
1632 statusOFS<<
"My root is "<<this->myRoot_<<std::endl;
1633 statusOFS<<
"My dests are ";
1634 for(
int i =0;i<this->myDests_.size();++i){statusOFS<<this->myDests_[i]<<
" ";}
1635 statusOFS<<std::endl;
1639 virtual void Reduce( ){
1641 blas::Axpy(this->msgSize_/
sizeof(T), ONE<T>(), this->remoteData_[0], 1, this->myData_, 1 );
1644 #if (defined(REDUCE_DEBUG))
1645 statusOFS << std::endl <<
" Recv contrib"<< std::endl;
1646 for(
int i = 0; i < this->msgSize_/
sizeof(T); ++i){
1647 statusOFS<< this->remoteData_[0][i]<<
" ";
1648 if(i%3==0){statusOFS<<std::endl;}
1650 statusOFS<<std::endl;
1652 statusOFS << std::endl <<
" Reduce buffer now is"<< std::endl;
1653 for(
int i = 0; i < this->msgSize_/
sizeof(T); ++i){
1654 statusOFS<< this->myData_[i]<<
" ";
1655 if(i%3==0){statusOFS<<std::endl;}
1657 statusOFS<<std::endl;
1665 FTreeReduce(
const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize):
TreeReduce<T>(pComm, ranks, rank_cnt, msgSize){
1666 buildTree(ranks,rank_cnt);
1669 virtual void PostFirstRecv()
1674 if(this->isAllocated_ && this->GetDestCount()>this->numRecvPosted_){
1675 MPI_Irecv( (
char*)this->remoteData_[0], this->msgSize_, MPI_BYTE,
1676 MPI_ANY_SOURCE, this->tag_,this->comm_, &this->myRequests_[0] );
1677 this->numRecvPosted_++;
1681 virtual void AllocRecvBuffers(){
1682 if(this->GetDestCount()>0){
1683 this->remoteData_.Resize(1);
1685 this->myRecvBuffers_.Resize(this->msgSize_);
1687 this->remoteData_[0] = (T*)&(this->myRecvBuffers_[0]);
1689 this->myRequests_.Resize(1);
1690 SetValue(this->myRequests_,MPI_REQUEST_NULL);
1691 this->myStatuses_.Resize(1);
1692 this->recvIdx_.Resize(1);
1695 this->sendRequest_ = MPI_REQUEST_NULL;
1697 this->isAllocated_ =
true;
1700 virtual bool Progress(){
1702 if(!this->isAllocated_){
1707 if(this->myRank_==this->myRoot_ && this->isAllocated_){
1708 this->isReady_=
true;
1715 bool retVal = this->AccumulationDone();
1716 if(this->isReady_ && !retVal){
1724 MPI_Testsome(reqCnt,&this->myRequests_[0],&recvCount,&this->recvIdx_[0],&this->myStatuses_[0]);
1727 for(Int i = 0;i<recvCount;++i ){
1728 Int idx = this->recvIdx_[i];
1730 if(idx!=MPI_UNDEFINED){
1733 MPI_Get_count(&this->myStatuses_[i], MPI_BYTE, &size);
1736 #if ( _DEBUGlevel_ >= 1 ) || defined(REDUCE_VERBOSE)
1738 statusOFS<<this->myRank_<<
" RECVD from "<<this->myStatuses_[i].MPI_SOURCE<<
" on tag "<<this->tag_<<std::endl;
1742 if(this->myData_==NULL){
1744 this->myLocalBuffer_.Resize(this->msgSize_);
1746 this->myData_ = (T*)&this->myLocalBuffer_[0];
1747 Int nelem = this->msgSize_/
sizeof(T);
1748 std::fill(this->myData_,this->myData_+nelem,ZERO<T>());
1759 this->PostFirstRecv();
1762 else if (this->isReady_ && this->sendRequest_ == MPI_REQUEST_NULL && this->myRoot_ != this->myRank_ && !this->fwded_){
1764 this->myRecvBuffers_.Clear();
1765 this->myRequests_.Clear();
1766 this->myStatuses_.Clear();
1767 this->recvIdx_.Clear();
1774 retVal = this->IsDone();
1777 this->myRecvBuffers_.Clear();
1778 this->myRequests_.Clear();
1779 this->myStatuses_.Clear();
1780 this->recvIdx_.Clear();
1799 template<
typename T>
1802 virtual void buildTree(Int * ranks, Int rank_cnt){
1804 Int idxEnd = rank_cnt;
1808 Int prevRoot = ranks[0];
1809 while(idxStart<idxEnd){
1810 Int curRoot = ranks[idxStart];
1811 Int listSize = idxEnd - idxStart;
1814 if(curRoot == this->myRank_){
1815 this->myRoot_ = prevRoot;
1820 Int halfList = floor(ceil(
double(listSize) / 2.0));
1821 Int idxStartL = idxStart+1;
1822 Int idxStartH = idxStart+halfList;
1824 if(curRoot == this->myRank_){
1825 if ((idxEnd - idxStartH) > 0 && (idxStartH - idxStartL)>0){
1826 Int childL = ranks[idxStartL];
1827 Int childR = ranks[idxStartH];
1829 this->myDests_.push_back(childL);
1830 this->myDests_.push_back(childR);
1832 else if ((idxEnd - idxStartH) > 0){
1833 Int childR = ranks[idxStartH];
1834 this->myDests_.push_back(childR);
1837 Int childL = ranks[idxStartL];
1838 this->myDests_.push_back(childL);
1840 this->myRoot_ = prevRoot;
1844 if( this->myRank_ < ranks[idxStartH]){
1845 idxStart = idxStartL;
1849 idxStart = idxStartH;
1856 #if (defined(REDUCE_VERBOSE))
1857 statusOFS<<
"My root is "<<this->myRoot_<<std::endl;
1858 statusOFS<<
"My dests are ";
1859 for(
int i =0;i<this->myDests_.size();++i){statusOFS<<this->myDests_[i]<<
" ";}
1860 statusOFS<<std::endl;
1864 BTreeReduce(
const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize):
TreeReduce<T>(pComm, ranks, rank_cnt, msgSize){
1865 buildTree(ranks,rank_cnt);
1875 template<
typename T>
1879 virtual void buildTree(Int * ranks, Int rank_cnt){
1882 Int idxEnd = rank_cnt;
1893 Int new_idx = (int)((rank_cnt - 0) * ( (double)this->rseed_ / (
double)RAND_MAX ) + 0);
1897 Int * new_start = &ranks[new_idx];
1903 std::rotate(&ranks[1], new_start, &ranks[0]+rank_cnt);
1907 Int prevRoot = ranks[0];
1908 while(idxStart<idxEnd){
1909 Int curRoot = ranks[idxStart];
1910 Int listSize = idxEnd - idxStart;
1913 if(curRoot == this->myRank_){
1914 this->myRoot_ = prevRoot;
1919 Int halfList = floor(ceil(
double(listSize) / 2.0));
1920 Int idxStartL = idxStart+1;
1921 Int idxStartH = idxStart+halfList;
1923 if(curRoot == this->myRank_){
1924 if ((idxEnd - idxStartH) > 0 && (idxStartH - idxStartL)>0){
1925 Int childL = ranks[idxStartL];
1926 Int childR = ranks[idxStartH];
1928 this->myDests_.push_back(childL);
1929 this->myDests_.push_back(childR);
1931 else if ((idxEnd - idxStartH) > 0){
1932 Int childR = ranks[idxStartH];
1933 this->myDests_.push_back(childR);
1936 Int childL = ranks[idxStartL];
1937 this->myDests_.push_back(childL);
1939 this->myRoot_ = prevRoot;
1945 TIMER_START(FIND_RANK);
1946 Int * pos = std::find(&ranks[idxStartL], &ranks[idxStartH], this->myRank_);
1947 TIMER_STOP(FIND_RANK);
1948 if( pos != &ranks[idxStartH]){
1949 idxStart = idxStartL;
1953 idxStart = idxStartH;
1960 #if (defined(REDUCE_VERBOSE))
1961 statusOFS<<
"My root is "<<this->myRoot_<<std::endl;
1962 statusOFS<<
"My dests are ";
1963 for(
int i =0;i<this->myDests_.size();++i){statusOFS<<this->myDests_[i]<<
" ";}
1964 statusOFS<<std::endl;
1968 ModBTreeReduce(
const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize,
double rseed):
TreeReduce<T>(pComm, ranks, rank_cnt, msgSize){
1969 this->rseed_ = rseed;
1970 buildTree(ranks,rank_cnt);
1999 this->rseed_ = Tree.rseed_;
2013 template<
typename T>
2017 virtual void buildTree(Int * ranks, Int rank_cnt){
2018 Int numLevel = floor(log2(rank_cnt));
2020 for(Int level=0;level<numLevel;++level){
2021 numRoots = std::min( rank_cnt, numRoots + (Int)pow(2,level));
2022 Int numNextRoots = std::min(rank_cnt,numRoots + (Int)pow(2,(level+1)));
2023 Int numReceivers = numNextRoots - numRoots;
2024 for(Int ip = 0; ip<numRoots;++ip){
2026 for(Int ir = ip; ir<numReceivers;ir+=numRoots){
2027 Int r = ranks[numRoots+ir];
2028 if(r==this->myRank_){
2032 if(p==this->myRank_){
2033 this->myDests_.push_back(r);
2039 #if (defined(BCAST_VERBOSE))
2040 statusOFS<<
"My root is "<<this->myRoot_<<std::endl;
2041 statusOFS<<
"My dests are ";
2042 for(
int i =0;i<this->myDests_.size();++i){statusOFS<<this->myDests_[i]<<
" ";}
2043 statusOFS<<std::endl;
2052 buildTree(ranks,rank_cnt);
2098 inline TreeBcast * TreeBcast::Create(
const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize,
double rseed){
2101 MPI_Comm_size(pComm, &nprocs);
2105 return new FTreeBcast(pComm,ranks,rank_cnt,msgSize);
2106 #elif defined(MODBTREE)
2107 return new ModBTreeBcast(pComm,ranks,rank_cnt,msgSize, rseed);
2108 #elif defined(BTREE)
2109 return new BTreeBcast(pComm,ranks,rank_cnt,msgSize);
2110 #elif defined(PALMTREE)
2119 if(nprocs<=FTREE_LIMIT){
2120 return new FTreeBcast(pComm,ranks,rank_cnt,msgSize);
2123 return new ModBTreeBcast(pComm,ranks,rank_cnt,msgSize, rseed);
2134 template<
typename T>
2135 inline TreeReduce<T> * TreeReduce<T>::Create(
const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize,
double rseed){
2138 MPI_Comm_size(pComm, &nprocs);
2141 return new FTreeReduce<T>(pComm,ranks,rank_cnt,msgSize);
2142 #elif defined(MODBTREE)
2143 return new ModBTreeReduce<T>(pComm,ranks,rank_cnt,msgSize, rseed);
2144 #elif defined(BTREE)
2145 return new BTreeReduce<T>(pComm,ranks,rank_cnt,msgSize);
2146 #elif defined(PALMTREE)
2147 return new PalmTreeReduce<T>(pComm,ranks,rank_cnt,msgSize);
2151 if(nprocs<=FTREE_LIMIT){
2152 #if ( _DEBUGlevel_ >= 1 ) || defined(REDUCE_VERBOSE)
2153 statusOFS<<
"FLAT TREE USED"<<endl;
2155 return new FTreeReduce<T>(pComm,ranks,rank_cnt,msgSize);
2158 #if ( _DEBUGlevel_ >= 1 ) || defined(REDUCE_VERBOSE)
2159 statusOFS<<
"BINARY TREE USED"<<endl;
2161 return new ModBTreeReduce<T>(pComm,ranks,rank_cnt,msgSize, rseed);
2168 template<
typename T>
2169 inline TreeBcast2<T> * TreeBcast2<T>::Create(
const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize,
double rseed){
2172 MPI_Comm_size(pComm, &nprocs);
2178 return new FTreeBcast2<T>(pComm,ranks,rank_cnt,msgSize);
2179 #elif defined(MODBTREE)
2180 return new ModBTreeBcast2<T>(pComm,ranks,rank_cnt,msgSize,rseed);
2181 #elif defined(BTREE)
2182 return new BTreeBcast2<T>(pComm,ranks,rank_cnt,msgSize);
2186 if(nprocs<=FTREE_LIMIT){
2187 #if ( _DEBUGlevel_ >= 1 ) || defined(REDUCE_VERBOSE)
2188 statusOFS<<
"FLAT TREE USED"<<endl;
2191 return new FTreeBcast2<T>(pComm,ranks,rank_cnt,msgSize);
2195 #if ( _DEBUGlevel_ >= 1 ) || defined(REDUCE_VERBOSE)
2196 statusOFS<<
"BINARY TREE USED"<<endl;
2198 return new ModBTreeBcast2<T>(pComm,ranks,rank_cnt,msgSize, rseed);
Definition: TreeBcast.hpp:705
Definition: TreeBcast.hpp:1616
Definition: TreeBcast.hpp:1126
virtual void Copy(const ModBTreeBcast &Tree)
Definition: TreeBcast.hpp:1090
Definition: TreeBcast.hpp:2014
Profiling and timing using TAU.
void SetValue(NumMat< F > &M, F val)
SetValue sets a numerical matrix to a constant val.
Definition: NumMat_impl.hpp:171
Definition: TreeBcast.hpp:843
Definition: TreeBcast.hpp:988
Definition: TreeBcast.hpp:1876
Definition: TreeBcast.hpp:883
Definition: TreeBcast.hpp:1219
Definition: TreeBcast.hpp:1268
Definition: TreeBcast.hpp:1800