Code:
#include <assert.h>
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <time.h>
#include <pthread.h>
#ifndef NDEBUG
// sorts in std::algorithm asserting *really* *really* slows them down
#define WILL_NDEBUG
#define NDEBUG
#endif
#include <algorithm>
#ifdef WILL_NDEBUG
#undef WILL_NDEBUG
#undef NDEBUG
#endif
#define STABLE_SORT // else uses std::sort
struct stl_memcmp {
stl_memcmp(const unsigned char* e) : end(e) {}
inline bool operator()(const unsigned char *a, const unsigned char *b) const {
// find first differing byte
const int lena = (end-a), lenb = (end-b);
if(int diff = ::memcmp(a,b,std::min(lena,lenb)))
return (diff < 0);
// one is longer than the other?
return (a < b);
}
private:
const unsigned char* const end;
};
struct thread_sort_t {
const unsigned char* S;
const unsigned char** P;
size_t N;
};
void* thread_sort(void* vparams) {
const thread_sort_t params = *(thread_sort_t*)vparams;
#ifdef STABLE_SORT
std::stable_sort(params.P,params.P+params.N,stl_memcmp(params.S+params.N));
#else
std::sort(params.P,params.P+params.N,stl_memcmp(params.S+params.N));
#endif
return NULL;
}
unsigned char* bwt_sort(const unsigned char* S,size_t N,int& I,char num_threads) {
// make a list to sort
const unsigned char** P = new const unsigned char*[N];
for(size_t i=0; i<N; i++)
P[i] = S + i;
// dispatch worker threads
const size_t part = (N/num_threads);
thread_sort_t* params = new thread_sort_t[num_threads];
pthread_t* threads = new pthread_t[num_threads];
for(int i=0, start=0; i<num_threads; i++, start+=part) {
params[i].S = S;
params[i].P = P+start;
params[i].N = (i==(num_threads-1))? N-start: part;
pthread_create(threads+i,NULL,thread_sort,params+i);
}
// wait for the workers to finish sorting
for(int i=0; i<num_threads; i++)
pthread_join(threads[i],NULL);
delete[] threads;
delete[] params;
// merge parts; this could be done in one pass but I'm lazy
for(int i=1, start=part; i<num_threads; i++, start+=part) {
const size_t stop = (i==(num_threads-1))? N: (start+part);
std::inplace_merge(P,P+start,P+stop,stl_memcmp(S+N));
}
// create output
I = -1;
unsigned char* L = new unsigned char[N];
for(size_t i=0; i<N; i++) {
if(S == P[i]) {
L[i] = P[i][N-1];
I = i;
} else {
L[i] = P[i][-1];
}
}
assert(0 <= I);
delete[] P;
return L;
}
int main(int argc,char** args) {
if(4!=argc) {
fprintf(stderr,"Usage: %s num_threads src dest\n",args[0]);
return -1;
}
const int num_threads = atoi(args[1]);
if(1>num_threads || 127<num_threads) {
fprintf(stderr,"number of threads is crazy\n");
return -1;
}
FILE* file = fopen(args[2],"r");
if(!file) {
fprintf(stderr,"could not open src \"%s\"\n",args[2]);
return -1;
}
fseek(file,0,SEEK_END);
const size_t N = ftell(file);
fseek(file,0,SEEK_SET);
unsigned char* S = new unsigned char[N+1];
fread(S,1,N+1,file);
S[N] = 0;
fclose(file);
int I;
unsigned char* L = bwt_sort(S,N,I,num_threads);
delete[] S;
if(strcmp(".",args[3])) {
file = fopen(args[3],"w");
if(!file) {
fprintf(stderr,"Could not open dest \"%s\"\n",args[3]);
return -1;
}
fwrite(&N,sizeof(N),1,file);
fwrite(&I,sizeof(I),1,file);
fwrite(L,1,N,file);
fclose(file);
}
delete[] L;
return 0;
}