From 55997ba442a3ccbefae38ea0f2e84267d64f0aa4 Mon Sep 17 00:00:00 2001 From: castano Date: Sun, 20 May 2007 10:37:32 +0000 Subject: [PATCH] some progress in DXT5 cuda compressor. --- src/nvimage/nvtt/cuda/CompressKernel.cu | 949 +++++++++++++----------- 1 file changed, 500 insertions(+), 449 deletions(-) diff --git a/src/nvimage/nvtt/cuda/CompressKernel.cu b/src/nvimage/nvtt/cuda/CompressKernel.cu index 4121d9c..84079ca 100644 --- a/src/nvimage/nvtt/cuda/CompressKernel.cu +++ b/src/nvimage/nvtt/cuda/CompressKernel.cu @@ -36,6 +36,7 @@ #define __debugsync() #endif +typedef unsigned char uchar; typedef unsigned short ushort; typedef unsigned int uint; @@ -50,6 +51,132 @@ __device__ inline void swap(T & a, T & b) __constant__ float3 kColorMetric = { 1.0f, 1.0f, 1.0f }; + +//////////////////////////////////////////////////////////////////////////////// +// Sort colors +//////////////////////////////////////////////////////////////////////////////// +__device__ void sortColors(float * values, int * cmp) +{ + int tid = threadIdx.x; + + cmp[tid] = (values[0] < values[tid]); + cmp[tid] += (values[1] < values[tid]); + cmp[tid] += (values[2] < values[tid]); + cmp[tid] += (values[3] < values[tid]); + cmp[tid] += (values[4] < values[tid]); + cmp[tid] += (values[5] < values[tid]); + cmp[tid] += (values[6] < values[tid]); + cmp[tid] += (values[7] < values[tid]); + cmp[tid] += (values[8] < values[tid]); + cmp[tid] += (values[9] < values[tid]); + cmp[tid] += (values[10] < values[tid]); + cmp[tid] += (values[11] < values[tid]); + cmp[tid] += (values[12] < values[tid]); + cmp[tid] += (values[13] < values[tid]); + cmp[tid] += (values[14] < values[tid]); + cmp[tid] += (values[15] < values[tid]); + + // Resolve elements with the same index. + if (tid > 0 && cmp[tid] == cmp[0]) ++cmp[tid]; + if (tid > 1 && cmp[tid] == cmp[1]) ++cmp[tid]; + if (tid > 2 && cmp[tid] == cmp[2]) ++cmp[tid]; + if (tid > 3 && cmp[tid] == cmp[3]) ++cmp[tid]; + if (tid > 4 && cmp[tid] == cmp[4]) ++cmp[tid]; + if (tid > 5 && cmp[tid] == cmp[5]) ++cmp[tid]; + if (tid > 6 && cmp[tid] == cmp[6]) ++cmp[tid]; + if (tid > 7 && cmp[tid] == cmp[7]) ++cmp[tid]; + if (tid > 8 && cmp[tid] == cmp[8]) ++cmp[tid]; + if (tid > 9 && cmp[tid] == cmp[9]) ++cmp[tid]; + if (tid > 10 && cmp[tid] == cmp[10]) ++cmp[tid]; + if (tid > 11 && cmp[tid] == cmp[11]) ++cmp[tid]; + if (tid > 12 && cmp[tid] == cmp[12]) ++cmp[tid]; + if (tid > 13 && cmp[tid] == cmp[13]) ++cmp[tid]; + if (tid > 14 && cmp[tid] == cmp[14]) ++cmp[tid]; +} + + +//////////////////////////////////////////////////////////////////////////////// +// Load color block to shared mem +//////////////////////////////////////////////////////////////////////////////// +__device__ void loadColorBlock(const uint * image, float3 colors[16], int xrefs[16]) +{ + const int bid = blockIdx.x; + const int idx = threadIdx.x; + + __shared__ float dps[16]; + + if (idx < 16) + { + // Read color and copy to shared mem. + uint c = image[(bid) * 16 + idx]; + + colors[idx].z = ((c >> 0) & 0xFF) * (1.0f / 255.0f); + colors[idx].y = ((c >> 8) & 0xFF) * (1.0f / 255.0f); + colors[idx].x = ((c >> 16) & 0xFF) * (1.0f / 255.0f); + + // No need to synchronize, 16 < warp size. +#if __DEVICE_EMULATION__ + } __debugsync(); if (idx < 16) { +#endif + + // Sort colors along the best fit line. + float3 axis = bestFitLine(colors); + + dps[idx] = dot(colors[idx], axis); + +#if __DEVICE_EMULATION__ + } __debugsync(); if (idx < 16) { +#endif + + sortColors(dps, xrefs); + + float3 tmp = colors[idx]; + colors[xrefs[idx]] = tmp; + } +} + +__device__ void loadColorBlock(const uint * image, float3 colors[16], float weights[16], int xrefs[16]) +{ + const int bid = blockIdx.x; + const int idx = threadIdx.x; + + __shared__ float dps[16]; + + if (idx < 16) + { + // Read color and copy to shared mem. + uint c = image[(bid) * 16 + idx]; + + colors[idx].z = ((c >> 0) & 0xFF) * (1.0f / 255.0f); + colors[idx].y = ((c >> 8) & 0xFF) * (1.0f / 255.0f); + colors[idx].x = ((c >> 16) & 0xFF) * (1.0f / 255.0f); + weights[idx] = ((c >> 24) & 0xFF) * (1.0f / 255.0f); + + // No need to synchronize, 16 < warp size. +#if __DEVICE_EMULATION__ + } __debugsync(); if (idx < 16) { +#endif + + // Sort colors along the best fit line. + float3 axis = bestFitLine(colors); + + dps[idx] = dot(colors[idx], axis); + +#if __DEVICE_EMULATION__ + } __debugsync(); if (idx < 16) { +#endif + + sortColors(dps, xrefs); + + float3 tmp = colors[idx]; + colors[xrefs[idx]] = tmp; + + float w = weights[idx]; + weights[xrefs[idx]] = w; + } +} + + //////////////////////////////////////////////////////////////////////////////// // Round color to RGB565 and expand //////////////////////////////////////////////////////////////////////////////// @@ -69,7 +196,7 @@ inline __device__ float3 roundAndExpand(float3 v, ushort * w) //////////////////////////////////////////////////////////////////////////////// // Evaluate permutations //////////////////////////////////////////////////////////////////////////////// -static __device__ float evalPermutation4(const float3 * colors, uint permutation, ushort * start, ushort * end) +__device__ float evalPermutation4(const float3 * colors, uint permutation, ushort * start, ushort * end) { // Compute endpoints using least squares. float alpha2_sum = 0.0f; @@ -81,19 +208,19 @@ static __device__ float evalPermutation4(const float3 * colors, uint permutation // Compute alpha & beta for this permutation. for (int i = 0; i < 16; i++) { - const uint bits = permutation >> (2*i); - + const uint bits = permutation >> (2*i); + float beta = (bits & 1); if (bits & 2) beta = (1 + beta) / 3.0f; float alpha = 1.0f - beta; - alpha2_sum += alpha * alpha; - beta2_sum += beta * beta; - alphabeta_sum += alpha * beta; + alpha2_sum += alpha * alpha; + beta2_sum += beta * beta; + alphabeta_sum += alpha * beta; alphax_sum += alpha * colors[i]; betax_sum += beta * colors[i]; } - + const float factor = 1.0f / (alpha2_sum * beta2_sum - alphabeta_sum * alphabeta_sum); float3 a = (alphax_sum * beta2_sum - betax_sum * alphabeta_sum) * factor; @@ -109,8 +236,7 @@ static __device__ float evalPermutation4(const float3 * colors, uint permutation return dot(e, kColorMetric); } - -static __device__ float evalPermutation3(const float3 * colors, uint permutation, ushort * start, ushort * end) +__device__ float evalPermutation3(const float3 * colors, uint permutation, ushort * start, ushort * end) { // Compute endpoints using least squares. float alpha2_sum = 0.0f; @@ -150,7 +276,7 @@ static __device__ float evalPermutation3(const float3 * colors, uint permutation return dot(e, kColorMetric); } -static __device__ float evalPermutation4(const float3 * colors, const float * weights, uint permutation, ushort * start, ushort * end) +__device__ float evalPermutation4(const float3 * colors, const float * weights, uint permutation, ushort * start, ushort * end) { // Compute endpoints using least squares. float alpha2_sum = 0.0f; @@ -190,8 +316,7 @@ static __device__ float evalPermutation4(const float3 * colors, const float * we return dot(e, kColorMetric); } - -static __device__ float evalPermutation3(const float3 * colors, const float * weights, uint permutation, ushort * start, ushort * end) +__device__ float evalPermutation3(const float3 * colors, const float * weights, uint permutation, ushort * start, ushort * end) { // Compute endpoints using least squares. float alpha2_sum = 0.0f; @@ -233,73 +358,195 @@ static __device__ float evalPermutation3(const float3 * colors, const float * we //////////////////////////////////////////////////////////////////////////////// -// Sort colors +// Evaluate all permutations //////////////////////////////////////////////////////////////////////////////// -__device__ void sortColors(float * values, int * cmp) +__device__ void evalAllPermutations(const float3 * colors, const uint * permutations, ushort & bestStart, ushort & bestEnd, uint & bestPermutation, float * errors) { - int tid = threadIdx.x; - - cmp[tid] = (values[0] < values[tid]); - cmp[tid] += (values[1] < values[tid]); - cmp[tid] += (values[2] < values[tid]); - cmp[tid] += (values[3] < values[tid]); - cmp[tid] += (values[4] < values[tid]); - cmp[tid] += (values[5] < values[tid]); - cmp[tid] += (values[6] < values[tid]); - cmp[tid] += (values[7] < values[tid]); - cmp[tid] += (values[8] < values[tid]); - cmp[tid] += (values[9] < values[tid]); - cmp[tid] += (values[10] < values[tid]); - cmp[tid] += (values[11] < values[tid]); - cmp[tid] += (values[12] < values[tid]); - cmp[tid] += (values[13] < values[tid]); - cmp[tid] += (values[14] < values[tid]); - cmp[tid] += (values[15] < values[tid]); + const int idx = threadIdx.x; - // Resolve elements with the same index. - if (tid > 0 && cmp[tid] == cmp[0]) ++cmp[tid]; - if (tid > 1 && cmp[tid] == cmp[1]) ++cmp[tid]; - if (tid > 2 && cmp[tid] == cmp[2]) ++cmp[tid]; - if (tid > 3 && cmp[tid] == cmp[3]) ++cmp[tid]; - if (tid > 4 && cmp[tid] == cmp[4]) ++cmp[tid]; - if (tid > 5 && cmp[tid] == cmp[5]) ++cmp[tid]; - if (tid > 6 && cmp[tid] == cmp[6]) ++cmp[tid]; - if (tid > 7 && cmp[tid] == cmp[7]) ++cmp[tid]; - if (tid > 8 && cmp[tid] == cmp[8]) ++cmp[tid]; - if (tid > 9 && cmp[tid] == cmp[9]) ++cmp[tid]; - if (tid > 10 && cmp[tid] == cmp[10]) ++cmp[tid]; - if (tid > 11 && cmp[tid] == cmp[11]) ++cmp[tid]; - if (tid > 12 && cmp[tid] == cmp[12]) ++cmp[tid]; - if (tid > 13 && cmp[tid] == cmp[13]) ++cmp[tid]; - if (tid > 14 && cmp[tid] == cmp[14]) ++cmp[tid]; -} + float bestError = FLT_MAX; + + __shared__ uint s_permutations[160]; + for(int i = 0; i < 16; i++) + { + int pidx = idx + NUM_THREADS * i; + if (pidx >= 992) break; + + ushort start, end; + uint permutation = permutations[pidx]; + if (pidx < 160) s_permutations[pidx] = permutation; + + float error = evalPermutation4(colors, permutation, &start, &end); + + if (error < bestError) + { + bestError = error; + bestPermutation = permutation; + bestStart = start; + bestEnd = end; + } + } -//////////////////////////////////////////////////////////////////////////////// -// Find index with minimum error -//////////////////////////////////////////////////////////////////////////////// -__device__ void minimizeError(float * errors, int * indices) + if (bestStart < bestEnd) + { + swap(bestEnd, bestStart); + bestPermutation ^= 0x55555555; // Flip indices. + } + + for(int i = 0; i < 3; i++) + { + int pidx = idx + NUM_THREADS * i; + if (pidx >= 160) break; + + ushort start, end; + uint permutation = s_permutations[pidx]; + float error = evalPermutation3(colors, permutation, &start, &end); + + if (error < bestError) + { + bestError = error; + bestPermutation = permutation; + bestStart = start; + bestEnd = end; + + if (bestStart > bestEnd) + { + swap(bestEnd, bestStart); + bestPermutation ^= (~bestPermutation >> 1) & 0x55555555; // Flip indices. + } + } + } + + errors[idx] = bestError; +} + +__device__ void evalAllPermutations(const float3 * colors, const float * weights, const uint * permutations, ushort & bestStart, ushort & bestEnd, uint & bestPermutation, float * errors) { const int idx = threadIdx.x; + + float bestError = FLT_MAX; + + __shared__ uint s_permutations[160]; + + for(int i = 0; i < 16; i++) + { + int pidx = idx + NUM_THREADS * i; + if (pidx >= 992) break; + + ushort start, end; + uint permutation = permutations[pidx]; + if (pidx < 160) s_permutations[pidx] = permutation; -#if __DEVICE_EMULATION__ - for(int d = NUM_THREADS/2; d > 0; d >>= 1) - { - __syncthreads(); - - if (idx < d) - { - float err0 = errors[idx]; - float err1 = errors[idx + d]; - - if (err1 < err0) { - errors[idx] = err1; - indices[idx] = indices[idx + d]; - } - } - } + float error = evalPermutation4(colors, weights, permutation, &start, &end); + + if (error < bestError) + { + bestError = error; + bestPermutation = permutation; + bestStart = start; + bestEnd = end; + } + } -#else + if (bestStart < bestEnd) + { + swap(bestEnd, bestStart); + bestPermutation ^= 0x55555555; // Flip indices. + } + + for(int i = 0; i < 3; i++) + { + int pidx = idx + NUM_THREADS * i; + if (pidx >= 160) break; + + ushort start, end; + uint permutation = s_permutations[pidx]; + float error = evalPermutation3(colors, weights, permutation, &start, &end); + + if (error < bestError) + { + bestError = error; + bestPermutation = permutation; + bestStart = start; + bestEnd = end; + + if (bestStart > bestEnd) + { + swap(bestEnd, bestStart); + bestPermutation ^= (~bestPermutation >> 1) & 0x55555555; // Flip indices. + } + } + } + + errors[idx] = bestError; +} + + +__device__ void evalLevel4Permutations(const float3 * colors, const float * weights, const uint * permutations, ushort & bestStart, ushort & bestEnd, uint & bestPermutation, float * errors) +{ + const int idx = threadIdx.x; + + float bestError = FLT_MAX; + + for(int i = 0; i < 16; i++) + { + int pidx = idx + NUM_THREADS * i; + if (pidx >= 992) break; + + ushort start, end; + uint permutation = permutations[pidx]; + + float error = evalPermutation4(colors, weights, permutation, &start, &end); + + if (error < bestError) + { + bestError = error; + bestPermutation = permutation; + bestStart = start; + bestEnd = end; + } + } + + if (bestStart < bestEnd) + { + swap(bestEnd, bestStart); + bestPermutation ^= 0x55555555; // Flip indices. + } + + errors[idx] = bestError; +} + + + +//////////////////////////////////////////////////////////////////////////////// +// Find index with minimum error +//////////////////////////////////////////////////////////////////////////////// +__device__ int findMinError(float * errors) +{ + const int idx = threadIdx.x; + + __shared__ int indices[NUM_THREADS]; + indices[idx] = idx; + +#if __DEVICE_EMULATION__ + for(int d = NUM_THREADS/2; d > 0; d >>= 1) + { + __syncthreads(); + + if (idx < d) + { + float err0 = errors[idx]; + float err1 = errors[idx + d]; + + if (err1 < err0) { + errors[idx] = err1; + indices[idx] = indices[idx + d]; + } + } + } + +#else for(int d = NUM_THREADS/2; d > 32; d >>= 1) { __syncthreads(); @@ -345,95 +592,19 @@ __device__ void minimizeError(float * errors, int * indices) } } #endif -} - - -//////////////////////////////////////////////////////////////////////////////// -// Load color block to shared mem -//////////////////////////////////////////////////////////////////////////////// -__device__ void loadColorBlock(const uint * image, float3 colors[16], int xrefs[16]) -{ - const int bid = blockIdx.x; - const int idx = threadIdx.x; - - __shared__ float dps[16]; - - if (idx < 16) - { - // Read color and copy to shared mem. - uint c = image[(bid) * 16 + idx]; - - colors[idx].z = ((c >> 0) & 0xFF) * (1.0f / 255.0f); - colors[idx].y = ((c >> 8) & 0xFF) * (1.0f / 255.0f); - colors[idx].x = ((c >> 16) & 0xFF) * (1.0f / 255.0f); - - // No need to synchronize, 16 < warp size. -#if __DEVICE_EMULATION__ - } __debugsync(); if (idx < 16) { -#endif - // Sort colors along the best fit line. - float3 axis = bestFitLine(colors); - - dps[idx] = dot(colors[idx], axis); - -#if __DEVICE_EMULATION__ - } __debugsync(); if (idx < 16) { -#endif - - sortColors(dps, xrefs); - - float3 tmp = colors[idx]; - colors[xrefs[idx]] = tmp; - } -} - -__device__ void loadColorBlock(const uint * image, float3 colors[16], float weights[16], int xrefs[16]) -{ - const int bid = blockIdx.x; - const int idx = threadIdx.x; - - __shared__ float dps[16]; - - if (idx < 16) - { - // Read color and copy to shared mem. - uint c = image[(bid) * 16 + idx]; - - colors[idx].z = ((c >> 0) & 0xFF) * (1.0f / 255.0f); - colors[idx].y = ((c >> 8) & 0xFF) * (1.0f / 255.0f); - colors[idx].x = ((c >> 16) & 0xFF) * (1.0f / 255.0f); - weights[idx] = ((c >> 24) & 0xFF) * (1.0f / 255.0f); - - // No need to synchronize, 16 < warp size. -#if __DEVICE_EMULATION__ - } __debugsync(); if (idx < 16) { -#endif + __syncthreads(); - // Sort colors along the best fit line. - float3 axis = bestFitLine(colors); - - dps[idx] = dot(colors[idx], axis); - -#if __DEVICE_EMULATION__ - } __debugsync(); if (idx < 16) { -#endif - - sortColors(dps, xrefs); - - float3 tmp = colors[idx]; - colors[xrefs[idx]] = tmp; - - float w = weights[idx]; - weights[xrefs[idx]] = tmp; - } + return indices[0]; } -__device__ void saveBlockDXT1(ushort start, ushort end, uint permutation, int xrefs[16]) +//////////////////////////////////////////////////////////////////////////////// +// Save DXT block +//////////////////////////////////////////////////////////////////////////////// +__device__ void saveBlockDXT1(ushort start, ushort end, uint permutation, int xrefs[16], uint2 * result) { const int bid = blockIdx.x; - const int idx = threadIdx.x; if (start == end) { @@ -441,7 +612,7 @@ __device__ void saveBlockDXT1(ushort start, ushort end, uint permutation, int xr } // Reorder permutation. - uint indices = 0; + uint indices = permutation; for(int i = 0; i < 16; i++) { int ref = xrefs[i]; @@ -456,362 +627,242 @@ __device__ void saveBlockDXT1(ushort start, ushort end, uint permutation, int xr } + + //////////////////////////////////////////////////////////////////////////////// // Compress color block //////////////////////////////////////////////////////////////////////////////// __global__ void compress(const uint * permutations, const uint * image, uint2 * result) { - const int bid = blockIdx.x; - const int idx = threadIdx.x; - __shared__ float3 colors[16]; __shared__ int xrefs[16]; loadColorBlock(image, colors, xrefs); - ushort bestStart, bestEnd; - uint bestPermutation; - float bestError = FLT_MAX; - __syncthreads(); -#if 0 - // This version is more clear, but slightly slower. - for(int i = 0; i < 16; i++) - { - if (i == 15 && idx >= 32) break; - - ushort start, end; - uint permutation = permutations[idx + NUM_THREADS * i]; - float error = evalPermutation4(colors, permutation, &start, &end); - - if (error < bestError) - { - bestError = error; - bestPermutation = permutation; - bestStart = start; - bestEnd = end; - } - } - - if (bestStart < bestEnd) - { - swap(bestEnd, bestStart); - bestPermutation ^= 0x55555555; // Flip indices. - } - - for(int i = 0; i < 3; i++) - { - if (i == 2 && idx >= 32) break; - - ushort start, end; - uint permutation = permutations[idx + NUM_THREADS * i]; - float error = evalPermutation3(colors, permutation, &start, &end); - - if (error < bestError) - { - bestError = error; - bestPermutation = permutation; - bestStart = start; - bestEnd = end; - - if (bestStart > bestEnd) - { - swap(bestEnd, bestStart); - bestPermutation ^= (~bestPermutation >> 1) & 0x55555555; // Flip indices. - } - } - } -#else - { - int pidx = idx + NUM_THREADS * 15; - if (idx >= 32) - { - pidx = idx + NUM_THREADS * 2; - } - - ushort start, end; - uint permutation = permutations[pidx]; - float error = evalPermutation4(colors, permutation, &start, &end); - - if (error < bestError) - { - bestError = error; - bestPermutation = permutation; - bestStart = start; - bestEnd = end; - } - } - - for(int i = 3; i < 15; i++) - { - ushort start, end; - uint permutation = permutations[idx + NUM_THREADS * i]; - float error = evalPermutation4(colors, permutation, &start, &end); - - if (error < bestError) - { - bestError = error; - bestPermutation = permutation; - bestStart = start; - bestEnd = end; - } - } - - if (bestStart < bestEnd) - { - swap(bestEnd, bestStart); - bestPermutation ^= 0x55555555; // Flip indices. - } + ushort bestStart, bestEnd; + uint bestPermutation; + __shared__ float errors[NUM_THREADS]; - for(int i = 0; i < 3; i++) - { - if (i == 2 && idx >= 32) break; - - ushort start, end; - uint permutation = permutations[idx + NUM_THREADS * i]; - float error = evalPermutation3(colors, permutation, &start, &end); - - if (error < bestError) - { - bestError = error; - bestPermutation = permutation; - bestStart = start; - bestEnd = end; - - if (bestStart > bestEnd) - { - swap(bestEnd, bestStart); - bestPermutation ^= (~bestPermutation >> 1) & 0x55555555; // Flip indices. - } - } - - error = evalPermutation4(colors, permutation, &start, &end); - - if (error < bestError) - { - bestError = error; - bestPermutation = permutation; - bestStart = start; - bestEnd = end; - - if (bestStart < bestEnd) - { - swap(bestEnd, bestStart); - bestPermutation ^= 0x55555555; // Flip indices. - } - } - } -#endif - - __syncthreads(); + evalAllPermutations(colors, permutations, bestStart, bestEnd, bestPermutation, errors); // Use a parallel reduction to find minimum error. - __shared__ float errors[NUM_THREADS]; - __shared__ int indices[NUM_THREADS]; - - errors[idx] = bestError; - indices[idx] = idx; - - minimizeError(errors, indices); - - __syncthreads(); + const int minIdx = findMinError(errors); // Only write the result of the winner thread. - if (idx == indices[0]) + if (threadIdx.x == minIdx) { - saveBlockDXT1(bestStart, bestEnd, bestPermutation, xrefs); + saveBlockDXT1(bestStart, bestEnd, bestPermutation, xrefs, result); } } __global__ void compressWeighted(const uint * permutations, const uint * image, uint2 * result) { - const int bid = blockIdx.x; - const int idx = threadIdx.x; - __shared__ float3 colors[16]; __shared__ float weights[16]; __shared__ int xrefs[16]; loadColorBlock(image, colors, weights, xrefs); + + __syncthreads(); ushort bestStart, bestEnd; uint bestPermutation; - float bestError = FLT_MAX; + + __shared__ float errors[NUM_THREADS]; - __syncthreads(); + evalLevel4Permutations(colors, weights, permutations, bestStart, bestEnd, bestPermutation, errors); + + // Use a parallel reduction to find minimum error. + int minIdx = findMinError(errors); + + // Only write the result of the winner thread. + if (threadIdx.x == minIdx) + { + saveBlockDXT1(bestStart, bestEnd, bestPermutation, xrefs, result); + } +} -#if 0 - // This version is more clear, but slightly slower. - for(int i = 0; i < 16; i++) - { - if (i == 15 && idx >= 32) break; - - ushort start, end; - uint permutation = permutations[idx + NUM_THREADS * i]; - float error = evalPermutation4(colors, weights, permutation, &start, &end); - - if (error < bestError) - { - bestError = error; - bestPermutation = permutation; - bestStart = start; - bestEnd = end; - } - } - if (bestStart < bestEnd) - { - swap(bestEnd, bestStart); - bestPermutation ^= 0x55555555; // Flip indices. - } +__device__ float computeError(const float weights[16], uchar a0, uchar a1) +{ + float palette[6]; + palette[0] = (6.0f/7.0f * a0 + 1.0f/7.0f * a1); + palette[1] = (5.0f/7.0f * a0 + 2.0f/7.0f * a1); + palette[2] = (4.0f/7.0f * a0 + 3.0f/7.0f * a1); + palette[3] = (3.0f/7.0f * a0 + 4.0f/7.0f * a1); + palette[4] = (2.0f/7.0f * a0 + 5.0f/7.0f * a1); + palette[5] = (1.0f/7.0f * a0 + 6.0f/7.0f * a1); + + float total = 0.0f; + + for (uint i = 0; i < 16; i++) + { + float alpha = weights[i]; + + float error = a0 - alpha; + error = min(error, palette[0] - alpha); + error = min(error, palette[1] - alpha); + error = min(error, palette[2] - alpha); + error = min(error, palette[3] - alpha); + error = min(error, palette[4] - alpha); + error = min(error, palette[5] - alpha); + error = min(error, a1 - alpha); + + total += error; + } + + return total; +} - for(int i = 0; i < 3; i++) - { - if (i == 2 && idx >= 32) break; - - ushort start, end; - uint permutation = permutations[idx + NUM_THREADS * i]; - float error = evalPermutation3(colors, weights, permutation, &start, &end); - - if (error < bestError) - { - bestError = error; - bestPermutation = permutation; - bestStart = start; - bestEnd = end; - - if (bestStart > bestEnd) - { - swap(bestEnd, bestStart); - bestPermutation ^= (~bestPermutation >> 1) & 0x55555555; // Flip indices. - } - } - } -#else +inline __device__ uchar roundAndExpand(float a) +{ + return rintf(__saturatef(a) * 255.0f); +} + +/* +__device__ void optimizeAlpha8(const float alphas[16], uchar & a0, uchar & a1) +{ + float alpha2_sum = 0; + float beta2_sum = 0; + float alphabeta_sum = 0; + float alphax_sum = 0; + float betax_sum = 0; + + for (int i = 0; i < 16; i++) + { + uint idx = index[i]; + float alpha; + if (idx < 2) alpha = 1.0f - idx; + else alpha = (8.0f - idx) / 7.0f; + + float beta = 1 - alpha; + + alpha2_sum += alpha * alpha; + beta2_sum += beta * beta; + alphabeta_sum += alpha * beta; + alphax_sum += alpha * alphas[i]; + betax_sum += beta * alphas[i]; + } + + const float factor = 1.0f / (alpha2_sum * beta2_sum - alphabeta_sum * alphabeta_sum); + + float a = (alphax_sum * beta2_sum - betax_sum * alphabeta_sum) * factor; + float b = (betax_sum * alpha2_sum - alphax_sum * alphabeta_sum) * factor; + + a0 = roundAndExpand(a); + a1 = roundAndExpand(b); +} +*/ + +__device__ void compressAlpha(const float alphas[16], uint4 * result) +{ + const int tid = threadIdx.x; + + // Compress alpha block! + // Brute force approach: + // Try all color pairs: 256*256/2 = 32768, 32768/64 = 512 iterations? + + // Determine min & max alphas + + float A0, A1; + + if (tid < 16) { - int pidx = idx + NUM_THREADS * 15; - if (idx >= 32) - { - pidx = idx + NUM_THREADS * 2; - } + __shared__ uint s_alphas[16]; - ushort start, end; - uint permutation = permutations[pidx]; - float error = evalPermutation4(colors, weights, permutation, &start, &end); - - if (error < bestError) - { - bestError = error; - bestPermutation = permutation; - bestStart = start; - bestEnd = end; - } + s_alphas[tid] = alphas[tid]; + s_alphas[tid] = min(s_alphas[tid], s_alphas[tid^8]); + s_alphas[tid] = min(s_alphas[tid], s_alphas[tid^4]); + s_alphas[tid] = min(s_alphas[tid], s_alphas[tid^2]); + s_alphas[tid] = min(s_alphas[tid], s_alphas[tid^1]); + A0 = s_alphas[tid]; + + s_alphas[tid] = alphas[tid]; + s_alphas[tid] = max(s_alphas[tid], s_alphas[tid^8]); + s_alphas[tid] = max(s_alphas[tid], s_alphas[tid^4]); + s_alphas[tid] = max(s_alphas[tid], s_alphas[tid^2]); + s_alphas[tid] = max(s_alphas[tid], s_alphas[tid^1]); + A1 = s_alphas[tid]; } - for(int i = 3; i < 15; i++) - { - ushort start, end; - uint permutation = permutations[idx + NUM_THREADS * i]; - float error = evalPermutation4(colors, weights, permutation, &start, &end); - - if (error < bestError) - { - bestError = error; - bestPermutation = permutation; - bestStart = start; - bestEnd = end; - } - } + __syncthreads(); - if (bestStart < bestEnd) - { - swap(bestEnd, bestStart); - bestPermutation ^= 0x55555555; // Flip indices. - } + int minIdx = 0; + if (A1 - A0 > 8) + { + float bestError = FLT_MAX; - for(int i = 0; i < 3; i++) - { - if (i == 2 && idx >= 32) break; - - ushort start, end; - uint permutation = permutations[idx + NUM_THREADS * i]; - float error = evalPermutation3(colors, weights, permutation, &start, &end); - - if (error < bestError) - { - bestError = error; - bestPermutation = permutation; - bestStart = start; - bestEnd = end; - - if (bestStart > bestEnd) - { - swap(bestEnd, bestStart); - bestPermutation ^= (~bestPermutation >> 1) & 0x55555555; // Flip indices. - } - } + // 64 threads -> 8x8 + // divide [A1-A0] in partitions. + // test endpoints - error = evalPermutation4(colors, weights, permutation, &start, &end); - - if (error < bestError) - { - bestError = error; - bestPermutation = permutation; - bestStart = start; - bestEnd = end; + for (int i = 0; i < 128; i++) + { + uint idx = (i * NUM_THREADS + tid) * 4; + uchar a0 = idx & 255; + uchar a1 = idx >> 8; - if (bestStart < bestEnd) + float error = computeError(alphas, a0, a1); + + if (error < bestError) { - swap(bestEnd, bestStart); - bestPermutation ^= 0x55555555; // Flip indices. + bestError = error; + A0 = a0; + A1 = a1; } - } - } -#endif + } + + __shared__ float errors[NUM_THREADS]; + errors[tid] = bestError; + + // Minimize error. + minIdx = findMinError(errors); + + } + + if (minIdx == tid) + { + // @@ Compute indices. - __syncthreads(); + // @@ Write alpha block. + } +} + +__global__ void compressDXT5(const uint * permutations, const uint * image, uint4 * result) +{ + __shared__ float3 colors[16]; + __shared__ float weights[16]; + __shared__ int xrefs[16]; - // Use a parallel reduction to find minimum error. + loadColorBlock(image, colors, weights, xrefs); + + __syncthreads(); + + compressAlpha(weights, result); + + ushort bestStart, bestEnd; + uint bestPermutation; + __shared__ float errors[NUM_THREADS]; - __shared__ int indices[NUM_THREADS]; - - errors[idx] = bestError; - indices[idx] = idx; - minimizeError(errors, indices); + evalLevel4Permutations(colors, weights, permutations, bestStart, bestEnd, bestPermutation, errors); - __syncthreads(); + // Use a parallel reduction to find minimum error. + int minIdx = findMinError(errors); // Only write the result of the winner thread. - if (idx == indices[0]) + if (threadIdx.x == minIdx) { - if (bestStart == bestEnd) - { - bestPermutation = 0; - } - - // Reorder permutation. - uint perm = 0; - for(int i = 0; i < 16; i++) - { - int ref = xrefs[i]; - perm |= ((bestPermutation >> (2 * ref)) & 3) << (2 * i); - } - - // Write endpoints. (bestStart, bestEnd) - result[bid].x = (bestEnd << 16) | bestStart; - - // Write palette indices (permutation). - result[bid].y = perm; + saveBlockDXT1(bestStart, bestEnd, bestPermutation, xrefs, (uint2 *)result); } } - //////////////////////////////////////////////////////////////////////////////// // Setup kernel ////////////////////////////////////////////////////////////////////////////////