'use strict';
import { linearRegression } from 'simple-statistics';
import * as Utils from './Utils';
/**
* Manages the training statistics for one model. TensorFlow produces stats each
* epoch. This class records them, and maintains trailing averages to smooth
* spikes and dips. It calculates deltas and slopes for these averages. This
* information can be used to detect problematic situations such as overfitting.
* EpochStats also has text helpers for logging and output as CSV.
*/
class EpochStats {
/**
* Creates an instance of EpochStats.
* @param {number} _trailDepth Total samples in a (simple) trailing average.
*/
constructor(_trailDepth) {
this._trailDepth = _trailDepth;
this._samplesAccuracy = [];
this._samplesLoss = [];
this._samplesValidationAccuracy = [];
this._samplesValidationLoss = [];
this._averageAccuracy = 0;
this._averageLoss = 0;
this._averageLossDelta = 0;
this._averageValidationAccuracy = 0;
this._averageValidationLoss = 0;
this._lineAccuracy = { m: 0, b: 0 };
this._lineLoss = { m: 0, b: 0 };
this._lineValidationAccuracy = { m: 0, b: 0 };
this._lineValidationLoss = { m: 0, b: 0 };
Utils.Assert(this._trailDepth > 0);
Utils.Assert(Math.floor(this._trailDepth) === this._trailDepth);
}
get averageAccuracy() { return this._averageAccuracy; }
get averageLoss() { return this._averageLoss; }
get averageValidationAccuracy() { return this._averageValidationAccuracy; }
get averageValidationLoss() { return this._averageValidationLoss; }
get lineAccuracy() { return this._lineAccuracy; }
get lineLoss() { return this._lineLoss; }
get lineValidationAccuracy() { return this._lineValidationAccuracy; }
get lineValidationLoss() { return this._lineValidationLoss; }
/**
* Takes the results of an epoch, and updates the trailing averages, deltas
* and slopes.
* @param {number} epoch Iteration count from model fit; currently unused.
* @param {Logs} logs A TensorFlow object with the latest values for
* accuracy, loss, validation-accuracy and
* validation-loss.
*/
Update(epoch, logs) {
Utils.Assert(epoch >= 0);
Utils.Assert(Math.floor(epoch) === epoch);
Utils.QueueRotate(this._samplesAccuracy, logs.acc, this._trailDepth);
Utils.QueueRotate(this._samplesLoss, logs.loss, this._trailDepth);
Utils.QueueRotate(this._samplesValidationAccuracy, logs.val_acc, this._trailDepth);
Utils.QueueRotate(this._samplesValidationLoss, logs.val_loss, this._trailDepth);
this._averageAccuracy = Utils.ArrayCalculateAverage(this._samplesAccuracy);
this._averageLoss = Utils.ArrayCalculateAverage(this._samplesLoss);
this._averageValidationAccuracy = Utils.ArrayCalculateAverage(this._samplesValidationAccuracy);
this._averageValidationLoss = Utils.ArrayCalculateAverage(this._samplesValidationLoss);
const TRAILING_ACC_AS_XY = this._samplesAccuracy.map((value, index) => { return [index, value]; });
const TRAILING_LOSS_AS_XY = this._samplesLoss.map((value, index) => { return [index, value]; });
const TRAILING_VAL_ACC_AS_XY = this._samplesValidationAccuracy.map((value, index) => { return [index, value]; });
const TRAILING_VAL_LOSS_AS_XY = this._samplesValidationLoss.map((value, index) => { return [index, value]; });
this._lineAccuracy = linearRegression(TRAILING_ACC_AS_XY);
this._lineLoss = linearRegression(TRAILING_LOSS_AS_XY);
this._lineValidationAccuracy = linearRegression(TRAILING_VAL_ACC_AS_XY);
this._lineValidationLoss = linearRegression(TRAILING_VAL_LOSS_AS_XY);
this._averageLossDelta = this._averageLoss - this._averageValidationLoss;
}
//vv TODO: These move into a CSVSource interface
WriteCSVLineKeys() {
return 'averageAccuracy,averageLoss,averageValidationAccuracy,'
+ 'averageValidationLoss,slopeAccuracy,slopeLoss,'
+ 'slopeValidationAccuracy,slopeValidationLoss';
}
WriteCSVLineValues() {
return this._averageAccuracy
+ ',' + this._averageLoss
+ ',' + this._averageValidationAccuracy
+ ',' + this._averageValidationLoss
+ ',' + this._lineAccuracy.m
+ ',' + this._lineLoss.m
+ ',' + this._lineValidationAccuracy.m
+ ',' + this._lineValidationLoss.m;
}
//^^
/**
* Generates a one-line text report with the following:
* <ul>
* <li>all of the trailing averages</li>
* <li>the slope of each average (accuracy, loss, validation-accuracy and
* validation-loss)</li>
* <li>relevant deltas between the training and validation values</li>
* <ul>
* @return {string}
*/
WriteReport() {
//NOTE: These '< 0' ternaries add a space before each positive number. This is
// done to maintain column alignment on the periods. This is useful when
// scrutinizing numeric details in a large Matrix-waterfall of digits.
const TEXT_OUT = this._averageLoss.toFixed(REPORTING_DIGITS_STAT)
+ '(' + this._averageValidationLoss.toFixed(REPORTING_DIGITS_STAT) + ') '
+ 'Δ ' + (this._averageLossDelta < 0 ? '' : ' ') + this._averageLossDelta.toFixed(2) + ', '
+ 'm ' + (this._lineLoss.m < 0 ? '' : ' ') + this._lineLoss.m.toFixed(REPORTING_DIGITS_SLOPE)
+ '(' + (this._lineValidationLoss.m < 0 ? '' : ' ') + this._lineValidationLoss.m.toFixed(REPORTING_DIGITS_SLOPE) + ') '
+ '\\/ '
+ this._averageAccuracy.toFixed(REPORTING_DIGITS_STAT)
+ '(' + this._averageValidationAccuracy.toFixed(REPORTING_DIGITS_STAT) + '), '
+ 'm ' + (this._lineAccuracy.m < 0 ? '' : ' ') + this._lineAccuracy.m.toFixed(REPORTING_DIGITS_SLOPE)
+ '(' + (this._lineValidationAccuracy.m < 0 ? '' : ' ') + this._lineValidationAccuracy.m.toFixed(REPORTING_DIGITS_SLOPE) + ')';
return TEXT_OUT;
}
/**
* Gets the header that goes with {@link WriteReport}.
* @static
* @return {string}
*/
static WriteReportHeader() {
//NOTE: This must be kept in sync with the text written by WriteReport().
return 'EPOCH '
+ 'LOSS(VALIDATION) '
+ 'Δ L-V DELTA, '
+ 'm LOSS-SLOPE(VALIDATION)'
+ ' \\/ '
+ 'ACCURACY(VALIDATION) '
+ 'm ACCURACY-SLOPE(VALIDATION)';
}
}
const REPORTING_DIGITS_SLOPE = 6;
const REPORTING_DIGITS_STAT = 4;
Object.freeze(EpochStats);
export { EpochStats };
//# sourceMappingURL=EpochStats.js.map