Use KahanSum to compute RMSE.

Fix typos.
This commit is contained in:
castano 2008-12-07 23:15:06 +00:00
parent a30490ab9b
commit 127052f404

View File

@ -26,6 +26,7 @@
#include <nvimage/ImageIO.h> #include <nvimage/ImageIO.h>
#include <nvimage/BlockDXT.h> #include <nvimage/BlockDXT.h>
#include <nvimage/ColorBlock.h> #include <nvimage/ColorBlock.h>
#include <nvmath/KahanSum.h>
#include <nvcore/Ptr.h> #include <nvcore/Ptr.h>
#include <nvcore/Debug.h> #include <nvcore/Debug.h>
#include <nvcore/StrLib.h> #include <nvcore/StrLib.h>
@ -166,7 +167,7 @@ float rmsError(const Image * a, const Image * b)
nvCheck(a->width() == b->width()); nvCheck(a->width() == b->width());
nvCheck(a->height() == b->height()); nvCheck(a->height() == b->height());
float mse = 0; KahanSum mse;
const uint count = a->width() * a->height(); const uint count = a->width() * a->height();
@ -180,14 +181,12 @@ float rmsError(const Image * a, const Image * b)
int b = c0.b - c1.b; int b = c0.b - c1.b;
//int a = c0.a - c1.a; //int a = c0.a - c1.a;
mse += r * r; mse.add(r * r);
mse += g * g; mse.add(g * g);
mse += b * b; mse.add(b * b);
} }
mse /= count; return sqrtf(mse.sum() / count);
return sqrtf(mse);
} }
@ -277,7 +276,7 @@ int main(int argc, char *argv[])
TextWriter csvWriter(&csvStream); TextWriter csvWriter(&csvStream);
float totalTime = 0; float totalTime = 0;
float totalRMS = 0; float totalRMSE = 0;
int failedTests = 0; int failedTests = 0;
float totalDiff = 0; float totalDiff = 0;
@ -314,13 +313,13 @@ int main(int argc, char *argv[])
printf("Error saving file '%s'.\n", outputFileName.str()); printf("Error saving file '%s'.\n", outputFileName.str());
} }
float rms = rmsError(img.ptr(), img_out.ptr()); float rmse = rmsError(img.ptr(), img_out.ptr());
totalRMS += rms; totalRMSE += rmse;
printf(" RMS: \t%.4f\n", rms); printf(" RMSE: \t%.4f\n", rmse);
// Output csv file // Output csv file
csvWriter << "\"" << s_fileNames[i] << "\"," << rms << "\n"; csvWriter << "\"" << s_fileNames[i] << "\"," << rmse << "\n";
if (regressPath != NULL) if (regressPath != NULL)
{ {
@ -335,9 +334,9 @@ int main(int argc, char *argv[])
return EXIT_FAILURE; return EXIT_FAILURE;
} }
float rms_reg = rmsError(img.ptr(), img_reg.ptr()); float rmse_reg = rmsError(img.ptr(), img_reg.ptr());
float diff = rms_reg - rms; float diff = rmse_reg - rmse;
totalDiff += diff; totalDiff += diff;
const char * text = "PASSED"; const char * text = "PASSED";
@ -353,12 +352,12 @@ int main(int argc, char *argv[])
fflush(stdout); fflush(stdout);
} }
totalRMS /= s_fileCount; totalRMSE /= s_fileCount;
totalDiff /= s_fileCount; totalDiff /= s_fileCount;
printf("Total Results:\n"); printf("Total Results:\n");
printf(" Total Time: \t%.3f sec\n", totalTime / CLOCKS_PER_SEC); printf(" Total Time: \t%.3f sec\n", totalTime / CLOCKS_PER_SEC);
printf(" Average RMS:\t%.4f\n", totalRMS); printf(" Average RMSE:\t%.4f\n", totalRMSE);
if (regressPath != NULL) if (regressPath != NULL)
{ {