diff --git a/contrib/seekable_format/examples/parallel_compression.c b/contrib/seekable_format/examples/parallel_compression.c index 4e06fae32..d54704c11 100644 --- a/contrib/seekable_format/examples/parallel_compression.c +++ b/contrib/seekable_format/examples/parallel_compression.c @@ -23,6 +23,8 @@ #include "xxhash.h" +#define ZSTD_MULTITHREAD 1 +#include "threading.h" #include "pool.h" // use zstd thread pool for demo #include "../zstd_seekable.h" @@ -72,114 +74,87 @@ static size_t fclose_orDie(FILE* file) exit(6); } -static void fseek_orDie(FILE* file, long int offset, int origin) -{ - if (!fseek(file, offset, origin)) { - if (!fflush(file)) return; - } - /* error */ - perror("fseek"); - exit(7); -} - -static long int ftell_orDie(FILE* file) -{ - long int off = ftell(file); - if (off != -1) return off; - /* error */ - perror("ftell"); - exit(8); -} +struct state { + FILE* fout; + ZSTD_pthread_mutex_t mutex; + size_t nextID; + struct job* pending; + ZSTD_frameLog* frameLog; + const int compressionLevel; +}; struct job { - const void* src; + size_t id; + struct job* next; + struct state* state; + + void* src; size_t srcSize; void* dst; size_t dstSize; unsigned checksum; - - int compressionLevel; - int done; }; +static void addPending_inmutex(struct state* state, struct job* job) +{ + struct job** p = &state->pending; + while (*p && (*p)->id < job->id) + p = &(*p)->next; + job->next = *p; + *p = job; +} + +static void flushFrame(struct state* state, struct job* job) +{ + fwrite_orDie(job->dst, job->dstSize, state->fout); + free(job->dst); + + size_t ret = ZSTD_seekable_logFrame(state->frameLog, job->dstSize, job->srcSize, job->checksum); + if (ZSTD_isError(ret)) { + fprintf(stderr, "ZSTD_seekable_logFrame() error : %s \n", ZSTD_getErrorName(ret)); + exit(12); + } +} + +static void flushPending_inmutex(struct state* state) +{ + while (state->pending && state->pending->id == state->nextID) { + struct job* p = state->pending; + state->pending = p->next; + flushFrame(state, p); + free(p); + state->nextID++; + } +} + +static void finishFrame(struct job* job) +{ + struct state *state = job->state; + ZSTD_pthread_mutex_lock(&state->mutex); + addPending_inmutex(state, job); + flushPending_inmutex(state); + ZSTD_pthread_mutex_unlock(&state->mutex); +} + static void compressFrame(void* opaque) { struct job* job = opaque; job->checksum = XXH64(job->src, job->srcSize, 0); - size_t ret = ZSTD_compress(job->dst, job->dstSize, job->src, job->srcSize, job->compressionLevel); + size_t ret = ZSTD_compress(job->dst, job->dstSize, job->src, job->srcSize, job->state->compressionLevel); if (ZSTD_isError(ret)) { fprintf(stderr, "ZSTD_compress() error : %s \n", ZSTD_getErrorName(ret)); exit(20); } - job->dstSize = ret; - job->done = 1; -} -static void compressFile_orDie(const char* fname, const char* outName, int cLevel, unsigned frameSize, int nbThreads) -{ - POOL_ctx* pool = POOL_create(nbThreads, nbThreads); - if (pool == NULL) { fprintf(stderr, "POOL_create() error \n"); exit(9); } + // No longer need + free(job->src); + job->src = NULL; - FILE* const fin = fopen_orDie(fname, "rb"); - FILE* const fout = fopen_orDie(outName, "wb"); - - if (ZSTD_compressBound(frameSize) > 0xFFFFFFFFU) { fprintf(stderr, "Frame size too large \n"); exit(10); } - unsigned dstSize = ZSTD_compressBound(frameSize); - - - fseek_orDie(fin, 0, SEEK_END); - long int length = ftell_orDie(fin); - fseek_orDie(fin, 0, SEEK_SET); - - size_t numFrames = (length + frameSize - 1) / frameSize; - - struct job* jobs = malloc_orDie(sizeof(struct job) * numFrames); - - size_t i; - for(i = 0; i < numFrames; i++) { - void* in = malloc_orDie(frameSize); - void* out = malloc_orDie(dstSize); - - size_t inSize = fread_orDie(in, frameSize, fin); - - jobs[i].src = in; - jobs[i].srcSize = inSize; - jobs[i].dst = out; - jobs[i].dstSize = dstSize; - jobs[i].compressionLevel = cLevel; - jobs[i].done = 0; - POOL_add(pool, compressFrame, &jobs[i]); - } - - ZSTD_frameLog* fl = ZSTD_seekable_createFrameLog(1); - if (fl == NULL) { fprintf(stderr, "ZSTD_seekable_createFrameLog() failed \n"); exit(11); } - for (i = 0; i < numFrames; i++) { - while (!jobs[i].done) SLEEP(5); /* wake up every 5 milliseconds to check */ - fwrite_orDie(jobs[i].dst, jobs[i].dstSize, fout); - free((void*)jobs[i].src); - free(jobs[i].dst); - - size_t ret = ZSTD_seekable_logFrame(fl, jobs[i].dstSize, jobs[i].srcSize, jobs[i].checksum); - if (ZSTD_isError(ret)) { fprintf(stderr, "ZSTD_seekable_logFrame() error : %s \n", ZSTD_getErrorName(ret)); } - } - - { unsigned char seekTableBuff[1024]; - ZSTD_outBuffer out = {seekTableBuff, 1024, 0}; - while (ZSTD_seekable_writeSeekTable(fl, &out) != 0) { - fwrite_orDie(seekTableBuff, out.pos, fout); - out.pos = 0; - } - fwrite_orDie(seekTableBuff, out.pos, fout); - } - - ZSTD_seekable_freeFrameLog(fl); - free(jobs); - fclose_orDie(fout); - fclose_orDie(fin); + finishFrame(job); } static const char* createOutFilename_orDie(const char* filename) @@ -193,6 +168,71 @@ static const char* createOutFilename_orDie(const char* filename) return (const char*)outSpace; } +static void openInOut_orDie(const char* fname, FILE** fin, FILE** fout) { + if (strcmp(fname, "-") == 0) { + *fin = stdin; + *fout = stdout; + } else { + *fin = fopen_orDie(fname, "rb"); + const char* outName = createOutFilename_orDie(fname); + *fout = fopen_orDie(outName, "wb"); + } +} + +static void compressFile_orDie(const char* fname, int cLevel, unsigned frameSize, int nbThreads) +{ + struct state state = { + .nextID = 0, + .pending = NULL, + .compressionLevel = cLevel, + }; + ZSTD_pthread_mutex_init(&state.mutex, NULL); + state.frameLog = ZSTD_seekable_createFrameLog(1); + if (state.frameLog == NULL) { fprintf(stderr, "ZSTD_seekable_createFrameLog() failed \n"); exit(11); } + + POOL_ctx* pool = POOL_create(nbThreads, nbThreads); + if (pool == NULL) { fprintf(stderr, "POOL_create() error \n"); exit(9); } + + FILE* fin; + openInOut_orDie(fname, &fin, &state.fout); + + if (ZSTD_compressBound(frameSize) > 0xFFFFFFFFU) { fprintf(stderr, "Frame size too large \n"); exit(10); } + unsigned dstSize = ZSTD_compressBound(frameSize); + + for (size_t id = 0; 1; id++) { + struct job* job = malloc_orDie(sizeof(struct job)); + job->id = id; + job->next = NULL; + job->state = &state; + job->src = malloc_orDie(frameSize); + job->dst = malloc_orDie(dstSize); + job->srcSize = fread_orDie(job->src, frameSize, fin); + job->dstSize = dstSize; + POOL_add(pool, compressFrame, job); + if (feof(fin)) + break; + } + + POOL_joinJobs(pool); + if (state.pending) { + fprintf(stderr, "Unexpected leftover output blocks!\n"); + exit(13); + } + + { unsigned char seekTableBuff[1024]; + ZSTD_outBuffer out = {seekTableBuff, 1024, 0}; + while (ZSTD_seekable_writeSeekTable(state.frameLog, &out) != 0) { + fwrite_orDie(seekTableBuff, out.pos, state.fout); + out.pos = 0; + } + fwrite_orDie(seekTableBuff, out.pos, state.fout); + } + + ZSTD_seekable_freeFrameLog(state.frameLog); + fclose_orDie(state.fout); + fclose_orDie(fin); +} + int main(int argc, const char** argv) { const char* const exeName = argv[0]; if (argc!=4) { @@ -206,8 +246,7 @@ int main(int argc, const char** argv) { unsigned const frameSize = (unsigned)atoi(argv[2]); int const nbThreads = atoi(argv[3]); - const char* const outFileName = createOutFilename_orDie(inFileName); - compressFile_orDie(inFileName, outFileName, 5, frameSize, nbThreads); + compressFile_orDie(inFileName, 5, frameSize, nbThreads); } return 0;