Use KahanSum to compute RMSE.

Fix typos.
This commit is contained in:
castano 2008-12-07 23:15:06 +00:00
parent a30490ab9b
commit 127052f404

View File

@ -26,6 +26,7 @@
#include <nvimage/ImageIO.h>
#include <nvimage/BlockDXT.h>
#include <nvimage/ColorBlock.h>
#include <nvmath/KahanSum.h>
#include <nvcore/Ptr.h>
#include <nvcore/Debug.h>
#include <nvcore/StrLib.h>
@ -166,7 +167,7 @@ float rmsError(const Image * a, const Image * b)
nvCheck(a->width() == b->width());
nvCheck(a->height() == b->height());
float mse = 0;
KahanSum mse;
const uint count = a->width() * a->height();
@ -180,14 +181,12 @@ float rmsError(const Image * a, const Image * b)
int b = c0.b - c1.b;
//int a = c0.a - c1.a;
mse += r * r;
mse += g * g;
mse += b * b;
mse.add(r * r);
mse.add(g * g);
mse.add(b * b);
}
mse /= count;
return sqrtf(mse);
return sqrtf(mse.sum() / count);
}
@ -277,7 +276,7 @@ int main(int argc, char *argv[])
TextWriter csvWriter(&csvStream);
float totalTime = 0;
float totalRMS = 0;
float totalRMSE = 0;
int failedTests = 0;
float totalDiff = 0;
@ -314,13 +313,13 @@ int main(int argc, char *argv[])
printf("Error saving file '%s'.\n", outputFileName.str());
}
float rms = rmsError(img.ptr(), img_out.ptr());
totalRMS += rms;
float rmse = rmsError(img.ptr(), img_out.ptr());
totalRMSE += rmse;
printf(" RMS: \t%.4f\n", rms);
printf(" RMSE: \t%.4f\n", rmse);
// Output csv file
csvWriter << "\"" << s_fileNames[i] << "\"," << rms << "\n";
csvWriter << "\"" << s_fileNames[i] << "\"," << rmse << "\n";
if (regressPath != NULL)
{
@ -335,9 +334,9 @@ int main(int argc, char *argv[])
return EXIT_FAILURE;
}
float rms_reg = rmsError(img.ptr(), img_reg.ptr());
float rmse_reg = rmsError(img.ptr(), img_reg.ptr());
float diff = rms_reg - rms;
float diff = rmse_reg - rmse;
totalDiff += diff;
const char * text = "PASSED";
@ -353,12 +352,12 @@ int main(int argc, char *argv[])
fflush(stdout);
}
totalRMS /= s_fileCount;
totalRMSE /= s_fileCount;
totalDiff /= s_fileCount;
printf("Total Results:\n");
printf(" Total Time: \t%.3f sec\n", totalTime / CLOCKS_PER_SEC);
printf(" Average RMS:\t%.4f\n", totalRMS);
printf(" Average RMSE:\t%.4f\n", totalRMSE);
if (regressPath != NULL)
{