PEXSI
 All Classes Namespaces Files Functions Variables Typedefs Pages
TreeBcast.hpp
1 #ifndef _PEXSI_TREE_HPP_
2 #define _PEXSI_TREE_HPP_
3 
4 #include "pexsi/environment.hpp"
5 #include "pexsi/timer.h"
6 
7 #include <vector>
8 #include <algorithm>
9 #include <string>
10 #include <random>
11 
12 // options to switch from a flat bcast/reduce tree to a binary tree
13 
14 #ifndef FTREE_LIMIT
15 #define FTREE_LIMIT 24
16 #endif
17 
18 
19 
20 namespace PEXSI{
21 
22 class TreeBcast{
23  protected:
24  Int myRoot_;
25  MPI_Comm comm_;
26  vector<Int> myDests_;
27  Int myRank_;
28  Int msgSize_;
29  bool isReady_;
30  Int mainRoot_;
31  Int tag_;
32  Int numRecv_;
33 
34 
35 #ifdef COMM_PROFILE
36 protected:
37  Int myGRank_;
38  vector<int> Granks_;
39 public:
40  void SetGlobalComm(const MPI_Comm & pGComm){
41  MPI_Comm_rank(pGComm,&myGRank_);
42  MPI_Group group2 = MPI_GROUP_NULL;
43  MPI_Comm_group(pGComm, &group2);
44  MPI_Group group1 = MPI_GROUP_NULL;
45  MPI_Comm_group(comm_, &group1);
46 
47  Int size;
48  MPI_Comm_size(comm_,&size);
49  Granks_.resize(size);
50  vector<int> Lranks(size);
51  for(int i = 0; i<size;++i){Lranks[i]=i;}
52  MPI_Group_translate_ranks(group1, size, &Lranks[0],group2, &Granks_[0]);
53  }
54 #endif
55 
56 
57 
58  virtual void buildTree(Int * ranks, Int rank_cnt)=0;
59  public:
60  TreeBcast(){
61  comm_ = MPI_COMM_WORLD;
62  myRank_=-1;
63  myRoot_ = -1;
64  msgSize_ = -1;
65  numRecv_ = -1;
66  tag_=-1;
67  mainRoot_=-1;
68  isReady_ = false;
69  }
70 
71 
72  TreeBcast(const MPI_Comm & pComm, Int * ranks, Int rank_cnt,Int msgSize){
73  comm_ = pComm;
74  MPI_Comm_rank(comm_,&myRank_);
75  myRoot_ = -1;
76  msgSize_ = msgSize;
77 
78  numRecv_ = 0;
79  tag_=-1;
80  mainRoot_=ranks[0];
81  isReady_ = false;
82  }
83 
84  TreeBcast(const TreeBcast & Tree){
85  Copy(Tree);
86  }
87 
88  virtual void Copy(const TreeBcast & Tree){
89  comm_ = Tree.comm_;
90  myRank_ = Tree.myRank_;
91  myRoot_ = Tree.myRoot_;
92  msgSize_ = Tree.msgSize_;
93 
94  numRecv_ = Tree.numRecv_;
95  tag_= Tree.tag_;
96  mainRoot_= Tree.mainRoot_;
97  isReady_ = Tree.isReady_;
98  myDests_ = Tree.myDests_;
99  }
100 
101  virtual TreeBcast * clone() const = 0;
102 
103 
104 
105  static TreeBcast * Create(const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize,double rseed);
106 
107  virtual inline Int GetNumRecvMsg(){return numRecv_;}
108  virtual inline Int GetNumMsgToRecv(){return 1;}
109  inline void SetDataReady(bool rdy){ isReady_=rdy;}
110  inline void SetTag(Int tag){ tag_ = tag;}
111 
112 
113  Int * GetDests(){ return &myDests_[0];}
114  Int GetDest(Int i){ return myDests_[i];}
115  Int GetDestCount(){ return myDests_.size();}
116  Int GetRoot(){ return myRoot_;}
117  Int GetMsgSize(){ return msgSize_;}
118 
119  void ForwardMessage( char * data, size_t size, int tag, MPI_Request * requests ){
120  for( Int idxRecv = 0; idxRecv < myDests_.size(); ++idxRecv ){
121  Int iProc = myDests_[idxRecv];
122  // Use Isend to send to multiple targets
123  MPI_Isend( data, size, MPI_BYTE,
124  iProc, tag,comm_, &requests[2*iProc+1] );
125 
126 #ifdef COMM_PROFILE
127  PROFILE_COMM(myGRank_,Granks_[iProc],tag,msgSize_);
128 #endif
129  } // for (iProc)
130  }
131 
132 
133 };
134 
135 class FTreeBcast: public TreeBcast{
136  protected:
137  virtual void buildTree(Int * ranks, Int rank_cnt){
138 
139  Int idxStart = 0;
140  Int idxEnd = rank_cnt;
141 
142 
143 
144  myRoot_ = ranks[0];
145 
146  if(myRank_==myRoot_){
147  myDests_.insert(myDests_.begin(),&ranks[1],&ranks[0]+rank_cnt);
148  }
149 
150 #if ( _DEBUGlevel_ >= 1 )
151  statusOFS<<"My root is "<<myRoot_<<std::endl;
152  statusOFS<<"My dests are ";
153  for(int i =0;i<myDests_.size();++i){statusOFS<<myDests_[i]<<" ";}
154  statusOFS<<std::endl;
155 #endif
156  }
157 
158 
159 
160  public:
161  FTreeBcast(const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize):TreeBcast(pComm,ranks,rank_cnt,msgSize){
162  //build the binary tree;
163  buildTree(ranks,rank_cnt);
164  }
165 
166 
167  virtual FTreeBcast * clone() const{
168  FTreeBcast * out = new FTreeBcast(*this);
169  return out;
170  }
171 };
172 
173 
174 
175 class BTreeBcast: public TreeBcast{
176  protected:
200  virtual void buildTree(Int * ranks, Int rank_cnt){
201 
202  Int idxStart = 0;
203  Int idxEnd = rank_cnt;
204 
205 
206 
207  Int prevRoot = ranks[0];
208  while(idxStart<idxEnd){
209  Int curRoot = ranks[idxStart];
210  Int listSize = idxEnd - idxStart;
211 
212  if(listSize == 1){
213  if(curRoot == myRank_){
214  myRoot_ = prevRoot;
215  break;
216  }
217  }
218  else{
219  Int halfList = floor(ceil(double(listSize) / 2.0));
220  Int idxStartL = idxStart+1;
221  Int idxStartH = idxStart+halfList;
222 
223  if(curRoot == myRank_){
224  if ((idxEnd - idxStartH) > 0 && (idxStartH - idxStartL)>0){
225  Int childL = ranks[idxStartL];
226  Int childR = ranks[idxStartH];
227 
228  myDests_.push_back(childL);
229  myDests_.push_back(childR);
230  }
231  else if ((idxEnd - idxStartH) > 0){
232  Int childR = ranks[idxStartH];
233  myDests_.push_back(childR);
234  }
235  else{
236  Int childL = ranks[idxStartL];
237  myDests_.push_back(childL);
238  }
239  myRoot_ = prevRoot;
240  break;
241  }
242 
243  if( myRank_ < ranks[idxStartH]){
244  idxStart = idxStartL;
245  idxEnd = idxStartH;
246  }
247  else{
248  idxStart = idxStartH;
249  }
250  prevRoot = curRoot;
251  }
252 
253  }
254 
255 #if ( _DEBUGlevel_ >= 1 )
256  statusOFS<<"My root is "<<myRoot_<<std::endl;
257  statusOFS<<"My dests are ";
258  for(int i =0;i<myDests_.size();++i){statusOFS<<myDests_[i]<<" ";}
259  statusOFS<<std::endl;
260 #endif
261  }
262 
263 
264 
265  public:
266  BTreeBcast(const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize):TreeBcast(pComm,ranks,rank_cnt,msgSize){
267  //build the binary tree;
268  buildTree(ranks,rank_cnt);
269  }
270 
271  virtual BTreeBcast * clone() const{
272  BTreeBcast * out = new BTreeBcast(*this);
273  return out;
274  }
275 
276 };
277 
278 
279 
280 class ModBTreeBcast: public TreeBcast{
281  protected:
282  double rseed_;
283 
284  virtual void buildTree(Int * ranks, Int rank_cnt){
285 
286  Int idxStart = 0;
287  Int idxEnd = rank_cnt;
288 
289  //sort the ranks with the modulo like operation
290  if(rank_cnt>1){
291  //Int new_idx = (int)((rand()+1.0) * (double)rank_cnt / ((double)RAND_MAX+1.0));
292 
293 // srand(ranks[0]+rank_cnt);
294  Int new_idx = (int)((rank_cnt - 0) * ( (double)this->rseed_ / (double)RAND_MAX ) + 0);// (this->rseed_)%(rank_cnt-1)+1;
295 // statusOFS<<new_idx<<endl;
296 
297 
298 
299  Int * new_start = &ranks[new_idx];
300 
301  //for(int i =0;i<rank_cnt;++i){statusOFS<<ranks[i]<<" ";} statusOFS<<std::endl;
302 
303 // Int * new_start = std::lower_bound(&ranks[1],&ranks[0]+rank_cnt,ranks[0]);
304  //just swap the two chunks r[0] | r[1] --- r[new_start-1] | r[new_start] --- r[end]
305  // becomes r[0] | r[new_start] --- r[end] | r[1] --- r[new_start-1]
306  std::rotate(&ranks[1], new_start, &ranks[0]+rank_cnt);
307 
308  //for(int i =0;i<rank_cnt;++i){statusOFS<<ranks[i]<<" ";} statusOFS<<std::endl;
309  }
310 
311  Int prevRoot = ranks[0];
312  while(idxStart<idxEnd){
313  Int curRoot = ranks[idxStart];
314  Int listSize = idxEnd - idxStart;
315 
316  if(listSize == 1){
317  if(curRoot == myRank_){
318  myRoot_ = prevRoot;
319  break;
320  }
321  }
322  else{
323  Int halfList = floor(ceil(double(listSize) / 2.0));
324  Int idxStartL = idxStart+1;
325  Int idxStartH = idxStart+halfList;
326 
327  if(curRoot == myRank_){
328  if ((idxEnd - idxStartH) > 0 && (idxStartH - idxStartL)>0){
329  Int childL = ranks[idxStartL];
330  Int childR = ranks[idxStartH];
331 
332  myDests_.push_back(childL);
333  myDests_.push_back(childR);
334  }
335  else if ((idxEnd - idxStartH) > 0){
336  Int childR = ranks[idxStartH];
337  myDests_.push_back(childR);
338  }
339  else{
340  Int childL = ranks[idxStartL];
341  myDests_.push_back(childL);
342  }
343  myRoot_ = prevRoot;
344  break;
345  }
346 
347  //not true anymore ?
348  //first half to
349 TIMER_START(FIND_RANK);
350  Int * pos = std::find(&ranks[idxStartL], &ranks[idxStartH], myRank_);
351 TIMER_STOP(FIND_RANK);
352  if( pos != &ranks[idxStartH]){
353  idxStart = idxStartL;
354  idxEnd = idxStartH;
355  }
356  else{
357  idxStart = idxStartH;
358  }
359  prevRoot = curRoot;
360  }
361 
362  }
363 
364 #if ( _DEBUGlevel_ >= 1 ) || defined(REDUCE_VERBOSE)
365  statusOFS<<"My root is "<<myRoot_<<std::endl;
366  statusOFS<<"My dests are ";
367  for(int i =0;i<myDests_.size();++i){statusOFS<<myDests_[i]<<" ";}
368  statusOFS<<std::endl;
369 #endif
370  }
371 
372 
373 
374  public:
375  ModBTreeBcast(const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize, double rseed):TreeBcast(pComm,ranks,rank_cnt,msgSize){
376  //build the binary tree;
377  rseed_ = rseed;
378  buildTree(ranks,rank_cnt);
379  }
380 
381  virtual void Copy(const ModBTreeBcast & Tree){
382  comm_ = Tree.comm_;
383  myRank_ = Tree.myRank_;
384  myRoot_ = Tree.myRoot_;
385  msgSize_ = Tree.msgSize_;
386 
387  numRecv_ = Tree.numRecv_;
388  tag_= Tree.tag_;
389  mainRoot_= Tree.mainRoot_;
390  isReady_ = Tree.isReady_;
391  myDests_ = Tree.myDests_;
392 
393  rseed_ = Tree.rseed_;
394  myRank_ = Tree.myRank_;
395  myRoot_ = Tree.myRoot_;
396  msgSize_ = Tree.msgSize_;
397 
398  numRecv_ = Tree.numRecv_;
399  tag_= Tree.tag_;
400  mainRoot_= Tree.mainRoot_;
401  isReady_ = Tree.isReady_;
402  myDests_ = Tree.myDests_;
403  }
404 
405  virtual ModBTreeBcast * clone() const{
406  ModBTreeBcast * out = new ModBTreeBcast(*this);
407  return out;
408  }
409 
410 };
411 
412 
413 class RandBTreeBcast: public TreeBcast{
414  protected:
415  virtual void buildTree(Int * ranks, Int rank_cnt){
416 
417  Int idxStart = 0;
418  Int idxEnd = rank_cnt;
419 
420  //random permute ranks
421  if(rank_cnt>1){
422  for(int i =0;i<rank_cnt;++i){statusOFS<<ranks[i]<<" ";} statusOFS<<std::endl;
423  srand(ranks[0]);
424  std::random_shuffle(&ranks[1],&ranks[0]+rank_cnt);
425  for(int i =0;i<rank_cnt;++i){statusOFS<<ranks[i]<<" ";} statusOFS<<std::endl;
426 
427  }
428 
429  Int prevRoot = ranks[0];
430  while(idxStart<idxEnd){
431  Int curRoot = ranks[idxStart];
432  Int listSize = idxEnd - idxStart;
433 
434  if(listSize == 1){
435  if(curRoot == myRank_){
436  myRoot_ = prevRoot;
437  break;
438  }
439  }
440  else{
441  Int halfList = floor(ceil(double(listSize) / 2.0));
442  Int idxStartL = idxStart+1;
443  Int idxStartH = idxStart+halfList;
444 
445  if(curRoot == myRank_){
446  if ((idxEnd - idxStartH) > 0 && (idxStartH - idxStartL)>0){
447  Int childL = ranks[idxStartL];
448  Int childR = ranks[idxStartH];
449 
450  myDests_.push_back(childL);
451  myDests_.push_back(childR);
452  }
453  else if ((idxEnd - idxStartH) > 0){
454  Int childR = ranks[idxStartH];
455  myDests_.push_back(childR);
456  }
457  else{
458  Int childL = ranks[idxStartL];
459  myDests_.push_back(childL);
460  }
461  myRoot_ = prevRoot;
462  break;
463  }
464 
465  //not true anymore ?
466  //first half to
467  Int * pos = std::find(&ranks[idxStartL], &ranks[idxStartH], myRank_);
468  if( pos != &ranks[idxStartH]){
469  idxStart = idxStartL;
470  idxEnd = idxStartH;
471  }
472  else{
473  idxStart = idxStartH;
474  }
475  prevRoot = curRoot;
476  }
477 
478  }
479 
480 #if ( _DEBUGlevel_ >= 1 ) || defined(REDUCE_VERBOSE)
481  statusOFS<<"My root is "<<myRoot_<<std::endl;
482  statusOFS<<"My dests are ";
483  for(int i =0;i<myDests_.size();++i){statusOFS<<myDests_[i]<<" ";}
484  statusOFS<<std::endl;
485 #endif
486  }
487 
488 
489 
490  public:
491  RandBTreeBcast(const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize):TreeBcast(pComm,ranks,rank_cnt,msgSize){
492  //build the binary tree;
493  buildTree(ranks,rank_cnt);
494  }
495 
496  virtual RandBTreeBcast * clone() const{
497  RandBTreeBcast * out = new RandBTreeBcast(*this);
498  return out;
499  }
500 
501 };
502 
503 
504 
505 
506 class PalmTreeBcast: public TreeBcast{
507  protected:
508  virtual void buildTree(Int * ranks, Int rank_cnt){
509  Int numLevel = floor(log2(rank_cnt));
510  Int numRoots = 0;
511  for(Int level=0;level<numLevel;++level){
512  numRoots = std::min( rank_cnt, numRoots + (Int)pow(2,level));
513  Int numNextRoots = std::min(rank_cnt,numRoots + (Int)pow(2,(level+1)));
514  Int numReceivers = numNextRoots - numRoots;
515  for(Int ip = 0; ip<numRoots;++ip){
516  Int p = ranks[ip];
517  for(Int ir = ip; ir<numReceivers;ir+=numRoots){
518  Int r = ranks[numRoots+ir];
519  if(r==myRank_){
520  myRoot_ = p;
521  }
522 
523  if(p==myRank_){
524  myDests_.push_back(r);
525  }
526  }
527  }
528  }
529 
530 #if ( _DEBUGlevel_ >= 1 )
531  statusOFS<<"My root is "<<myRoot_<<std::endl;
532  statusOFS<<"My dests are ";
533  for(int i =0;i<myDests_.size();++i){statusOFS<<myDests_[i]<<" ";}
534  statusOFS<<std::endl;
535 #endif
536  }
537 
538 
539 
540  public:
541  PalmTreeBcast(const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize):TreeBcast(pComm,ranks,rank_cnt,msgSize){
542  //build the binary tree;
543  buildTree(ranks,rank_cnt);
544  }
545 
546  virtual PalmTreeBcast * clone() const{
547  PalmTreeBcast * out = new PalmTreeBcast(*this);
548  return out;
549  }
550 
551 };
552 
553 
554 template< typename T>
555 class TreeReduce: public TreeBcast{
556  protected:
557 
558  T * myData_;
559  MPI_Request sendRequest_;
560 
561  //char * myLocalBuffer_;
562  NumVec<char> myLocalBuffer_;
563  //char * myRecvBuffers_;
564  NumVec<char> myRecvBuffers_;
565  NumVec<T *> remoteData_;
566  NumVec<MPI_Request> myRequests_;
567  NumVec<MPI_Status> myStatuses_;
568  NumVec<int> recvIdx_;
569 
570  bool fwded_;
571  bool isAllocated_;
572  Int numRecvPosted_;
573 
574  public:
575  TreeReduce(const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize):TreeBcast(pComm,ranks,rank_cnt,msgSize){
576  myData_ = NULL;
577  sendRequest_ = MPI_REQUEST_NULL;
578  fwded_=false;
579  isAllocated_=false;
580  numRecvPosted_= 0;
581  }
582 
583 
584  virtual TreeReduce * clone() const = 0;
585 
586  TreeReduce(const TreeReduce & Tree){
587  this->Copy(Tree);
588  }
589 
590  virtual void Copy(const TreeReduce & Tree){
591  this->comm_ = Tree.comm_;
592  this->myRank_ = Tree.myRank_;
593  this->myRoot_ = Tree.myRoot_;
594  this->msgSize_ = Tree.msgSize_;
595 
596  this->numRecv_ = Tree.numRecv_;
597  this->tag_= Tree.tag_;
598  this->mainRoot_= Tree.mainRoot_;
599  this->isReady_ = Tree.isReady_;
600  this->myDests_ = Tree.myDests_;
601 
602 
603  this->myData_ = Tree.myData_;
604  this->sendRequest_ = Tree.sendRequest_;
605  this->fwded_= Tree.fwded_;
606  this->isAllocated_= Tree.isAllocated_;
607  this->numRecvPosted_= Tree.numRecvPosted_;
608 
609  this->myLocalBuffer_ = Tree.myLocalBuffer_;
610  this->myRecvBuffers_ = Tree.myRecvBuffers_;
611  this->remoteData_ = Tree.remoteData_;
612  this->myRequests_ = Tree.myRequests_;
613  this->myStatuses_ = Tree.myStatuses_;
614  this->recvIdx_ = Tree.recvIdx_;
615  }
616 
617 
618 
619  bool IsAllocated(){return isAllocated_;}
620 
621  virtual ~TreeReduce(){
622  CleanupBuffers();
623  }
624 
625 
626  static TreeReduce<T> * Create(const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize,double rseed);
627 
628  virtual inline Int GetNumMsgToRecv(){return GetDestCount();}
629 
630  virtual void AllocRecvBuffers(){
631  remoteData_.Resize(GetDestCount());
632  //SetValue(remoteData_,(T*)NULL);
633 
634  //assert(myRecvBuffers_==NULL);
635  //myRecvBuffers_ = new char[GetDestCount()*msgSize_];
636 
637 
638  myRecvBuffers_.Resize(GetDestCount()*msgSize_);
639  //SetValue(myRecvBuffers_,(char)0);
640 
641  for( Int idxRecv = 0; idxRecv < GetDestCount(); ++idxRecv ){
642  remoteData_[idxRecv] = (T*)&(myRecvBuffers_[idxRecv*msgSize_]);
643  //Int nelem = msgSize_ / sizeof(T);
644  //std::fill(remoteData_[idxRecv],remoteData_[idxRecv]+nelem,ZERO<T>());
645  }
646 
647  myRequests_.Resize(GetDestCount());
648  SetValue(myRequests_,MPI_REQUEST_NULL);
649  myStatuses_.Resize(GetDestCount());
650  recvIdx_.Resize(GetDestCount());
651 
652  sendRequest_ = MPI_REQUEST_NULL;
653 
654  isAllocated_ = true;
655  }
656 
657  void CleanupBuffers(){
658  myLocalBuffer_.Clear();
659 // if(myLocalBuffer_!=NULL){
660 // delete []myLocalBuffer_;
661 // myLocalBuffer_=NULL;
662 // }
663 
664 
665  remoteData_.Clear();
666 // myRecvBuffers_.Clear();
667 // if(myRecvBuffers_!=NULL){
668 // delete []myRecvBuffers_;
669 // myRecvBuffers_=NULL;
670 // }
671 
672 
673  myRequests_.Clear();
674  myStatuses_.Clear();
675  recvIdx_.Clear();
676 
677 
678 // if(myLocalBuffer_!=NULL){
679 // delete [] myLocalBuffer_;
680 // }
681 // myLocalBuffer_=NULL;
682 
683 
684  myData_ = NULL;
685  sendRequest_ = MPI_REQUEST_NULL;
686  fwded_=false;
687  isAllocated_=false;
688  isReady_=false;
689  numRecv_ = 0;
690  numRecvPosted_= 0;
691  }
692 
693 
694  void SetLocalBuffer(T * locBuffer){
695  if(myData_!=NULL && myData_!=locBuffer){
696  blas::Axpy(msgSize_/sizeof(T), ONE<T>(), myData_, 1, locBuffer, 1 );
697  myLocalBuffer_.Clear();
698  }
699 
700  myData_ = locBuffer;
701  }
702 
703  inline bool AccumulationDone(){
704  if(myRank_==myRoot_ && isAllocated_){
705  isReady_=true;
706  }
707  return isReady_ && (numRecv_ == GetDestCount());
708  }
709 
710 
711  inline bool IsDone(){
712  if(myRank_==myRoot_ && isAllocated_){
713  isReady_=true;
714  }
715 
716  bool retVal = AccumulationDone();
717  if(myRoot_ != myRank_ && !fwded_){
718  retVal = false;
719  }
720 
721  if (retVal && myRoot_ != myRank_ && fwded_){
722  //test the send request
723  int flag = 0;
724  MPI_Test(&sendRequest_,&flag,MPI_STATUS_IGNORE);
725  retVal = flag==1;
726  }
727 
728  return retVal;
729  }
730 
731  //async wait and forward
732  virtual bool Progress(){
733  if(!isAllocated_){
734  return true;
735  }
736 
737  if(myRank_==myRoot_ && isAllocated_){
738  isReady_=true;
739  }
740 
741 // if(this->numRecvPosted_==0){
742 // this->PostFirstRecv();
743 // }
744 
745  bool retVal = AccumulationDone();
746  if(isReady_ && !retVal){
747 
748  //assert(isAllocated_);
749 
750  //mpi_test_some on my requests
751  int recvCount = -1;
752  int reqCnt = GetDestCount();
753  assert(reqCnt == myRequests_.m());
754  MPI_Testsome(reqCnt,&myRequests_[0],&recvCount,&recvIdx_[0],&myStatuses_[0]);
755  //if something has been received, accumulate and potentially forward it
756  for(Int i = 0;i<recvCount;++i ){
757  Int idx = recvIdx_[i];
758 
759  if(idx!=MPI_UNDEFINED){
760 
761  Int size = 0;
762  MPI_Get_count(&myStatuses_[i], MPI_BYTE, &size);
763 
764 
765 #if ( _DEBUGlevel_ >= 1 ) || defined(REDUCE_VERBOSE)
766  statusOFS<<myRank_<<" RECVD from "<<myStatuses_[i].MPI_SOURCE<<" on tag "<<tag_<<std::endl;
767 #endif
768  if(size>0){
769  //If myData is 0, allocate to the size of what has been received
770  if(myData_==NULL){
771  //assert(size==msgSize_);
772  myLocalBuffer_.Resize(msgSize_);
773 
774  myData_ = (T*)&myLocalBuffer_[0];
775  Int nelem = +msgSize_/sizeof(T);
776  std::fill(myData_,myData_+nelem,ZERO<T>());
777  }
778 
779  Reduce(idx,i);
780 
781  }
782 
783  numRecv_++;
784  //MPI_Request_free(&myRequests_[idx]);
785  }
786  }
787 
788  }
789  else if (isReady_ && sendRequest_ == MPI_REQUEST_NULL && myRoot_ != myRank_ && !fwded_){
790  //free the unnecessary arrays
791  myRecvBuffers_.Clear();
792  myRequests_.Clear();
793  myStatuses_.Clear();
794  recvIdx_.Clear();
795 
796  //assert(isAllocated_);
797 
798  //Forward
799  Forward();
800  retVal = false;
801  }
802  else{
803  retVal = IsDone();
804  if(retVal){
805  //free the unnecessary arrays
806  myRecvBuffers_.Clear();
807  myRequests_.Clear();
808  myStatuses_.Clear();
809  recvIdx_.Clear();
810  }
811  }
812 
813  return retVal;
814  }
815 
816  //blocking wait
817  void Wait(){
818  if(isAllocated_){
819  while(!Progress());
820  }
821  }
822 
823  T * GetLocalBuffer(){
824  return myData_;
825  }
826 
827 
828 
829  void CopyLocalBuffer(T* destBuffer){
830  std::copy((char*)myData_,(char*)myData_+GetMsgSize(),(char*)destBuffer);
831  }
832 
833 
834  virtual void PostFirstRecv()
835  {
836  if(this->GetDestCount()>this->numRecvPosted_){
837  for( Int idxRecv = 0; idxRecv < myDests_.size(); ++idxRecv ){
838  Int iProc = myDests_[idxRecv];
839  //assert(msgSize_>=0);
840  MPI_Irecv( (char*)remoteData_[idxRecv], msgSize_, MPI_BYTE,
841  iProc, tag_,comm_, &myRequests_[idxRecv] );
842  this->numRecvPosted_++;
843  } // for (iProc)
844  }
845  }
846 
847 
848 
849 
850  protected:
851  void Reduce( Int idxRecv, Int idReq){
852  //add thing to my data
853  blas::Axpy(msgSize_/sizeof(T), ONE<T>(), remoteData_[idxRecv], 1, myData_, 1 );
854  }
855 
856  void Forward(){
857  //forward to my root if I have reseived everything
858  Int iProc = myRoot_;
859  // Use Isend to send to multiple targets
860  if(myData_==NULL){
861  MPI_Isend( NULL, 0, MPI_BYTE,
862  iProc, tag_,comm_, &sendRequest_ );
863 #ifdef COMM_PROFILE
864  PROFILE_COMM(myGRank_,Granks_[iProc],tag_,0);
865 #endif
866  }
867  else{
868  MPI_Isend( (char*)myData_, msgSize_, MPI_BYTE,
869  iProc, tag_,comm_, &sendRequest_ );
870 #ifdef COMM_PROFILE
871  PROFILE_COMM(myGRank_,Granks_[iProc],tag_,msgSize_);
872 #endif
873  }
874 
875 #if ( _DEBUGlevel_ >= 1 ) || defined(REDUCE_VERBOSE)
876  statusOFS<<myRank_<<" FWD to "<<iProc<<" on tag "<<tag_<<std::endl;
877 #endif
878 
879  fwded_ = true;
880 
881  }
882 
883 };
884 
885 
886 template< typename T>
887 class FTreeReduce: public TreeReduce<T>{
888  protected:
889  virtual void buildTree(Int * ranks, Int rank_cnt){
890 
891  Int idxStart = 0;
892  Int idxEnd = rank_cnt;
893 
894 
895 
896  this->myRoot_ = ranks[0];
897 
898  if(this->myRank_==this->myRoot_){
899  this->myDests_.insert(this->myDests_.begin(),&ranks[1],&ranks[0]+rank_cnt);
900  }
901 
902 #if ( _DEBUGlevel_ >= 1 ) || defined(REDUCE_VERBOSE)
903  statusOFS<<"My root is "<<this->myRoot_<<std::endl;
904  statusOFS<<"My dests are ";
905  for(int i =0;i<this->myDests_.size();++i){statusOFS<<this->myDests_[i]<<" ";}
906  statusOFS<<std::endl;
907 #endif
908  }
909 
910  void Reduce( ){
911  //add thing to my data
912  blas::Axpy(this->msgSize_/sizeof(T), ONE<T>(), this->remoteData_[0], 1, this->myData_, 1 );
913  }
914 
915 
916 
917  public:
918  FTreeReduce(const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize):TreeReduce<T>(pComm, ranks, rank_cnt, msgSize){
919  buildTree(ranks,rank_cnt);
920  }
921 
922  virtual void PostFirstRecv()
923  {
924 // if(!this->isAllocated_){
925 // this->AllocRecvBuffers();
926 // }
927  if(this->isAllocated_ && this->GetDestCount()>this->numRecvPosted_){
928  MPI_Irecv( (char*)this->remoteData_[0], this->msgSize_, MPI_BYTE,
929  MPI_ANY_SOURCE, this->tag_,this->comm_, &this->myRequests_[0] );
930  this->numRecvPosted_++;
931  }
932  }
933 
934  virtual void AllocRecvBuffers(){
935  if(this->GetDestCount()>0){
936  this->remoteData_.Resize(1);
937 
938  this->myRecvBuffers_.Resize(this->msgSize_);
939 
940  this->remoteData_[0] = (T*)&(this->myRecvBuffers_[0]);
941 
942  this->myRequests_.Resize(1);
943  SetValue(this->myRequests_,MPI_REQUEST_NULL);
944  this->myStatuses_.Resize(1);
945  this->recvIdx_.Resize(1);
946  }
947 
948  this->sendRequest_ = MPI_REQUEST_NULL;
949 
950  this->isAllocated_ = true;
951  }
952 
953  virtual bool Progress(){
954 
955  if(!this->isAllocated_){
956  return true;
957  }
958 
959 
960  if(this->myRank_==this->myRoot_ && this->isAllocated_){
961  this->isReady_=true;
962  }
963 
964 // if(this->numRecvPosted_==0){
965 // this->PostFirstRecv();
966 // }
967 
968  bool retVal = this->AccumulationDone();
969  if(this->isReady_ && !retVal){
970 
971  //assert(this->isAllocated_);
972 
973  //mpi_test_some on my requests
974  int recvCount = -1;
975  int reqCnt = 1;
976 
977  MPI_Testsome(reqCnt,&this->myRequests_[0],&recvCount,&this->recvIdx_[0],&this->myStatuses_[0]);
978  //MPI_Waitsome(reqCnt,&myRequests_[0],&recvCount,&recvIdx_[0],&myStatuses_[0]);
979  //if something has been received, accumulate and potentially forward it
980  for(Int i = 0;i<recvCount;++i ){
981  Int idx = this->recvIdx_[i];
982 
983  if(idx!=MPI_UNDEFINED){
984 
985  Int size = 0;
986  MPI_Get_count(&this->myStatuses_[i], MPI_BYTE, &size);
987 
988 
989 #if ( _DEBUGlevel_ >= 1 ) || defined(REDUCE_VERBOSE)
990 
991  statusOFS<<this->myRank_<<" RECVD from "<<this->myStatuses_[i].MPI_SOURCE<<" on tag "<<this->tag_<<std::endl;
992 #endif
993  if(size>0){
994  //If myData is 0, allocate to the size of what has been received
995  if(this->myData_==NULL){
996  //assert(size==this->msgSize_);
997  this->myLocalBuffer_.Resize(this->msgSize_);
998 
999  this->myData_ = (T*)&this->myLocalBuffer_[0];
1000  Int nelem = this->msgSize_/sizeof(T);
1001  std::fill(this->myData_,this->myData_+nelem,ZERO<T>());
1002  }
1003 
1004  this->Reduce();
1005  }
1006 
1007  this->numRecv_++;
1008  }
1009  }
1010 
1011  if(recvCount>0){
1012  this->PostFirstRecv();
1013  }
1014  }
1015  else if (this->isReady_ && this->sendRequest_ == MPI_REQUEST_NULL && this->myRoot_ != this->myRank_ && !this->fwded_){
1016  //free the unnecessary arrays
1017  this->myRecvBuffers_.Clear();
1018  this->myRequests_.Clear();
1019  this->myStatuses_.Clear();
1020  this->recvIdx_.Clear();
1021 
1022  //Forward
1023  this->Forward();
1024  retVal = false;
1025  }
1026  else{
1027  retVal = this->IsDone();
1028  if(retVal){
1029  //free the unnecessary arrays
1030  this->myRecvBuffers_.Clear();
1031  this->myRequests_.Clear();
1032  this->myStatuses_.Clear();
1033  this->recvIdx_.Clear();
1034  }
1035  }
1036 
1037  return retVal;
1038  }
1039 
1040 
1041  virtual FTreeReduce * clone() const{
1042  FTreeReduce * out = new FTreeReduce(*this);
1043  return out;
1044  }
1045 
1046 
1047 
1048 };
1049 
1050 
1051 
1052 template< typename T>
1053 class BTreeReduce: public TreeReduce<T>{
1054  protected:
1055  virtual void buildTree(Int * ranks, Int rank_cnt){
1056  Int idxStart = 0;
1057  Int idxEnd = rank_cnt;
1058 
1059 
1060 
1061  Int prevRoot = ranks[0];
1062  while(idxStart<idxEnd){
1063  Int curRoot = ranks[idxStart];
1064  Int listSize = idxEnd - idxStart;
1065 
1066  if(listSize == 1){
1067  if(curRoot == this->myRank_){
1068  this->myRoot_ = prevRoot;
1069  break;
1070  }
1071  }
1072  else{
1073  Int halfList = floor(ceil(double(listSize) / 2.0));
1074  Int idxStartL = idxStart+1;
1075  Int idxStartH = idxStart+halfList;
1076 
1077  if(curRoot == this->myRank_){
1078  if ((idxEnd - idxStartH) > 0 && (idxStartH - idxStartL)>0){
1079  Int childL = ranks[idxStartL];
1080  Int childR = ranks[idxStartH];
1081 
1082  this->myDests_.push_back(childL);
1083  this->myDests_.push_back(childR);
1084  }
1085  else if ((idxEnd - idxStartH) > 0){
1086  Int childR = ranks[idxStartH];
1087  this->myDests_.push_back(childR);
1088  }
1089  else{
1090  Int childL = ranks[idxStartL];
1091  this->myDests_.push_back(childL);
1092  }
1093  this->myRoot_ = prevRoot;
1094  break;
1095  }
1096 
1097  if( this->myRank_ < ranks[idxStartH]){
1098  idxStart = idxStartL;
1099  idxEnd = idxStartH;
1100  }
1101  else{
1102  idxStart = idxStartH;
1103  }
1104  prevRoot = curRoot;
1105  }
1106 
1107  }
1108 
1109 #if ( _DEBUGlevel_ >= 1 ) || defined(REDUCE_VERBOSE)
1110  statusOFS<<"My root is "<<this->myRoot_<<std::endl;
1111  statusOFS<<"My dests are ";
1112  for(int i =0;i<this->myDests_.size();++i){statusOFS<<this->myDests_[i]<<" ";}
1113  statusOFS<<std::endl;
1114 #endif
1115  }
1116  public:
1117  BTreeReduce(const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize):TreeReduce<T>(pComm, ranks, rank_cnt, msgSize){
1118  buildTree(ranks,rank_cnt);
1119  }
1120 
1121  virtual BTreeReduce * clone() const{
1122  BTreeReduce * out = new BTreeReduce(*this);
1123  return out;
1124  }
1125 };
1126 
1127 
1128 template< typename T>
1129 class ModBTreeReduce: public TreeReduce<T>{
1130  protected:
1131  double rseed_;
1132  virtual void buildTree(Int * ranks, Int rank_cnt){
1133 
1134  Int idxStart = 0;
1135  Int idxEnd = rank_cnt;
1136 
1137  //sort the ranks with the modulo like operation
1138  if(rank_cnt>1){
1139  //generate a random position in [1 .. rand_cnt]
1140  //Int new_idx = (int)((rand()+1.0) * (double)rank_cnt / ((double)RAND_MAX+1.0));
1141  //srand(ranks[0]+rank_cnt);
1142  //Int new_idx = rseed_%(rank_cnt-1)+1;
1143  Int new_idx = (int)((rank_cnt - 0) * ( (double)this->rseed_ / (double)RAND_MAX ) + 0);// (this->rseed_)%(rank_cnt-1)+1;
1144 
1145 
1146  Int * new_start = &ranks[new_idx];
1147 // for(int i =0;i<rank_cnt;++i){statusOFS<<ranks[i]<<" ";} statusOFS<<std::endl;
1148 
1149 // Int * new_start = std::lower_bound(&ranks[1],&ranks[0]+rank_cnt,ranks[0]);
1150  //just swap the two chunks r[0] | r[1] --- r[new_start-1] | r[new_start] --- r[end]
1151  // becomes r[0] | r[new_start] --- r[end] | r[1] --- r[new_start-1]
1152  std::rotate(&ranks[1], new_start, &ranks[0]+rank_cnt);
1153 // for(int i =0;i<rank_cnt;++i){statusOFS<<ranks[i]<<" ";} statusOFS<<std::endl;
1154  }
1155 
1156  Int prevRoot = ranks[0];
1157  while(idxStart<idxEnd){
1158  Int curRoot = ranks[idxStart];
1159  Int listSize = idxEnd - idxStart;
1160 
1161  if(listSize == 1){
1162  if(curRoot == this->myRank_){
1163  this->myRoot_ = prevRoot;
1164  break;
1165  }
1166  }
1167  else{
1168  Int halfList = floor(ceil(double(listSize) / 2.0));
1169  Int idxStartL = idxStart+1;
1170  Int idxStartH = idxStart+halfList;
1171 
1172  if(curRoot == this->myRank_){
1173  if ((idxEnd - idxStartH) > 0 && (idxStartH - idxStartL)>0){
1174  Int childL = ranks[idxStartL];
1175  Int childR = ranks[idxStartH];
1176 
1177  this->myDests_.push_back(childL);
1178  this->myDests_.push_back(childR);
1179  }
1180  else if ((idxEnd - idxStartH) > 0){
1181  Int childR = ranks[idxStartH];
1182  this->myDests_.push_back(childR);
1183  }
1184  else{
1185  Int childL = ranks[idxStartL];
1186  this->myDests_.push_back(childL);
1187  }
1188  this->myRoot_ = prevRoot;
1189  break;
1190  }
1191 
1192  //not true anymore ?
1193  //first half to
1194 TIMER_START(FIND_RANK);
1195  Int * pos = std::find(&ranks[idxStartL], &ranks[idxStartH], this->myRank_);
1196 TIMER_STOP(FIND_RANK);
1197  if( pos != &ranks[idxStartH]){
1198  idxStart = idxStartL;
1199  idxEnd = idxStartH;
1200  }
1201  else{
1202  idxStart = idxStartH;
1203  }
1204  prevRoot = curRoot;
1205  }
1206 
1207  }
1208 
1209 #if ( _DEBUGlevel_ >= 1 ) || defined(REDUCE_VERBOSE)
1210  statusOFS<<"My root is "<<this->myRoot_<<std::endl;
1211  statusOFS<<"My dests are ";
1212  for(int i =0;i<this->myDests_.size();++i){statusOFS<<this->myDests_[i]<<" ";}
1213  statusOFS<<std::endl;
1214 #endif
1215  }
1216  public:
1217  ModBTreeReduce(const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize, double rseed):TreeReduce<T>(pComm, ranks, rank_cnt, msgSize){
1218  this->rseed_ = rseed;
1219  buildTree(ranks,rank_cnt);
1220  }
1221 
1222  virtual void Copy(const ModBTreeReduce & Tree){
1223  this->comm_ = Tree.comm_;
1224  this->myRank_ = Tree.myRank_;
1225  this->myRoot_ = Tree.myRoot_;
1226  this->msgSize_ = Tree.msgSize_;
1227 
1228  this->numRecv_ = Tree.numRecv_;
1229  this->tag_= Tree.tag_;
1230  this->mainRoot_= Tree.mainRoot_;
1231  this->isReady_ = Tree.isReady_;
1232  this->myDests_ = Tree.myDests_;
1233 
1234 
1235  this->myData_ = Tree.myData_;
1236  this->sendRequest_ = Tree.sendRequest_;
1237  this->fwded_= Tree.fwded_;
1238  this->isAllocated_= Tree.isAllocated_;
1239  this->numRecvPosted_= Tree.numRecvPosted_;
1240 
1241  this->myLocalBuffer_ = Tree.myLocalBuffer_;
1242  this->myRecvBuffers_ = Tree.myRecvBuffers_;
1243  this->remoteData_ = Tree.remoteData_;
1244  this->myRequests_ = Tree.myRequests_;
1245  this->myStatuses_ = Tree.myStatuses_;
1246  this->recvIdx_ = Tree.recvIdx_;
1247  this->rseed_ = Tree.rseed_;
1248  }
1249 
1250 
1251 
1252 
1253  virtual ModBTreeReduce * clone() const{
1254  ModBTreeReduce * out = new ModBTreeReduce(*this);
1255  return out;
1256  }
1257 
1258 };
1259 
1260 
1261 
1262  inline TreeBcast * TreeBcast::Create(const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize, double rseed){
1263  //get communicator size
1264  Int nprocs = 0;
1265  MPI_Comm_size(pComm, &nprocs);
1266 
1267 // return new PalmTreeBcast(pComm,ranks,rank_cnt,msgSize);
1268 // return new ModBTreeBcast(pComm,ranks,rank_cnt,msgSize, rseed);
1269 // return new RandBTreeBcast(pComm,ranks,rank_cnt,msgSize);
1270 
1271  if(nprocs<=FTREE_LIMIT){
1272  return new FTreeBcast(pComm,ranks,rank_cnt,msgSize);
1273  }
1274  else{
1275  return new ModBTreeBcast(pComm,ranks,rank_cnt,msgSize, rseed);
1276  //return new BTreeBcast(pComm,ranks,rank_cnt,msgSize);
1277  }
1278 
1279 
1280 
1281 
1282  }
1283 
1284 
1285 
1286 
1287 template< typename T>
1288  inline TreeReduce<T> * TreeReduce<T>::Create(const MPI_Comm & pComm, Int * ranks, Int rank_cnt, Int msgSize, double rseed){
1289  //get communicator size
1290  Int nprocs = 0;
1291  MPI_Comm_size(pComm, &nprocs);
1292 
1293 
1294  if(nprocs<=FTREE_LIMIT){
1295 #if ( _DEBUGlevel_ >= 1 ) || defined(REDUCE_VERBOSE)
1296 statusOFS<<"FLAT TREE USED"<<endl;
1297 #endif
1298  return new FTreeReduce<T>(pComm,ranks,rank_cnt,msgSize);
1299  }
1300  else{
1301 #if ( _DEBUGlevel_ >= 1 ) || defined(REDUCE_VERBOSE)
1302 statusOFS<<"BINARY TREE USED"<<endl;
1303 #endif
1304  return new ModBTreeReduce<T>(pComm,ranks,rank_cnt,msgSize, rseed);
1305  //return new BTreeReduce<T>(pComm,ranks,rank_cnt,msgSize);
1306  }
1307  }
1308 
1309 
1310 
1311 
1312 
1313 
1314 
1315 
1316 }
1317 
1318 #endif
Environmental variables.
Definition: TreeBcast.hpp:22
Definition: TreeBcast.hpp:887
Definition: TreeBcast.hpp:413
Profiling and timing using TAU.
void SetValue(NumMat< F > &M, F val)
SetValue sets a numerical matrix to a constant val.
Definition: NumMat_impl.hpp:171
Definition: TreeBcast.hpp:135
Definition: TreeBcast.hpp:280
Definition: TreeBcast.hpp:1129
Definition: TreeBcast.hpp:175
Definition: TreeBcast.hpp:506
Definition: TreeBcast.hpp:555
Definition: TreeBcast.hpp:1053