1 #ifndef _PEXSI_TREE_HPP_
2 #define _PEXSI_TREE_HPP_
16 #define FTREE_LIMIT 16
24 extern std::map< MPI_Comm , std::vector<int> > commGlobRanks;
33 MPI_Request recvRequest_;
34 NumVec<char> myRecvBuffer_;
36 NumVec<MPI_Request> myRequests_;
37 NumVec<MPI_Status> myStatuses_;
53 #ifdef COMM_PROFILE_BCAST
59 void SetGlobalComm(
const MPI_Comm & pGComm){
60 if(commGlobRanks.count(comm_)==0){
61 MPI_Group group2 = MPI_GROUP_NULL;
62 MPI_Comm_group(pGComm, &group2);
63 MPI_Group group1 = MPI_GROUP_NULL;
64 MPI_Comm_group(comm_, &group1);
67 MPI_Comm_size(comm_,&size);
68 vector<int> globRanks(size);
69 vector<int> Lranks(size);
70 for(
int i = 0; i<size;++i){Lranks[i]=i;}
71 MPI_Group_translate_ranks(group1, size, &Lranks[0],group2, &globRanks[0]);
72 commGlobRanks[comm_] = globRanks;
74 myGRoot_ = commGlobRanks[comm_][myRoot_];
75 myGRank_ = commGlobRanks[comm_][myRank_];
89 virtual void buildTree(Int * ranks, Int rank_cnt)=0;
98 comm_ = MPI_COMM_WORLD;
107 recvRequest_ = MPI_REQUEST_NULL;
114 TreeBcast2(
const MPI_Comm & pComm, Int * ranks, Int rank_cnt,Int msgSize):TreeBcast2(){
116 MPI_Comm_rank(comm_,&myRank_);
126 virtual TreeBcast2 * clone()
const = 0;
128 TreeBcast2(
const TreeBcast2 & Tree){
132 virtual void Copy(
const TreeBcast2 & Tree){
133 this->comm_ = Tree.comm_;
134 this->myRank_ = Tree.myRank_;
135 this->myRoot_ = Tree.myRoot_;
136 this->msgSize_ = Tree.msgSize_;
138 this->numRecv_ = Tree.numRecv_;
139 this->tag_= Tree.tag_;
140 this->mainRoot_= Tree.mainRoot_;
141 this->isReady_ = Tree.isReady_;
142 this->myDests_ = Tree.myDests_;
145 this->recvRequest_ = Tree.recvRequest_;
146 this->myRecvBuffer_ = Tree.myRecvBuffer_;
147 this->myRequests_ = Tree.myRequests_;
148 this->myStatuses_ = Tree.myStatuses_;
149 this->myData_ = Tree.myData_;
150 if(Tree.myData_==(T*)&Tree.myRecvBuffer_[0]){
151 this->myData_=(T*)&this->myRecvBuffer_[0];
157 this->fwded_= Tree.fwded_;
159 this->done_= Tree.done_;
167 recvRequest_ = MPI_REQUEST_NULL;
176 virtual ~TreeBcast2(){
181 static TreeBcast2<T> * Create(
const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize,
double rseed);
183 virtual inline Int GetNumRecvMsg(){
return numRecv_;}
184 virtual inline Int GetNumMsgToSend(){
return GetDestCount();}
185 inline void SetDataReady(
bool rdy){
189 inline void SetTag(Int tag){ tag_ = tag;}
190 inline int GetTag(){
return tag_;}
192 bool IsDone(){
return done_;}
193 bool IsDataReady(){
return isReady_;}
194 bool IsDataReceived(){
return numRecv_==1;}
196 Int * GetDests(){
return &myDests_[0];}
197 Int GetDest(Int i){
return myDests_[i];}
198 Int GetDestCount(){
return myDests_.size();}
199 Int GetRoot(){
return myRoot_;}
201 bool IsRoot(){
return myRoot_==myRank_;}
202 Int GetMsgSize(){
return msgSize_;}
204 void ForwardMessage( ){
205 if(myRequests_.m()!=GetDestCount()){
206 myRequests_.Resize(GetDestCount());
207 SetValue(myRequests_,MPI_REQUEST_NULL);
209 for( Int idxRecv = 0; idxRecv < myDests_.size(); ++idxRecv ){
210 Int iProc = myDests_[idxRecv];
212 MPI_Isend( myData_, msgSize_, MPI_BYTE,
213 iProc, tag_,comm_, &myRequests_[idxRecv] );
215 #if ( _DEBUGlevel_ >= 1 ) || defined(BCAST_VERBOSE)
216 statusOFS<<myRank_<<
" FWD to "<<iProc<<
" on tag "<<tag_<<std::endl;
218 #ifdef COMM_PROFILE_BCAST
223 PROFILE_COMM(myGRank_,commGlobRanks[comm_][iProc],tag_,msgSize_);
229 void CleanupBuffers(){
232 myRecvBuffer_.Clear();
236 void SetLocalBuffer(T * locBuffer){
237 if(myData_!=NULL && myData_!=locBuffer){
239 CopyLocalBuffer(locBuffer);
242 myRecvBuffer_.Clear();
250 virtual bool Progress(){
257 if (myRank_==myRoot_){
260 #if ( _DEBUGlevel_ >= 1 ) || defined(BCAST_VERBOSE)
261 statusOFS<<myRank_<<
" FORWARDING on tag "<<tag_<<std::endl;
267 if(myStatuses_.m()!=GetDestCount()){
268 myStatuses_.Resize(GetDestCount());
269 recvRequest_ = MPI_REQUEST_NULL;
273 int reqCnt = GetDestCount();
275 assert(reqCnt == myRequests_.m());
276 MPI_Testall(reqCnt,&myRequests_[0],&flag,&myStatuses_[0]);
286 bool received = (numRecv_==1);
289 if(recvRequest_ == MPI_REQUEST_NULL ){
290 #if ( _DEBUGlevel_ >= 1 ) || defined(BCAST_VERBOSE)
291 statusOFS<<myRank_<<
" POSTING RECV on tag "<<tag_<<std::endl;
298 if(myStatuses_.m()!=GetDestCount()){
299 myStatuses_.Resize(GetDestCount());
300 recvRequest_ = MPI_REQUEST_NULL;
302 #if ( _DEBUGlevel_ >= 1 ) || defined(BCAST_VERBOSE)
303 statusOFS<<myRank_<<
" TESTING RECV on tag "<<tag_<<std::endl;
308 int test = MPI_Test(&recvRequest_,&flag,&stat);
309 assert(test==MPI_SUCCESS);
316 #if ( _DEBUGlevel_ >= 1 ) || defined(BCAST_VERBOSE)
317 statusOFS<<myRank_<<
" FORWARDING on tag "<<tag_<<std::endl;
328 int reqCnt = GetDestCount();
330 assert(reqCnt == myRequests_.m());
331 MPI_Testall(reqCnt,&myRequests_[0],&flag,&myStatuses_[0]);
344 #if ( _DEBUGlevel_ >= 1 ) || defined(BCAST_VERBOSE)
345 statusOFS<<myRank_<<
" EVERYTHING COMPLETED on tag "<<tag_<<std::endl;
361 T * GetLocalBuffer(){
365 virtual void PostRecv()
367 if(this->numRecv_<1 && this->recvRequest_==MPI_REQUEST_NULL && myRank_!=myRoot_){
370 myRecvBuffer_.Resize(msgSize_);
371 myData_ = (T*)&myRecvBuffer_[0];
373 MPI_Irecv( (
char*)this->myData_, this->msgSize_, MPI_BYTE,
374 this->myRoot_, this->tag_,this->comm_, &this->recvRequest_ );
380 void CopyLocalBuffer(T* destBuffer){
381 std::copy((
char*)myData_,(
char*)myData_+GetMsgSize(),(
char*)destBuffer);
390 template<
typename T>
391 class FTreeBcast2:
public TreeBcast2<T>{
393 virtual void buildTree(Int * ranks, Int rank_cnt){
396 Int idxEnd = rank_cnt;
400 this->myRoot_ = ranks[0];
402 if(this->myRank_==this->myRoot_){
403 this->myDests_.insert(this->myDests_.begin(),&ranks[1],&ranks[0]+rank_cnt);
406 #if (defined(BCAST_VERBOSE))
407 statusOFS<<
"My root is "<<this->myRoot_<<std::endl;
408 statusOFS<<
"My dests are ";
409 for(
int i =0;i<this->myDests_.size();++i){statusOFS<<this->myDests_[i]<<
" ";}
410 statusOFS<<std::endl;
417 FTreeBcast2(
const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize):TreeBcast2<T>(pComm,ranks,rank_cnt,msgSize){
419 buildTree(ranks,rank_cnt);
423 virtual FTreeBcast2 * clone()
const{
424 FTreeBcast2 * out =
new FTreeBcast2(*
this);
429 template<
typename T>
430 class BTreeBcast2:
public TreeBcast2<T>{
455 virtual void buildTree(Int * ranks, Int rank_cnt){
458 Int idxEnd = rank_cnt;
462 Int prevRoot = ranks[0];
463 while(idxStart<idxEnd){
464 Int curRoot = ranks[idxStart];
465 Int listSize = idxEnd - idxStart;
468 if(curRoot == this->myRank_){
469 this->myRoot_ = prevRoot;
474 Int halfList = floor(ceil(
double(listSize) / 2.0));
475 Int idxStartL = idxStart+1;
476 Int idxStartH = idxStart+halfList;
478 if(curRoot == this->myRank_){
479 if ((idxEnd - idxStartH) > 0 && (idxStartH - idxStartL)>0){
480 Int childL = ranks[idxStartL];
481 Int childR = ranks[idxStartH];
483 this->myDests_.push_back(childL);
484 this->myDests_.push_back(childR);
486 else if ((idxEnd - idxStartH) > 0){
487 Int childR = ranks[idxStartH];
488 this->myDests_.push_back(childR);
491 Int childL = ranks[idxStartL];
492 this->myDests_.push_back(childL);
494 this->myRoot_ = prevRoot;
498 if( this->myRank_ < ranks[idxStartH]){
499 idxStart = idxStartL;
503 idxStart = idxStartH;
510 #if (defined(BCAST_VERBOSE))
511 statusOFS<<
"My root is "<<myRoot_<<std::endl;
512 statusOFS<<
"My dests are ";
513 for(
int i =0;i<this->myDests_.size();++i){statusOFS<<this->myDests_[i]<<
" ";}
514 statusOFS<<std::endl;
521 BTreeBcast2(
const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize):TreeBcast2<T>(pComm,ranks,rank_cnt,msgSize){
523 buildTree(ranks,rank_cnt);
526 virtual BTreeBcast2<T> * clone()
const{
527 BTreeBcast2<T> * out =
new BTreeBcast2<T>(*this);
535 template<
typename T>
536 class ModBTreeBcast2:
public TreeBcast2<T>{
540 virtual void buildTree(Int * ranks, Int rank_cnt){
543 Int idxEnd = rank_cnt;
550 Int new_idx = (Int)rseed_ % (rank_cnt - 1) + 1;
556 Int * new_start = &ranks[new_idx];
563 std::rotate(&ranks[1], new_start, &ranks[0]+rank_cnt);
568 Int prevRoot = ranks[0];
569 while(idxStart<idxEnd){
570 Int curRoot = ranks[idxStart];
571 Int listSize = idxEnd - idxStart;
574 if(curRoot == this->myRank_){
575 this->myRoot_ = prevRoot;
580 Int halfList = floor(ceil(
double(listSize) / 2.0));
581 Int idxStartL = idxStart+1;
582 Int idxStartH = idxStart+halfList;
584 if(curRoot == this->myRank_){
585 if ((idxEnd - idxStartH) > 0 && (idxStartH - idxStartL)>0){
586 Int childL = ranks[idxStartL];
587 Int childR = ranks[idxStartH];
589 this->myDests_.push_back(childL);
590 this->myDests_.push_back(childR);
592 else if ((idxEnd - idxStartH) > 0){
593 Int childR = ranks[idxStartH];
594 this->myDests_.push_back(childR);
597 Int childL = ranks[idxStartL];
598 this->myDests_.push_back(childL);
600 this->myRoot_ = prevRoot;
606 TIMER_START(FIND_RANK);
607 Int * pos = std::find(&ranks[idxStartL], &ranks[idxStartH], this->myRank_);
608 TIMER_STOP(FIND_RANK);
609 if( pos != &ranks[idxStartH]){
610 idxStart = idxStartL;
614 idxStart = idxStartH;
621 #if (defined(REDUCE_VERBOSE))
622 statusOFS<<
"My root is "<<myRoot_<<std::endl;
623 statusOFS<<
"My dests are ";
624 for(
int i =0;i<this->myDests_.size();++i){statusOFS<<this->myDests_[i]<<
" ";}
625 statusOFS<<std::endl;
632 ModBTreeBcast2(
const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize,
double rseed):TreeBcast2<T>(pComm,ranks,rank_cnt,msgSize){
635 buildTree(ranks,rank_cnt);
662 virtual ModBTreeBcast2 * clone()
const{
663 ModBTreeBcast2 * out =
new ModBTreeBcast2(*
this);
709 vector<Int> myDests_;
718 #if defined(COMM_PROFILE_BCAST) || defined(COMM_PROFILE)
724 void SetGlobalComm(
const MPI_Comm & pGComm){
725 if(commGlobRanks.count(comm_)==0){
726 MPI_Group group2 = MPI_GROUP_NULL;
727 MPI_Comm_group(pGComm, &group2);
728 MPI_Group group1 = MPI_GROUP_NULL;
729 MPI_Comm_group(comm_, &group1);
732 MPI_Comm_size(comm_,&size);
733 vector<int> globRanks(size);
734 vector<int> Lranks(size);
735 for(
int i = 0; i<size;++i){Lranks[i]=i;}
736 MPI_Group_translate_ranks(group1, size, &Lranks[0],group2, &globRanks[0]);
737 commGlobRanks[comm_] = globRanks;
739 myGRoot_ = commGlobRanks[comm_][myRoot_];
740 myGRank_ = commGlobRanks[comm_][myRank_];
752 virtual void buildTree(Int * ranks, Int rank_cnt)=0;
755 comm_ = MPI_COMM_WORLD;
766 TreeBcast(
const MPI_Comm & pComm, Int * ranks, Int rank_cnt,Int msgSize){
768 MPI_Comm_rank(comm_,&myRank_);
782 virtual void Copy(
const TreeBcast & Tree){
784 myRank_ = Tree.myRank_;
785 myRoot_ = Tree.myRoot_;
786 msgSize_ = Tree.msgSize_;
789 mainRoot_= Tree.mainRoot_;
790 myDests_ = Tree.myDests_;
803 this->isReady_ =
false;
808 static TreeBcast * Create(
const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize,
double rseed);
810 virtual inline Int GetNumRecvMsg(){
return numRecv_;}
811 virtual inline Int GetNumMsgToRecv(){
return 1;}
812 inline void SetDataReady(
bool rdy){ isReady_=rdy; }
813 inline void SetTag(Int tag){ tag_ = tag;}
814 inline int GetTag(){
return tag_;}
817 Int * GetDests(){
return &myDests_[0];}
818 Int GetDest(Int i){
return myDests_[i];}
819 Int GetDestCount(){
return myDests_.size();}
820 Int GetRoot(){
return myRoot_;}
821 Int GetMsgSize(){
return msgSize_;}
823 void ForwardMessage(
char * data,
size_t size,
int tag, MPI_Request * requests ){
825 for( Int idxRecv = 0; idxRecv < myDests_.size(); ++idxRecv ){
826 Int iProc = myDests_[idxRecv];
828 MPI_Isend( data, size, MPI_BYTE,
829 iProc, tag,comm_, &requests[2*iProc+1] );
831 #if defined(COMM_PROFILE_BCAST) || defined(COMM_PROFILE)
835 PROFILE_COMM(myGRank_,commGlobRanks[comm_][iProc],tag_,msgSize_);
845 virtual void buildTree(Int * ranks, Int rank_cnt){
848 Int idxEnd = rank_cnt;
854 if(myRank_==myRoot_){
855 myDests_.insert(myDests_.begin(),&ranks[1],&ranks[0]+rank_cnt);
858 #if (defined(BCAST_VERBOSE))
859 statusOFS<<
"My root is "<<myRoot_<<std::endl;
860 statusOFS<<
"My dests are ";
861 for(
int i =0;i<myDests_.size();++i){statusOFS<<myDests_[i]<<
" ";}
862 statusOFS<<std::endl;
869 FTreeBcast(
const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize):
TreeBcast(pComm,ranks,rank_cnt,msgSize){
871 buildTree(ranks,rank_cnt);
908 virtual void buildTree(Int * ranks, Int rank_cnt){
911 Int idxEnd = rank_cnt;
915 Int prevRoot = ranks[0];
916 while(idxStart<idxEnd){
917 Int curRoot = ranks[idxStart];
918 Int listSize = idxEnd - idxStart;
921 if(curRoot == myRank_){
927 Int halfList = floor(ceil(
double(listSize) / 2.0));
928 Int idxStartL = idxStart+1;
929 Int idxStartH = idxStart+halfList;
931 if(curRoot == myRank_){
932 if ((idxEnd - idxStartH) > 0 && (idxStartH - idxStartL)>0){
933 Int childL = ranks[idxStartL];
934 Int childR = ranks[idxStartH];
936 myDests_.push_back(childL);
937 myDests_.push_back(childR);
939 else if ((idxEnd - idxStartH) > 0){
940 Int childR = ranks[idxStartH];
941 myDests_.push_back(childR);
944 Int childL = ranks[idxStartL];
945 myDests_.push_back(childL);
951 if( myRank_ < ranks[idxStartH]){
952 idxStart = idxStartL;
956 idxStart = idxStartH;
963 #if (defined(BCAST_VERBOSE))
964 statusOFS<<
"My root is "<<myRoot_<<std::endl;
965 statusOFS<<
"My dests are ";
966 for(
int i =0;i<myDests_.size();++i){statusOFS<<myDests_[i]<<
" ";}
967 statusOFS<<std::endl;
974 BTreeBcast(
const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize):
TreeBcast(pComm,ranks,rank_cnt,msgSize){
976 buildTree(ranks,rank_cnt);
992 virtual void buildTree(Int * ranks, Int rank_cnt){
995 Int idxEnd = rank_cnt;
1004 Int new_idx = (int)(this->rseed_)%(rank_cnt-1)+1;
1007 assert(new_idx<rank_cnt && new_idx>=1);
1009 Int * new_start = &ranks[new_idx];
1016 std::rotate(&ranks[1], new_start, &ranks[0]+rank_cnt);
1021 Int prevRoot = ranks[0];
1022 while(idxStart<idxEnd){
1023 assert(idxStart<rank_cnt && idxStart>=0);
1024 Int curRoot = ranks[idxStart];
1025 Int listSize = idxEnd - idxStart;
1028 if(curRoot == myRank_){
1034 Int halfList = floor(ceil(
double(listSize) / 2.0));
1035 Int idxStartL = idxStart+1;
1036 Int idxStartH = idxStart+halfList;
1037 assert(idxStartL<rank_cnt && idxStartL>=0);
1038 assert(idxStartH<rank_cnt && idxStartH>=0);
1040 if(curRoot == myRank_){
1041 if ((idxEnd - idxStartH) > 0 && (idxStartH - idxStartL)>0){
1042 Int childL = ranks[idxStartL];
1043 Int childR = ranks[idxStartH];
1045 myDests_.push_back(childL);
1046 myDests_.push_back(childR);
1048 else if ((idxEnd - idxStartH) > 0){
1049 Int childR = ranks[idxStartH];
1050 myDests_.push_back(childR);
1053 Int childL = ranks[idxStartL];
1054 myDests_.push_back(childL);
1062 TIMER_START(FIND_RANK);
1063 Int * pos = std::find(&ranks[idxStartL], &ranks[idxStartH], myRank_);
1064 TIMER_STOP(FIND_RANK);
1065 if( pos != &ranks[idxStartH]){
1066 idxStart = idxStartL;
1070 idxStart = idxStartH;
1077 #if (defined(REDUCE_VERBOSE))
1078 statusOFS<<
"My root is "<<myRoot_<<std::endl;
1079 statusOFS<<
"My dests are ";
1080 for(
int i =0;i<myDests_.size();++i){statusOFS<<myDests_[i]<<
" ";}
1081 statusOFS<<std::endl;
1088 ModBTreeBcast(
const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize,
double rseed):
TreeBcast(pComm,ranks,rank_cnt,msgSize){
1091 buildTree(ranks,rank_cnt);
1118 rseed_ = Tree.rseed_;
1132 virtual void buildTree(Int * ranks, Int rank_cnt){
1135 Int idxEnd = rank_cnt;
1139 for(
int i =0;i<rank_cnt;++i){statusOFS<<ranks[i]<<
" ";} statusOFS<<std::endl;
1141 std::random_shuffle(&ranks[1],&ranks[0]+rank_cnt);
1142 for(
int i =0;i<rank_cnt;++i){statusOFS<<ranks[i]<<
" ";} statusOFS<<std::endl;
1146 Int prevRoot = ranks[0];
1147 while(idxStart<idxEnd){
1148 Int curRoot = ranks[idxStart];
1149 Int listSize = idxEnd - idxStart;
1152 if(curRoot == myRank_){
1158 Int halfList = floor(ceil(
double(listSize) / 2.0));
1159 Int idxStartL = idxStart+1;
1160 Int idxStartH = idxStart+halfList;
1162 if(curRoot == myRank_){
1163 if ((idxEnd - idxStartH) > 0 && (idxStartH - idxStartL)>0){
1164 Int childL = ranks[idxStartL];
1165 Int childR = ranks[idxStartH];
1167 myDests_.push_back(childL);
1168 myDests_.push_back(childR);
1170 else if ((idxEnd - idxStartH) > 0){
1171 Int childR = ranks[idxStartH];
1172 myDests_.push_back(childR);
1175 Int childL = ranks[idxStartL];
1176 myDests_.push_back(childL);
1184 Int * pos = std::find(&ranks[idxStartL], &ranks[idxStartH], myRank_);
1185 if( pos != &ranks[idxStartH]){
1186 idxStart = idxStartL;
1190 idxStart = idxStartH;
1197 #if (defined(REDUCE_VERBOSE))
1198 statusOFS<<
"My root is "<<myRoot_<<std::endl;
1199 statusOFS<<
"My dests are ";
1200 for(
int i =0;i<myDests_.size();++i){statusOFS<<myDests_[i]<<
" ";}
1201 statusOFS<<std::endl;
1208 RandBTreeBcast(
const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize):
TreeBcast(pComm,ranks,rank_cnt,msgSize){
1210 buildTree(ranks,rank_cnt);
1225 virtual void buildTree(Int * ranks, Int rank_cnt){
1226 Int numLevel = floor(log2(rank_cnt));
1228 for(Int level=0;level<numLevel;++level){
1229 numRoots = std::min( rank_cnt, numRoots + (Int)pow(2.0,level));
1230 Int numNextRoots = std::min(rank_cnt,numRoots + (Int)pow(2.0,(level+1)));
1231 Int numReceivers = numNextRoots - numRoots;
1232 for(Int ip = 0; ip<numRoots;++ip){
1234 for(Int ir = ip; ir<numReceivers;ir+=numRoots){
1235 Int r = ranks[numRoots+ir];
1241 myDests_.push_back(r);
1247 #if (defined(BCAST_VERBOSE))
1248 statusOFS<<
"My root is "<<myRoot_<<std::endl;
1249 statusOFS<<
"My dests are ";
1250 for(
int i =0;i<myDests_.size();++i){statusOFS<<myDests_[i]<<
" ";}
1251 statusOFS<<std::endl;
1258 PalmTreeBcast(
const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize):
TreeBcast(pComm,ranks,rank_cnt,msgSize){
1260 buildTree(ranks,rank_cnt);
1271 template<
typename T>
1275 MPI_Request sendRequest_;
1289 TreeReduce(
const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize):
TreeBcast(pComm,ranks,rank_cnt,msgSize){
1291 sendRequest_ = MPI_REQUEST_NULL;
1319 this->myData_ = NULL;
1320 this->sendRequest_ = MPI_REQUEST_NULL;
1321 this->fwded_=
false;
1323 this->isAllocated_= Tree.isAllocated_;
1324 this->numRecvPosted_= 0;
1335 bool IsAllocated(){
return isAllocated_;}
1342 static TreeReduce<T> * Create(
const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize,
double rseed);
1344 virtual inline Int GetNumMsgToRecv(){
return GetDestCount();}
1346 virtual void AllocRecvBuffers(){
1347 remoteData_.Resize(GetDestCount());
1354 myRecvBuffers_.Resize(GetDestCount()*msgSize_);
1357 for( Int idxRecv = 0; idxRecv < GetDestCount(); ++idxRecv ){
1358 remoteData_[idxRecv] = (T*)&(myRecvBuffers_[idxRecv*msgSize_]);
1363 myRequests_.Resize(GetDestCount());
1364 SetValue(myRequests_,MPI_REQUEST_NULL);
1365 myStatuses_.Resize(GetDestCount());
1366 recvIdx_.Resize(GetDestCount());
1368 sendRequest_ = MPI_REQUEST_NULL;
1370 isAllocated_ =
true;
1373 void CleanupBuffers(){
1374 myLocalBuffer_.Clear();
1381 remoteData_.Clear();
1389 myRequests_.Clear();
1390 myStatuses_.Clear();
1408 sendRequest_ = MPI_REQUEST_NULL;
1418 void SetLocalBuffer(T * locBuffer){
1419 if(myData_!=NULL && myData_!=locBuffer){
1423 blas::Axpy(msgSize_/
sizeof(T), ONE<T>(), myData_, 1, locBuffer, 1 );
1424 myLocalBuffer_.Clear();
1427 myData_ = locBuffer;
1430 inline bool AccumulationDone(){
1431 if(myRank_==myRoot_ && isAllocated_){
1434 return isReady_ && (numRecv_ == GetDestCount());
1438 inline bool IsDone(){
1439 if(myRank_==myRoot_ && isAllocated_){
1443 bool retVal = AccumulationDone();
1444 if(myRoot_ != myRank_ && !fwded_){
1448 if (retVal && myRoot_ != myRank_ && fwded_){
1451 MPI_Test(&sendRequest_,&flag,MPI_STATUS_IGNORE);
1459 virtual bool Progress(){
1468 if(myRank_==myRoot_ && isAllocated_){
1476 bool retVal = AccumulationDone();
1477 if(isReady_ && !retVal){
1483 int reqCnt = GetDestCount();
1484 assert(reqCnt == myRequests_.m());
1485 MPI_Testsome(reqCnt,&myRequests_[0],&recvCount,&recvIdx_[0],&myStatuses_[0]);
1487 for(Int i = 0;i<recvCount;++i ){
1488 Int idx = recvIdx_[i];
1490 if(idx!=MPI_UNDEFINED){
1493 MPI_Get_count(&myStatuses_[i], MPI_BYTE, &size);
1496 #if ( _DEBUGlevel_ >= 1 ) || defined(REDUCE_VERBOSE)
1497 statusOFS<<myRank_<<
" RECVD from "<<myStatuses_[i].MPI_SOURCE<<
" on tag "<<tag_<<std::endl;
1503 myLocalBuffer_.Resize(msgSize_);
1505 myData_ = (T*)&myLocalBuffer_[0];
1506 Int nelem = +msgSize_/
sizeof(T);
1507 std::fill(myData_,myData_+nelem,ZERO<T>());
1520 else if (isReady_ && sendRequest_ == MPI_REQUEST_NULL && myRoot_ != myRank_ && !fwded_){
1522 myRecvBuffers_.Clear();
1523 myRequests_.Clear();
1524 myStatuses_.Clear();
1537 myRecvBuffers_.Clear();
1538 myRequests_.Clear();
1539 myStatuses_.Clear();
1556 T * GetLocalBuffer(){
1562 void CopyLocalBuffer(T* destBuffer){
1563 std::copy((
char*)myData_,(
char*)myData_+GetMsgSize(),(
char*)destBuffer);
1567 virtual void PostFirstRecv()
1569 if(this->GetDestCount()>this->numRecvPosted_){
1570 for( Int idxRecv = 0; idxRecv < myDests_.size(); ++idxRecv ){
1571 Int iProc = myDests_[idxRecv];
1573 MPI_Irecv( (
char*)remoteData_[idxRecv], msgSize_, MPI_BYTE,
1574 iProc, tag_,comm_, &myRequests_[idxRecv] );
1575 this->numRecvPosted_++;
1584 virtual void Reduce( Int idxRecv, Int idReq){
1586 blas::Axpy(msgSize_/
sizeof(T), ONE<T>(), remoteData_[idxRecv], 1, myData_, 1 );
1591 Int iProc = myRoot_;
1594 MPI_Isend( NULL, 0, MPI_BYTE,
1595 iProc, tag_,comm_, &sendRequest_ );
1597 PROFILE_COMM(myGRank_,myGRoot_,tag_,0);
1601 MPI_Isend( (
char*)myData_, msgSize_, MPI_BYTE,
1602 iProc, tag_,comm_, &sendRequest_ );
1604 PROFILE_COMM(myGRank_,myGRoot_,tag_,msgSize_);
1608 #if ( _DEBUGlevel_ >= 1 ) || defined(REDUCE_VERBOSE)
1609 statusOFS<<myRank_<<
" FWD to "<<iProc<<
" on tag "<<tag_<<std::endl;
1619 template<
typename T>
1622 virtual void buildTree(Int * ranks, Int rank_cnt){
1625 Int idxEnd = rank_cnt;
1629 this->myRoot_ = ranks[0];
1631 if(this->myRank_==this->myRoot_){
1632 this->myDests_.insert(this->myDests_.begin(),&ranks[1],&ranks[0]+rank_cnt);
1635 #if (defined(REDUCE_VERBOSE))
1636 statusOFS<<
"My root is "<<this->myRoot_<<std::endl;
1637 statusOFS<<
"My dests are ";
1638 for(
int i =0;i<this->myDests_.size();++i){statusOFS<<this->myDests_[i]<<
" ";}
1639 statusOFS<<std::endl;
1643 virtual void Reduce( ){
1645 blas::Axpy(this->msgSize_/
sizeof(T), ONE<T>(), this->remoteData_[0], 1, this->myData_, 1 );
1648 #if (defined(REDUCE_DEBUG))
1649 statusOFS << std::endl <<
" Recv contrib"<< std::endl;
1650 for(
int i = 0; i < this->msgSize_/
sizeof(T); ++i){
1651 statusOFS<< this->remoteData_[0][i]<<
" ";
1652 if(i%3==0){statusOFS<<std::endl;}
1654 statusOFS<<std::endl;
1656 statusOFS << std::endl <<
" Reduce buffer now is"<< std::endl;
1657 for(
int i = 0; i < this->msgSize_/
sizeof(T); ++i){
1658 statusOFS<< this->myData_[i]<<
" ";
1659 if(i%3==0){statusOFS<<std::endl;}
1661 statusOFS<<std::endl;
1669 FTreeReduce(
const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize):
TreeReduce<T>(pComm, ranks, rank_cnt, msgSize){
1670 buildTree(ranks,rank_cnt);
1673 virtual void PostFirstRecv()
1678 if(this->isAllocated_ && this->GetDestCount()>this->numRecvPosted_){
1679 MPI_Irecv( (
char*)this->remoteData_[0], this->msgSize_, MPI_BYTE,
1680 MPI_ANY_SOURCE, this->tag_,this->comm_, &this->myRequests_[0] );
1681 this->numRecvPosted_++;
1685 virtual void AllocRecvBuffers(){
1686 if(this->GetDestCount()>0){
1687 this->remoteData_.Resize(1);
1689 this->myRecvBuffers_.Resize(this->msgSize_);
1691 this->remoteData_[0] = (T*)&(this->myRecvBuffers_[0]);
1693 this->myRequests_.Resize(1);
1694 SetValue(this->myRequests_,MPI_REQUEST_NULL);
1695 this->myStatuses_.Resize(1);
1696 this->recvIdx_.Resize(1);
1699 this->sendRequest_ = MPI_REQUEST_NULL;
1701 this->isAllocated_ =
true;
1704 virtual bool Progress(){
1706 if(!this->isAllocated_){
1711 if(this->myRank_==this->myRoot_ && this->isAllocated_){
1712 this->isReady_=
true;
1719 bool retVal = this->AccumulationDone();
1720 if(this->isReady_ && !retVal){
1728 MPI_Testsome(reqCnt,&this->myRequests_[0],&recvCount,&this->recvIdx_[0],&this->myStatuses_[0]);
1731 for(Int i = 0;i<recvCount;++i ){
1732 Int idx = this->recvIdx_[i];
1734 if(idx!=MPI_UNDEFINED){
1737 MPI_Get_count(&this->myStatuses_[i], MPI_BYTE, &size);
1740 #if ( _DEBUGlevel_ >= 1 ) || defined(REDUCE_VERBOSE)
1742 statusOFS<<this->myRank_<<
" RECVD from "<<this->myStatuses_[i].MPI_SOURCE<<
" on tag "<<this->tag_<<std::endl;
1746 if(this->myData_==NULL){
1748 this->myLocalBuffer_.Resize(this->msgSize_);
1750 this->myData_ = (T*)&this->myLocalBuffer_[0];
1751 Int nelem = this->msgSize_/
sizeof(T);
1752 std::fill(this->myData_,this->myData_+nelem,ZERO<T>());
1763 this->PostFirstRecv();
1766 else if (this->isReady_ && this->sendRequest_ == MPI_REQUEST_NULL && this->myRoot_ != this->myRank_ && !this->fwded_){
1768 this->myRecvBuffers_.Clear();
1769 this->myRequests_.Clear();
1770 this->myStatuses_.Clear();
1771 this->recvIdx_.Clear();
1778 retVal = this->IsDone();
1781 this->myRecvBuffers_.Clear();
1782 this->myRequests_.Clear();
1783 this->myStatuses_.Clear();
1784 this->recvIdx_.Clear();
1803 template<
typename T>
1806 virtual void buildTree(Int * ranks, Int rank_cnt){
1808 Int idxEnd = rank_cnt;
1812 Int prevRoot = ranks[0];
1813 while(idxStart<idxEnd){
1814 Int curRoot = ranks[idxStart];
1815 Int listSize = idxEnd - idxStart;
1818 if(curRoot == this->myRank_){
1819 this->myRoot_ = prevRoot;
1824 Int halfList = floor(ceil(
double(listSize) / 2.0));
1825 Int idxStartL = idxStart+1;
1826 Int idxStartH = idxStart+halfList;
1828 if(curRoot == this->myRank_){
1829 if ((idxEnd - idxStartH) > 0 && (idxStartH - idxStartL)>0){
1830 Int childL = ranks[idxStartL];
1831 Int childR = ranks[idxStartH];
1833 this->myDests_.push_back(childL);
1834 this->myDests_.push_back(childR);
1836 else if ((idxEnd - idxStartH) > 0){
1837 Int childR = ranks[idxStartH];
1838 this->myDests_.push_back(childR);
1841 Int childL = ranks[idxStartL];
1842 this->myDests_.push_back(childL);
1844 this->myRoot_ = prevRoot;
1848 if( this->myRank_ < ranks[idxStartH]){
1849 idxStart = idxStartL;
1853 idxStart = idxStartH;
1860 #if (defined(REDUCE_VERBOSE))
1861 statusOFS<<
"My root is "<<this->myRoot_<<std::endl;
1862 statusOFS<<
"My dests are ";
1863 for(
int i =0;i<this->myDests_.size();++i){statusOFS<<this->myDests_[i]<<
" ";}
1864 statusOFS<<std::endl;
1868 BTreeReduce(
const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize):
TreeReduce<T>(pComm, ranks, rank_cnt, msgSize){
1869 buildTree(ranks,rank_cnt);
1879 template<
typename T>
1883 virtual void buildTree(Int * ranks, Int rank_cnt){
1886 Int idxEnd = rank_cnt;
1898 Int new_idx = (int)(this->rseed_)%(rank_cnt-1)+1;
1900 Int * new_start = &ranks[new_idx];
1906 std::rotate(&ranks[1], new_start, &ranks[0]+rank_cnt);
1910 Int prevRoot = ranks[0];
1911 while(idxStart<idxEnd){
1912 Int curRoot = ranks[idxStart];
1913 Int listSize = idxEnd - idxStart;
1916 if(curRoot == this->myRank_){
1917 this->myRoot_ = prevRoot;
1922 Int halfList = floor(ceil(
double(listSize) / 2.0));
1923 Int idxStartL = idxStart+1;
1924 Int idxStartH = idxStart+halfList;
1926 if(curRoot == this->myRank_){
1927 if ((idxEnd - idxStartH) > 0 && (idxStartH - idxStartL)>0){
1928 Int childL = ranks[idxStartL];
1929 Int childR = ranks[idxStartH];
1931 this->myDests_.push_back(childL);
1932 this->myDests_.push_back(childR);
1934 else if ((idxEnd - idxStartH) > 0){
1935 Int childR = ranks[idxStartH];
1936 this->myDests_.push_back(childR);
1939 Int childL = ranks[idxStartL];
1940 this->myDests_.push_back(childL);
1942 this->myRoot_ = prevRoot;
1948 TIMER_START(FIND_RANK);
1949 Int * pos = std::find(&ranks[idxStartL], &ranks[idxStartH], this->myRank_);
1950 TIMER_STOP(FIND_RANK);
1951 if( pos != &ranks[idxStartH]){
1952 idxStart = idxStartL;
1956 idxStart = idxStartH;
1963 #if (defined(REDUCE_VERBOSE))
1964 statusOFS<<
"My root is "<<this->myRoot_<<std::endl;
1965 statusOFS<<
"My dests are ";
1966 for(
int i =0;i<this->myDests_.size();++i){statusOFS<<this->myDests_[i]<<
" ";}
1967 statusOFS<<std::endl;
1971 ModBTreeReduce(
const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize,
double rseed):
TreeReduce<T>(pComm, ranks, rank_cnt, msgSize){
1972 this->rseed_ = rseed;
1973 buildTree(ranks,rank_cnt);
2002 this->rseed_ = Tree.rseed_;
2016 template<
typename T>
2020 virtual void buildTree(Int * ranks, Int rank_cnt){
2021 Int numLevel = floor(log2(rank_cnt));
2023 for(Int level=0;level<numLevel;++level){
2024 numRoots = std::min( rank_cnt, numRoots + (Int)pow(2,level));
2025 Int numNextRoots = std::min(rank_cnt,numRoots + (Int)pow(2,(level+1)));
2026 Int numReceivers = numNextRoots - numRoots;
2027 for(Int ip = 0; ip<numRoots;++ip){
2029 for(Int ir = ip; ir<numReceivers;ir+=numRoots){
2030 Int r = ranks[numRoots+ir];
2031 if(r==this->myRank_){
2035 if(p==this->myRank_){
2036 this->myDests_.push_back(r);
2042 #if (defined(BCAST_VERBOSE))
2043 statusOFS<<
"My root is "<<this->myRoot_<<std::endl;
2044 statusOFS<<
"My dests are ";
2045 for(
int i =0;i<this->myDests_.size();++i){statusOFS<<this->myDests_[i]<<
" ";}
2046 statusOFS<<std::endl;
2055 buildTree(ranks,rank_cnt);
2101 inline TreeBcast * TreeBcast::Create(
const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize,
double rseed){
2104 MPI_Comm_size(pComm, &nprocs);
2108 return new FTreeBcast(pComm,ranks,rank_cnt,msgSize);
2109 #elif defined(MODBTREE)
2110 return new ModBTreeBcast(pComm,ranks,rank_cnt,msgSize, rseed);
2111 #elif defined(BTREE)
2112 return new BTreeBcast(pComm,ranks,rank_cnt,msgSize);
2113 #elif defined(PALMTREE)
2122 if(nprocs<=FTREE_LIMIT){
2123 return new FTreeBcast(pComm,ranks,rank_cnt,msgSize);
2126 return new ModBTreeBcast(pComm,ranks,rank_cnt,msgSize, rseed);
2137 template<
typename T>
2138 inline TreeReduce<T> * TreeReduce<T>::Create(
const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize,
double rseed){
2141 MPI_Comm_size(pComm, &nprocs);
2144 return new FTreeReduce<T>(pComm,ranks,rank_cnt,msgSize);
2145 #elif defined(MODBTREE)
2146 return new ModBTreeReduce<T>(pComm,ranks,rank_cnt,msgSize, rseed);
2147 #elif defined(BTREE)
2148 return new BTreeReduce<T>(pComm,ranks,rank_cnt,msgSize);
2149 #elif defined(PALMTREE)
2150 return new PalmTreeReduce<T>(pComm,ranks,rank_cnt,msgSize);
2154 if(nprocs<=FTREE_LIMIT){
2155 #if ( _DEBUGlevel_ >= 1 ) || defined(REDUCE_VERBOSE)
2156 statusOFS<<
"FLAT TREE USED"<<endl;
2158 return new FTreeReduce<T>(pComm,ranks,rank_cnt,msgSize);
2161 #if ( _DEBUGlevel_ >= 1 ) || defined(REDUCE_VERBOSE)
2162 statusOFS<<
"BINARY TREE USED"<<endl;
2164 return new ModBTreeReduce<T>(pComm,ranks,rank_cnt,msgSize, rseed);
2171 template<
typename T>
2172 inline TreeBcast2<T> * TreeBcast2<T>::Create(
const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize,
double rseed){
2175 MPI_Comm_size(pComm, &nprocs);
2181 return new FTreeBcast2<T>(pComm,ranks,rank_cnt,msgSize);
2182 #elif defined(MODBTREE)
2183 return new ModBTreeBcast2<T>(pComm,ranks,rank_cnt,msgSize,rseed);
2184 #elif defined(BTREE)
2185 return new BTreeBcast2<T>(pComm,ranks,rank_cnt,msgSize);
2189 if(nprocs<=FTREE_LIMIT){
2190 #if ( _DEBUGlevel_ >= 1 ) || defined(REDUCE_VERBOSE)
2191 statusOFS<<
"FLAT TREE USED"<<endl;
2194 return new FTreeBcast2<T>(pComm,ranks,rank_cnt,msgSize);
2198 #if ( _DEBUGlevel_ >= 1 ) || defined(REDUCE_VERBOSE)
2199 statusOFS<<
"BINARY TREE USED"<<endl;
2201 return new ModBTreeBcast2<T>(pComm,ranks,rank_cnt,msgSize, rseed);
Definition: TreeBcast.hpp:705
Definition: TreeBcast.hpp:1620
Definition: TreeBcast.hpp:1130
virtual void Copy(const ModBTreeBcast &Tree)
Definition: TreeBcast.hpp:1094
Definition: TreeBcast.hpp:2017
Profiling and timing using TAU.
void SetValue(NumMat< F > &M, F val)
SetValue sets a numerical matrix to a constant val.
Definition: NumMat_impl.hpp:153
Definition: TreeBcast.hpp:843
Definition: TreeBcast.hpp:988
Definition: TreeBcast.hpp:1880
Definition: TreeBcast.hpp:883
Definition: TreeBcast.hpp:1223
Definition: TreeBcast.hpp:1272
Definition: TreeBcast.hpp:1804