lib/ModelStatics.js

'use strict';
import * as TENSOR_FLOW from '@tensorflow/tfjs-node';
import * as Axis from './Axis';
//TODO: Sub these ^^ in: export { Axis, AxisDefaults, AxisNames, AxisTypes };
import { FailureMessage } from './FailureMessage';
import * as Utils from './Utils';
/**
 * Manages the hyperparameters that do <i>not</i> change over the course of
 * the grid search (i.e. those not governed by an {@link Axis}).
 */
class ModelStatics {
    /**
     * Creates an instance of ModelStatics.<br>
     * - See {@link Axis.AxisTypes} for the available fields.<br>
     * - See {@link Axis.AxisDefaults} for defaults.<br>
     * All fields are optional. Any field used here that also has an axis will
     * be ignored (the dynamic axis values will be used instead).
     * @param {Types.StringKeyedNumbersObject} _userStatics
     * @example
     * new tngs.ModelStatics({
     *   batchSize: 10,
     *   epochs: 50,
     *   hiddenLayers: 2,
     *   learnRate: 0.001,
     *   neuronsPerHiddenLayer: 16,
     *   validationSplit: 0.2
     * });
     */
    constructor(_userStatics) {
        // validate the user-supplied static model params, i.e. those params that never change during grid search
        this._userStatics = _userStatics;
        this._staticParams = {};
        const FAILURE_MESSAGE = new FailureMessage();
        for (const k in this._userStatics) {
            if (!Axis.Axis.AttemptValidateParameter(k, this._userStatics[k], FAILURE_MESSAGE)) {
                // fatal, so that users don't kick off a (potentially very long) grid search with a bad model config
                throw new Error('There was a problem with the static model params. ' + FAILURE_MESSAGE.text);
            }
        }
        // params are valid; write the working set, backfilling w/ defaults for any the user left out
        this.WriteStaticParams();
    }
    /**
     * Check whether the received parameter key is also part in our set. If so,
     * delete the entry, after printing an informative warning to the log.
     * @param {string} paramKey
     */
    AttemptStripParam(paramKey) {
        Utils.Assert(paramKey !== '');
        if (this._staticParams[paramKey] === undefined) {
            // nothing to strip
            return;
        }
        // this static param will be dropped
        // if the user sent this (as opposed to it being a default), warn them that it won't be used
        if (this._userStatics[paramKey] !== undefined) {
            console.warn('The static model param "' + paramKey + '" will be ignored. (It\'s likely overridden by a dynamic grid axis.)');
        }
        delete this._staticParams[paramKey];
    }
    //TODO: Each of these four 'Generate' calls will be overridable via user callback.
    /**
     * Produces a TensorFlow initializer for bias nodes.<br>
     * Currently set to constant(0.1)
     * @return {TF_INITIALIZERS.Initializer}
     */
    GenerateInitializerBias() {
        //NOTE: See https://js.tensorflow.org/api/2.7.0/#class:initializers.Initializer
        return TENSOR_FLOW.initializers.constant({ value: 0.1 });
    }
    /**
     * Produces a TensorFlow initializer for kernel nodes.<br>
     * Currently set to heNormal()
     * @return {TF_INITIALIZERS.Initializer}
     */
    GenerateInitializerKernel() {
        //NOTE: See https://js.tensorflow.org/api/2.7.0/#class:initializers.Initializer
        return TENSOR_FLOW.initializers.heNormal({ seed: Math.random() });
    }
    /**
     * Produces a TensorFlow loss function identifier.<br>
     * Currently set to "categoricalCrossentropy"
     * @return {string}
     */
    GenerateLossFunction() {
        //TODO: This will have a more complex type. It can take a string or string[], or a LossOrMetricFn or LossOrMetricFn[].
        //NOTE: See https://js.tensorflow.org/api/2.7.0/#tf.LayersModel.compile
        return 'categoricalCrossentropy';
    }
    /**
     * Produces a TensorFlow optimizer.<br>
     * Currently set to adam(learnRate)
     * @param {number} learnRate The learning rate. See {@link https://en.wikipedia.org/wiki/Stochastic_gradient_descent#Adam}
     * @return {TENSOR_FLOW.Optimizer}
     */
    GenerateOptimizer(learnRate) {
        //NOTE: See https://js.tensorflow.org/api/2.7.0/#tf.LayersModel.compile
        Utils.Assert(learnRate > 0.0);
        Utils.Assert(learnRate < 1.0);
        return TENSOR_FLOW.train.adam(learnRate);
    }
    /**
     * Produce a shallow clone of the remaining parameters as a simple object.
     * @return {Types.StringKeyedSimpleObject}
     */
    ShallowCloneParams() {
        return Object.assign({}, this._staticParams);
    }
    /**
     * Build an object with all available axes (hyperparams), taking the user's
     * value if available, otherwise taking the system default.
     * @private
     */
    WriteStaticParams() {
        // set the user's value, or take the program default (these are optional from the user's point-of-view)
        this._staticParams[Axis.AxisNames.BATCH_SIZE] =
            this._userStatics[Axis.AxisNames.BATCH_SIZE] !== undefined
                ? this._userStatics[Axis.AxisNames.BATCH_SIZE]
                : Axis.AxisDefaults.BATCH_SIZE;
        this._staticParams[Axis.AxisNames.EPOCHS] =
            this._userStatics[Axis.AxisNames.EPOCHS] !== undefined
                ? this._userStatics[Axis.AxisNames.EPOCHS]
                : Axis.AxisDefaults.EPOCHS;
        this._staticParams[Axis.AxisNames.LAYERS] =
            this._userStatics[Axis.AxisNames.LAYERS] !== undefined
                ? this._userStatics[Axis.AxisNames.LAYERS]
                : Axis.AxisDefaults.LAYERS;
        this._staticParams[Axis.AxisNames.LEARN_RATE] =
            this._userStatics[Axis.AxisNames.LEARN_RATE] !== undefined
                ? this._userStatics[Axis.AxisNames.LEARN_RATE]
                : Axis.AxisDefaults.LEARN_RATE;
        this._staticParams[Axis.AxisNames.NEURONS] =
            this._userStatics[Axis.AxisNames.NEURONS] !== undefined
                ? this._userStatics[Axis.AxisNames.NEURONS]
                : Axis.AxisDefaults.NEURONS;
        this._staticParams[Axis.AxisNames.VALIDATION_SPLIT] =
            this._userStatics[Axis.AxisNames.VALIDATION_SPLIT] !== undefined
                ? this._userStatics[Axis.AxisNames.VALIDATION_SPLIT]
                : Axis.AxisDefaults.VALIDATION_SPLIT;
        // now we tack on the parameters that can't be axes (or rather not-yet-supported-as-axes)
        //TODO: The primary goal of this project is to support as many of these in as many ways as possible (progressions,
        //      custom schedules, randomizers, smart systems, ... ?).
        this._staticParams.activationHidden = 'relu';
        this._staticParams.activationInput = 'relu';
        this._staticParams.activationOutput = 'softmax';
    }
}
Object.freeze(ModelStatics);
export { ModelStatics };
//# sourceMappingURL=ModelStatics.js.map