diff --git a/src/nvtt/TaskDispatcher.h b/src/nvtt/TaskDispatcher.h index 65cc326..0995dc9 100644 --- a/src/nvtt/TaskDispatcher.h +++ b/src/nvtt/TaskDispatcher.h @@ -14,11 +14,14 @@ #define HAVE_PPL 1 #endif -#if HAVE_PPL +#if defined(HAVE_PPL) #include //#include #endif +#if defined(HAVE_TBB) +#include +#endif namespace nvtt { @@ -50,12 +53,9 @@ namespace nvtt { struct OpenMPTaskDispatcher : public TaskDispatcher { virtual void dispatch(Task * task, void * context, size_t count) { - #pragma omp parallel - { - #pragma omp for - for (size_t i = 0; i < count; i++) { - task(context, i); - } + #pragma omp parallel for + for (size_t i = 0; i < count; i++) { + task(context, i); } } }; @@ -83,7 +83,7 @@ namespace nvtt { struct TaskFunctor { TaskFunctor(Task * task, void * context) : task(task), context(context) {} void operator()(int & n) const { - n *= n; + task(context, n); } Task * task; void * context; @@ -105,5 +105,25 @@ namespace nvtt { #endif +#if defined(HAVE_TBB) + + struct TaskFunctor { + TaskFunctor(Task * task, void * context) : task(task), context(context) {} + void operator()(int & n) const { + task(context, n); + } + Task * task; + void * context; + }; + + struct IntelTaskDispatcher : public TaskDispatcher + { + virtual void dispatch(Task * task, void * context, size_t count) { + parallel_for(blocked_range(0, count, 1), TaskFunctor(task, context)); + } + }; + +#endif + } // namespace nvtt