#include #include #include #include #include #include #include #include #include #include #include using namespace std; void merge(int* begin1, int* end1, int* begin2, int* end2, int* merged) { int* curr1 = begin1; int* curr2 = begin2; while(curr1=end2 || (curr1 f1 = async(launch::async, [=]()->void{mergeRec(begin1, mid1, begin2, mid2, merged, th);}); future f2 = async(launch::async, [=]()->void{mergeRec(mid1, end1, mid2, end2, merged+sz1, nrThreads - th);}); f1.wait(); f2.wait(); } void mergeSortRec(size_t n, int* in, int* buf, bool& isResultInInput, size_t nrThreads) { if(n <= 1) { isResultInInput = true; return; } size_t k = n/2; bool r1; bool r2; if(nrThreads >= 2) { size_t th = nrThreads/2; future f1 = async(launch::async, [k,in,buf,&r1,th]()->void{mergeSortRec(k, in, buf, r1, th);}); future f2 = async(launch::async, [n,k,in,buf,&r2,nrThreads,th]()->void{mergeSortRec(n-k, in+k, buf+k, r2, nrThreads - th);}); f1.wait(); f2.wait(); } else { mergeSortRec(k, in, buf, r1, 1); mergeSortRec(n-k, in+k, buf+k, r2, 1); } if(r2 != r1){ if(r2){ copy(in+k, in+n, buf+k); } else { copy(buf+k, buf+n, in+k); } } if(r1) { mergeRec(in, in+k, in+k, in+n, buf, nrThreads); isResultInInput = false; } else { mergeRec(buf, buf+k, buf+k, buf+n, in, nrThreads); isResultInInput = true; } } void mergeSort(int* v, size_t n, size_t nrThreads) { unique_ptr buf { new int[n] }; bool isResultInInput; mergeSortRec(n, v, buf.get(), isResultInInput, nrThreads); if(!isResultInInput) { copy(buf.get(), buf.get()+n, v); } } void generate(int* v, size_t n, size_t nrThreads) { vector> futures; futures.reserve(nrThreads); size_t start = 0; for(size_t th=0 ; thvoid { for(size_t i=0 ; i& f : futures) { f.wait(); } } bool isSorted(int const* const& v, size_t n) { for(size_t i=1 ; iv[i]) return false; } return true; } int main(int argc, char** argv) { size_t n = 0; size_t nrThreads = 0; if(argc != 3 || 1!=sscanf(argv[1], "%u", &n) || 1!=sscanf(argv[2], "%u", &nrThreads)){ fprintf(stderr, "usage: mergesort \n"); return 1; } unique_ptr v{new int[n]}; generate(v.get(), n, nrThreads); fprintf(stderr, "generated\n"); chrono::high_resolution_clock::time_point const beginTime = chrono::high_resolution_clock::now(); mergeSort(v.get(), n, nrThreads); chrono::high_resolution_clock::time_point const endTime = chrono::high_resolution_clock::now(); cout << (isSorted(v.get(), n) ? "ok" : "WRONG") << "; time="<< (chrono::duration_cast(endTime-beginTime)).count() <<"ms\n"; }