=1;b-=1){P=g[b],E=k+b*(b+3)/2,S=E-b;for(w=b+1;w<=v;w+=1)P-=g[E]*g[C+w],E+=w;P/=g[S],g[C+b]=P;if(d[b]p)g[L+O]=P;else{g[L+O]=-Math.abs(P);if(P>0){for(w=1;w<=o;w+=1)f[w][O]=-f[w][O];l[O]=-l[O]}}return 700}v+=1,d[v]=O,E=k+(v-1)*v/2+1;for(b=1;b<=v-1;b+=1)g[E]=g[b],E+=1;if(v===o)g[E]=g[o];else{for(b=o;b>=v+1;b-=1){if(g[b]===0)break;j=Math.max(Math.abs(g[b-1]),Math.abs(g[b])),F=Math.min(Math.abs(g[b-1]),Math.abs(g[b])),g[b-1]>=0?D=Math.abs(j*Math.sqrt(1+F*F/(j*j))):D=-Math.abs(j*Math.sqrt(1+F*F/(j*j))),j=g[b-1]/D,F=g[b]/D;if(j===1)break;if(j===0){g[b-1]=F*D;for(w=1;w<=o;w+=1)D=e[w][b-1],e[w][b-1]=e[w][b],e[w][b]=D}else{g[b-1]=D,I=F/(1+j);for(w=1;w<=o;w+=1)D=j*e[w][b-1]+F*e[w][b],e[w][b]=I*(e[w][b-1]+D)-e[w][b],e[w][b-1]=D}}g[E]=g[v]}return 0}function J(){E=k+T*(T+1)/2+1,S=E+T;if(g[S]===0)return 798;j=Math.max(Math.abs(g[S-1]),Math.abs(g[S])),F=Math.min(Math.abs(g[S-1]),Math.abs(g[S])),g[S-1]>=0?D=Math.abs(j*Math.sqrt(1+F*F/(j*j))):D=-Math.abs(j*Math.sqrt(1+F*F/(j*j))),j=g[S-1]/D,F=g[S]/D;if(j===1)return 798;if(j===0){for(b=T+1;b<=v;b+=1)D=g[S-1],g[S-1]=g[S],g[S]=D,S+=b;for(b=1;b<=o;b+=1)D=e[b][T],e[b][T]=e[b][T+1],e[b][T+1]=D}else{I=F/(1+j);for(b=T+1;b<=v;b+=1)D=j*g[S-1]+F*g[S],g[S]=I*(g[S-1]+D)-g[S],g[S-1]=D,S+=b;for(b=1;b<=o;b+=1)D=j*e[b][T]+F*e[b][T+1],e[b][T+1]=I*(e[b][T]+D)-e[b][T+1],e[b][T]=D}return 0}function K(){S=E-T;for(b=1;b<=T;b+=1)g[S]=g[E],E+=1,S+=1;return g[A+T]=g[A+T+1],d[T]=d[T+1],T+=1,Tt?e*Math.sqrt(1+t*t/e/e):t==0?e:t*Math.sqrt(1+e*e/t/t)}var n,r=numeric.epsilon,i=1e-64/r,s=50,o=0,u=0,a=0,f=0,l=0,c=numeric.clone(t),h=c.length,p=c[0].length;if(h=0&&(b=-b),w=y*b-T,c[u][u]=y-b;for(a=l;a
=0&&(b=-b),w=y*b-T,c[u][u+1]=y-b;for(a=l;a
E&&(E=S)}for(u=p-1;u!=-1;u+=-1){if(b!=0){w=b*c[u][u+1];for(a=l;a
=s-1)throw"Error: no convergence.";E=v[l],S=v[f-1],b=d[f-1],w=d[f],y=((S-x)*(S+x)+(b-w)*(b+w))/(2*w*S),b=g(y,1),y<0?y=((E-x)*(E+x)+w*(S/(y-b)-w))/E:y=((E-x)*(E+x)+w*(S/(y+b)-w))/E,o=1,T=1;for(u=l+1;u=0;a--)if(v[a] 0 && t[t.length - 1]) && (op[0] === 6 || op[0] === 2)) { _ = 0; continue; }
- if (op[0] === 3 && (!t || (op[1] > t[0] && op[1] < t[3]))) { _.label = op[1]; break; }
- if (op[0] === 6 && _.label < t[1]) { _.label = t[1]; t = op; break; }
- if (t && _.label < t[2]) { _.label = t[2]; _.ops.push(op); break; }
- if (t[2]) _.ops.pop();
- _.trys.pop(); continue;
- }
- op = body.call(thisArg, _);
- } catch (e) { op = [6, e]; y = 0; } finally { f = t = 0; }
- if (op[0] & 5) throw op[1]; return { value: op[0] ? op[1] : void 0, done: true };
- }
- }
-
- /**
- * @license
- * Copyright 2017 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- // Expects flags from URL in the format ?tfjsflags=FLAG1:1,FLAG2:true.
- var TENSORFLOWJS_FLAGS_PREFIX = 'tfjsflags';
- /**
- * The environment contains evaluated flags as well as the registered platform.
- * This is always used as a global singleton and can be retrieved with
- * `tf.env()`.
- */
- /** @doc {heading: 'Environment'} */
- var Environment = /** @class */ (function () {
- // tslint:disable-next-line: no-any
- function Environment(global) {
- this.global = global;
- this.flags = {};
- this.flagRegistry = {};
- this.urlFlags = {};
- this.populateURLFlags();
- }
- Environment.prototype.setPlatform = function (platformName, platform) {
- if (this.platform != null) {
- console.warn("Platform " + this.platformName + " has already been set. " +
- ("Overwriting the platform with " + platform + "."));
- }
- this.platformName = platformName;
- this.platform = platform;
- };
- Environment.prototype.registerFlag = function (flagName, evaluationFn, setHook) {
- this.flagRegistry[flagName] = { evaluationFn: evaluationFn, setHook: setHook };
- // Override the flag value from the URL. This has to happen here because the
- // environment is initialized before flags get registered.
- if (this.urlFlags[flagName] != null) {
- var flagValue = this.urlFlags[flagName];
- console.warn("Setting feature override from URL " + flagName + ": " + flagValue + ".");
- this.set(flagName, flagValue);
- }
- };
- Environment.prototype.get = function (flagName) {
- if (flagName in this.flags) {
- return this.flags[flagName];
- }
- this.flags[flagName] = this.evaluateFlag(flagName);
- return this.flags[flagName];
- };
- Environment.prototype.getNumber = function (flagName) {
- return this.get(flagName);
- };
- Environment.prototype.getBool = function (flagName) {
- return this.get(flagName);
- };
- Environment.prototype.getFlags = function () {
- return this.flags;
- };
- Object.defineProperty(Environment.prototype, "features", {
- // For backwards compatibility.
- get: function () {
- return this.flags;
- },
- enumerable: true,
- configurable: true
- });
- Environment.prototype.set = function (flagName, value) {
- if (this.flagRegistry[flagName] == null) {
- throw new Error("Cannot set flag " + flagName + " as it has not been registered.");
- }
- this.flags[flagName] = value;
- if (this.flagRegistry[flagName].setHook != null) {
- this.flagRegistry[flagName].setHook(value);
- }
- };
- Environment.prototype.evaluateFlag = function (flagName) {
- if (this.flagRegistry[flagName] == null) {
- throw new Error("Cannot evaluate flag '" + flagName + "': no evaluation function found.");
- }
- return this.flagRegistry[flagName].evaluationFn();
- };
- Environment.prototype.setFlags = function (flags) {
- this.flags = Object.assign({}, flags);
- };
- Environment.prototype.reset = function () {
- this.flags = {};
- this.urlFlags = {};
- this.populateURLFlags();
- };
- Environment.prototype.populateURLFlags = function () {
- var _this = this;
- if (typeof this.global === 'undefined' ||
- typeof this.global.location === 'undefined' ||
- typeof this.global.location.search === 'undefined') {
- return;
- }
- var urlParams = getQueryParams(this.global.location.search);
- if (TENSORFLOWJS_FLAGS_PREFIX in urlParams) {
- var keyValues = urlParams[TENSORFLOWJS_FLAGS_PREFIX].split(',');
- keyValues.forEach(function (keyValue) {
- var _a = keyValue.split(':'), key = _a[0], value = _a[1];
- _this.urlFlags[key] = parseValue(key, value);
- });
- }
- };
- return Environment;
- }());
- function getQueryParams(queryString) {
- var params = {};
- queryString.replace(/[?&]([^=?&]+)(?:=([^&]*))?/g, function (s) {
- var t = [];
- for (var _i = 1; _i < arguments.length; _i++) {
- t[_i - 1] = arguments[_i];
- }
- decodeParam(params, t[0], t[1]);
- return t.join('=');
- });
- return params;
- }
- function decodeParam(params, name, value) {
- params[decodeURIComponent(name)] = decodeURIComponent(value || '');
- }
- function parseValue(flagName, value) {
- value = value.toLowerCase();
- if (value === 'true' || value === 'false') {
- return value === 'true';
- }
- else if ("" + +value === value) {
- return +value;
- }
- throw new Error("Could not parse value flag value " + value + " for flag " + flagName + ".");
- }
- /**
- * Returns the current environment (a global singleton).
- *
- * The environment object contains the evaluated feature values as well as the
- * active platform.
- */
- /** @doc {heading: 'Environment'} */
- function env() {
- return exports.ENV;
- }
- exports.ENV = null;
- function setEnvironmentGlobal(environment) {
- exports.ENV = environment;
- }
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var kernelRegistry = new Map();
- var gradRegistry = new Map();
- /**
- * Returns the kernel function (code) associated with the provided names.
- *
- * @param kernelName The official name of the kernel.
- * @param backendName The official name of the backend.
- */
- function getKernel(kernelName, backendName) {
- var key = makeKey(kernelName, backendName);
- return kernelRegistry.get(key);
- }
- /**
- * Returns the registered gradient info associated with the provided kernel.
- * @param kernelName The official TF kernel name.
- */
- function getGradient(kernelName) {
- return gradRegistry.get(kernelName);
- }
- function getKernelsForBackend(backendName) {
- var it = kernelRegistry.entries();
- var result = [];
- while (true) {
- var _a = it.next(), done = _a.done, value = _a.value;
- if (done) {
- break;
- }
- var key = value[0], config = value[1];
- var backend = key.split('_')[0];
- if (backend === backendName) {
- result.push(config);
- }
- }
- return result;
- }
- /**
- * Registers the function (forward pass) for the kernel in a global registry.
- *
- * @param config A config object with the following properties:
- * - `kernelName` The official name of the kernel.
- * - `backendName` The official name of the backend.
- * - `kernelFunc` The function to run during the forward pass of the kernel.
- * - `setupFunc` Optional. Gets called once, after the backend initializes.
- * - `disposeFunc` Optional. Gets called once, right before the backend is
- * disposed.
- */
- function registerKernel(config) {
- var kernelName = config.kernelName, backendName = config.backendName;
- var key = makeKey(kernelName, backendName);
- if (kernelRegistry.has(key)) {
- throw new Error("The kernel '" + kernelName + "' for backend " +
- ("'" + backendName + "' is already registered"));
- }
- kernelRegistry.set(key, config);
- }
- /**
- * Registers a gradient function for a given kernel in the global registry,
- * to be used during the back-propagation of that kernel.
- *
- * @param config An object with the following properties:
- * - `kernelName` The name of the kernel that the gradient function is for.
- * - `gradFunc` The function to run during back-propagation.
- */
- function registerGradient(config) {
- var kernelName = config.kernelName;
- if (gradRegistry.has(kernelName)) {
- console.warn("Overriding the gradient for '" + kernelName + "'");
- }
- gradRegistry.set(kernelName, config);
- }
- /**
- * Removes the kernel function from the registry.
- *
- * @param kernelName The official name of the kernel.
- * @param backendName The official name of the backend.
- *
- */
- function unregisterKernel(kernelName, backendName) {
- var key = makeKey(kernelName, backendName);
- if (!kernelRegistry.has(key)) {
- throw new Error("The kernel '" + kernelName + "' for backend " +
- ("'" + backendName + "' is not registered"));
- }
- kernelRegistry.delete(key);
- }
- /** Removes the registered gradient from the global registry. */
- function unregisterGradient(kernelName) {
- if (!gradRegistry.has(kernelName)) {
- throw new Error("The gradient '" + kernelName + "' for backend is not registered");
- }
- gradRegistry.delete(kernelName);
- }
- function makeKey(kernelName, backendName) {
- return backendName + "_" + kernelName;
- }
-
- /**
- * @license
- * Copyright 2017 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Shuffles the array in-place using Fisher-Yates algorithm.
- *
- * ```js
- * const a = [1, 2, 3, 4, 5];
- * tf.util.shuffle(a);
- * console.log(a);
- * ```
- *
- * @param array The array to shuffle in-place.
- */
- /** @doc {heading: 'Util', namespace: 'util'} */
- // tslint:disable-next-line:no-any
- function shuffle(array) {
- var counter = array.length;
- var temp = 0;
- var index = 0;
- // While there are elements in the array
- while (counter > 0) {
- // Pick a random index
- index = (Math.random() * counter) | 0;
- // Decrease counter by 1
- counter--;
- // And swap the last element with it
- temp = array[counter];
- array[counter] = array[index];
- array[index] = temp;
- }
- }
- /** Clamps a value to a specified range. */
- function clamp(min, x, max) {
- return Math.max(min, Math.min(x, max));
- }
- function nearestLargerEven(val) {
- return val % 2 === 0 ? val : val + 1;
- }
- function sum(arr) {
- var sum = 0;
- for (var i = 0; i < arr.length; i++) {
- sum += arr[i];
- }
- return sum;
- }
- /**
- * Returns a sample from a uniform [a, b) distribution.
- *
- * @param a The minimum support (inclusive).
- * @param b The maximum support (exclusive).
- * @return A pseudorandom number on the half-open interval [a,b).
- */
- function randUniform(a, b) {
- var r = Math.random();
- return (b * r) + (1 - r) * a;
- }
- /** Returns the squared Euclidean distance between two vectors. */
- function distSquared(a, b) {
- var result = 0;
- for (var i = 0; i < a.length; i++) {
- var diff = Number(a[i]) - Number(b[i]);
- result += diff * diff;
- }
- return result;
- }
- /**
- * Asserts that the expression is true. Otherwise throws an error with the
- * provided message.
- *
- * ```js
- * const x = 2;
- * tf.util.assert(x === 2, 'x is not 2');
- * ```
- *
- * @param expr The expression to assert (as a boolean).
- * @param msg A function that returns the message to report when throwing an
- * error. We use a function for performance reasons.
- */
- /** @doc {heading: 'Util', namespace: 'util'} */
- function assert(expr, msg) {
- if (!expr) {
- throw new Error(typeof msg === 'string' ? msg : msg());
- }
- }
- function assertShapesMatch(shapeA, shapeB, errorMessagePrefix) {
- if (errorMessagePrefix === void 0) { errorMessagePrefix = ''; }
- assert(arraysEqual(shapeA, shapeB), function () { return errorMessagePrefix + (" Shapes " + shapeA + " and " + shapeB + " must match"); });
- }
- function assertNonNull(a) {
- assert(a != null, function () { return "The input to the tensor constructor must be a non-null value."; });
- }
- // NOTE: We explicitly type out what T extends instead of any so that
- // util.flatten on a nested array of number doesn't try to infer T as a
- // number[][], causing us to explicitly type util.flatten().
- /**
- * Flattens an arbitrarily nested array.
- *
- * ```js
- * const a = [[1, 2], [3, 4], [5, [6, [7]]]];
- * const flat = tf.util.flatten(a);
- * console.log(flat);
- * ```
- *
- * @param arr The nested array to flatten.
- * @param result The destination array which holds the elements.
- * @param skipTypedArray If true, avoids flattening the typed arrays. Defaults
- * to false.
- */
- /** @doc {heading: 'Util', namespace: 'util'} */
- function flatten(arr, result, skipTypedArray) {
- if (result === void 0) { result = []; }
- if (skipTypedArray === void 0) { skipTypedArray = false; }
- if (result == null) {
- result = [];
- }
- if (Array.isArray(arr) || isTypedArray(arr) && !skipTypedArray) {
- for (var i = 0; i < arr.length; ++i) {
- flatten(arr[i], result, skipTypedArray);
- }
- }
- else {
- result.push(arr);
- }
- return result;
- }
- /**
- * Returns the size (number of elements) of the tensor given its shape.
- *
- * ```js
- * const shape = [3, 4, 2];
- * const size = tf.util.sizeFromShape(shape);
- * console.log(size);
- * ```
- */
- /** @doc {heading: 'Util', namespace: 'util'} */
- function sizeFromShape(shape) {
- if (shape.length === 0) {
- // Scalar.
- return 1;
- }
- var size = shape[0];
- for (var i = 1; i < shape.length; i++) {
- size *= shape[i];
- }
- return size;
- }
- function isScalarShape(shape) {
- return shape.length === 0;
- }
- function arraysEqual(n1, n2) {
- if (n1 === n2) {
- return true;
- }
- if (n1 == null || n2 == null) {
- return false;
- }
- if (n1.length !== n2.length) {
- return false;
- }
- for (var i = 0; i < n1.length; i++) {
- if (n1[i] !== n2[i]) {
- return false;
- }
- }
- return true;
- }
- function isInt(a) {
- return a % 1 === 0;
- }
- function tanh(x) {
- // tslint:disable-next-line:no-any
- if (Math.tanh != null) {
- // tslint:disable-next-line:no-any
- return Math.tanh(x);
- }
- if (x === Infinity) {
- return 1;
- }
- else if (x === -Infinity) {
- return -1;
- }
- else {
- var e2x = Math.exp(2 * x);
- return (e2x - 1) / (e2x + 1);
- }
- }
- function sizeToSquarishShape(size) {
- var width = Math.ceil(Math.sqrt(size));
- return [width, Math.ceil(size / width)];
- }
- /**
- * Creates a new array with randomized indicies to a given quantity.
- *
- * ```js
- * const randomTen = tf.util.createShuffledIndices(10);
- * console.log(randomTen);
- * ```
- *
- * @param number Quantity of how many shuffled indicies to create.
- */
- /** @doc {heading: 'Util', namespace: 'util'} */
- function createShuffledIndices(n) {
- var shuffledIndices = new Uint32Array(n);
- for (var i = 0; i < n; ++i) {
- shuffledIndices[i] = i;
- }
- shuffle(shuffledIndices);
- return shuffledIndices;
- }
- function rightPad(a, size) {
- if (size <= a.length) {
- return a;
- }
- return a + ' '.repeat(size - a.length);
- }
- function repeatedTry(checkFn, delayFn, maxCounter) {
- if (delayFn === void 0) { delayFn = function (counter) { return 0; }; }
- return new Promise(function (resolve, reject) {
- var tryCount = 0;
- var tryFn = function () {
- if (checkFn()) {
- resolve();
- return;
- }
- tryCount++;
- var nextBackoff = delayFn(tryCount);
- if (maxCounter != null && tryCount >= maxCounter) {
- reject();
- return;
- }
- setTimeout(tryFn, nextBackoff);
- };
- tryFn();
- });
- }
- /**
- * Given the full size of the array and a shape that may contain -1 as the
- * implicit dimension, returns the inferred shape where -1 is replaced.
- * E.g. For shape=[2, -1, 3] and size=24, it will return [2, 4, 3].
- *
- * @param shape The shape, which may contain -1 in some dimension.
- * @param size The full size (number of elements) of the array.
- * @return The inferred shape where -1 is replaced with the inferred size.
- */
- function inferFromImplicitShape(shape, size) {
- var shapeProd = 1;
- var implicitIdx = -1;
- for (var i = 0; i < shape.length; ++i) {
- if (shape[i] >= 0) {
- shapeProd *= shape[i];
- }
- else if (shape[i] === -1) {
- if (implicitIdx !== -1) {
- throw Error("Shapes can only have 1 implicit size. " +
- ("Found -1 at dim " + implicitIdx + " and dim " + i));
- }
- implicitIdx = i;
- }
- else if (shape[i] < 0) {
- throw Error("Shapes can not be < 0. Found " + shape[i] + " at dim " + i);
- }
- }
- if (implicitIdx === -1) {
- if (size > 0 && size !== shapeProd) {
- throw Error("Size(" + size + ") must match the product of shape " + shape);
- }
- return shape;
- }
- if (shapeProd === 0) {
- throw Error("Cannot infer the missing size in [" + shape + "] when " +
- "there are 0 elements");
- }
- if (size % shapeProd !== 0) {
- throw Error("The implicit shape can't be a fractional number. " +
- ("Got " + size + " / " + shapeProd));
- }
- var newShape = shape.slice();
- newShape[implicitIdx] = size / shapeProd;
- return newShape;
- }
- function parseAxisParam(axis, shape) {
- var rank = shape.length;
- // Normalize input
- axis = axis == null ? shape.map(function (s, i) { return i; }) : [].concat(axis);
- // Check for valid range
- assert(axis.every(function (ax) { return ax >= -rank && ax < rank; }), function () {
- return "All values in axis param must be in range [-" + rank + ", " + rank + ") but " +
- ("got axis " + axis);
- });
- // Check for only integers
- assert(axis.every(function (ax) { return isInt(ax); }), function () { return "All values in axis param must be integers but " +
- ("got axis " + axis); });
- // Handle negative axis.
- return axis.map(function (a) { return a < 0 ? rank + a : a; });
- }
- /** Reduces the shape by removing all dimensions of shape 1. */
- function squeezeShape(shape, axis) {
- var newShape = [];
- var keptDims = [];
- var isEmptyArray = axis != null && Array.isArray(axis) && axis.length === 0;
- var axes = (axis == null || isEmptyArray) ?
- null :
- parseAxisParam(axis, shape).sort();
- var j = 0;
- for (var i = 0; i < shape.length; ++i) {
- if (axes != null) {
- if (axes[j] === i && shape[i] !== 1) {
- throw new Error("Can't squeeze axis " + i + " since its dim '" + shape[i] + "' is not 1");
- }
- if ((axes[j] == null || axes[j] > i) && shape[i] === 1) {
- newShape.push(shape[i]);
- keptDims.push(i);
- }
- if (axes[j] <= i) {
- j++;
- }
- }
- if (shape[i] !== 1) {
- newShape.push(shape[i]);
- keptDims.push(i);
- }
- }
- return { newShape: newShape, keptDims: keptDims };
- }
- function getTypedArrayFromDType(dtype, size) {
- var values = null;
- if (dtype == null || dtype === 'float32') {
- values = new Float32Array(size);
- }
- else if (dtype === 'int32') {
- values = new Int32Array(size);
- }
- else if (dtype === 'bool') {
- values = new Uint8Array(size);
- }
- else {
- throw new Error("Unknown data type " + dtype);
- }
- return values;
- }
- function getArrayFromDType(dtype, size) {
- var values = null;
- if (dtype == null || dtype === 'float32') {
- values = new Float32Array(size);
- }
- else if (dtype === 'int32') {
- values = new Int32Array(size);
- }
- else if (dtype === 'bool') {
- values = new Uint8Array(size);
- }
- else if (dtype === 'string') {
- values = new Array(size);
- }
- else {
- throw new Error("Unknown data type " + dtype);
- }
- return values;
- }
- function checkConversionForErrors(vals, dtype) {
- for (var i = 0; i < vals.length; i++) {
- var num = vals[i];
- if (isNaN(num) || !isFinite(num)) {
- throw Error("A tensor of type " + dtype + " being uploaded contains " + num + ".");
- }
- }
- }
- /** Returns true if the dtype is valid. */
- function isValidDtype(dtype) {
- return dtype === 'bool' || dtype === 'complex64' || dtype === 'float32' ||
- dtype === 'int32' || dtype === 'string';
- }
- /**
- * Returns true if the new type can't encode the old type without loss of
- * precision.
- */
- function hasEncodingLoss(oldType, newType) {
- if (newType === 'complex64') {
- return false;
- }
- if (newType === 'float32' && oldType !== 'complex64') {
- return false;
- }
- if (newType === 'int32' && oldType !== 'float32' && oldType !== 'complex64') {
- return false;
- }
- if (newType === 'bool' && oldType === 'bool') {
- return false;
- }
- return true;
- }
- function isTypedArray(a) {
- return a instanceof Float32Array || a instanceof Int32Array ||
- a instanceof Uint8Array;
- }
- function bytesPerElement(dtype) {
- if (dtype === 'float32' || dtype === 'int32') {
- return 4;
- }
- else if (dtype === 'complex64') {
- return 8;
- }
- else if (dtype === 'bool') {
- return 1;
- }
- else {
- throw new Error("Unknown dtype " + dtype);
- }
- }
- /**
- * Returns the approximate number of bytes allocated in the string array - 2
- * bytes per character. Computing the exact bytes for a native string in JS is
- * not possible since it depends on the encoding of the html page that serves
- * the website.
- */
- function bytesFromStringArray(arr) {
- if (arr == null) {
- return 0;
- }
- var bytes = 0;
- arr.forEach(function (x) { return bytes += x.length; });
- return bytes;
- }
- /** Returns true if the value is a string. */
- function isString(value) {
- return typeof value === 'string' || value instanceof String;
- }
- function isBoolean(value) {
- return typeof value === 'boolean';
- }
- function isNumber(value) {
- return typeof value === 'number';
- }
- function inferDtype(values) {
- if (Array.isArray(values)) {
- return inferDtype(values[0]);
- }
- if (values instanceof Float32Array) {
- return 'float32';
- }
- else if (values instanceof Int32Array || values instanceof Uint8Array) {
- return 'int32';
- }
- else if (isNumber(values)) {
- return 'float32';
- }
- else if (isString(values)) {
- return 'string';
- }
- else if (isBoolean(values)) {
- return 'bool';
- }
- return 'float32';
- }
- function isFunction(f) {
- return !!(f && f.constructor && f.call && f.apply);
- }
- function nearestDivisor(size, start) {
- for (var i = start; i < size; ++i) {
- if (size % i === 0) {
- return i;
- }
- }
- return size;
- }
- function computeStrides(shape) {
- var rank = shape.length;
- if (rank < 2) {
- return [];
- }
- // Last dimension has implicit stride of 1, thus having D-1 (instead of D)
- // strides.
- var strides = new Array(rank - 1);
- strides[rank - 2] = shape[rank - 1];
- for (var i = rank - 3; i >= 0; --i) {
- strides[i] = strides[i + 1] * shape[i + 1];
- }
- return strides;
- }
- function toTypedArray(a, dtype, debugMode) {
- if (dtype === 'string') {
- throw new Error('Cannot convert a string[] to a TypedArray');
- }
- if (Array.isArray(a)) {
- a = flatten(a);
- }
- if (debugMode) {
- checkConversionForErrors(a, dtype);
- }
- if (noConversionNeeded(a, dtype)) {
- return a;
- }
- if (dtype == null || dtype === 'float32' || dtype === 'complex64') {
- return new Float32Array(a);
- }
- else if (dtype === 'int32') {
- return new Int32Array(a);
- }
- else if (dtype === 'bool') {
- var bool = new Uint8Array(a.length);
- for (var i = 0; i < bool.length; ++i) {
- if (Math.round(a[i]) !== 0) {
- bool[i] = 1;
- }
- }
- return bool;
- }
- else {
- throw new Error("Unknown data type " + dtype);
- }
- }
- function createNestedArray(offset, shape, a) {
- var ret = new Array();
- if (shape.length === 1) {
- var d = shape[0];
- for (var i = 0; i < d; i++) {
- ret[i] = a[offset + i];
- }
- }
- else {
- var d = shape[0];
- var rest = shape.slice(1);
- var len = rest.reduce(function (acc, c) { return acc * c; });
- for (var i = 0; i < d; i++) {
- ret[i] = createNestedArray(offset + i * len, rest, a);
- }
- }
- return ret;
- }
- // Provide a nested array of TypedArray in given shape.
- function toNestedArray(shape, a) {
- if (shape.length === 0) {
- // Scalar type should return a single number.
- return a[0];
- }
- var size = shape.reduce(function (acc, c) { return acc * c; });
- if (size === 0) {
- // A tensor with shape zero should be turned into empty list.
- return [];
- }
- if (size !== a.length) {
- throw new Error("[" + shape + "] does not match the input size.");
- }
- return createNestedArray(0, shape, a);
- }
- function noConversionNeeded(a, dtype) {
- return (a instanceof Float32Array && dtype === 'float32') ||
- (a instanceof Int32Array && dtype === 'int32') ||
- (a instanceof Uint8Array && dtype === 'bool');
- }
- function makeOnesTypedArray(size, dtype) {
- var array = makeZerosTypedArray(size, dtype);
- for (var i = 0; i < array.length; i++) {
- array[i] = 1;
- }
- return array;
- }
- function makeZerosTypedArray(size, dtype) {
- if (dtype == null || dtype === 'float32' || dtype === 'complex64') {
- return new Float32Array(size);
- }
- else if (dtype === 'int32') {
- return new Int32Array(size);
- }
- else if (dtype === 'bool') {
- return new Uint8Array(size);
- }
- else {
- throw new Error("Unknown data type " + dtype);
- }
- }
- /**
- * Returns the current high-resolution time in milliseconds relative to an
- * arbitrary time in the past. It works across different platforms (node.js,
- * browsers).
- *
- * ```js
- * console.log(tf.util.now());
- * ```
- */
- /** @doc {heading: 'Util', namespace: 'util'} */
- function now() {
- return env().platform.now();
- }
- function assertNonNegativeIntegerDimensions(shape) {
- shape.forEach(function (dimSize) {
- assert(Number.isInteger(dimSize) && dimSize >= 0, function () {
- return "Tensor must have a shape comprised of positive integers but got " +
- ("shape [" + shape + "].");
- });
- });
- }
- /**
- * Returns a platform-specific implementation of
- * [`fetch`](https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API).
- *
- * If `fetch` is defined on the global object (`window`, `process`, etc.),
- * `tf.util.fetch` returns that function.
- *
- * If not, `tf.util.fetch` returns a platform-specific solution.
- *
- * ```js
- * const resource = await tf.util.fetch('https://unpkg.com/@tensorflow/tfjs');
- * // handle response
- * ```
- */
- /** @doc {heading: 'Util'} */
- function fetch$1(path, requestInits) {
- return env().platform.fetch(path, requestInits);
- }
- /**
- * Encodes the provided string into bytes using the provided encoding scheme.
- *
- * @param s The string to encode.
- * @param encoding The encoding scheme. Defaults to utf-8.
- *
- */
- /** @doc {heading: 'Util'} */
- function encodeString(s, encoding) {
- if (encoding === void 0) { encoding = 'utf-8'; }
- encoding = encoding || 'utf-8';
- return env().platform.encode(s, encoding);
- }
- /**
- * Decodes the provided bytes into a string using the provided encoding scheme.
- * @param bytes The bytes to decode.
- *
- * @param encoding The encoding scheme. Defaults to utf-8.
- */
- /** @doc {heading: 'Util'} */
- function decodeString(bytes, encoding) {
- if (encoding === void 0) { encoding = 'utf-8'; }
- encoding = encoding || 'utf-8';
- return env().platform.decode(bytes, encoding);
- }
- /**
- * Computes flat index for a given location (multidimentionsal index) in a
- * Tensor/multidimensional array.
- *
- * @param locs Location in the tensor.
- * @param rank Rank of the tensor.
- * @param strides Tensor strides.
- */
- function locToIndex(locs, rank, strides) {
- if (rank === 0) {
- return 0;
- }
- else if (rank === 1) {
- return locs[0];
- }
- var index = locs[locs.length - 1];
- for (var i = 0; i < locs.length - 1; ++i) {
- index += strides[i] * locs[i];
- }
- return index;
- }
- /**
- * Computes the location (multidimensional index) in a tensor/multidimentional
- * array for a given flat index.
- *
- * @param index Index in flat array.
- * @param rank Rank of tensor.
- * @param strides Strides of tensor.
- */
- function indexToLoc(index, rank, strides) {
- if (rank === 0) {
- return [];
- }
- else if (rank === 1) {
- return [index];
- }
- var locs = new Array(rank);
- for (var i = 0; i < locs.length - 1; ++i) {
- locs[i] = Math.floor(index / strides[i]);
- index -= locs[i] * strides[i];
- }
- locs[locs.length - 1] = index;
- return locs;
- }
-
- var util = /*#__PURE__*/Object.freeze({
- shuffle: shuffle,
- clamp: clamp,
- nearestLargerEven: nearestLargerEven,
- sum: sum,
- randUniform: randUniform,
- distSquared: distSquared,
- assert: assert,
- assertShapesMatch: assertShapesMatch,
- assertNonNull: assertNonNull,
- flatten: flatten,
- sizeFromShape: sizeFromShape,
- isScalarShape: isScalarShape,
- arraysEqual: arraysEqual,
- isInt: isInt,
- tanh: tanh,
- sizeToSquarishShape: sizeToSquarishShape,
- createShuffledIndices: createShuffledIndices,
- rightPad: rightPad,
- repeatedTry: repeatedTry,
- inferFromImplicitShape: inferFromImplicitShape,
- parseAxisParam: parseAxisParam,
- squeezeShape: squeezeShape,
- getTypedArrayFromDType: getTypedArrayFromDType,
- getArrayFromDType: getArrayFromDType,
- checkConversionForErrors: checkConversionForErrors,
- isValidDtype: isValidDtype,
- hasEncodingLoss: hasEncodingLoss,
- isTypedArray: isTypedArray,
- bytesPerElement: bytesPerElement,
- bytesFromStringArray: bytesFromStringArray,
- isString: isString,
- isBoolean: isBoolean,
- isNumber: isNumber,
- inferDtype: inferDtype,
- isFunction: isFunction,
- nearestDivisor: nearestDivisor,
- computeStrides: computeStrides,
- toTypedArray: toTypedArray,
- toNestedArray: toNestedArray,
- makeOnesTypedArray: makeOnesTypedArray,
- makeZerosTypedArray: makeZerosTypedArray,
- now: now,
- assertNonNegativeIntegerDimensions: assertNonNegativeIntegerDimensions,
- fetch: fetch$1,
- encodeString: encodeString,
- decodeString: decodeString,
- locToIndex: locToIndex,
- indexToLoc: indexToLoc
- });
-
- /**
- * @license
- * Copyright 2018 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var Profiler = /** @class */ (function () {
- function Profiler(backendTimer, logger) {
- this.backendTimer = backendTimer;
- this.logger = logger;
- if (logger == null) {
- this.logger = new Logger();
- }
- }
- Profiler.prototype.profileKernel = function (kernelName, inputs, f) {
- var _this = this;
- var outputs;
- var holdResultWrapperFn = function () {
- outputs = f();
- };
- var timer = this.backendTimer.time(holdResultWrapperFn);
- outputs.forEach(function (r) {
- // Dangling promise here because we don't want to propagate up
- // asynchronicity.
- r.data().then(function (vals) {
- checkComputationForErrors(vals, r.dtype, kernelName);
- timer.then(function (timing) {
- var extraInfo = '';
- if (timing.getExtraProfileInfo != null) {
- extraInfo = timing.getExtraProfileInfo();
- }
- _this.logger.logKernelProfile(kernelName, r, vals, timing.kernelMs, inputs, extraInfo);
- });
- });
- });
- return outputs;
- };
- return Profiler;
- }());
- function checkComputationForErrors(vals, dtype, kernelName) {
- if (dtype !== 'float32') {
- // Only floating point computations will generate NaN values
- return false;
- }
- for (var i = 0; i < vals.length; i++) {
- var num = vals[i];
- if (isNaN(num) || !isFinite(num)) {
- // Throwing custom exception so behavior is testable.
- console.warn("Found " + num + " in the result of '" + kernelName + "'");
- return true;
- }
- }
- return false;
- }
- var Logger = /** @class */ (function () {
- function Logger() {
- }
- Logger.prototype.logKernelProfile = function (name, result, vals, timeMs, inputs, extraInfo) {
- var time = typeof timeMs === 'number' ? rightPad(timeMs + "ms", 9) :
- timeMs['error'];
- var paddedName = rightPad(name, 25);
- var rank = result.rank;
- var size = result.size;
- var shape = rightPad(result.shape.toString(), 14);
- var inputShapesDescription = '';
- for (var name_1 in inputs) {
- var input = inputs[name_1];
- // The input might be a non-tensor (e.g HTMLImageElement), in which case
- // we claim the output shape as input shape.
- var inputShape = input.shape || result.shape;
- var inputRank = inputShape.length;
- inputShapesDescription +=
- name_1 + ": " + inputRank + "D " + (inputRank > 0 ? inputShape : '') + " ";
- }
- console.log("%c" + paddedName + "\t%c" + time + "\t%c" + rank + "D " + shape + "\t%c" + size + "\t%c" + inputShapesDescription + "\t%c" + extraInfo, 'font-weight:bold', 'color:red', 'color:blue', 'color: orange', 'color: green', 'color: steelblue');
- };
- return Logger;
- }());
-
- /**
- * @license
- * Copyright 2017 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes a list of TapeNodes that connect x to y, filtering everything else
- * out and preserving the order of the original tape elements.
- *
- * @param tape The tape elements to filter.
- * @param xs The input Tensors.
- * @param y The output Tensor.
- */
- function getFilteredNodesXToY(tape, xs, y) {
- // Forward pass to compute all the nodes and Tensors that are transitively a
- // function of x.
- var tensorsFromX = {};
- var nodesFromX = {};
- for (var i = 0; i < xs.length; i++) {
- tensorsFromX[xs[i].id] = true;
- }
- for (var i = 0; i < tape.length; i++) {
- var node = tape[i];
- var nodeInputs = node.inputs;
- for (var inputName in nodeInputs) {
- var input = nodeInputs[inputName];
- var anyInputFromX = false;
- for (var j = 0; j < xs.length; j++) {
- if (tensorsFromX[input.id]) {
- node.outputs.forEach(function (output) { return tensorsFromX[output.id] = true; });
- anyInputFromX = true;
- nodesFromX[node.id] = true;
- break;
- }
- }
- if (anyInputFromX) {
- break;
- }
- }
- }
- // Backward pass to find all of the nodes and Tensors that lead to y.
- var tensorsLeadToY = {};
- tensorsLeadToY[y.id] = true;
- var nodesToY = {};
- for (var i = tape.length - 1; i >= 0; i--) {
- var node = tape[i];
- var nodeInputs = node.inputs;
- // If any of the outputs lead to y, mark all of the inputs as leading to y.
- for (var j = 0; j < node.outputs.length; j++) {
- if (tensorsLeadToY[node.outputs[j].id]) {
- for (var inputName in nodeInputs) {
- tensorsLeadToY[nodeInputs[inputName].id] = true;
- nodesToY[node.id] = true;
- }
- break;
- }
- }
- }
- // Return the paths that come from x and lead to y.
- var filteredTape = [];
- for (var i = 0; i < tape.length; i++) {
- var node = tape[i];
- if (nodesFromX[node.id] && nodesToY[node.id]) {
- // Prune the inputs from the node that aren't a function of x.
- var prunedInputs = {};
- for (var inputName in node.inputs) {
- var nodeInput = node.inputs[inputName];
- if (tensorsFromX[nodeInput.id]) {
- prunedInputs[inputName] = nodeInput;
- }
- }
- // Copy the node and overwrite inputsAndArgs to the pruned version.
- var prunedNode = Object.assign({}, node);
- prunedNode.inputs = prunedInputs;
- prunedNode.outputs = node.outputs;
- filteredTape.push(prunedNode);
- }
- }
- return filteredTape;
- }
- /**
- * Backpropagate gradients through the filtered TapeNodes.
- *
- * @param tensorAccumulatedGradientMap A map of Tensor to its gradient. This map
- * is mutated by this method.
- * @param filteredTape The filtered TapeNodes to backprop through.
- */
- function backpropagateGradients(tensorAccumulatedGradientMap, filteredTape, tidy) {
- var _loop_1 = function (i) {
- var node = filteredTape[i];
- var dys = [];
- node.outputs.forEach(function (o) {
- var gradTensor = tensorAccumulatedGradientMap[o.id];
- if (gradTensor != null) {
- dys.push(gradTensor);
- }
- else {
- // This particular output is not in the back-propagation subgraph, so it
- // does not affect the final output, thus we put null for its dy.
- dys.push(null);
- }
- });
- if (node.gradient == null) {
- throw new Error("Cannot compute gradient: gradient function not found " +
- ("for " + node.kernelName + "."));
- }
- // Backprop dy through this node and accumulate gradients over the inputs.
- var inputGradients = node.gradient(dys);
- var _loop_2 = function (inputName) {
- if (!(inputName in inputGradients)) {
- throw new Error("Cannot backprop through input " + inputName + ". " +
- ("Available gradients found: " + Object.keys(inputGradients) + "."));
- }
- // Call the gradient function.
- var dx = tidy(function () { return inputGradients[inputName](); });
- if (dx.dtype !== 'float32') {
- throw new Error("Error in gradient for op " + node.kernelName + ". The gradient of input " +
- (inputName + " must have 'float32' dtype, but has '" + dx.dtype + "'"));
- }
- var x = node.inputs[inputName];
- if (!arraysEqual(dx.shape, x.shape)) {
- throw new Error("Error in gradient for op " + node.kernelName + ". The gradient of input " +
- ("'" + inputName + "' has shape '" + dx.shape + "', which does not match ") +
- ("the shape of the input '" + x.shape + "'"));
- }
- if (tensorAccumulatedGradientMap[x.id] == null) {
- tensorAccumulatedGradientMap[x.id] = dx;
- }
- else {
- var curGradient = tensorAccumulatedGradientMap[x.id];
- tensorAccumulatedGradientMap[x.id] = curGradient.add(dx);
- curGradient.dispose();
- }
- };
- for (var inputName in node.inputs) {
- _loop_2(inputName);
- }
- };
- // Walk the tape backward and keep a map of Tensor to its gradient.
- for (var i = filteredTape.length - 1; i >= 0; i--) {
- _loop_1(i);
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- // Maximum number of values before we decide to show ellipsis.
- var FORMAT_LIMIT_NUM_VALS = 20;
- // Number of first and last values to show when displaying a, b,...,y, z.
- var FORMAT_NUM_FIRST_LAST_VALS = 3;
- // Number of significant digits to show.
- var FORMAT_NUM_SIG_DIGITS = 7;
- function tensorToString(vals, shape, dtype, verbose) {
- var strides = computeStrides(shape);
- var padPerCol = computeMaxSizePerColumn(vals, shape, dtype, strides);
- var rank = shape.length;
- var valsLines = subTensorToString(vals, shape, dtype, strides, padPerCol);
- var lines = ['Tensor'];
- if (verbose) {
- lines.push(" dtype: " + dtype);
- lines.push(" rank: " + rank);
- lines.push(" shape: [" + shape + "]");
- lines.push(" values:");
- }
- lines.push(valsLines.map(function (l) { return ' ' + l; }).join('\n'));
- return lines.join('\n');
- }
- function computeMaxSizePerColumn(vals, shape, dtype, strides) {
- var n = sizeFromShape(shape);
- var numCols = strides[strides.length - 1];
- var padPerCol = new Array(numCols).fill(0);
- var rank = shape.length;
- var valuesOrTuples = dtype === 'complex64' ? createComplexTuples(vals) : vals;
- if (rank > 1) {
- for (var row = 0; row < n / numCols; row++) {
- var offset = row * numCols;
- for (var j = 0; j < numCols; j++) {
- padPerCol[j] = Math.max(padPerCol[j], valToString(valuesOrTuples[offset + j], 0, dtype).length);
- }
- }
- }
- return padPerCol;
- }
- function valToString(val, pad, dtype) {
- var valStr;
- if (Array.isArray(val)) {
- valStr = parseFloat(val[0].toFixed(FORMAT_NUM_SIG_DIGITS)) + " + " +
- (parseFloat(val[1].toFixed(FORMAT_NUM_SIG_DIGITS)) + "j");
- }
- else if (isString(val)) {
- valStr = "'" + val + "'";
- }
- else if (dtype === 'bool') {
- valStr = boolNumToString(val);
- }
- else {
- valStr = parseFloat(val.toFixed(FORMAT_NUM_SIG_DIGITS)).toString();
- }
- return rightPad(valStr, pad);
- }
- function boolNumToString(v) {
- return v === 0 ? 'false' : 'true';
- }
- function subTensorToString(vals, shape, dtype, strides, padPerCol, isLast) {
- if (isLast === void 0) { isLast = true; }
- var storagePerElement = dtype === 'complex64' ? 2 : 1;
- var size = shape[0];
- var rank = shape.length;
- if (rank === 0) {
- if (dtype === 'complex64') {
- var complexTuple = createComplexTuples(vals);
- return [valToString(complexTuple[0], 0, dtype)];
- }
- if (dtype === 'bool') {
- return [boolNumToString(vals[0])];
- }
- return [vals[0].toString()];
- }
- if (rank === 1) {
- if (size > FORMAT_LIMIT_NUM_VALS) {
- var firstValsSize = FORMAT_NUM_FIRST_LAST_VALS * storagePerElement;
- var firstVals = Array.from(vals.slice(0, firstValsSize));
- var lastVals = Array.from(vals.slice((size - FORMAT_NUM_FIRST_LAST_VALS) * storagePerElement, size * storagePerElement));
- if (dtype === 'complex64') {
- firstVals = createComplexTuples(firstVals);
- lastVals = createComplexTuples(lastVals);
- }
- return [
- '[' +
- firstVals.map(function (x, i) { return valToString(x, padPerCol[i], dtype); })
- .join(', ') +
- ', ..., ' +
- lastVals
- .map(function (x, i) { return valToString(x, padPerCol[size - FORMAT_NUM_FIRST_LAST_VALS + i], dtype); })
- .join(', ') +
- ']'
- ];
- }
- var displayVals = dtype === 'complex64' ? createComplexTuples(vals) :
- Array.from(vals);
- return [
- '[' +
- displayVals.map(function (x, i) { return valToString(x, padPerCol[i], dtype); })
- .join(', ') +
- ']'
- ];
- }
- // The array is rank 2 or more.
- var subshape = shape.slice(1);
- var substrides = strides.slice(1);
- var stride = strides[0] * storagePerElement;
- var lines = [];
- if (size > FORMAT_LIMIT_NUM_VALS) {
- for (var i = 0; i < FORMAT_NUM_FIRST_LAST_VALS; i++) {
- var start = i * stride;
- var end = start + stride;
- lines.push.apply(lines, subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, false /* isLast */));
- }
- lines.push('...');
- for (var i = size - FORMAT_NUM_FIRST_LAST_VALS; i < size; i++) {
- var start = i * stride;
- var end = start + stride;
- lines.push.apply(lines, subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, i === size - 1 /* isLast */));
- }
- }
- else {
- for (var i = 0; i < size; i++) {
- var start = i * stride;
- var end = start + stride;
- lines.push.apply(lines, subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, i === size - 1 /* isLast */));
- }
- }
- var sep = rank === 2 ? ',' : '';
- lines[0] = '[' + lines[0] + sep;
- for (var i = 1; i < lines.length - 1; i++) {
- lines[i] = ' ' + lines[i] + sep;
- }
- var newLineSep = ',\n';
- for (var i = 2; i < rank; i++) {
- newLineSep += '\n';
- }
- lines[lines.length - 1] =
- ' ' + lines[lines.length - 1] + ']' + (isLast ? '' : newLineSep);
- return lines;
- }
- function createComplexTuples(vals) {
- var complexTuples = [];
- for (var i = 0; i < vals.length; i += 2) {
- complexTuples.push([vals[i], vals[i + 1]]);
- }
- return complexTuples;
- }
-
- /**
- * @license
- * Copyright 2017 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * A mutable object, similar to `tf.Tensor`, that allows users to set values
- * at locations before converting to an immutable `tf.Tensor`.
- *
- * See `tf.buffer` for creating a tensor buffer.
- */
- /** @doc {heading: 'Tensors', subheading: 'Classes'} */
- var TensorBuffer = /** @class */ (function () {
- function TensorBuffer(shape, dtype, values) {
- var _this = this;
- this.dtype = dtype;
- this.shape = shape.slice();
- this.size = sizeFromShape(shape);
- if (values != null) {
- var n_1 = values.length;
- assert(n_1 === this.size, function () { return "Length of values '" + n_1 + "' does not match the size " +
- ("inferred by the shape '" + _this.size + "'."); });
- }
- if (dtype === 'complex64') {
- throw new Error("complex64 dtype TensorBuffers are not supported. Please create " +
- "a TensorBuffer for the real and imaginary parts separately and " +
- "call tf.complex(real, imag).");
- }
- this.values = values || getArrayFromDType(dtype, this.size);
- this.strides = computeStrides(shape);
- }
- /**
- * Sets a value in the buffer at a given location.
- *
- * @param value The value to set.
- * @param locs The location indices.
- */
- /** @doc {heading: 'Tensors', subheading: 'Creation'} */
- TensorBuffer.prototype.set = function (value) {
- var _this = this;
- var locs = [];
- for (var _i = 1; _i < arguments.length; _i++) {
- locs[_i - 1] = arguments[_i];
- }
- if (locs.length === 0) {
- locs = [0];
- }
- assert(locs.length === this.rank, function () { return "The number of provided coordinates (" + locs.length + ") must " +
- ("match the rank (" + _this.rank + ")"); });
- var index = this.locToIndex(locs);
- this.values[index] = value;
- };
- /**
- * Returns the value in the buffer at the provided location.
- *
- * @param locs The location indices.
- */
- /** @doc {heading: 'Tensors', subheading: 'Creation'} */
- TensorBuffer.prototype.get = function () {
- var locs = [];
- for (var _i = 0; _i < arguments.length; _i++) {
- locs[_i] = arguments[_i];
- }
- if (locs.length === 0) {
- locs = [0];
- }
- var i = 0;
- for (var _a = 0, locs_1 = locs; _a < locs_1.length; _a++) {
- var loc = locs_1[_a];
- if (loc < 0 || loc >= this.shape[i]) {
- var msg = "Requested out of range element at " + locs + ". " +
- (" Buffer shape=" + this.shape);
- throw new Error(msg);
- }
- i++;
- }
- var index = locs[locs.length - 1];
- for (var i_1 = 0; i_1 < locs.length - 1; ++i_1) {
- index += this.strides[i_1] * locs[i_1];
- }
- return this.values[index];
- };
- TensorBuffer.prototype.locToIndex = function (locs) {
- if (this.rank === 0) {
- return 0;
- }
- else if (this.rank === 1) {
- return locs[0];
- }
- var index = locs[locs.length - 1];
- for (var i = 0; i < locs.length - 1; ++i) {
- index += this.strides[i] * locs[i];
- }
- return index;
- };
- TensorBuffer.prototype.indexToLoc = function (index) {
- if (this.rank === 0) {
- return [];
- }
- else if (this.rank === 1) {
- return [index];
- }
- var locs = new Array(this.shape.length);
- for (var i = 0; i < locs.length - 1; ++i) {
- locs[i] = Math.floor(index / this.strides[i]);
- index -= locs[i] * this.strides[i];
- }
- locs[locs.length - 1] = index;
- return locs;
- };
- Object.defineProperty(TensorBuffer.prototype, "rank", {
- get: function () {
- return this.shape.length;
- },
- enumerable: true,
- configurable: true
- });
- /**
- * Creates an immutable `tf.Tensor` object from the buffer.
- */
- /** @doc {heading: 'Tensors', subheading: 'Creation'} */
- TensorBuffer.prototype.toTensor = function () {
- return trackerFn().makeTensor(this.values, this.shape, this.dtype);
- };
- return TensorBuffer;
- }());
- // For tracking tensor creation and disposal.
- var trackerFn = null;
- // Used by chaining methods to call into ops.
- var opHandler = null;
- // Used to warn about deprecated methods.
- var deprecationWarningFn = null;
- /**
- * An external consumer can register itself as the tensor tracker. This way
- * the Tensor class can notify the tracker for every tensor created and
- * disposed.
- */
- function setTensorTracker(fn) {
- trackerFn = fn;
- }
- /**
- * An external consumer can register itself as the op handler. This way the
- * Tensor class can have chaining methods that call into ops via the op
- * handler.
- */
- function setOpHandler(handler) {
- opHandler = handler;
- }
- /**
- * Sets the deprecation warning function to be used by this file. This way the
- * Tensor class can be a leaf but still use the environment.
- */
- function setDeprecationWarningFn(fn) {
- deprecationWarningFn = fn;
- }
- /**
- * A `tf.Tensor` object represents an immutable, multidimensional array of
- * numbers that has a shape and a data type.
- *
- * See `tf.tensor` for details on how to create a `tf.Tensor`.
- */
- /** @doc {heading: 'Tensors', subheading: 'Classes'} */
- var Tensor = /** @class */ (function () {
- function Tensor(shape, dtype, dataId, id) {
- /** Whether this tensor has been globally kept. */
- this.kept = false;
- this.isDisposedInternal = false;
- this.shape = shape.slice();
- this.dtype = dtype || 'float32';
- this.size = sizeFromShape(shape);
- this.strides = computeStrides(shape);
- this.dataId = dataId;
- this.id = id;
- this.rankType = (this.rank < 5 ? this.rank.toString() : 'higher');
- }
- /** Flatten a Tensor to a 1D array. */
- /** @doc {heading: 'Tensors', subheading: 'Classes'} */
- Tensor.prototype.flatten = function () {
- this.throwIfDisposed();
- return this.as1D();
- };
- /** Converts a size-1 `tf.Tensor` to a `tf.Scalar`. */
- /** @doc {heading: 'Tensors', subheading: 'Classes'} */
- Tensor.prototype.asScalar = function () {
- this.throwIfDisposed();
- assert(this.size === 1, function () { return 'The array must have only 1 element.'; });
- return this.reshape([]);
- };
- /** Converts a `tf.Tensor` to a `tf.Tensor1D`. */
- /** @doc {heading: 'Tensors', subheading: 'Classes'} */
- Tensor.prototype.as1D = function () {
- this.throwIfDisposed();
- return this.reshape([this.size]);
- };
- /**
- * Converts a `tf.Tensor` to a `tf.Tensor2D`.
- *
- * @param rows Number of rows in `tf.Tensor2D`.
- * @param columns Number of columns in `tf.Tensor2D`.
- */
- /** @doc {heading: 'Tensors', subheading: 'Classes'} */
- Tensor.prototype.as2D = function (rows, columns) {
- this.throwIfDisposed();
- return this.reshape([rows, columns]);
- };
- /**
- * Converts a `tf.Tensor` to a `tf.Tensor3D`.
- *
- * @param rows Number of rows in `tf.Tensor3D`.
- * @param columns Number of columns in `tf.Tensor3D`.
- * @param depth Depth of `tf.Tensor3D`.
- */
- /** @doc {heading: 'Tensors', subheading: 'Classes'} */
- Tensor.prototype.as3D = function (rows, columns, depth) {
- this.throwIfDisposed();
- return this.reshape([rows, columns, depth]);
- };
- /**
- * Converts a `tf.Tensor` to a `tf.Tensor4D`.
- *
- * @param rows Number of rows in `tf.Tensor4D`.
- * @param columns Number of columns in `tf.Tensor4D`.
- * @param depth Depth of `tf.Tensor4D`.
- * @param depth2 4th dimension of `tf.Tensor4D`.
- */
- /** @doc {heading: 'Tensors', subheading: 'Classes'} */
- Tensor.prototype.as4D = function (rows, columns, depth, depth2) {
- this.throwIfDisposed();
- return this.reshape([rows, columns, depth, depth2]);
- };
- /**
- * Converts a `tf.Tensor` to a `tf.Tensor5D`.
- *
- * @param rows Number of rows in `tf.Tensor5D`.
- * @param columns Number of columns in `tf.Tensor5D`.
- * @param depth Depth of `tf.Tensor5D`.
- * @param depth2 4th dimension of `tf.Tensor5D`.
- * @param depth3 5th dimension of 'tf.Tensor5D'
- */
- /** @doc {heading: 'Tensors', subheading: 'Classes'} */
- Tensor.prototype.as5D = function (rows, columns, depth, depth2, depth3) {
- this.throwIfDisposed();
- return this.reshape([rows, columns, depth, depth2, depth3]);
- };
- /**
- * Casts a `tf.Tensor` to a specified dtype.
- *
- * @param dtype Data-type to cast the tensor to.
- */
- /** @doc {heading: 'Tensors', subheading: 'Classes'} */
- Tensor.prototype.asType = function (dtype) {
- this.throwIfDisposed();
- return opHandler.cast(this, dtype);
- };
- Object.defineProperty(Tensor.prototype, "rank", {
- get: function () {
- return this.shape.length;
- },
- enumerable: true,
- configurable: true
- });
- /**
- * Returns a promise of `tf.TensorBuffer` that holds the underlying data.
- */
- /** @doc {heading: 'Tensors', subheading: 'Classes'} */
- Tensor.prototype.buffer = function () {
- return __awaiter(this, void 0, void 0, function () {
- var vals;
- return __generator(this, function (_a) {
- switch (_a.label) {
- case 0: return [4 /*yield*/, this.data()];
- case 1:
- vals = _a.sent();
- return [2 /*return*/, opHandler.buffer(this.shape, this.dtype, vals)];
- }
- });
- });
- };
- /** Returns a `tf.TensorBuffer` that holds the underlying data. */
- /** @doc {heading: 'Tensors', subheading: 'Classes'} */
- Tensor.prototype.bufferSync = function () {
- return opHandler.buffer(this.shape, this.dtype, this.dataSync());
- };
- /**
- * Returns the tensor data as a nested array. The transfer of data is done
- * asynchronously.
- */
- /** @doc {heading: 'Tensors', subheading: 'Classes'} */
- Tensor.prototype.array = function () {
- return __awaiter(this, void 0, void 0, function () {
- var vals;
- return __generator(this, function (_a) {
- switch (_a.label) {
- case 0: return [4 /*yield*/, this.data()];
- case 1:
- vals = _a.sent();
- return [2 /*return*/, toNestedArray(this.shape, vals)];
- }
- });
- });
- };
- /**
- * Returns the tensor data as a nested array. The transfer of data is done
- * synchronously.
- */
- /** @doc {heading: 'Tensors', subheading: 'Classes'} */
- Tensor.prototype.arraySync = function () {
- return toNestedArray(this.shape, this.dataSync());
- };
- /**
- * Asynchronously downloads the values from the `tf.Tensor`. Returns a
- * promise of `TypedArray` that resolves when the computation has finished.
- */
- /** @doc {heading: 'Tensors', subheading: 'Classes'} */
- Tensor.prototype.data = function () {
- return __awaiter(this, void 0, void 0, function () {
- var data, bytes;
- return __generator(this, function (_a) {
- switch (_a.label) {
- case 0:
- this.throwIfDisposed();
- data = trackerFn().read(this.dataId);
- if (!(this.dtype === 'string')) return [3 /*break*/, 2];
- return [4 /*yield*/, data];
- case 1:
- bytes = _a.sent();
- try {
- return [2 /*return*/, bytes.map(function (b) { return decodeString(b); })];
- }
- catch (_b) {
- throw new Error('Failed to decode the string bytes into utf-8. ' +
- 'To get the original bytes, call tensor.bytes().');
- }
- _a.label = 2;
- case 2: return [2 /*return*/, data];
- }
- });
- });
- };
- /**
- * Synchronously downloads the values from the `tf.Tensor`. This blocks the
- * UI thread until the values are ready, which can cause performance issues.
- */
- /** @doc {heading: 'Tensors', subheading: 'Classes'} */
- Tensor.prototype.dataSync = function () {
- this.throwIfDisposed();
- var data = trackerFn().readSync(this.dataId);
- if (this.dtype === 'string') {
- try {
- return data.map(function (b) { return decodeString(b); });
- }
- catch (_a) {
- throw new Error('Failed to decode the string bytes into utf-8. ' +
- 'To get the original bytes, call tensor.bytes().');
- }
- }
- return data;
- };
- /** Returns the underlying bytes of the tensor's data. */
- Tensor.prototype.bytes = function () {
- return __awaiter(this, void 0, void 0, function () {
- var data;
- return __generator(this, function (_a) {
- switch (_a.label) {
- case 0:
- this.throwIfDisposed();
- return [4 /*yield*/, trackerFn().read(this.dataId)];
- case 1:
- data = _a.sent();
- if (this.dtype === 'string') {
- return [2 /*return*/, data];
- }
- else {
- return [2 /*return*/, new Uint8Array(data.buffer)];
- }
- return [2 /*return*/];
- }
- });
- });
- };
- /**
- * Disposes `tf.Tensor` from memory.
- */
- /** @doc {heading: 'Tensors', subheading: 'Classes'} */
- Tensor.prototype.dispose = function () {
- if (this.isDisposed) {
- return;
- }
- trackerFn().disposeTensor(this);
- this.isDisposedInternal = true;
- };
- Object.defineProperty(Tensor.prototype, "isDisposed", {
- get: function () {
- return this.isDisposedInternal;
- },
- enumerable: true,
- configurable: true
- });
- Tensor.prototype.throwIfDisposed = function () {
- if (this.isDisposed) {
- throw new Error("Tensor is disposed.");
- }
- };
- /** Casts the array to type `float32` */
- /** @doc {heading: 'Tensors', subheading: 'Classes'} */
- Tensor.prototype.toFloat = function () {
- return this.asType('float32');
- };
- /** Casts the array to type `int32` */
- /** @doc {heading: 'Tensors', subheading: 'Classes'} */
- Tensor.prototype.toInt = function () {
- return this.asType('int32');
- };
- /** Casts the array to type `bool` */
- /** @doc {heading: 'Tensors', subheading: 'Classes'} */
- Tensor.prototype.toBool = function () {
- return this.asType('bool');
- };
- /**
- * Prints the `tf.Tensor`. See `tf.print` for details.
- *
- * @param verbose Whether to print verbose information about the tensor,
- * including dtype and size.
- */
- /** @doc {heading: 'Tensors', subheading: 'Classes'} */
- Tensor.prototype.print = function (verbose) {
- if (verbose === void 0) { verbose = false; }
- return opHandler.print(this, verbose);
- };
- /**
- * Reshapes the tensor into the provided shape.
- * See `tf.reshape` for more details.
- *
- * @param newShape An array of integers defining the output tensor shape.
- */
- /** @doc {heading: 'Tensors', subheading: 'Classes'} */
- Tensor.prototype.reshape = function (newShape) {
- this.throwIfDisposed();
- return opHandler.reshape(this, newShape);
- };
- /**
- * Reshapes the tensor into the shape of the provided tensor.
- *
- * @param x The tensor of required shape.
- */
- /** @doc {heading: 'Tensors', subheading: 'Classes'} */
- Tensor.prototype.reshapeAs = function (x) {
- this.throwIfDisposed();
- return this.reshape(x.shape);
- };
- /**
- * Returns a `tf.Tensor` that has expanded rank, by inserting a dimension
- * into the tensor's shape. See `tf.expandDims` for details.
- *
- * @param axis The dimension index at which to insert shape of 1. Defaults to
- * 0 (the first dimension).
- */
- /** @doc {heading: 'Tensors', subheading: 'Classes'} */
- Tensor.prototype.expandDims = function (axis) {
- if (axis === void 0) { axis = 0; }
- return opHandler.expandDims(this, axis);
- };
- /**
- * Returns the cumulative sum of the `tf.Tensor` along `axis`.
- *
- * @param axis The axis along which to sum. Optional. Defaults to 0.
- * @param exclusive Whether to perform exclusive cumulative sum. Defaults to
- * false. If set to true then the sum of each tensor entry does not
- * include its own value, but only the values previous to it along the
- * specified axis.
- * @param reverse Whether to sum in the opposite direction. Defaults to
- * false.
- */
- /** @doc {heading: 'Tensors', subheading: 'Classes'} */
- Tensor.prototype.cumsum = function (axis, exclusive, reverse) {
- if (axis === void 0) { axis = 0; }
- if (exclusive === void 0) { exclusive = false; }
- if (reverse === void 0) { reverse = false; }
- return opHandler.cumsum(this, axis, exclusive, reverse);
- };
- /**
- * Returns a `tf.Tensor` with dimensions of size 1 removed from the shape.
- * See `tf.squeeze` for more details.
- *
- * @param axis A list of numbers. If specified, only squeezes the
- * dimensions listed. The dimension index starts at 0. It is an error to
- * squeeze a dimension that is not 1.
- */
- /** @doc {heading: 'Tensors', subheading: 'Classes'} */
- Tensor.prototype.squeeze = function (axis) {
- this.throwIfDisposed();
- return opHandler.squeeze(this, axis);
- };
- /** Returns a copy of the tensor. See `tf.clone` for details. */
- /** @doc {heading: 'Tensors', subheading: 'Classes'} */
- Tensor.prototype.clone = function () {
- this.throwIfDisposed();
- return opHandler.clone(this);
- };
- /**
- * Returns a human-readable description of the tensor. Useful for logging.
- */
- /** @doc {heading: 'Tensors', subheading: 'Classes'} */
- Tensor.prototype.toString = function (verbose) {
- if (verbose === void 0) { verbose = false; }
- var vals = this.dataSync();
- return tensorToString(vals, this.shape, this.dtype, verbose);
- };
- // Below is chain API that is not exposed to docs to avoid repetition. To
- // expose a method, move it above this comment and add @doc and jsdoc.
- Tensor.prototype.gather = function (indices, axis) {
- if (axis === void 0) { axis = 0; }
- this.throwIfDisposed();
- return opHandler.gather(this, indices, axis);
- };
- Tensor.prototype.matMul = function (b, transposeA, transposeB) {
- if (transposeA === void 0) { transposeA = false; }
- if (transposeB === void 0) { transposeB = false; }
- this.throwIfDisposed();
- return opHandler.matMul(this, b, transposeA, transposeB);
- };
- Tensor.prototype.dot = function (b) {
- this.throwIfDisposed();
- return opHandler.dot(this, b);
- };
- Tensor.prototype.norm = function (ord, axis, keepDims) {
- if (ord === void 0) { ord = 'euclidean'; }
- if (axis === void 0) { axis = null; }
- if (keepDims === void 0) { keepDims = false; }
- this.throwIfDisposed();
- return opHandler.norm(this, ord, axis, keepDims);
- };
- Tensor.prototype.slice = function (begin, size) {
- this.throwIfDisposed();
- return opHandler.slice(this, begin, size);
- };
- Tensor.prototype.reverse = function (axis) {
- this.throwIfDisposed();
- return opHandler.reverse(this, axis);
- };
- Tensor.prototype.concat = function (x, axis) {
- if (axis === void 0) { axis = 0; }
- this.throwIfDisposed();
- if (x instanceof Tensor) {
- x = [x];
- }
- return opHandler.concat([this].concat(x), axis);
- };
- Tensor.prototype.split = function (numOrSizeSplits, axis) {
- if (axis === void 0) { axis = 0; }
- this.throwIfDisposed();
- return opHandler.split(this, numOrSizeSplits, axis);
- };
- Tensor.prototype.stack = function (x, axis) {
- if (axis === void 0) { axis = 0; }
- return opHandler.stack([this, x], axis);
- };
- Tensor.prototype.unstack = function (axis) {
- if (axis === void 0) { axis = 0; }
- return opHandler.unstack(this, axis);
- };
- /**
- * @deprecated Use `tf.batchNorm` instead, and note the positional argument
- * change of scale, offset, and varianceEpsilon.
- */
- Tensor.prototype.batchNormalization = function (mean, variance, varianceEpsilon, scale, offset) {
- if (varianceEpsilon === void 0) { varianceEpsilon = .001; }
- deprecationWarningFn('tf.batchNormalization() is going away. ' +
- 'Use tf.batchNorm() instead, and note the positional argument change ' +
- 'of scale, offset, and varianceEpsilon');
- return this.batchNorm(mean, variance, offset, scale, varianceEpsilon);
- };
- // Reduction ops.
- Tensor.prototype.all = function (axis, keepDims) {
- if (axis === void 0) { axis = null; }
- if (keepDims === void 0) { keepDims = false; }
- this.throwIfDisposed();
- return opHandler.all(this, axis, keepDims);
- };
- Tensor.prototype.any = function (axis, keepDims) {
- if (axis === void 0) { axis = null; }
- if (keepDims === void 0) { keepDims = false; }
- this.throwIfDisposed();
- return opHandler.any(this, axis, keepDims);
- };
- Tensor.prototype.logSumExp = function (axis, keepDims) {
- if (axis === void 0) { axis = null; }
- if (keepDims === void 0) { keepDims = false; }
- this.throwIfDisposed();
- return opHandler.logSumExp(this, axis, keepDims);
- };
- Tensor.prototype.sum = function (axis, keepDims) {
- if (axis === void 0) { axis = null; }
- if (keepDims === void 0) { keepDims = false; }
- this.throwIfDisposed();
- return opHandler.sum(this, axis, keepDims);
- };
- Tensor.prototype.prod = function (axis, keepDims) {
- if (axis === void 0) { axis = null; }
- if (keepDims === void 0) { keepDims = false; }
- this.throwIfDisposed();
- return opHandler.prod(this, axis, keepDims);
- };
- Tensor.prototype.mean = function (axis, keepDims) {
- if (axis === void 0) { axis = null; }
- if (keepDims === void 0) { keepDims = false; }
- this.throwIfDisposed();
- return opHandler.mean(this, axis, keepDims);
- };
- Tensor.prototype.min = function (axis, keepDims) {
- if (axis === void 0) { axis = null; }
- if (keepDims === void 0) { keepDims = false; }
- this.throwIfDisposed();
- return opHandler.min(this, axis, keepDims);
- };
- Tensor.prototype.max = function (axis, keepDims) {
- if (axis === void 0) { axis = null; }
- if (keepDims === void 0) { keepDims = false; }
- this.throwIfDisposed();
- return opHandler.max(this, axis, keepDims);
- };
- Tensor.prototype.argMin = function (axis) {
- if (axis === void 0) { axis = null; }
- this.throwIfDisposed();
- return opHandler.argMin(this, axis);
- };
- Tensor.prototype.argMax = function (axis) {
- if (axis === void 0) { axis = null; }
- this.throwIfDisposed();
- return opHandler.argMax(this, axis);
- };
- // Transformations
- Tensor.prototype.cast = function (dtype) {
- this.throwIfDisposed();
- return opHandler.cast(this, dtype);
- };
- // Binary ops.
- Tensor.prototype.addStrict = function (x) {
- this.throwIfDisposed();
- return opHandler.addStrict(this, x);
- };
- Tensor.prototype.atan2 = function (x) {
- this.throwIfDisposed();
- return opHandler.atan2(this, x);
- };
- Tensor.prototype.sub = function (x) {
- this.throwIfDisposed();
- return opHandler.sub(this, x);
- };
- Tensor.prototype.subStrict = function (x) {
- this.throwIfDisposed();
- return opHandler.subStrict(this, x);
- };
- Tensor.prototype.pow = function (exp) {
- this.throwIfDisposed();
- return opHandler.pow(this, exp);
- };
- Tensor.prototype.powStrict = function (exp) {
- this.throwIfDisposed();
- return opHandler.powStrict(this, exp);
- };
- Tensor.prototype.mul = function (x) {
- this.throwIfDisposed();
- return opHandler.mul(this, x);
- };
- Tensor.prototype.mulStrict = function (x) {
- this.throwIfDisposed();
- return opHandler.mulStrict(this, x);
- };
- Tensor.prototype.floorDiv = function (x) {
- this.throwIfDisposed();
- return opHandler.floorDiv(this, x);
- };
- Tensor.prototype.divStrict = function (x) {
- this.throwIfDisposed();
- return opHandler.divStrict(this, x);
- };
- Tensor.prototype.minimum = function (x) {
- this.throwIfDisposed();
- return opHandler.minimum(this, x);
- };
- Tensor.prototype.minimumStrict = function (x) {
- this.throwIfDisposed();
- return opHandler.minimumStrict(this, x);
- };
- Tensor.prototype.maximum = function (x) {
- this.throwIfDisposed();
- return opHandler.maximum(this, x);
- };
- Tensor.prototype.maximumStrict = function (x) {
- this.throwIfDisposed();
- return opHandler.maximumStrict(this, x);
- };
- Tensor.prototype.mod = function (x) {
- this.throwIfDisposed();
- return opHandler.mod(this, x);
- };
- Tensor.prototype.modStrict = function (x) {
- this.throwIfDisposed();
- return opHandler.modStrict(this, x);
- };
- Tensor.prototype.squaredDifferenceStrict = function (x) {
- this.throwIfDisposed();
- return opHandler.squaredDifferenceStrict(this, x);
- };
- // Compare ops.
- Tensor.prototype.notEqual = function (x) {
- this.throwIfDisposed();
- return opHandler.notEqual(this, x);
- };
- Tensor.prototype.notEqualStrict = function (x) {
- this.throwIfDisposed();
- return opHandler.notEqualStrict(this, x);
- };
- Tensor.prototype.less = function (x) {
- this.throwIfDisposed();
- return opHandler.less(this, x);
- };
- Tensor.prototype.lessStrict = function (x) {
- this.throwIfDisposed();
- return opHandler.lessStrict(this, x);
- };
- Tensor.prototype.equal = function (x) {
- this.throwIfDisposed();
- return opHandler.equal(this, x);
- };
- Tensor.prototype.equalStrict = function (x) {
- this.throwIfDisposed();
- return opHandler.equalStrict(this, x);
- };
- Tensor.prototype.lessEqual = function (x) {
- this.throwIfDisposed();
- return opHandler.lessEqual(this, x);
- };
- Tensor.prototype.lessEqualStrict = function (x) {
- this.throwIfDisposed();
- return opHandler.lessEqualStrict(this, x);
- };
- Tensor.prototype.greater = function (x) {
- this.throwIfDisposed();
- return opHandler.greater(this, x);
- };
- Tensor.prototype.greaterStrict = function (x) {
- this.throwIfDisposed();
- return opHandler.greaterStrict(this, x);
- };
- Tensor.prototype.greaterEqual = function (x) {
- this.throwIfDisposed();
- return opHandler.greaterEqual(this, x);
- };
- Tensor.prototype.greaterEqualStrict = function (x) {
- this.throwIfDisposed();
- return opHandler.greaterEqualStrict(this, x);
- };
- // Compare ops.
- Tensor.prototype.logicalAnd = function (x) {
- this.throwIfDisposed();
- return opHandler.logicalAnd(this, x);
- };
- Tensor.prototype.logicalOr = function (x) {
- this.throwIfDisposed();
- return opHandler.logicalOr(this, x);
- };
- Tensor.prototype.logicalNot = function () {
- this.throwIfDisposed();
- return opHandler.logicalNot(this);
- };
- Tensor.prototype.logicalXor = function (x) {
- this.throwIfDisposed();
- return opHandler.logicalXor(this, x);
- };
- Tensor.prototype.where = function (condition, x) {
- this.throwIfDisposed();
- return opHandler.where(condition, this, x);
- };
- // Unary ops.
- Tensor.prototype.neg = function () {
- this.throwIfDisposed();
- return opHandler.neg(this);
- };
- Tensor.prototype.ceil = function () {
- this.throwIfDisposed();
- return opHandler.ceil(this);
- };
- Tensor.prototype.floor = function () {
- this.throwIfDisposed();
- return opHandler.floor(this);
- };
- Tensor.prototype.sign = function () {
- this.throwIfDisposed();
- return opHandler.sign(this);
- };
- Tensor.prototype.isNaN = function () {
- this.throwIfDisposed();
- return opHandler.isNaN(this);
- };
- Tensor.prototype.isInf = function () {
- this.throwIfDisposed();
- return opHandler.isInf(this);
- };
- Tensor.prototype.isFinite = function () {
- this.throwIfDisposed();
- return opHandler.isFinite(this);
- };
- Tensor.prototype.exp = function () {
- this.throwIfDisposed();
- return opHandler.exp(this);
- };
- Tensor.prototype.expm1 = function () {
- this.throwIfDisposed();
- return opHandler.expm1(this);
- };
- Tensor.prototype.log = function () {
- this.throwIfDisposed();
- return opHandler.log(this);
- };
- Tensor.prototype.log1p = function () {
- this.throwIfDisposed();
- return opHandler.log1p(this);
- };
- Tensor.prototype.sqrt = function () {
- this.throwIfDisposed();
- return opHandler.sqrt(this);
- };
- Tensor.prototype.rsqrt = function () {
- this.throwIfDisposed();
- return opHandler.rsqrt(this);
- };
- Tensor.prototype.square = function () {
- this.throwIfDisposed();
- return opHandler.square(this);
- };
- Tensor.prototype.reciprocal = function () {
- this.throwIfDisposed();
- return opHandler.reciprocal(this);
- };
- Tensor.prototype.abs = function () {
- this.throwIfDisposed();
- return opHandler.abs(this);
- };
- Tensor.prototype.clipByValue = function (min, max) {
- this.throwIfDisposed();
- return opHandler.clipByValue(this, min, max);
- };
- Tensor.prototype.relu = function () {
- this.throwIfDisposed();
- return opHandler.relu(this);
- };
- Tensor.prototype.relu6 = function () {
- this.throwIfDisposed();
- return opHandler.relu6(this);
- };
- Tensor.prototype.elu = function () {
- this.throwIfDisposed();
- return opHandler.elu(this);
- };
- Tensor.prototype.selu = function () {
- this.throwIfDisposed();
- return opHandler.selu(this);
- };
- Tensor.prototype.leakyRelu = function (alpha) {
- if (alpha === void 0) { alpha = 0.2; }
- this.throwIfDisposed();
- return opHandler.leakyRelu(this, alpha);
- };
- Tensor.prototype.prelu = function (alpha) {
- this.throwIfDisposed();
- return opHandler.prelu(this, alpha);
- };
- Tensor.prototype.sigmoid = function () {
- this.throwIfDisposed();
- return opHandler.sigmoid(this);
- };
- Tensor.prototype.logSigmoid = function () {
- this.throwIfDisposed();
- return opHandler.logSigmoid(this);
- };
- Tensor.prototype.softplus = function () {
- this.throwIfDisposed();
- return opHandler.softplus(this);
- };
- Tensor.prototype.zerosLike = function () {
- this.throwIfDisposed();
- return opHandler.zerosLike(this);
- };
- Tensor.prototype.onesLike = function () {
- this.throwIfDisposed();
- return opHandler.onesLike(this);
- };
- Tensor.prototype.sin = function () {
- this.throwIfDisposed();
- return opHandler.sin(this);
- };
- Tensor.prototype.cos = function () {
- this.throwIfDisposed();
- return opHandler.cos(this);
- };
- Tensor.prototype.tan = function () {
- this.throwIfDisposed();
- return opHandler.tan(this);
- };
- Tensor.prototype.asin = function () {
- this.throwIfDisposed();
- return opHandler.asin(this);
- };
- Tensor.prototype.acos = function () {
- this.throwIfDisposed();
- return opHandler.acos(this);
- };
- Tensor.prototype.atan = function () {
- this.throwIfDisposed();
- return opHandler.atan(this);
- };
- Tensor.prototype.sinh = function () {
- this.throwIfDisposed();
- return opHandler.sinh(this);
- };
- Tensor.prototype.cosh = function () {
- this.throwIfDisposed();
- return opHandler.cosh(this);
- };
- Tensor.prototype.tanh = function () {
- this.throwIfDisposed();
- return opHandler.tanh(this);
- };
- Tensor.prototype.asinh = function () {
- this.throwIfDisposed();
- return opHandler.asinh(this);
- };
- Tensor.prototype.acosh = function () {
- this.throwIfDisposed();
- return opHandler.acosh(this);
- };
- Tensor.prototype.atanh = function () {
- this.throwIfDisposed();
- return opHandler.atanh(this);
- };
- Tensor.prototype.erf = function () {
- this.throwIfDisposed();
- return opHandler.erf(this);
- };
- Tensor.prototype.round = function () {
- this.throwIfDisposed();
- return opHandler.round(this);
- };
- Tensor.prototype.step = function (alpha) {
- if (alpha === void 0) { alpha = 0.0; }
- this.throwIfDisposed();
- return opHandler.step(this, alpha);
- };
- Tensor.prototype.softmax = function (dim) {
- if (dim === void 0) { dim = -1; }
- this.throwIfDisposed();
- return opHandler.softmax(this, dim);
- };
- Tensor.prototype.logSoftmax = function (axis) {
- if (axis === void 0) { axis = -1; }
- this.throwIfDisposed();
- return opHandler.logSoftmax(this, axis);
- };
- // Image ops.
- Tensor.prototype.resizeBilinear = function (newShape2D, alignCorners) {
- if (alignCorners === void 0) { alignCorners = false; }
- this.throwIfDisposed();
- return opHandler.image.resizeBilinear(this, newShape2D, alignCorners);
- };
- Tensor.prototype.resizeNearestNeighbor = function (newShape2D, alignCorners) {
- if (alignCorners === void 0) { alignCorners = false; }
- this.throwIfDisposed();
- return opHandler.image.resizeNearestNeighbor(this, newShape2D, alignCorners);
- };
- // Convolutions.
- Tensor.prototype.conv1d = function (filter, stride, pad, dataFormat, dilation, dimRoundingMode) {
- if (dataFormat === void 0) { dataFormat = 'NWC'; }
- if (dilation === void 0) { dilation = 1; }
- this.throwIfDisposed();
- return opHandler.conv1d(this, filter, stride, pad, dataFormat, dilation, dimRoundingMode);
- };
- Tensor.prototype.conv2d = function (filter, strides, pad, dataFormat, dilations, dimRoundingMode) {
- if (dataFormat === void 0) { dataFormat = 'NHWC'; }
- if (dilations === void 0) { dilations = [1, 1]; }
- this.throwIfDisposed();
- return opHandler.conv2d(this, filter, strides, pad, dataFormat, dilations, dimRoundingMode);
- };
- Tensor.prototype.conv2dTranspose = function (filter, outputShape, strides, pad, dimRoundingMode) {
- this.throwIfDisposed();
- return opHandler.conv2dTranspose(this, filter, outputShape, strides, pad, dimRoundingMode);
- };
- Tensor.prototype.depthwiseConv2D = function (filter, strides, pad, dataFormat, dilations, dimRoundingMode) {
- if (dataFormat === void 0) { dataFormat = 'NHWC'; }
- if (dilations === void 0) { dilations = [1, 1]; }
- this.throwIfDisposed();
- return opHandler.depthwiseConv2d(this, filter, strides, pad, dataFormat, dilations, dimRoundingMode);
- };
- Tensor.prototype.separableConv2d = function (depthwiseFilter, pointwiseFilter, strides, pad, dilation, dataFormat) {
- if (dilation === void 0) { dilation = [1, 1]; }
- if (dataFormat === void 0) { dataFormat = 'NHWC'; }
- this.throwIfDisposed();
- return opHandler.separableConv2d(this, depthwiseFilter, pointwiseFilter, strides, pad, dilation, dataFormat);
- };
- // Pooling.
- Tensor.prototype.avgPool = function (filterSize, strides, pad, dimRoundingMode) {
- this.throwIfDisposed();
- return opHandler.avgPool(this, filterSize, strides, pad, dimRoundingMode);
- };
- Tensor.prototype.maxPool = function (filterSize, strides, pad, dimRoundingMode) {
- this.throwIfDisposed();
- return opHandler.maxPool(this, filterSize, strides, pad, dimRoundingMode);
- };
- Tensor.prototype.localResponseNormalization = function (radius, bias, alpha, beta) {
- if (radius === void 0) { radius = 5; }
- if (bias === void 0) { bias = 1; }
- if (alpha === void 0) { alpha = 1; }
- if (beta === void 0) { beta = 0.5; }
- return opHandler.localResponseNormalization(this, radius, bias, alpha, beta);
- };
- Tensor.prototype.pool = function (windowShape, poolingType, padding, dilationRate, strides) {
- this.throwIfDisposed();
- return opHandler.pool(this, windowShape, poolingType, padding, dilationRate, strides);
- };
- Tensor.prototype.variable = function (trainable, name, dtype) {
- if (trainable === void 0) { trainable = true; }
- this.throwIfDisposed();
- return trackerFn().makeVariable(this, trainable, name, dtype);
- };
- Tensor.prototype.unsortedSegmentSum = function (segmentIds, numSegments) {
- this.throwIfDisposed();
- return opHandler.unsortedSegmentSum(this, segmentIds, numSegments);
- };
- Tensor.prototype.batchToSpaceND = function (blockShape, crops) {
- this.throwIfDisposed();
- return opHandler.batchToSpaceND(this, blockShape, crops);
- };
- Tensor.prototype.spaceToBatchND = function (blockShape, paddings) {
- this.throwIfDisposed();
- return opHandler.spaceToBatchND(this, blockShape, paddings);
- };
- Tensor.prototype.topk = function (k, sorted) {
- if (k === void 0) { k = 1; }
- if (sorted === void 0) { sorted = true; }
- this.throwIfDisposed();
- return opHandler.topk(this, k, sorted);
- };
- Tensor.prototype.stridedSlice = function (begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask) {
- if (beginMask === void 0) { beginMask = 0; }
- if (endMask === void 0) { endMask = 0; }
- if (ellipsisMask === void 0) { ellipsisMask = 0; }
- if (newAxisMask === void 0) { newAxisMask = 0; }
- if (shrinkAxisMask === void 0) { shrinkAxisMask = 0; }
- this.throwIfDisposed();
- return opHandler.stridedSlice(this, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask);
- };
- Tensor.prototype.depthToSpace = function (blockSize, dataFormat) {
- this.throwIfDisposed();
- return opHandler.depthToSpace(this, blockSize, dataFormat);
- };
- Tensor.prototype.fft = function () {
- this.throwIfDisposed();
- return opHandler.spectral.fft(this);
- };
- Tensor.prototype.ifft = function () {
- this.throwIfDisposed();
- return opHandler.spectral.ifft(this);
- };
- Tensor.prototype.rfft = function () {
- this.throwIfDisposed();
- return opHandler.spectral.rfft(this);
- };
- Tensor.prototype.irfft = function () {
- this.throwIfDisposed();
- return opHandler.spectral.irfft(this);
- };
- return Tensor;
- }());
- Object.defineProperty(Tensor, Symbol.hasInstance, {
- value: function (instance) {
- return !!instance && instance.dataId != null && instance.shape != null &&
- instance.dtype != null;
- }
- });
- /**
- * A mutable `tf.Tensor`, useful for persisting state, e.g. for training.
- */
- /** @doc {heading: 'Tensors', subheading: 'Classes'} */
- var Variable = /** @class */ (function (_super) {
- __extends(Variable, _super);
- function Variable(initialValue, trainable, name, tensorId) {
- var _this = _super.call(this, initialValue.shape, initialValue.dtype, initialValue.dataId, tensorId) || this;
- _this.trainable = trainable;
- _this.name = name;
- return _this;
- }
- /**
- * Assign a new `tf.Tensor` to this variable. The new `tf.Tensor` must have
- * the same shape and dtype as the old `tf.Tensor`.
- *
- * @param newValue New tensor to be assigned to this variable.
- */
- /** @doc {heading: 'Tensors', subheading: 'Classes'} */
- Variable.prototype.assign = function (newValue) {
- if (newValue.dtype !== this.dtype) {
- throw new Error("dtype of the new value (" + newValue.dtype + ") and " +
- ("previous value (" + this.dtype + ") must match"));
- }
- if (!arraysEqual(newValue.shape, this.shape)) {
- throw new Error("shape of the new value (" + newValue.shape + ") and " +
- ("previous value (" + this.shape + ") must match"));
- }
- trackerFn().disposeTensor(this);
- this.dataId = newValue.dataId;
- trackerFn().incRef(this, null /* backend */);
- };
- Variable.prototype.dispose = function () {
- trackerFn().disposeVariable(this);
- this.isDisposedInternal = true;
- };
- return Variable;
- }(Tensor));
- Object.defineProperty(Variable, Symbol.hasInstance, {
- value: function (instance) {
- return instance instanceof Tensor && instance.assign != null &&
- instance.assign instanceof Function;
- }
- });
-
- /**
- * @license
- * Copyright 2017 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- (function (Rank) {
- Rank["R0"] = "R0";
- Rank["R1"] = "R1";
- Rank["R2"] = "R2";
- Rank["R3"] = "R3";
- Rank["R4"] = "R4";
- Rank["R5"] = "R5";
- Rank["R6"] = "R6";
- })(exports.Rank || (exports.Rank = {}));
- // Looks for upcasting types. Used, for example, in operations with mixed dtype
- // inputs.
- var UpcastInt32AndMap;
- (function (UpcastInt32AndMap) {
- UpcastInt32AndMap["float32"] = "float32";
- UpcastInt32AndMap["int32"] = "int32";
- UpcastInt32AndMap["bool"] = "int32";
- UpcastInt32AndMap["complex64"] = "complex64";
- })(UpcastInt32AndMap || (UpcastInt32AndMap = {}));
- var UpcastBoolAndMap;
- (function (UpcastBoolAndMap) {
- UpcastBoolAndMap["float32"] = "float32";
- UpcastBoolAndMap["int32"] = "int32";
- UpcastBoolAndMap["bool"] = "bool";
- UpcastBoolAndMap["complex64"] = "complex64";
- })(UpcastBoolAndMap || (UpcastBoolAndMap = {}));
- var UpcastFloat32AndMap;
- (function (UpcastFloat32AndMap) {
- UpcastFloat32AndMap["float32"] = "float32";
- UpcastFloat32AndMap["int32"] = "float32";
- UpcastFloat32AndMap["bool"] = "float32";
- UpcastFloat32AndMap["complex64"] = "complex64";
- })(UpcastFloat32AndMap || (UpcastFloat32AndMap = {}));
- var UpcastComplex64AndMap;
- (function (UpcastComplex64AndMap) {
- UpcastComplex64AndMap["float32"] = "complex64";
- UpcastComplex64AndMap["int32"] = "complex64";
- UpcastComplex64AndMap["bool"] = "complex64";
- UpcastComplex64AndMap["complex64"] = "complex64";
- })(UpcastComplex64AndMap || (UpcastComplex64AndMap = {}));
- var upcastTypeMap = {
- 'float32': UpcastFloat32AndMap,
- 'int32': UpcastInt32AndMap,
- 'bool': UpcastBoolAndMap,
- 'complex64': UpcastComplex64AndMap
- };
- function upcastType(typeA, typeB) {
- if (typeA === 'string' || typeB === 'string') {
- if (typeA === 'string' && typeB === 'string') {
- return 'string';
- }
- throw new Error("Can not upcast " + typeA + " with " + typeB);
- }
- return upcastTypeMap[typeA][typeB];
- }
- /** Returns the output type after summation. */
- function sumOutType(type) {
- return upcastType(type, 'int32');
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function makeTypesMatch(a, b) {
- if (a.dtype === b.dtype) {
- return [a, b];
- }
- var dtype = upcastType(a.dtype, b.dtype);
- return [a.cast(dtype), b.cast(dtype)];
- }
- function assertTypesMatch(a, b) {
- assert(a.dtype === b.dtype, function () { return "The dtypes of the first(" + a.dtype + ") and" +
- (" second(" + b.dtype + ") input must match"); });
- }
- function isTensorInList(tensor, tensorList) {
- return tensorList.some(function (x) { return x.id === tensor.id; });
- }
- /**
- * Extracts any `Tensor`s found within the provided object.
- *
- * @param container an object that may be a `Tensor` or may directly contain
- * `Tensor`s, such as a `Tensor[]` or `{key: Tensor, ...}`. In general it
- * is safe to pass any object here, except that `Promise`s are not
- * supported.
- * @returns An array of `Tensors` found within the passed object. If the
- * argument is simply a `Tensor', a list containing that `Tensor` is
- * returned. If the object is not a `Tensor` or does not
- * contain `Tensors`, an empty list is returned.
- */
- function getTensorsInContainer(result) {
- var list = [];
- var seen = new Set();
- walkTensorContainer(result, list, seen);
- return list;
- }
- function walkTensorContainer(container, list, seen) {
- if (container == null) {
- return;
- }
- if (container instanceof Tensor) {
- list.push(container);
- return;
- }
- if (!isIterable(container)) {
- return;
- }
- // Iteration over keys works also for arrays.
- var iterable = container;
- for (var k in iterable) {
- var val = iterable[k];
- if (!seen.has(val)) {
- seen.add(val);
- walkTensorContainer(val, list, seen);
- }
- }
- }
- // tslint:disable-next-line:no-any
- function isIterable(obj) {
- return Array.isArray(obj) || typeof obj === 'object';
- }
-
- var tensor_util = /*#__PURE__*/Object.freeze({
- makeTypesMatch: makeTypesMatch,
- assertTypesMatch: assertTypesMatch,
- isTensorInList: isTensorInList,
- getTensorsInContainer: getTensorsInContainer
- });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var EngineState = /** @class */ (function () {
- function EngineState() {
- // Public since optimizers will use it.
- this.registeredVariables = {};
- this.nextTapeNodeId = 0;
- this.numBytes = 0;
- this.numTensors = 0;
- this.numStringTensors = 0;
- this.numDataBuffers = 0;
- // Number of nested tf.grad() statements when computing higher-order
- // gradients. E.g. `1` for first-order gradients and `2` for second-order
- // gradients. Used to track if the tape should be removed after a backprop.
- this.gradientDepth = 0;
- // Number of nested kernel calls. When kernel depth is greater than 1, we turn
- // off the tape.
- this.kernelDepth = 0;
- this.scopeStack = [];
- /**
- * Keeps track of the number of data moves during a kernel execution. We
- * maintain a stack since kernels can call other kernels, recursively.
- */
- this.numDataMovesStack = [];
- this.nextScopeId = 0;
- this.tensorInfo = new WeakMap();
- this.profiling = false;
- this.activeProfile = { newBytes: 0, newTensors: 0, peakBytes: 0, kernels: [], result: null };
- }
- EngineState.prototype.dispose = function () {
- for (var variableName in this.registeredVariables) {
- this.registeredVariables[variableName].dispose();
- }
- };
- return EngineState;
- }());
- var Engine = /** @class */ (function () {
- function Engine(ENV) {
- this.ENV = ENV;
- this.registry = {};
- this.registryFactory = {};
- this.pendingBackendInitId = 0;
- this.state = new EngineState();
- }
- Engine.prototype.ready = function () {
- return __awaiter(this, void 0, void 0, function () {
- var sortedBackends, i, backendName, success;
- return __generator(this, function (_a) {
- switch (_a.label) {
- case 0:
- if (this.pendingBackendInit != null) {
- return [2 /*return*/, this.pendingBackendInit.then(function () { })];
- }
- if (this.backendInstance != null) {
- return [2 /*return*/];
- }
- sortedBackends = this.getSortedBackends();
- i = 0;
- _a.label = 1;
- case 1:
- if (!(i < sortedBackends.length)) return [3 /*break*/, 5];
- backendName = sortedBackends[i];
- return [4 /*yield*/, this.initializeBackend(backendName).success];
- case 2:
- success = _a.sent();
- if (!success) return [3 /*break*/, 4];
- return [4 /*yield*/, this.setBackend(backendName)];
- case 3:
- _a.sent();
- return [2 /*return*/];
- case 4:
- i++;
- return [3 /*break*/, 1];
- case 5: throw new Error("Could not initialize any backends, all backend initializations " +
- "failed.");
- }
- });
- });
- };
- Object.defineProperty(Engine.prototype, "backend", {
- get: function () {
- if (this.pendingBackendInit != null) {
- throw new Error("Backend '" + this.backendName + "' has not yet been initialized. Make " +
- "sure to await tf.ready() or await tf.setBackend() before calling " +
- "other methods");
- }
- if (this.backendInstance == null) {
- var _a = this.initializeBackendsAndReturnBest(), name_1 = _a.name, asyncInit = _a.asyncInit;
- if (asyncInit) {
- throw new Error("The highest priority backend '" + name_1 + "' has not yet been " +
- "initialized. Make sure to await tf.ready() or " +
- "await tf.setBackend() before calling other methods");
- }
- this.setBackend(name_1);
- }
- return this.backendInstance;
- },
- enumerable: true,
- configurable: true
- });
- Engine.prototype.backendNames = function () {
- return Object.keys(this.registryFactory);
- };
- Engine.prototype.findBackend = function (backendName) {
- if (!(backendName in this.registry)) {
- // If the backend hasn't been initialized but we have a registry entry for
- // it, initialize it and return it.
- if (backendName in this.registryFactory) {
- var asyncInit = this.initializeBackend(backendName).asyncInit;
- if (asyncInit) {
- // Backend is not ready yet.
- return null;
- }
- }
- else {
- return null;
- }
- }
- return this.registry[backendName];
- };
- Engine.prototype.findBackendFactory = function (backendName) {
- if (!(backendName in this.registryFactory)) {
- return null;
- }
- return this.registryFactory[backendName].factory;
- };
- Engine.prototype.registerBackend = function (backendName, factory, priority) {
- if (priority === void 0) { priority = 1; }
- if (backendName in this.registryFactory) {
- console.warn(backendName + " backend was already registered. " +
- "Reusing existing backend factory.");
- return false;
- }
- this.registryFactory[backendName] = { factory: factory, priority: priority };
- return true;
- };
- Engine.prototype.setBackend = function (backendName) {
- return __awaiter(this, void 0, void 0, function () {
- var _a, success, asyncInit, result, _b;
- return __generator(this, function (_c) {
- switch (_c.label) {
- case 0:
- if (this.registryFactory[backendName] == null) {
- throw new Error("Backend name '" + backendName + "' not found in registry");
- }
- this.backendName = backendName;
- if (!(this.registry[backendName] == null)) return [3 /*break*/, 4];
- this.backendInstance = null;
- _a = this.initializeBackend(backendName), success = _a.success, asyncInit = _a.asyncInit;
- if (!asyncInit) return [3 /*break*/, 2];
- return [4 /*yield*/, success];
- case 1:
- _b = _c.sent();
- return [3 /*break*/, 3];
- case 2:
- _b = success;
- _c.label = 3;
- case 3:
- result = _b;
- if (!result) {
- return [2 /*return*/, false];
- }
- _c.label = 4;
- case 4:
- this.backendInstance = this.registry[backendName];
- this.setupRegisteredKernels();
- // Reset the profiler.
- this.profiler = new Profiler(this.backendInstance);
- return [2 /*return*/, true];
- }
- });
- });
- };
- Engine.prototype.setupRegisteredKernels = function () {
- var _this = this;
- var kernels = getKernelsForBackend(this.backendName);
- kernels.forEach(function (kernel) {
- if (kernel.setupFunc != null) {
- kernel.setupFunc(_this.backendInstance);
- }
- });
- };
- Engine.prototype.disposeRegisteredKernels = function (backendName) {
- var _this = this;
- var kernels = getKernelsForBackend(backendName);
- kernels.forEach(function (kernel) {
- if (kernel.disposeFunc != null) {
- kernel.disposeFunc(_this.registry[backendName]);
- }
- });
- };
- /**
- * Initializes a backend by looking up the backend name in the factory
- * registry and calling the factory method. Returns a boolean representing
- * whether the initialization of the backend suceeded. Throws an error if
- * there is no backend in the factory registry.
- */
- Engine.prototype.initializeBackend = function (backendName) {
- var _this = this;
- var registryFactoryEntry = this.registryFactory[backendName];
- if (registryFactoryEntry == null) {
- throw new Error("Cannot initialize backend " + backendName + ", no registration found.");
- }
- try {
- var backend = registryFactoryEntry.factory();
- // Test if the factory returns a promise.
- if (Promise.resolve(backend) === backend) {
- var promiseId_1 = ++this.pendingBackendInitId;
- var success = backend
- .then(function (backendInstance) {
- // Outdated promise. Another backend was set in the meantime.
- if (promiseId_1 < _this.pendingBackendInitId) {
- return false;
- }
- _this.registry[backendName] = backendInstance;
- _this.pendingBackendInit = null;
- return true;
- })
- .catch(function (err) {
- // Outdated promise. Another backend was set in the meantime.
- if (promiseId_1 < _this.pendingBackendInitId) {
- return false;
- }
- _this.pendingBackendInit = null;
- console.warn("Initialization of backend " + backendName + " failed");
- console.warn(err.stack || err.message);
- return false;
- });
- this.pendingBackendInit = success;
- return { success: success, asyncInit: true };
- }
- else {
- this.registry[backendName] = backend;
- return { success: true, asyncInit: false };
- }
- }
- catch (err) {
- console.warn("Initialization of backend " + backendName + " failed");
- console.warn(err.stack || err.message);
- return { success: false, asyncInit: false };
- }
- };
- Engine.prototype.removeBackend = function (backendName) {
- if (!(backendName in this.registryFactory)) {
- throw new Error(backendName + " backend not found in registry");
- }
- if (this.backendName === backendName && this.pendingBackendInit != null) {
- // There is a pending promise of the backend we want to remove. Make it
- // obsolete.
- this.pendingBackendInitId++;
- }
- if (backendName in this.registry) {
- this.disposeRegisteredKernels(backendName);
- this.registry[backendName].dispose();
- delete this.registry[backendName];
- }
- delete this.registryFactory[backendName];
- // Unset the backend if it is active.
- if (this.backendName === backendName) {
- this.pendingBackendInit = null;
- this.backendName = null;
- this.backendInstance = null;
- }
- };
- Engine.prototype.getSortedBackends = function () {
- var _this = this;
- if (Object.keys(this.registryFactory).length === 0) {
- throw new Error('No backend found in registry.');
- }
- return Object.keys(this.registryFactory).sort(function (a, b) {
- // Highest priority comes first.
- return _this.registryFactory[b].priority -
- _this.registryFactory[a].priority;
- });
- };
- Engine.prototype.initializeBackendsAndReturnBest = function () {
- var sortedBackends = this.getSortedBackends();
- for (var i = 0; i < sortedBackends.length; i++) {
- var backendName = sortedBackends[i];
- var _a = this.initializeBackend(backendName), success = _a.success, asyncInit = _a.asyncInit;
- if (asyncInit || success) {
- return { name: backendName, asyncInit: asyncInit };
- }
- }
- throw new Error("Could not initialize any backends, all backend initializations " +
- "failed.");
- };
- Engine.prototype.moveData = function (destBackend, dataId) {
- var info = this.state.tensorInfo.get(dataId);
- var srcBackend = info.backend;
- var values = this.readSync(dataId);
- // Delete the tensor from the old backend and move it to the new
- // backend.
- srcBackend.disposeData(dataId);
- info.backend = destBackend;
- destBackend.move(dataId, values, info.shape, info.dtype);
- if (this.shouldCheckForMemLeaks()) {
- // Track the number of moves during a kernel execution to correctly
- // detect memory leaks.
- this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]++;
- }
- };
- Engine.prototype.tidy = function (nameOrFn, fn) {
- var _this = this;
- var name = null;
- if (fn == null) {
- // Called with only 1 argument.
- if (typeof nameOrFn !== 'function') {
- throw new Error('Please provide a function to tidy()');
- }
- fn = nameOrFn;
- }
- else {
- // Called with 2 arguments.
- if (typeof nameOrFn !== 'string' && !(nameOrFn instanceof String)) {
- throw new Error('When calling with two arguments, the first argument ' +
- 'to tidy() must be a string');
- }
- if (typeof fn !== 'function') {
- throw new Error('When calling with two arguments, the 2nd argument ' +
- 'to tidy() must be a function');
- }
- name = nameOrFn;
- // TODO(nsthorat,smilkov): Do operation logging and performance
- // profiling.
- }
- var result;
- return this.scopedRun(function () { return _this.startScope(name); }, function () { return _this.endScope(result); }, function () {
- result = fn();
- if (result instanceof Promise) {
- console.error('Cannot return a Promise inside of tidy.');
- }
- return result;
- });
- };
- Engine.prototype.scopedRun = function (start, end, f) {
- start();
- try {
- var res = f();
- end();
- return res;
- }
- catch (ex) {
- end();
- throw ex;
- }
- };
- Engine.prototype.nextTensorId = function () {
- return Engine.nextTensorId++;
- };
- Engine.prototype.nextVariableId = function () {
- return Engine.nextVariableId++;
- };
- /**
- * This method is called instead of the public-facing tensor.clone() when
- * saving a tensor for backwards pass. It makes sure to add the clone
- * operation to the tape regardless of being called inside a kernel
- * execution.
- *
- * This method will go away once all kernels are modularized since we won't
- * need to turn off the tape inside runKernel().
- */
- Engine.prototype.clone = function (x) {
- var y = this.makeTensorFromDataId(x.dataId, x.shape, x.dtype);
- var inputs = { x: x };
- var grad = function (dy) { return ({ x: function () { return dy.toFloat(); } }); };
- var saved = [];
- this.addTapeNode(this.state.activeScope.name, inputs, [y], grad, saved, {});
- return y;
- };
- /**
- * Execute a kernel with the given name and return the output tensor.
- *
- * @param kernelName The name of the kernel to execute.
- * @param inputs A map of input names to tensors.
- * @param attrs A map of attribute names to their values. An attribute is a
- * primitive (non-tensor) input to the kernel.
- * @param inputsToSave A list of tensors, inputs to save for the backprop
- * computation.
- * @param outputsToSave A list of booleans, specifying which output to save
- * for the backprop computation. These are booleans since the output
- * tensors are not visible to the user.
- */
- Engine.prototype.runKernel = function (kernelName, inputs, attrs, inputsToSave, outputsToSave) {
- var forwardFunc = null;
- var backwardsFunc = null;
- // Call runKernel as a stop-gap until we modularize all kernels.
- // Once we modularize all kernels, we will remove the existing
- // `runKernelFunc`.
- return this.runKernelFunc(forwardFunc, inputs, backwardsFunc, kernelName, attrs, inputsToSave, outputsToSave);
- };
- Engine.prototype.shouldCheckForMemLeaks = function () {
- return this.ENV.getBool('IS_TEST');
- };
- Engine.prototype.checkKernelForMemLeak = function (kernelName, numDataIdsBefore, outInfos) {
- var numDataIdsAfter = this.backend.numDataIds();
- // Count the number of data ids associated with the result of the kernel.
- var numOutputDataIds = 0;
- outInfos.forEach(function (info) {
- // Complex numbers allocate 3 data ids, one for 'real', one for
- // 'imaginary', and one for the container that holds the former two.
- numOutputDataIds += (info.dtype === 'complex64' ? 3 : 1);
- });
- // Account for the number of moves during kernel execution. A "data move"
- // can happen in the middle of a kernel execution, placing a new (key,value)
- // pair in the data storage. Since data moves have net zero effect (we
- // always remove the data from the old backend), we have to cancel them out
- // when detecting memory leaks.
- var numMoves = this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1];
- var dataIdsLeaked = numDataIdsAfter - numDataIdsBefore - numOutputDataIds - numMoves;
- if (dataIdsLeaked > 0) {
- throw new Error("Backend '" + this.backendName + "' has an internal memory leak " +
- ("(" + dataIdsLeaked + " data ids) after running '" + kernelName + "'"));
- }
- };
- /**
- * @deprecated Use `runKernel` for newly added kernels. Keep using this method
- * only for kernels that are not yet fully modularized.
- */
- Engine.prototype.runKernelFunc = function (forwardFunc, inputs, backwardsFunc, kernelName, attrs, inputsToSave, outputsToSave) {
- var _this = this;
- var outputs;
- var saved = [];
- var isTapeOn = this.isTapeOn();
- if (kernelName == null) {
- kernelName =
- this.state.activeScope != null ? this.state.activeScope.name : '';
- }
- var startingBytecount = this.state.numBytes;
- var startingNumTensors = this.state.numTensors;
- if (this.shouldCheckForMemLeaks()) {
- this.state.numDataMovesStack.push(0);
- }
- var kernelFunc;
- var kernel = getKernel(kernelName, this.backendName);
- var out;
- if (kernel != null) {
- kernelFunc = function () {
- var numDataIdsBefore = _this.backend.numDataIds();
- out = kernel.kernelFunc({ inputs: inputs, attrs: attrs, backend: _this.backend });
- var outInfos = Array.isArray(out) ? out : [out];
- if (_this.shouldCheckForMemLeaks()) {
- _this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos);
- }
- var outTensors = outInfos.map(function (_a) {
- var dataId = _a.dataId, shape = _a.shape, dtype = _a.dtype;
- return _this.makeTensorFromDataId(dataId, shape, dtype);
- });
- // Save the inputs and outputs.
- // Do not save unless we are recording to the tape. Otherwise it would
- // cause a mem leak since we would never run backprop, which disposes
- // the kept tensors.
- if (isTapeOn) {
- var tensorsToSave = _this.getTensorsForGradient(kernelName, inputs, outTensors);
- if (tensorsToSave == null) {
- // Fallback for ops that call runKernelFunc and pass in
- // inputsToSave and outputsToSave. Currently this is the set of ops
- // with kernel support in the WASM backend. Once those ops and
- // respective gradients are modularised we can remove this path.
- if (outputsToSave == null) {
- outputsToSave = [];
- }
- var outsToSave = outTensors.filter(function (_, i) { return outputsToSave[i]; });
- tensorsToSave = (inputsToSave || []).slice().concat(outsToSave);
- }
- saved = _this.saveTensorsForBackwardMode(tensorsToSave);
- }
- return outTensors;
- };
- }
- else {
- var saveFunc_1 = function (tensors) {
- // Do not save unless we are recording to the tape. Otherwise it would
- // cause a mem leak since we would never run backprop, which disposes
- // the kept tensors.
- if (!isTapeOn) {
- return;
- }
- saved = tensors.map(function (tensor) { return _this.keep(_this.clone(tensor)); });
- };
- kernelFunc = function () {
- var numDataIdsBefore = _this.backend.numDataIds();
- out = _this.tidy(function () { return forwardFunc(_this.backend, saveFunc_1); });
- var outs = (Array.isArray(out) ? out : [out]);
- if (_this.shouldCheckForMemLeaks()) {
- _this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outs);
- }
- return outs;
- };
- }
- // Stop recording to a tape when running a kernel.
- this.scopedRun(function () { return _this.state.kernelDepth++; }, function () { return _this.state.kernelDepth--; }, function () {
- if (!_this.ENV.getBool('DEBUG')) {
- outputs = kernelFunc();
- }
- else {
- outputs = _this.profiler.profileKernel(kernelName, inputs, function () { return kernelFunc(); });
- }
- });
- if (isTapeOn) {
- this.addTapeNode(kernelName, inputs, outputs, backwardsFunc, saved, attrs);
- }
- if (this.state.profiling) {
- this.state.activeProfile.kernels.push({
- name: kernelName,
- bytesAdded: this.state.numBytes - startingBytecount,
- totalBytesSnapshot: this.state.numBytes,
- tensorsAdded: this.state.numTensors - startingNumTensors,
- totalTensorsSnapshot: this.state.numTensors,
- inputShapes: Object.keys(inputs).map(function (key) { return inputs[key].shape; }),
- outputShapes: outputs.map(function (item) { return item.shape; })
- });
- }
- return (Array.isArray(out) ? outputs : outputs[0]);
- };
- /**
- * Saves tensors used in forward mode for use in backward mode.
- *
- * @param tensors the list of tensors to save.
- */
- Engine.prototype.saveTensorsForBackwardMode = function (tensors) {
- var _this = this;
- var saved = tensors.map(function (tensor) { return _this.keep(_this.clone(tensor)); });
- return saved;
- };
- /**
- * Returns a list of tensors to save for a given gradient calculation.
- *
- * Returns undefined if their is no registered gradient for this kernel in the
- * gradient registry.
- *
- * @param kernelName name of kernel to look up gradient for.
- * @param inputs a map of input tensors.
- * @param outputs an array of output tensors from forward mode of kernel.
- */
- Engine.prototype.getTensorsForGradient = function (kernelName, inputs, outputs) {
- var gradConfig = getGradient(kernelName);
- if (gradConfig != null) {
- var inputsToSave = gradConfig.inputsToSave || [];
- var outputsToSave_1 = gradConfig.outputsToSave || [];
- // If saveAllInputs is true, all inputs will be saved. Otherwise, inputs
- // specified in inputsToSave will be saved.
- var inputTensorsToSave = void 0;
- if (gradConfig.saveAllInputs) {
- assert(Array.isArray(inputs), function () { return 'saveAllInputs is true, expected inputs to be an array.'; });
- inputTensorsToSave = Object.keys(inputs).map(function (key) { return inputs[key]; });
- }
- else {
- inputTensorsToSave = inputsToSave.map(function (inputName) { return inputs[inputName]; });
- }
- var outputTensorsToSave = outputs.filter(function (_, i) { return outputsToSave_1[i]; });
- return inputTensorsToSave.concat(outputTensorsToSave);
- }
- // TODO(yassogba) throw exception here once all runkernelFunc calls with
- // inputsToSave/outputsToSave are removed
- return null;
- };
- /**
- * Internal method used by public APIs for tensor creation. Makes a new
- * tensor with the provided shape, dtype and values. It always
- * creates a new data id and writes the values to the underlying backend.
- */
- Engine.prototype.makeTensor = function (values, shape, dtype, backend) {
- if (values == null) {
- throw new Error('Values passed to engine.makeTensor() are null');
- }
- dtype = dtype || 'float32';
- backend = backend || this.backend;
- var backendVals = values;
- if (dtype === 'string' && isString(values[0])) {
- backendVals = values.map(function (d) { return encodeString(d); });
- }
- var dataId = backend.write(backendVals, shape, dtype);
- var t = new Tensor(shape, dtype, dataId, this.nextTensorId());
- this.incRef(t, backend);
- // Count bytes for string tensors.
- if (dtype === 'string') {
- var info = this.state.tensorInfo.get(dataId);
- var newBytes = bytesFromStringArray(backendVals);
- this.state.numBytes += newBytes - info.bytes;
- info.bytes = newBytes;
- }
- return t;
- };
- /**
- * Internal method used by backends. Makes a new tensor
- * that is a wrapper around an existing data id. It doesn't create
- * a new data id, only increments the ref count used in memory tracking.
- */
- Engine.prototype.makeTensorFromDataId = function (dataId, shape, dtype, backend) {
- dtype = dtype || 'float32';
- var t = new Tensor(shape, dtype, dataId, this.nextTensorId());
- this.incRef(t, backend);
- return t;
- };
- Engine.prototype.makeVariable = function (initialValue, trainable, name, dtype) {
- if (trainable === void 0) { trainable = true; }
- name = name || this.nextVariableId().toString();
- if (dtype != null && dtype !== initialValue.dtype) {
- initialValue = initialValue.asType(dtype);
- }
- var v = new Variable(initialValue, trainable, name, this.nextTensorId());
- if (this.state.registeredVariables[v.name] != null) {
- throw new Error("Variable with name " + v.name + " was already registered");
- }
- this.state.registeredVariables[v.name] = v;
- this.incRef(v, this.backend);
- return v;
- };
- Engine.prototype.incRef = function (a, backend) {
- var refCount = this.state.tensorInfo.has(a.dataId) ?
- this.state.tensorInfo.get(a.dataId).refCount :
- 0;
- this.state.numTensors++;
- if (a.dtype === 'string') {
- this.state.numStringTensors++;
- }
- if (refCount === 0) {
- this.state.numDataBuffers++;
- // Bytes for complex numbers are counted by their components. Bytes for
- // string tensors are counted when writing values.
- var bytes = 0;
- if (a.dtype !== 'complex64' && a.dtype !== 'string') {
- bytes = a.size * bytesPerElement(a.dtype);
- }
- this.state.tensorInfo.set(a.dataId, {
- backend: backend || this.backend,
- dtype: a.dtype,
- shape: a.shape,
- bytes: bytes,
- refCount: 0
- });
- this.state.numBytes += bytes;
- }
- this.state.tensorInfo.get(a.dataId).refCount++;
- if (!(a instanceof Variable)) {
- this.track(a);
- }
- };
- Engine.prototype.disposeTensor = function (a) {
- if (!this.state.tensorInfo.has(a.dataId)) {
- return;
- }
- this.state.numTensors--;
- if (a.dtype === 'string') {
- this.state.numStringTensors--;
- }
- var info = this.state.tensorInfo.get(a.dataId);
- var refCount = info.refCount;
- if (refCount <= 1) {
- // Don't count bytes for complex numbers as they are counted by their
- // components.
- if (a.dtype !== 'complex64') {
- this.state.numBytes -= info.bytes;
- }
- this.state.numDataBuffers--;
- info.backend.disposeData(a.dataId);
- this.state.tensorInfo.delete(a.dataId);
- }
- else {
- this.state.tensorInfo.get(a.dataId).refCount--;
- }
- // TODO(nsthorat): Construct an error and save the stack trace for
- // debugging when in debug mode. Creating a stack trace is too expensive
- // to do unconditionally.
- };
- Engine.prototype.disposeVariables = function () {
- for (var varName in this.state.registeredVariables) {
- var v = this.state.registeredVariables[varName];
- this.disposeVariable(v);
- }
- };
- Engine.prototype.disposeVariable = function (v) {
- this.disposeTensor(v);
- if (this.state.registeredVariables[v.name] != null) {
- delete this.state.registeredVariables[v.name];
- }
- };
- Engine.prototype.memory = function () {
- var info = this.backend.memory();
- info.numTensors = this.state.numTensors;
- info.numDataBuffers = this.state.numDataBuffers;
- info.numBytes = this.state.numBytes;
- if (this.state.numStringTensors > 0) {
- info.unreliable = true;
- if (info.reasons == null) {
- info.reasons = [];
- }
- info.reasons.push('Memory usage by string tensors is approximate ' +
- '(2 bytes per character)');
- }
- return info;
- };
- Engine.prototype.profile = function (query) {
- return __awaiter(this, void 0, void 0, function () {
- var startBytes, startNumTensors;
- return __generator(this, function (_a) {
- this.state.profiling = true;
- startBytes = this.state.numBytes;
- startNumTensors = this.state.numTensors;
- this.state.activeProfile.kernels = [];
- this.state.activeProfile.result = query();
- this.state.profiling = false;
- this.state.activeProfile.peakBytes = Math.max.apply(Math, this.state.activeProfile.kernels.map(function (d) { return d.totalBytesSnapshot; }));
- this.state.activeProfile.newBytes = this.state.numBytes - startBytes;
- this.state.activeProfile.newTensors =
- this.state.numTensors - startNumTensors;
- return [2 /*return*/, this.state.activeProfile];
- });
- });
- };
- Engine.prototype.isTapeOn = function () {
- return this.state.gradientDepth > 0 && this.state.kernelDepth === 0;
- };
- Engine.prototype.addTapeNode = function (kernelName, inputs, outputs, gradientsFunc, saved, attrs) {
- var _this = this;
- var tapeNode = { id: this.state.nextTapeNodeId++, kernelName: kernelName, inputs: inputs, outputs: outputs, saved: saved };
- var gradConfig = getGradient(kernelName);
- if (gradConfig != null) {
- gradientsFunc = gradConfig.gradFunc;
- }
- if (gradientsFunc != null) {
- tapeNode.gradient = function (dys) {
- // TODO(smilkov): To optimize back-prop, pass dys that are not used in
- // the backprop graph to the user as null instead of zeros
- dys = dys.map(function (dy, i) {
- if (dy == null) {
- var output = outputs[i];
- var vals = makeZerosTypedArray(output.size, output.dtype);
- return _this.makeTensor(vals, output.shape, output.dtype);
- }
- return dy;
- });
- // Grad functions of ops with single outputs expect a dy, while ops
- // with multiple outputs expect dys (array of dy).
- return gradientsFunc(dys.length > 1 ? dys : dys[0], saved, attrs);
- };
- }
- this.state.activeTape.push(tapeNode);
- };
- Engine.prototype.keep = function (result) {
- result.kept = true;
- return result;
- };
- Engine.prototype.startTape = function () {
- if (this.state.gradientDepth === 0) {
- this.state.activeTape = [];
- }
- this.state.gradientDepth++;
- };
- Engine.prototype.endTape = function () {
- this.state.gradientDepth--;
- };
- /**
- * Start a scope. Use this with endScope() to achieve the same functionality
- * as scope() without the need for a function closure.
- */
- Engine.prototype.startScope = function (name) {
- var scopeInfo = {
- track: [],
- name: 'unnamed scope',
- id: this.state.nextScopeId++
- };
- if (name) {
- scopeInfo.name = name;
- }
- this.state.scopeStack.push(scopeInfo);
- this.state.activeScope = scopeInfo;
- };
- /**
- * End a scope. Use this with startScope() to achieve the same functionality
- * as scope() without the need for a function closure.
- */
- Engine.prototype.endScope = function (result) {
- var _this = this;
- var tensorsToTrackInParent = getTensorsInContainer(result);
- var tensorsToTrackInParentSet = new Set(tensorsToTrackInParent.map(function (t) { return t.id; }));
- // Dispose the arrays tracked in this scope.
- for (var i = 0; i < this.state.activeScope.track.length; i++) {
- var tensor = this.state.activeScope.track[i];
- if (!tensor.kept && !tensorsToTrackInParentSet.has(tensor.id)) {
- tensor.dispose();
- }
- }
- var oldScope = this.state.scopeStack.pop();
- this.state.activeScope = this.state.scopeStack.length === 0 ?
- null :
- this.state.scopeStack[this.state.scopeStack.length - 1];
- // Track the current result in the parent scope.
- tensorsToTrackInParent.forEach(function (tensor) {
- // Only track the tensor if was allocated in the inner scope and is not
- // globally kept.
- if (!tensor.kept && tensor.scopeId === oldScope.id) {
- _this.track(tensor);
- }
- });
- };
- /**
- * Returns gradients of `f` with respect to each of the `xs`. The gradients
- * returned are of the same length as `xs`, but some might be null if `f`
- * was not a function of that `x`. It also takes optional dy to multiply the
- * gradient, which defaults to `1`.
- */
- Engine.prototype.gradients = function (f, xs, dy, allowNoGradients) {
- var _this = this;
- if (allowNoGradients === void 0) { allowNoGradients = false; }
- assert(xs.length > 0, function () { return 'gradients() received an empty list of xs.'; });
- if (dy != null && dy.dtype !== 'float32') {
- throw new Error("dy must have 'float32' dtype, but has '" + dy.dtype + "'");
- }
- var y = this.scopedRun(function () { return _this.startTape(); }, function () { return _this.endTape(); }, function () { return _this.tidy('forward', f); });
- assert(y instanceof Tensor, function () { return 'The result y returned by f() must be a tensor.'; });
- // Filter out the nodes that don't connect x => y.
- var filteredTape = getFilteredNodesXToY(this.state.activeTape, xs, y);
- if (!allowNoGradients && filteredTape.length === 0 && xs.length > 0) {
- throw new Error('Cannot compute gradient of y=f(x) with respect to x. Make sure ' +
- 'that the f you passed encloses all operations that lead from x ' +
- 'to y.');
- }
- return this.tidy('backward', function () {
- var accumulatedGradientMap = {};
- accumulatedGradientMap[y.id] = (dy == null) ? ones(y.shape) : dy;
- // Backprop gradients through the filtered nodes.
- backpropagateGradients(accumulatedGradientMap, filteredTape,
- // Pass the tidy function to avoid circular dep with `tape.ts`.
- function (f) { return _this.tidy(f); });
- var grads = xs.map(function (x) { return accumulatedGradientMap[x.id]; });
- if (_this.state.gradientDepth === 0) {
- // This means that we are not computing higher-order gradients
- // and can clean up the tape.
- _this.state.activeTape.forEach(function (node) {
- for (var _i = 0, _a = node.saved; _i < _a.length; _i++) {
- var tensor = _a[_i];
- tensor.dispose();
- }
- });
- _this.state.activeTape = null;
- }
- return { value: y, grads: grads };
- });
- };
- Engine.prototype.customGrad = function (f) {
- var _this = this;
- assert(isFunction(f), function () { return 'The f passed in customGrad(f) must be a function.'; });
- return function () {
- var inputs = [];
- for (var _i = 0; _i < arguments.length; _i++) {
- inputs[_i] = arguments[_i];
- }
- assert(inputs.every(function (t) { return t instanceof Tensor; }), function () { return 'The args passed in customGrad(f)(x1, x2,...) must all be ' +
- 'tensors'; });
- var res;
- var inputMap = {};
- inputs.forEach(function (input, i) {
- inputMap[i] = input;
- });
- return _this.runKernelFunc(function (_, save) {
- res = f.apply(void 0, inputs.concat([save]));
- assert(res.value instanceof Tensor, function () { return 'The function f passed in customGrad(f) must return an ' +
- 'object where `obj.value` is a tensor'; });
- assert(isFunction(res.gradFunc), function () { return 'The function f passed in customGrad(f) must return an ' +
- 'object where `obj.gradFunc` is a function.'; });
- return res.value;
- }, inputMap, function (dy, saved) {
- var gradRes = res.gradFunc(dy, saved);
- var grads = Array.isArray(gradRes) ? gradRes : [gradRes];
- assert(grads.length === inputs.length, function () { return 'The function f passed in customGrad(f) must return an ' +
- 'object where `obj.gradFunc` is a function that returns ' +
- 'the same number of tensors as inputs passed to f(...).'; });
- assert(grads.every(function (t) { return t instanceof Tensor; }), function () { return 'The function f passed in customGrad(f) must return an ' +
- 'object where `obj.gradFunc` is a function that returns ' +
- 'a list of only tensors.'; });
- var gradMap = {};
- grads.forEach(function (grad, i) {
- gradMap[i] = function () { return grad; };
- });
- return gradMap;
- });
- };
- };
- Engine.prototype.readSync = function (dataId) {
- // Route the read to the correct backend.
- var info = this.state.tensorInfo.get(dataId);
- return info.backend.readSync(dataId);
- };
- Engine.prototype.read = function (dataId) {
- // Route the read to the correct backend.
- var info = this.state.tensorInfo.get(dataId);
- return info.backend.read(dataId);
- };
- Engine.prototype.time = function (query) {
- return __awaiter(this, void 0, void 0, function () {
- var start, timingInfo;
- return __generator(this, function (_a) {
- switch (_a.label) {
- case 0:
- start = now();
- return [4 /*yield*/, this.backend.time(query)];
- case 1:
- timingInfo = _a.sent();
- timingInfo.wallMs = now() - start;
- return [2 /*return*/, timingInfo];
- }
- });
- });
- };
- /**
- * Tracks a Tensor in the current scope to be automatically cleaned up
- * when the current scope ends, and returns the value.
- *
- * @param result The Tensor to track in the current scope.
- */
- Engine.prototype.track = function (result) {
- if (this.state.activeScope != null) {
- result.scopeId = this.state.activeScope.id;
- this.state.activeScope.track.push(result);
- }
- return result;
- };
- Object.defineProperty(Engine.prototype, "registeredVariables", {
- get: function () {
- return this.state.registeredVariables;
- },
- enumerable: true,
- configurable: true
- });
- /**
- * Resets the engine state. Removes all backends but does not remove
- * registered backend factories.
- */
- Engine.prototype.reset = function () {
- // Make any pending promise obsolete.
- this.pendingBackendInitId++;
- this.state.dispose();
- this.ENV.reset();
- this.state = new EngineState();
- for (var backendName in this.registry) {
- this.disposeRegisteredKernels(backendName);
- this.registry[backendName].dispose();
- delete this.registry[backendName];
- }
- this.backendName = null;
- this.backendInstance = null;
- this.pendingBackendInit = null;
- };
- Engine.nextTensorId = 0;
- Engine.nextVariableId = 0;
- return Engine;
- }());
- function ones(shape) {
- var values = makeOnesTypedArray(sizeFromShape(shape), 'float32');
- return ENGINE.makeTensor(values, shape, 'float32');
- }
- var GLOBAL;
- function getGlobalNamespace() {
- if (GLOBAL == null) {
- // tslint:disable-next-line:no-any
- var ns = void 0;
- if (typeof (window) !== 'undefined') {
- ns = window;
- }
- else if (typeof (global) !== 'undefined') {
- ns = global;
- }
- else if (typeof (process) !== 'undefined') {
- ns = process;
- }
- else if (typeof (self) !== 'undefined') {
- ns = self;
- }
- else {
- throw new Error('Could not find a global object');
- }
- GLOBAL = ns;
- }
- return GLOBAL;
- }
- function getOrMakeEngine() {
- var ns = getGlobalNamespace();
- if (ns._tfengine == null) {
- var environment = new Environment(ns);
- ns._tfengine = new Engine(environment);
- }
- setEnvironmentGlobal(ns._tfengine.ENV);
- // Tell the current tensor interface that the global engine is responsible
- // for tracking.
- setTensorTracker(function () { return ns._tfengine; });
- return ns._tfengine;
- }
- var ENGINE = getOrMakeEngine();
-
- /**
- * @license
- * Copyright 2017 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function isMobile() {
- // tslint:disable-next-line:no-any
- var a = navigator.userAgent || navigator.vendor || window.opera;
- // tslint:disable-next-line:max-line-length
- return /(android|bb\d+|meego).+mobile|avantgo|bada\/|blackberry|blazer|compal|elaine|fennec|hiptop|iemobile|ip(hone|od)|iris|kindle|lge |maemo|midp|mmp|mobile.+firefox|netfront|opera m(ob|in)i|palm( os)?|phone|p(ixi|re)\/|plucker|pocket|psp|series(4|6)0|symbian|treo|up\.(browser|link)|vodafone|wap|windows ce|xda|xiino/i
- .test(a) ||
- // tslint:disable-next-line:max-line-length
- /1207|6310|6590|3gso|4thp|50[1-6]i|770s|802s|a wa|abac|ac(er|oo|s\-)|ai(ko|rn)|al(av|ca|co)|amoi|an(ex|ny|yw)|aptu|ar(ch|go)|as(te|us)|attw|au(di|\-m|r |s )|avan|be(ck|ll|nq)|bi(lb|rd)|bl(ac|az)|br(e|v)w|bumb|bw\-(n|u)|c55\/|capi|ccwa|cdm\-|cell|chtm|cldc|cmd\-|co(mp|nd)|craw|da(it|ll|ng)|dbte|dc\-s|devi|dica|dmob|do(c|p)o|ds(12|\-d)|el(49|ai)|em(l2|ul)|er(ic|k0)|esl8|ez([4-7]0|os|wa|ze)|fetc|fly(\-|_)|g1 u|g560|gene|gf\-5|g\-mo|go(\.w|od)|gr(ad|un)|haie|hcit|hd\-(m|p|t)|hei\-|hi(pt|ta)|hp( i|ip)|hs\-c|ht(c(\-| |_|a|g|p|s|t)|tp)|hu(aw|tc)|i\-(20|go|ma)|i230|iac( |\-|\/)|ibro|idea|ig01|ikom|im1k|inno|ipaq|iris|ja(t|v)a|jbro|jemu|jigs|kddi|keji|kgt( |\/)|klon|kpt |kwc\-|kyo(c|k)|le(no|xi)|lg( g|\/(k|l|u)|50|54|\-[a-w])|libw|lynx|m1\-w|m3ga|m50\/|ma(te|ui|xo)|mc(01|21|ca)|m\-cr|me(rc|ri)|mi(o8|oa|ts)|mmef|mo(01|02|bi|de|do|t(\-| |o|v)|zz)|mt(50|p1|v )|mwbp|mywa|n10[0-2]|n20[2-3]|n30(0|2)|n50(0|2|5)|n7(0(0|1)|10)|ne((c|m)\-|on|tf|wf|wg|wt)|nok(6|i)|nzph|o2im|op(ti|wv)|oran|owg1|p800|pan(a|d|t)|pdxg|pg(13|\-([1-8]|c))|phil|pire|pl(ay|uc)|pn\-2|po(ck|rt|se)|prox|psio|pt\-g|qa\-a|qc(07|12|21|32|60|\-[2-7]|i\-)|qtek|r380|r600|raks|rim9|ro(ve|zo)|s55\/|sa(ge|ma|mm|ms|ny|va)|sc(01|h\-|oo|p\-)|sdk\/|se(c(\-|0|1)|47|mc|nd|ri)|sgh\-|shar|sie(\-|m)|sk\-0|sl(45|id)|sm(al|ar|b3|it|t5)|so(ft|ny)|sp(01|h\-|v\-|v )|sy(01|mb)|t2(18|50)|t6(00|10|18)|ta(gt|lk)|tcl\-|tdg\-|tel(i|m)|tim\-|t\-mo|to(pl|sh)|ts(70|m\-|m3|m5)|tx\-9|up(\.b|g1|si)|utst|v400|v750|veri|vi(rg|te)|vk(40|5[0-3]|\-v)|vm40|voda|vulc|vx(52|53|60|61|70|80|81|83|85|98)|w3c(\-| )|webc|whit|wi(g |nc|nw)|wmlb|wonu|x700|yas\-|your|zeto|zte\-/i
- .test(a.substr(0, 4));
- }
- function isBrowser() {
- return (typeof window !== 'undefined' && window.document != null) ||
- //@ts-ignore
- (typeof WorkerGlobalScope !== 'undefined');
- }
-
- /**
- * @license
- * Copyright 2019 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var ENV = env();
- /**
- * This file contains environment-related flag registrations.
- */
- /** Whether to enable debug mode. */
- ENV.registerFlag('DEBUG', function () { return false; }, function (debugValue) {
- if (debugValue) {
- console.warn('Debugging mode is ON. The output of every math call will ' +
- 'be downloaded to CPU and checked for NaNs. ' +
- 'This significantly impacts performance.');
- }
- });
- /** Whether we are in a browser (as versus, say, node.js) environment. */
- ENV.registerFlag('IS_BROWSER', function () { return isBrowser(); });
- /** Whether we are in a browser (as versus, say, node.js) environment. */
- ENV.registerFlag('IS_NODE', function () { return (typeof process !== 'undefined') &&
- (typeof process.versions !== 'undefined') &&
- (typeof process.versions.node !== 'undefined'); });
- /** Whether this browser is Chrome. */
- ENV.registerFlag('IS_CHROME', function () { return typeof navigator !== 'undefined' && navigator != null &&
- navigator.userAgent != null && /Chrome/.test(navigator.userAgent) &&
- /Google Inc/.test(navigator.vendor); });
- /**
- * True when the environment is "production" where we disable safety checks
- * to gain performance.
- */
- ENV.registerFlag('PROD', function () { return false; });
- /**
- * Whether to do sanity checks when inferring a shape from user-provided
- * values, used when creating a new tensor.
- */
- ENV.registerFlag('TENSORLIKE_CHECK_SHAPE_CONSISTENCY', function () { return ENV.getBool('DEBUG'); });
- /** Whether deprecation warnings are enabled. */
- ENV.registerFlag('DEPRECATION_WARNINGS_ENABLED', function () { return true; });
- /** True if running unit tests. */
- ENV.registerFlag('IS_TEST', function () { return false; });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var contexts = {};
- var WEBGL_ATTRIBUTES = {
- alpha: false,
- antialias: false,
- premultipliedAlpha: false,
- preserveDrawingBuffer: false,
- depth: false,
- stencil: false,
- failIfMajorPerformanceCaveat: true
- };
- function setWebGLContext(webGLVersion, gl) {
- contexts[webGLVersion] = gl;
- }
- function getWebGLContext(webGLVersion) {
- if (!(webGLVersion in contexts)) {
- contexts[webGLVersion] = getWebGLRenderingContext(webGLVersion);
- }
- var gl = contexts[webGLVersion];
- if (gl.isContextLost()) {
- delete contexts[webGLVersion];
- return getWebGLContext(webGLVersion);
- }
- gl.disable(gl.DEPTH_TEST);
- gl.disable(gl.STENCIL_TEST);
- gl.disable(gl.BLEND);
- gl.disable(gl.DITHER);
- gl.disable(gl.POLYGON_OFFSET_FILL);
- gl.disable(gl.SAMPLE_COVERAGE);
- gl.enable(gl.SCISSOR_TEST);
- gl.enable(gl.CULL_FACE);
- gl.cullFace(gl.BACK);
- return contexts[webGLVersion];
- }
- function createCanvas(webGLVersion) {
- if (typeof OffscreenCanvas !== 'undefined' && webGLVersion === 2) {
- return new OffscreenCanvas(300, 150);
- }
- else if (typeof document !== 'undefined') {
- return document.createElement('canvas');
- }
- else {
- throw new Error('Cannot create a canvas in this context');
- }
- }
- function getWebGLRenderingContext(webGLVersion) {
- if (webGLVersion !== 1 && webGLVersion !== 2) {
- throw new Error('Cannot get WebGL rendering context, WebGL is disabled.');
- }
- var canvas = createCanvas(webGLVersion);
- canvas.addEventListener('webglcontextlost', function (ev) {
- ev.preventDefault();
- delete contexts[webGLVersion];
- }, false);
- if (webGLVersion === 1) {
- return (canvas.getContext('webgl', WEBGL_ATTRIBUTES) ||
- canvas.getContext('experimental-webgl', WEBGL_ATTRIBUTES));
- }
- return canvas.getContext('webgl2', WEBGL_ATTRIBUTES);
- }
-
- /**
- * @license
- * Copyright 2017 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var PackingScheme;
- (function (PackingScheme) {
- /**
- * All values in a single texel are densely packed without any constraints.
- *
- * This is how the shader encodes a tensor with shape = [2, 3, 4]
- * (indices are [batch, row, col]).
- *
- * 000|001 010|011 020|021
- * ------- ------- -------
- * 002|003 012|013 022|023
- *
- * 100|101 110|111 120|121
- * ------- ------- -------
- * 102|103 112|113 122|123
- *
- */
- PackingScheme[PackingScheme["DENSE"] = 0] = "DENSE";
- /**
- * Single texels contain only values from the same batch, and from adjacent
- * rows and columns.
- *
- * This is how the shader encodes a tensor with shape = [2, 3, 5]
- * (indices are [batch, row, col]).
- *
- * 000|001 002|003 004|xxx 020|021 022|023 024|xxx
- * ------- ------- ------- ------- ------- -------
- * 010|011 012|013 014|xxx xxx|xxx xxx|xxx xxx|xxx
- *
- * 100|101 102|103 104|xxx 120|121 122|123 124|xxx
- * ------- ------- ------- ------- ------- -------
- * 110|111 112|113 114|xxx xxx|xxx xxx|xxx xxx|xxx
- *
- */
- PackingScheme[PackingScheme["SHARED_BATCH"] = 1] = "SHARED_BATCH";
- })(PackingScheme || (PackingScheme = {}));
- var TextureUsage;
- (function (TextureUsage) {
- TextureUsage[TextureUsage["RENDER"] = 0] = "RENDER";
- TextureUsage[TextureUsage["UPLOAD"] = 1] = "UPLOAD";
- TextureUsage[TextureUsage["PIXELS"] = 2] = "PIXELS";
- TextureUsage[TextureUsage["DOWNLOAD"] = 3] = "DOWNLOAD";
- })(TextureUsage || (TextureUsage = {}));
- var PhysicalTextureType;
- (function (PhysicalTextureType) {
- PhysicalTextureType[PhysicalTextureType["UNPACKED_FLOAT16"] = 0] = "UNPACKED_FLOAT16";
- PhysicalTextureType[PhysicalTextureType["UNPACKED_FLOAT32"] = 1] = "UNPACKED_FLOAT32";
- PhysicalTextureType[PhysicalTextureType["PACKED_4X1_UNSIGNED_BYTE"] = 2] = "PACKED_4X1_UNSIGNED_BYTE";
- PhysicalTextureType[PhysicalTextureType["PACKED_2X2_FLOAT32"] = 3] = "PACKED_2X2_FLOAT32";
- PhysicalTextureType[PhysicalTextureType["PACKED_2X2_FLOAT16"] = 4] = "PACKED_2X2_FLOAT16";
- })(PhysicalTextureType || (PhysicalTextureType = {}));
- function getUnpackedMatrixTextureShapeWidthHeight(rows, columns) {
- return [columns, rows];
- }
- function getUnpackedArraySizeFromMatrixSize(matrixSize, channelsPerTexture) {
- return matrixSize * channelsPerTexture;
- }
- /**
- * Get shape for densely packed RGBA texture.
- */
- function getDenseTexShape(shape) {
- var size = sizeFromShape(shape);
- var texelsNeeded = Math.ceil(size / 4);
- return sizeToSquarishShape(texelsNeeded);
- }
- function getPackedMatrixTextureShapeWidthHeight(rows, columns) {
- return [
- Math.max(1, Math.ceil(columns / 2)), Math.max(1, Math.ceil(rows / 2))
- ];
- }
- function getPackedRGBAArraySizeFromMatrixShape(rows, columns) {
- var _a = getPackedMatrixTextureShapeWidthHeight(rows, columns), w = _a[0], h = _a[1];
- return w * h * 4;
- }
- function getTextureConfig(
- // tslint:disable-next-line:no-any
- gl, textureHalfFloatExtension) {
- // tslint:disable-next-line:no-any
- var glany = gl;
- var internalFormatFloat;
- var internalFormatHalfFloat;
- var internalFormatPackedHalfFloat;
- var internalFormatPackedFloat;
- var textureFormatFloat;
- var downloadTextureFormat;
- var downloadUnpackNumChannels;
- var defaultNumChannels;
- var textureTypeHalfFloat;
- var textureTypeFloat;
- if (env().getNumber('WEBGL_VERSION') === 2) {
- internalFormatFloat = glany.R32F;
- internalFormatHalfFloat = glany.R16F;
- internalFormatPackedHalfFloat = glany.RGBA16F;
- internalFormatPackedFloat = glany.RGBA32F;
- textureFormatFloat = glany.RED;
- downloadUnpackNumChannels = 4;
- defaultNumChannels = 1;
- textureTypeHalfFloat = glany.HALF_FLOAT;
- textureTypeFloat = glany.FLOAT;
- }
- else {
- internalFormatFloat = gl.RGBA;
- internalFormatHalfFloat = gl.RGBA;
- internalFormatPackedHalfFloat = gl.RGBA;
- internalFormatPackedFloat = glany.RGBA;
- textureFormatFloat = gl.RGBA;
- downloadUnpackNumChannels = 4;
- defaultNumChannels = 4;
- textureTypeHalfFloat = textureHalfFloatExtension != null ?
- textureHalfFloatExtension.HALF_FLOAT_OES :
- null;
- textureTypeFloat = gl.FLOAT;
- }
- downloadTextureFormat = gl.RGBA;
- return {
- internalFormatFloat: internalFormatFloat,
- internalFormatHalfFloat: internalFormatHalfFloat,
- internalFormatPackedHalfFloat: internalFormatPackedHalfFloat,
- internalFormatPackedFloat: internalFormatPackedFloat,
- textureFormatFloat: textureFormatFloat,
- downloadTextureFormat: downloadTextureFormat,
- downloadUnpackNumChannels: downloadUnpackNumChannels,
- defaultNumChannels: defaultNumChannels,
- textureTypeHalfFloat: textureTypeHalfFloat,
- textureTypeFloat: textureTypeFloat
- };
- }
-
- /**
- * @license
- * Copyright 2017 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function callAndCheck(gl, debugMode, func) {
- var returnValue = func();
- if (debugMode) {
- checkWebGLError(gl);
- }
- return returnValue;
- }
- function checkWebGLError(gl) {
- var error = gl.getError();
- if (error !== gl.NO_ERROR) {
- throw new Error('WebGL Error: ' + getWebGLErrorMessage(gl, error));
- }
- }
- // https://en.wikipedia.org/wiki/Half-precision_floating-point_format
- var MIN_FLOAT16 = 5.96e-8;
- var MAX_FLOAT16 = 65504;
- function canBeRepresented(num) {
- if (env().getBool('WEBGL_RENDER_FLOAT32_ENABLED') || num === 0 ||
- (MIN_FLOAT16 < Math.abs(num) && Math.abs(num) < MAX_FLOAT16)) {
- return true;
- }
- return false;
- }
- function getWebGLErrorMessage(gl, status) {
- switch (status) {
- case gl.NO_ERROR:
- return 'NO_ERROR';
- case gl.INVALID_ENUM:
- return 'INVALID_ENUM';
- case gl.INVALID_VALUE:
- return 'INVALID_VALUE';
- case gl.INVALID_OPERATION:
- return 'INVALID_OPERATION';
- case gl.INVALID_FRAMEBUFFER_OPERATION:
- return 'INVALID_FRAMEBUFFER_OPERATION';
- case gl.OUT_OF_MEMORY:
- return 'OUT_OF_MEMORY';
- case gl.CONTEXT_LOST_WEBGL:
- return 'CONTEXT_LOST_WEBGL';
- default:
- return "Unknown error code " + status;
- }
- }
- function getExtensionOrThrow(gl, debug, extensionName) {
- return throwIfNull(gl, debug, function () { return gl.getExtension(extensionName); }, 'Extension "' + extensionName + '" not supported on this browser.');
- }
- function createVertexShader(gl, debug, vertexShaderSource) {
- var vertexShader = throwIfNull(gl, debug, function () { return gl.createShader(gl.VERTEX_SHADER); }, 'Unable to create vertex WebGLShader.');
- callAndCheck(gl, debug, function () { return gl.shaderSource(vertexShader, vertexShaderSource); });
- callAndCheck(gl, debug, function () { return gl.compileShader(vertexShader); });
- if (gl.getShaderParameter(vertexShader, gl.COMPILE_STATUS) === false) {
- console.log(gl.getShaderInfoLog(vertexShader));
- throw new Error('Failed to compile vertex shader.');
- }
- return vertexShader;
- }
- function createFragmentShader(gl, debug, fragmentShaderSource) {
- var fragmentShader = throwIfNull(gl, debug, function () { return gl.createShader(gl.FRAGMENT_SHADER); }, 'Unable to create fragment WebGLShader.');
- callAndCheck(gl, debug, function () { return gl.shaderSource(fragmentShader, fragmentShaderSource); });
- callAndCheck(gl, debug, function () { return gl.compileShader(fragmentShader); });
- if (gl.getShaderParameter(fragmentShader, gl.COMPILE_STATUS) === false) {
- logShaderSourceAndInfoLog(fragmentShaderSource, gl.getShaderInfoLog(fragmentShader));
- throw new Error('Failed to compile fragment shader.');
- }
- return fragmentShader;
- }
- var lineNumberRegex = /ERROR: [0-9]+:([0-9]+):/g;
- function logShaderSourceAndInfoLog(shaderSource, shaderInfoLog) {
- var lineNumberRegexResult = lineNumberRegex.exec(shaderInfoLog);
- if (lineNumberRegexResult == null) {
- console.log("Couldn't parse line number in error: " + shaderInfoLog);
- console.log(shaderSource);
- return;
- }
- var lineNumber = +lineNumberRegexResult[1];
- var shaderLines = shaderSource.split('\n');
- var pad = shaderLines.length.toString().length + 2;
- var linesWithLineNumbers = shaderLines.map(function (line, lineNumber) {
- return rightPad((lineNumber + 1).toString(), pad) + line;
- });
- var maxLineLength = 0;
- for (var i = 0; i < linesWithLineNumbers.length; i++) {
- maxLineLength = Math.max(linesWithLineNumbers[i].length, maxLineLength);
- }
- var beforeErrorLines = linesWithLineNumbers.slice(0, lineNumber - 1);
- var errorLine = linesWithLineNumbers.slice(lineNumber - 1, lineNumber);
- var afterErrorLines = linesWithLineNumbers.slice(lineNumber);
- console.log(beforeErrorLines.join('\n'));
- console.log(shaderInfoLog.split('\n')[0]);
- console.log("%c " + rightPad(errorLine[0], maxLineLength), 'border:1px solid red; background-color:#e3d2d2; color:#a61717');
- console.log(afterErrorLines.join('\n'));
- }
- function createProgram(gl, debug) {
- return throwIfNull(gl, debug, function () { return gl.createProgram(); }, 'Unable to create WebGLProgram.');
- }
- function linkProgram(gl, debug, program) {
- callAndCheck(gl, debug, function () { return gl.linkProgram(program); });
- if (gl.getProgramParameter(program, gl.LINK_STATUS) === false) {
- console.log(gl.getProgramInfoLog(program));
- throw new Error('Failed to link vertex and fragment shaders.');
- }
- }
- function validateProgram(gl, debug, program) {
- callAndCheck(gl, debug, function () { return gl.validateProgram(program); });
- if (gl.getProgramParameter(program, gl.VALIDATE_STATUS) === false) {
- console.log(gl.getProgramInfoLog(program));
- throw new Error('Shader program validation failed.');
- }
- }
- function createStaticVertexBuffer(gl, debug, data) {
- var buffer = throwIfNull(gl, debug, function () { return gl.createBuffer(); }, 'Unable to create WebGLBuffer');
- callAndCheck(gl, debug, function () { return gl.bindBuffer(gl.ARRAY_BUFFER, buffer); });
- callAndCheck(gl, debug, function () { return gl.bufferData(gl.ARRAY_BUFFER, data, gl.STATIC_DRAW); });
- return buffer;
- }
- function createStaticIndexBuffer(gl, debug, data) {
- var buffer = throwIfNull(gl, debug, function () { return gl.createBuffer(); }, 'Unable to create WebGLBuffer');
- callAndCheck(gl, debug, function () { return gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, buffer); });
- callAndCheck(gl, debug, function () { return gl.bufferData(gl.ELEMENT_ARRAY_BUFFER, data, gl.STATIC_DRAW); });
- return buffer;
- }
- function getNumChannels() {
- if (env().getNumber('WEBGL_VERSION') === 2) {
- return 1;
- }
- return 4;
- }
- function createTexture(gl, debug) {
- return throwIfNull(gl, debug, function () { return gl.createTexture(); }, 'Unable to create WebGLTexture.');
- }
- function validateTextureSize(width, height) {
- var maxTextureSize = env().getNumber('WEBGL_MAX_TEXTURE_SIZE');
- if ((width <= 0) || (height <= 0)) {
- var requested = "[" + width + "x" + height + "]";
- throw new Error('Requested texture size ' + requested + ' is invalid.');
- }
- if ((width > maxTextureSize) || (height > maxTextureSize)) {
- var requested = "[" + width + "x" + height + "]";
- var max = "[" + maxTextureSize + "x" + maxTextureSize + "]";
- throw new Error('Requested texture size ' + requested +
- ' greater than WebGL maximum on this browser / GPU ' + max + '.');
- }
- }
- function createFramebuffer(gl, debug) {
- return throwIfNull(gl, debug, function () { return gl.createFramebuffer(); }, 'Unable to create WebGLFramebuffer.');
- }
- function bindVertexBufferToProgramAttribute(gl, debug, program, attribute, buffer, arrayEntriesPerItem, itemStrideInBytes, itemOffsetInBytes) {
- var loc = gl.getAttribLocation(program, attribute);
- if (loc === -1) {
- // The GPU compiler decided to strip out this attribute because it's unused,
- // thus no need to bind.
- return false;
- }
- callAndCheck(gl, debug, function () { return gl.bindBuffer(gl.ARRAY_BUFFER, buffer); });
- callAndCheck(gl, debug, function () { return gl.vertexAttribPointer(loc, arrayEntriesPerItem, gl.FLOAT, false, itemStrideInBytes, itemOffsetInBytes); });
- callAndCheck(gl, debug, function () { return gl.enableVertexAttribArray(loc); });
- return true;
- }
- function bindTextureUnit(gl, debug, texture, textureUnit) {
- validateTextureUnit(gl, textureUnit);
- callAndCheck(gl, debug, function () { return gl.activeTexture(gl.TEXTURE0 + textureUnit); });
- callAndCheck(gl, debug, function () { return gl.bindTexture(gl.TEXTURE_2D, texture); });
- }
- function unbindTextureUnit(gl, debug, textureUnit) {
- validateTextureUnit(gl, textureUnit);
- callAndCheck(gl, debug, function () { return gl.activeTexture(gl.TEXTURE0 + textureUnit); });
- callAndCheck(gl, debug, function () { return gl.bindTexture(gl.TEXTURE_2D, null); });
- }
- function getProgramUniformLocationOrThrow(gl, debug, program, uniformName) {
- return throwIfNull(gl, debug, function () { return gl.getUniformLocation(program, uniformName); }, 'uniform "' + uniformName + '" not present in program.');
- }
- function getProgramUniformLocation(gl, program, uniformName) {
- return gl.getUniformLocation(program, uniformName);
- }
- function bindTextureToProgramUniformSampler(gl, debug, program, texture, uniformSamplerLocation, textureUnit) {
- callAndCheck(gl, debug, function () { return bindTextureUnit(gl, debug, texture, textureUnit); });
- callAndCheck(gl, debug, function () { return gl.uniform1i(uniformSamplerLocation, textureUnit); });
- }
- function bindCanvasToFramebuffer(gl, debug) {
- callAndCheck(gl, debug, function () { return gl.bindFramebuffer(gl.FRAMEBUFFER, null); });
- callAndCheck(gl, debug, function () { return gl.viewport(0, 0, gl.canvas.width, gl.canvas.height); });
- callAndCheck(gl, debug, function () { return gl.scissor(0, 0, gl.canvas.width, gl.canvas.height); });
- }
- function bindColorTextureToFramebuffer(gl, debug, texture, framebuffer) {
- callAndCheck(gl, debug, function () { return gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer); });
- callAndCheck(gl, debug, function () { return gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0); });
- }
- function unbindColorTextureFromFramebuffer(gl, debug, framebuffer) {
- callAndCheck(gl, debug, function () { return gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer); });
- callAndCheck(gl, debug, function () { return gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, null, 0); });
- }
- function validateFramebuffer(gl) {
- var status = gl.checkFramebufferStatus(gl.FRAMEBUFFER);
- if (status !== gl.FRAMEBUFFER_COMPLETE) {
- throw new Error('Error binding framebuffer: ' + getFramebufferErrorMessage(gl, status));
- }
- }
- function getFramebufferErrorMessage(gl, status) {
- switch (status) {
- case gl.FRAMEBUFFER_INCOMPLETE_ATTACHMENT:
- return 'FRAMEBUFFER_INCOMPLETE_ATTACHMENT';
- case gl.FRAMEBUFFER_INCOMPLETE_MISSING_ATTACHMENT:
- return 'FRAMEBUFFER_INCOMPLETE_MISSING_ATTACHMENT';
- case gl.FRAMEBUFFER_INCOMPLETE_DIMENSIONS:
- return 'FRAMEBUFFER_INCOMPLETE_DIMENSIONS';
- case gl.FRAMEBUFFER_UNSUPPORTED:
- return 'FRAMEBUFFER_UNSUPPORTED';
- default:
- return "unknown error " + status;
- }
- }
- function throwIfNull(gl, debug, returnTOrNull, failureMessage) {
- var tOrNull = callAndCheck(gl, debug, function () { return returnTOrNull(); });
- if (tOrNull == null) {
- throw new Error(failureMessage);
- }
- return tOrNull;
- }
- function validateTextureUnit(gl, textureUnit) {
- var maxTextureUnit = gl.MAX_COMBINED_TEXTURE_IMAGE_UNITS - 1;
- var glTextureUnit = textureUnit + gl.TEXTURE0;
- if (glTextureUnit < gl.TEXTURE0 || glTextureUnit > maxTextureUnit) {
- var textureUnitRange = "[gl.TEXTURE0, gl.TEXTURE" + maxTextureUnit + "]";
- throw new Error("textureUnit must be in " + textureUnitRange + ".");
- }
- }
- function getBatchDim(shape, dimsToSkip) {
- if (dimsToSkip === void 0) { dimsToSkip = 2; }
- return sizeFromShape(shape.slice(0, shape.length - dimsToSkip));
- }
- function getRowsCols(shape) {
- if (shape.length === 0) {
- throw Error('Cannot get rows and columns of an empty shape array.');
- }
- return [
- shape.length > 1 ? shape[shape.length - 2] : 1, shape[shape.length - 1]
- ];
- }
- function getShapeAs3D(shape) {
- var shapeAs3D = [1, 1, 1];
- var isScalar = shape.length === 0 || (shape.length === 1 && shape[0] === 1);
- if (!isScalar) {
- shapeAs3D =
- [getBatchDim(shape)].concat(getRowsCols(shape));
- }
- return shapeAs3D;
- }
- function getTextureShapeFromLogicalShape(logShape, isPacked) {
- var _a;
- if (isPacked === void 0) { isPacked = false; }
- var maxTexSize = env().getNumber('WEBGL_MAX_TEXTURE_SIZE');
- if (isPacked) {
- maxTexSize = maxTexSize * 2;
- // This logic ensures we accurately count the number of packed texels needed
- // to accommodate the tensor. We can only pack values in the same texel if
- // they are from adjacent pairs of rows/cols within the same batch. So if a
- // tensor has 3 rows, we pretend it has 4 rows in order to account for the
- // fact that the texels containing the third row are half empty.
- logShape = logShape.map(function (d, i) { return i >= logShape.length - 2 ?
- nearestLargerEven(logShape[i]) :
- logShape[i]; });
- // Packed texture height is at least 2 (the channel height of a single
- // texel).
- if (logShape.length === 1) {
- logShape = [2, logShape[0]];
- }
- }
- // If logical shape is 2, we don't squeeze, since we want to match physical.
- if (logShape.length !== 2) {
- var squeezeResult = squeezeShape(logShape);
- logShape = squeezeResult.newShape;
- }
- var size = sizeFromShape(logShape);
- if (logShape.length <= 1 && size <= maxTexSize) {
- return [1, size];
- }
- else if (logShape.length === 2 && logShape[0] <= maxTexSize &&
- logShape[1] <= maxTexSize) {
- return logShape;
- }
- else if (logShape.length === 3 && logShape[0] * logShape[1] <= maxTexSize &&
- logShape[2] <= maxTexSize) {
- return [logShape[0] * logShape[1], logShape[2]];
- }
- else if (logShape.length === 3 && logShape[0] <= maxTexSize &&
- logShape[1] * logShape[2] <= maxTexSize) {
- return [logShape[0], logShape[1] * logShape[2]];
- }
- else if (logShape.length === 4 &&
- logShape[0] * logShape[1] * logShape[2] <= maxTexSize &&
- logShape[3] <= maxTexSize) {
- return [logShape[0] * logShape[1] * logShape[2], logShape[3]];
- }
- else if (logShape.length === 4 && logShape[0] <= maxTexSize &&
- logShape[1] * logShape[2] * logShape[3] <= maxTexSize) {
- return [logShape[0], logShape[1] * logShape[2] * logShape[3]];
- }
- else {
- if (isPacked) {
- // For packed textures size equals the number of channels required to
- // accommodate the texture data. However in order to squarify such that
- // inner dimensions stay even, we rewrite size to equal the number of
- // texels. Then in the return statement we rehydrate the squarified
- // dimensions to channel units.
- var batchDim = getBatchDim(logShape);
- var rows = 2, cols = 2;
- if (logShape.length) {
- _a = getRowsCols(logShape), rows = _a[0], cols = _a[1];
- }
- size = batchDim * (rows / 2) * (cols / 2);
- return sizeToSquarishShape(size).map(function (d) { return d * 2; });
- }
- return sizeToSquarishShape(size);
- }
- }
- function isEven(n) {
- return n % 2 === 0;
- }
- /**
- * This determines whether reshaping a packed texture requires rearranging
- * the data within the texture, assuming 2x2 packing.
- */
- function isReshapeFree(shape1, shape2) {
- shape1 = shape1.slice(-2);
- shape2 = shape2.slice(-2);
- if (arraysEqual(shape1, shape2)) {
- return true;
- }
- if (!shape1.length || !shape2.length) { // One of the shapes is a scalar.
- return true;
- }
- if (shape1[0] === 0 || shape1[1] === 0 || shape2[0] === 0 ||
- shape2[1] === 0) {
- return true;
- }
- if (shape1.length !== shape2.length) { // One of the shapes is a vector.
- var shape1Cols = shape1.slice(-1)[0];
- var shape2Cols = shape2.slice(-1)[0];
- if (shape1Cols === shape2Cols) {
- return true;
- }
- if (isEven(shape1Cols) && isEven(shape2Cols) &&
- (shape1[0] === 1 || shape2[0] === 1)) {
- return true;
- }
- }
- return shape1[1] === shape2[1] && isEven(shape1[0]) && isEven(shape2[0]);
- }
- // We cache webgl params because the environment gets reset between
- // unit tests and we don't want to constantly query the WebGLContext for
- // MAX_TEXTURE_SIZE.
- var MAX_TEXTURE_SIZE;
- var MAX_TEXTURES_IN_SHADER;
- function getWebGLMaxTextureSize(webGLVersion) {
- if (MAX_TEXTURE_SIZE == null) {
- var gl = getWebGLContext(webGLVersion);
- MAX_TEXTURE_SIZE = gl.getParameter(gl.MAX_TEXTURE_SIZE);
- }
- return MAX_TEXTURE_SIZE;
- }
- function resetMaxTextureSize() {
- MAX_TEXTURE_SIZE = null;
- }
- function resetMaxTexturesInShader() {
- MAX_TEXTURES_IN_SHADER = null;
- }
- function getMaxTexturesInShader(webGLVersion) {
- if (MAX_TEXTURES_IN_SHADER == null) {
- var gl = getWebGLContext(webGLVersion);
- MAX_TEXTURES_IN_SHADER = gl.getParameter(gl.MAX_TEXTURE_IMAGE_UNITS);
- }
- // We cap at 16 to avoid spurious runtime "memory exhausted" error.
- return Math.min(16, MAX_TEXTURES_IN_SHADER);
- }
- function getWebGLDisjointQueryTimerVersion(webGLVersion) {
- if (webGLVersion === 0) {
- return 0;
- }
- var queryTimerVersion;
- var gl = getWebGLContext(webGLVersion);
- if (hasExtension(gl, 'EXT_disjoint_timer_query_webgl2') &&
- webGLVersion === 2) {
- queryTimerVersion = 2;
- }
- else if (hasExtension(gl, 'EXT_disjoint_timer_query')) {
- queryTimerVersion = 1;
- }
- else {
- queryTimerVersion = 0;
- }
- return queryTimerVersion;
- }
- function hasExtension(gl, extensionName) {
- var ext = gl.getExtension(extensionName);
- return ext != null;
- }
- function isWebGLVersionEnabled(webGLVersion) {
- try {
- var gl = getWebGLContext(webGLVersion);
- if (gl != null) {
- return true;
- }
- }
- catch (e) {
- return false;
- }
- return false;
- }
- function isCapableOfRenderingToFloatTexture(webGLVersion) {
- if (webGLVersion === 0) {
- return false;
- }
- var gl = getWebGLContext(webGLVersion);
- if (webGLVersion === 1) {
- if (!hasExtension(gl, 'OES_texture_float')) {
- return false;
- }
- }
- else {
- if (!hasExtension(gl, 'EXT_color_buffer_float')) {
- return false;
- }
- }
- var isFrameBufferComplete = createFloatTextureAndBindToFramebuffer(gl);
- return isFrameBufferComplete;
- }
- /**
- * Check if we can download values from a float/half-float texture.
- *
- * Note that for performance reasons we use binding a texture to a framebuffer
- * as a proxy for ability to download float values later using readPixels. The
- * texture params of this texture will not match those in readPixels exactly
- * but if we are unable to bind some kind of float texture to the frameBuffer
- * then we definitely will not be able to read float values from it.
- */
- function isDownloadFloatTextureEnabled(webGLVersion) {
- if (webGLVersion === 0) {
- return false;
- }
- var gl = getWebGLContext(webGLVersion);
- if (webGLVersion === 1) {
- if (!hasExtension(gl, 'OES_texture_float')) {
- return false;
- }
- if (!hasExtension(gl, 'WEBGL_color_buffer_float')) {
- return false;
- }
- }
- else {
- if (hasExtension(gl, 'EXT_color_buffer_float')) {
- return createFloatTextureAndBindToFramebuffer(gl);
- }
- var COLOR_BUFFER_HALF_FLOAT = 'EXT_color_buffer_half_float';
- if (hasExtension(gl, COLOR_BUFFER_HALF_FLOAT)) {
- var textureHalfFloatExtension = gl.getExtension(COLOR_BUFFER_HALF_FLOAT);
- return createHalfFloatTextureAndBindToFramebuffer(gl, textureHalfFloatExtension);
- }
- return false;
- }
- var isFrameBufferComplete = createFloatTextureAndBindToFramebuffer(gl);
- return isFrameBufferComplete;
- }
- function createFloatTextureAndBindToFramebuffer(gl) {
- var texConfig = getTextureConfig(gl);
- var texture = gl.createTexture();
- gl.bindTexture(gl.TEXTURE_2D, texture);
- var width = 1;
- var height = 1;
- gl.texImage2D(gl.TEXTURE_2D, 0, texConfig.internalFormatFloat, width, height, 0, texConfig.textureFormatFloat, texConfig.textureTypeFloat, null);
- var frameBuffer = gl.createFramebuffer();
- gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer);
- gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0);
- var isFrameBufferComplete = gl.checkFramebufferStatus(gl.FRAMEBUFFER) === gl.FRAMEBUFFER_COMPLETE;
- gl.bindTexture(gl.TEXTURE_2D, null);
- gl.bindFramebuffer(gl.FRAMEBUFFER, null);
- gl.deleteTexture(texture);
- gl.deleteFramebuffer(frameBuffer);
- return isFrameBufferComplete;
- }
- function createHalfFloatTextureAndBindToFramebuffer(
- // tslint:disable-next-line:no-any
- gl, textureHalfFloatExtension) {
- var texConfig = getTextureConfig(gl, textureHalfFloatExtension);
- var texture = gl.createTexture();
- gl.bindTexture(gl.TEXTURE_2D, texture);
- var width = 1;
- var height = 1;
- gl.texImage2D(gl.TEXTURE_2D, 0, texConfig.internalFormatHalfFloat, width, height, 0, texConfig.textureFormatFloat, texConfig.textureTypeHalfFloat, null);
- var frameBuffer = gl.createFramebuffer();
- gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer);
- gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0);
- var isFrameBufferComplete = gl.checkFramebufferStatus(gl.FRAMEBUFFER) === gl.FRAMEBUFFER_COMPLETE;
- gl.bindTexture(gl.TEXTURE_2D, null);
- gl.bindFramebuffer(gl.FRAMEBUFFER, null);
- gl.deleteTexture(texture);
- gl.deleteFramebuffer(frameBuffer);
- return isFrameBufferComplete;
- }
- function isWebGLFenceEnabled(webGLVersion) {
- if (webGLVersion !== 2) {
- return false;
- }
- var gl = getWebGLContext(webGLVersion);
- // tslint:disable-next-line:no-any
- var isEnabled = gl.fenceSync != null;
- return isEnabled;
- }
-
- var webgl_util = /*#__PURE__*/Object.freeze({
- callAndCheck: callAndCheck,
- canBeRepresented: canBeRepresented,
- getWebGLErrorMessage: getWebGLErrorMessage,
- getExtensionOrThrow: getExtensionOrThrow,
- createVertexShader: createVertexShader,
- createFragmentShader: createFragmentShader,
- createProgram: createProgram,
- linkProgram: linkProgram,
- validateProgram: validateProgram,
- createStaticVertexBuffer: createStaticVertexBuffer,
- createStaticIndexBuffer: createStaticIndexBuffer,
- getNumChannels: getNumChannels,
- createTexture: createTexture,
- validateTextureSize: validateTextureSize,
- createFramebuffer: createFramebuffer,
- bindVertexBufferToProgramAttribute: bindVertexBufferToProgramAttribute,
- bindTextureUnit: bindTextureUnit,
- unbindTextureUnit: unbindTextureUnit,
- getProgramUniformLocationOrThrow: getProgramUniformLocationOrThrow,
- getProgramUniformLocation: getProgramUniformLocation,
- bindTextureToProgramUniformSampler: bindTextureToProgramUniformSampler,
- bindCanvasToFramebuffer: bindCanvasToFramebuffer,
- bindColorTextureToFramebuffer: bindColorTextureToFramebuffer,
- unbindColorTextureFromFramebuffer: unbindColorTextureFromFramebuffer,
- validateFramebuffer: validateFramebuffer,
- getFramebufferErrorMessage: getFramebufferErrorMessage,
- getBatchDim: getBatchDim,
- getRowsCols: getRowsCols,
- getShapeAs3D: getShapeAs3D,
- getTextureShapeFromLogicalShape: getTextureShapeFromLogicalShape,
- isReshapeFree: isReshapeFree,
- getWebGLMaxTextureSize: getWebGLMaxTextureSize,
- resetMaxTextureSize: resetMaxTextureSize,
- resetMaxTexturesInShader: resetMaxTexturesInShader,
- getMaxTexturesInShader: getMaxTexturesInShader,
- getWebGLDisjointQueryTimerVersion: getWebGLDisjointQueryTimerVersion,
- hasExtension: hasExtension,
- isWebGLVersionEnabled: isWebGLVersionEnabled,
- isCapableOfRenderingToFloatTexture: isCapableOfRenderingToFloatTexture,
- isDownloadFloatTextureEnabled: isDownloadFloatTextureEnabled,
- isWebGLFenceEnabled: isWebGLFenceEnabled
- });
-
- /**
- * @license
- * Copyright 2019 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var ENV$1 = env();
- /**
- * This file contains WebGL-specific flag registrations.
- */
- /**
- * True if WebGL is supported.
- */
- ENV$1.registerFlag('HAS_WEBGL', function () { return ENV$1.getNumber('WEBGL_VERSION') > 0; });
- /** 0: No WebGL, 1: WebGL 1.0, 2: WebGL 2.0. */
- ENV$1.registerFlag('WEBGL_VERSION', function () {
- if (isWebGLVersionEnabled(2)) {
- return 2;
- }
- else if (isWebGLVersionEnabled(1)) {
- return 1;
- }
- return 0;
- });
- ENV$1.registerFlag('WEBGL_BUFFER_SUPPORTED', function () { return ENV$1.get('WEBGL_VERSION') === 2; });
- /** Whether the WebGL backend will sometimes forward ops to the CPU. */
- ENV$1.registerFlag('WEBGL_CPU_FORWARD', function () { return true; });
- /** Whether the WebGL backend will always use f16 textures for rendering. */
- ENV$1.registerFlag('WEBGL_FORCE_F16_TEXTURES', function () { return false; });
- /** Whether to turn all packing related flags on. */
- ENV$1.registerFlag('WEBGL_PACK', function () { return ENV$1.getBool('HAS_WEBGL'); });
- /** Whether we will pack the batchnormalization op. */
- ENV$1.registerFlag('WEBGL_PACK_NORMALIZATION', function () { return ENV$1.getBool('WEBGL_PACK'); });
- /** Whether we will pack the clip op. */
- ENV$1.registerFlag('WEBGL_PACK_CLIP', function () { return ENV$1.getBool('WEBGL_PACK'); });
- /** Whether we will pack the depthwise conv op. */
- // TODO: https://github.com/tensorflow/tfjs/issues/1679
- ENV$1.registerFlag('WEBGL_PACK_DEPTHWISECONV', function () { return false; });
- /** Whether we will pack binary ops. */
- ENV$1.registerFlag('WEBGL_PACK_BINARY_OPERATIONS', function () { return ENV$1.getBool('WEBGL_PACK'); });
- /** Whether we will pack unary ops. */
- ENV$1.registerFlag('WEBGL_PACK_UNARY_OPERATIONS', function () { return ENV$1.getBool('WEBGL_PACK'); });
- /** Whether we will pack array ops. */
- ENV$1.registerFlag('WEBGL_PACK_ARRAY_OPERATIONS', function () { return ENV$1.getBool('WEBGL_PACK'); });
- /** Whether we will pack image ops. */
- ENV$1.registerFlag('WEBGL_PACK_IMAGE_OPERATIONS', function () { return ENV$1.getBool('WEBGL_PACK'); });
- /** Whether we will pack reduce ops. */
- ENV$1.registerFlag('WEBGL_PACK_REDUCE', function () { return ENV$1.getBool('WEBGL_PACK'); });
- /** Whether packed WebGL kernels lazily unpack their outputs. */
- ENV$1.registerFlag('WEBGL_LAZILY_UNPACK', function () { return ENV$1.getBool('WEBGL_PACK'); });
- /** Whether we will use the im2col algorithm to speed up convolutions. */
- ENV$1.registerFlag('WEBGL_CONV_IM2COL', function () { return ENV$1.getBool('WEBGL_PACK'); });
- /** The maximum texture dimension. */
- ENV$1.registerFlag('WEBGL_MAX_TEXTURE_SIZE', function () { return getWebGLMaxTextureSize(ENV$1.getNumber('WEBGL_VERSION')); });
- /** The maximum texture dimension. */
- ENV$1.registerFlag('WEBGL_MAX_TEXTURES_IN_SHADER', function () { return getMaxTexturesInShader(ENV$1.getNumber('WEBGL_VERSION')); });
- /**
- * The disjoint_query_timer extension version.
- * 0: disabled, 1: EXT_disjoint_timer_query, 2:
- * EXT_disjoint_timer_query_webgl2.
- * In Firefox with WebGL 2.0,
- * EXT_disjoint_timer_query_webgl2 is not available, so we must use the
- * WebGL 1.0 extension.
- */
- ENV$1.registerFlag('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION', function () {
- var webGLVersion = ENV$1.getNumber('WEBGL_VERSION');
- if (webGLVersion === 0) {
- return 0;
- }
- return getWebGLDisjointQueryTimerVersion(webGLVersion);
- });
- /**
- * Whether the timer object from the disjoint_query_timer extension gives
- * timing information that is reliable.
- */
- ENV$1.registerFlag('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE', function () { return ENV$1.getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0 &&
- !isMobile(); });
- /**
- * Whether the device is physically capable of rendering to float32 textures.
- */
- ENV$1.registerFlag('WEBGL_RENDER_FLOAT32_CAPABLE', function () { return isCapableOfRenderingToFloatTexture(ENV$1.getNumber('WEBGL_VERSION')); });
- /**
- * Whether rendering to float32 textures is enabled. If disabled, renders to
- * float16 textures.
- */
- ENV$1.registerFlag('WEBGL_RENDER_FLOAT32_ENABLED', function () {
- return ENV$1.getBool('WEBGL_FORCE_F16_TEXTURES') ?
- false :
- ENV$1.getBool('WEBGL_RENDER_FLOAT32_CAPABLE');
- });
- /**
- * Whether downloading float textures is enabled (16 or 32 bit). If disabled,
- * uses IEEE 754 encoding of the float32 values to 4 uint8 when downloading.
- */
- ENV$1.registerFlag('WEBGL_DOWNLOAD_FLOAT_ENABLED', function () { return isDownloadFloatTextureEnabled(ENV$1.getNumber('WEBGL_VERSION')); });
- /** Whether the fence API is available. */
- ENV$1.registerFlag('WEBGL_FENCE_API_ENABLED', function () { return isWebGLFenceEnabled(ENV$1.getNumber('WEBGL_VERSION')); });
- /**
- * Tensors with size <= than this will be uploaded as uniforms, not textures.
- */
- ENV$1.registerFlag('WEBGL_SIZE_UPLOAD_UNIFORM', function () {
- // Use uniform uploads only when 32bit floats are supported. In
- // 16bit
- // environments there are problems with comparing a 16bit texture value
- // with a 32bit uniform value.
- var useUniforms = ENV$1.getBool('WEBGL_RENDER_FLOAT32_ENABLED');
- return useUniforms ? 4 : 0;
- });
-
- /**
- * @license
- * Copyright 2018 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Enables production mode which disables correctness checks in favor of
- * performance.
- */
- /** @doc {heading: 'Environment'} */
- function enableProdMode() {
- env().set('PROD', true);
- }
- /**
- * Enables debug mode which will log information about all executed kernels:
- * the elapsed time of the kernel execution, as well as the rank, shape, and
- * size of the output tensor.
- *
- * Debug mode will significantly slow down your application as it will
- * download the result of every operation to the CPU. This should not be used in
- * production. Debug mode does not affect the timing information of the kernel
- * execution as we do not measure download time in the kernel execution time.
- *
- * See also: `tf.profile`, `tf.memory`.
- */
- /** @doc {heading: 'Environment'} */
- function enableDebugMode() {
- env().set('DEBUG', true);
- }
- /** Globally disables deprecation warnings */
- function disableDeprecationWarnings() {
- env().set('DEPRECATION_WARNINGS_ENABLED', false);
- console.warn("TensorFlow.js deprecation warnings have been disabled.");
- }
- /** Warn users about deprecated functionality. */
- function deprecationWarn(msg) {
- if (env().getBool('DEPRECATION_WARNINGS_ENABLED')) {
- console.warn(msg + ' You can disable deprecation warnings with ' +
- 'tf.disableDeprecationWarnings().');
- }
- }
- setDeprecationWarningFn(deprecationWarn);
- /**
- * Dispose all variables kept in backend engine.
- */
- /** @doc {heading: 'Environment'} */
- function disposeVariables() {
- ENGINE.disposeVariables();
- }
- /**
- * It returns the global engine that keeps track of all tensors and backends.
- */
- /** @doc {heading: 'Environment'} */
- function engine() {
- return ENGINE;
- }
- /**
- * Returns memory info at the current time in the program. The result is an
- * object with the following properties:
- *
- * - `numBytes`: Number of bytes allocated (undisposed) at this time.
- * - `numTensors`: Number of unique tensors allocated.
- * - `numDataBuffers`: Number of unique data buffers allocated
- * (undisposed) at this time, which is ≤ the number of tensors
- * (e.g. `a.reshape(newShape)` makes a new Tensor that shares the same
- * data buffer with `a`).
- * - `unreliable`: True if the memory usage is unreliable. See `reasons` when
- * `unreliable` is true.
- * - `reasons`: `string[]`, reasons why the memory is unreliable, present if
- * `unreliable` is true.
- *
- * WebGL Properties:
- * - `numBytesInGPU`: Number of bytes allocated (undisposed) in the GPU only at
- * this time.
- */
- /** @doc {heading: 'Performance', subheading: 'Memory'} */
- function memory() {
- return ENGINE.memory();
- }
- /**
- * Executes the provided function `f()` and returns a promise that resolves
- * with information about the function's memory use:
- * - `newBytes`: the number of new bytes allocated
- * - `newTensors`: the number of new tensors created
- * - `peakBytes`: the peak number of bytes allocated
- * - `kernels`: an array of objects for each kernel involved that reports
- * their input and output shapes, number of bytes used, and number of new
- * tensors created.
- *
- * ```js
- * const profile = await tf.profile(() => {
- * const x = tf.tensor1d([1, 2, 3]);
- * let x2 = x.square();
- * x2.dispose();
- * x2 = x.square();
- * x2.dispose();
- * return x;
- * });
- *
- * console.log(`newBytes: ${profile.newBytes}`);
- * console.log(`newTensors: ${profile.newTensors}`);
- * console.log(`byte usage over all kernels: ${profile.kernels.map(k =>
- * k.totalBytesSnapshot)}`);
- * ```
- *
- */
- /** @doc {heading: 'Performance', subheading: 'Profile'} */
- function profile(f) {
- return ENGINE.profile(f);
- }
- /**
- * Executes the provided function `fn` and after it is executed, cleans up all
- * intermediate tensors allocated by `fn` except those returned by `fn`.
- * `fn` must not return a Promise (async functions not allowed). The returned
- * result can be a complex object.
- *
- * Using this method helps avoid memory leaks. In general, wrap calls to
- * operations in `tf.tidy` for automatic memory cleanup.
- *
- * NOTE: Variables do *not* get cleaned up when inside a tidy(). If you want to
- * dispose variables, please use `tf.disposeVariables` or call dispose()
- * directly on variables.
- *
- * ```js
- * // y = 2 ^ 2 + 1
- * const y = tf.tidy(() => {
- * // a, b, and one will be cleaned up when the tidy ends.
- * const one = tf.scalar(1);
- * const a = tf.scalar(2);
- * const b = a.square();
- *
- * console.log('numTensors (in tidy): ' + tf.memory().numTensors);
- *
- * // The value returned inside the tidy function will return
- * // through the tidy, in this case to the variable y.
- * return b.add(one);
- * });
- *
- * console.log('numTensors (outside tidy): ' + tf.memory().numTensors);
- * y.print();
- * ```
- *
- * @param nameOrFn The name of the closure, or the function to execute.
- * If a name is provided, the 2nd argument should be the function.
- * If debug mode is on, the timing and the memory usage of the function
- * will be tracked and displayed on the console using the provided name.
- * @param fn The function to execute.
- */
- /** @doc {heading: 'Performance', subheading: 'Memory'} */
- function tidy(nameOrFn, fn) {
- return ENGINE.tidy(nameOrFn, fn);
- }
- /**
- * Disposes any `tf.Tensor`s found within the provided object.
- *
- * @param container an object that may be a `tf.Tensor` or may directly
- * contain `tf.Tensor`s, such as a `Tensor[]` or `{key: Tensor, ...}`. If
- * the object is not a `tf.Tensor` or does not contain `Tensors`, nothing
- * happens. In general it is safe to pass any object here, except that
- * `Promise`s are not supported.
- */
- /** @doc {heading: 'Performance', subheading: 'Memory'} */
- function dispose(container) {
- var tensors = getTensorsInContainer(container);
- tensors.forEach(function (tensor) { return tensor.dispose(); });
- }
- /**
- * Keeps a `tf.Tensor` generated inside a `tf.tidy` from being disposed
- * automatically.
- *
- * ```js
- * let b;
- * const y = tf.tidy(() => {
- * const one = tf.scalar(1);
- * const a = tf.scalar(2);
- *
- * // b will not be cleaned up by the tidy. a and one will be cleaned up
- * // when the tidy ends.
- * b = tf.keep(a.square());
- *
- * console.log('numTensors (in tidy): ' + tf.memory().numTensors);
- *
- * // The value returned inside the tidy function will return
- * // through the tidy, in this case to the variable y.
- * return b.add(one);
- * });
- *
- * console.log('numTensors (outside tidy): ' + tf.memory().numTensors);
- * console.log('y:');
- * y.print();
- * console.log('b:');
- * b.print();
- * ```
- *
- * @param result The tensor to keep from being disposed.
- */
- /** @doc {heading: 'Performance', subheading: 'Memory'} */
- function keep(result) {
- return ENGINE.keep(result);
- }
- /**
- * Executes `f()` and returns a promise that resolves with timing
- * information.
- *
- * The result is an object with the following properties:
- *
- * - `wallMs`: Wall execution time.
- * - `kernelMs`: Kernel execution time, ignoring data transfer. If using the
- * WebGL backend and the query timer extension is not available, this will
- * return an error object.
- * - On `WebGL` The following additional properties exist:
- * - `uploadWaitMs`: CPU blocking time on texture uploads.
- * - `downloadWaitMs`: CPU blocking time on texture downloads (readPixels).
- *
- * ```js
- * const x = tf.randomNormal([20, 20]);
- * const time = await tf.time(() => x.matMul(x));
- *
- * console.log(`kernelMs: ${time.kernelMs}, wallTimeMs: ${time.wallMs}`);
- * ```
- *
- * @param f The function to execute and time.
- */
- /** @doc {heading: 'Performance', subheading: 'Timing'} */
- function time(f) {
- return ENGINE.time(f);
- }
- /**
- * Sets the backend (cpu, webgl, wasm, etc) responsible for creating tensors and
- * executing operations on those tensors. Returns a promise that resolves
- * to a boolean if the backend initialization was successful.
- *
- * Note this disposes the current backend, if any, as well as any tensors
- * associated with it. A new backend is initialized, even if it is of the
- * same type as the previous one.
- *
- * @param backendName The name of the backend. Currently supports
- * `'webgl'|'cpu'` in the browser, `'tensorflow'` under node.js
- * (requires tfjs-node), and `'wasm'` (requires tfjs-backend-wasm).
- */
- /** @doc {heading: 'Backends'} */
- function setBackend(backendName) {
- return ENGINE.setBackend(backendName);
- }
- /**
- * Returns a promise that resolves when the currently selected backend (or the
- * highest priority one) has initialized. Await this promise when you are using
- * a backend that has async initialization.
- */
- /** @doc {heading: 'Backends'} */
- function ready() {
- return ENGINE.ready();
- }
- /**
- * Returns the current backend name (cpu, webgl, etc). The backend is
- * responsible for creating tensors and executing operations on those tensors.
- */
- /** @doc {heading: 'Backends'} */
- function getBackend() {
- return ENGINE.backendName;
- }
- /**
- * Removes a backend and the registered factory.
- */
- /** @doc {heading: 'Backends'} */
- function removeBackend(name) {
- ENGINE.removeBackend(name);
- }
- /**
- * Finds the backend registered under the provided name. Returns null if the
- * name is not in the registry, or the registration hasn't finished yet.
- */
- function findBackend(name) {
- return ENGINE.findBackend(name);
- }
- /**
- * Finds the backend factory registered under the provided name. Returns a
- * function that produces a new backend when called. Returns null if the name
- * is not in the registry.
- */
- function findBackendFactory(name) {
- return ENGINE.findBackendFactory(name);
- }
- /**
- * Registers a global backend. The registration should happen when importing
- * a module file (e.g. when importing `backend_webgl.ts`), and is used for
- * modular builds (e.g. custom tfjs bundle with only webgl support).
- *
- * @param factory The backend factory function. When called, it should
- * return a backend instance, or a promise of an instance.
- * @param priority The priority of the backend (higher = more important).
- * In case multiple backends are registered, the priority is used to find
- * the best backend. Defaults to 1.
- * @return False if there is already a registered backend under this name, true
- * if not.
- */
- /** @doc {heading: 'Backends'} */
- function registerBackend(name, factory, priority) {
- if (priority === void 0) { priority = 1; }
- return ENGINE.registerBackend(name, factory, priority);
- }
- /**
- * Gets the current backend. If no backends have been initialized, this will
- * attempt to initialize the best backend. Will throw an error if the highest
- * priority backend has async initialization, in which case, you should call
- * 'await tf.ready()' before running other code.
- */
- /** @doc {heading: 'Backends'} */
- function backend() {
- return ENGINE.backend;
- }
- /**
- * Sets the global platform.
- *
- * @param platformName The name of this platform.
- * @param platform A platform implementation.
- */
- function setPlatform(platformName, platform) {
- env().setPlatform(platformName, platform);
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function warn() {
- var msg = [];
- for (var _i = 0; _i < arguments.length; _i++) {
- msg[_i] = arguments[_i];
- }
- if (!env().getBool('IS_TEST')) {
- console.warn.apply(console, msg);
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function inferShape(val, dtype) {
- var firstElem = val;
- if (isTypedArray(val)) {
- return dtype === 'string' ? [] : [val.length];
- }
- if (!Array.isArray(val)) {
- return []; // Scalar.
- }
- var shape = [];
- while (Array.isArray(firstElem) ||
- isTypedArray(firstElem) && dtype !== 'string') {
- shape.push(firstElem.length);
- firstElem = firstElem[0];
- }
- if (Array.isArray(val) &&
- env().getBool('TENSORLIKE_CHECK_SHAPE_CONSISTENCY')) {
- deepAssertShapeConsistency(val, shape, []);
- }
- return shape;
- }
- function deepAssertShapeConsistency(val, shape, indices) {
- indices = indices || [];
- if (!(Array.isArray(val)) && !isTypedArray(val)) {
- assert(shape.length === 0, function () { return "Element arr[" + indices.join('][') + "] is a primitive, " +
- ("but should be an array/TypedArray of " + shape[0] + " elements"); });
- return;
- }
- assert(shape.length > 0, function () { return "Element arr[" + indices.join('][') + "] should be a primitive, " +
- ("but is an array of " + val.length + " elements"); });
- assert(val.length === shape[0], function () { return "Element arr[" + indices.join('][') + "] should have " + shape[0] + " " +
- ("elements, but has " + val.length + " elements"); });
- var subShape = shape.slice(1);
- for (var i = 0; i < val.length; ++i) {
- deepAssertShapeConsistency(val[i], subShape, indices.concat(i));
- }
- }
- function assertDtype(expectedDtype, actualDType, argName, functionName) {
- if (expectedDtype == null) {
- return;
- }
- if (expectedDtype !== 'numeric' && expectedDtype !== actualDType ||
- expectedDtype === 'numeric' && actualDType === 'string') {
- throw new Error("Argument '" + argName + "' passed to '" + functionName + "' must " +
- ("be " + expectedDtype + " tensor, but got " + actualDType + " tensor"));
- }
- }
- function convertToTensor(x, argName, functionName, parseAsDtype) {
- if (parseAsDtype === void 0) { parseAsDtype = 'numeric'; }
- if (x instanceof Tensor) {
- assertDtype(parseAsDtype, x.dtype, argName, functionName);
- return x;
- }
- var inferredDtype = inferDtype(x);
- // If the user expects a bool/int/float, use that info to update the
- // inferredDtype when it is not a string.
- if (inferredDtype !== 'string' &&
- ['bool', 'int32', 'float32'].indexOf(parseAsDtype) >= 0) {
- inferredDtype = parseAsDtype;
- }
- assertDtype(parseAsDtype, inferredDtype, argName, functionName);
- if ((x == null) ||
- (!isTypedArray(x) && !Array.isArray(x) && typeof x !== 'number' &&
- typeof x !== 'boolean' && typeof x !== 'string')) {
- var type = x == null ? 'null' : x.constructor.name;
- throw new Error("Argument '" + argName + "' passed to '" + functionName + "' must be a " +
- ("Tensor or TensorLike, but got '" + type + "'"));
- }
- var inferredShape = inferShape(x, inferredDtype);
- if (!isTypedArray(x) && !Array.isArray(x)) {
- x = [x];
- }
- var skipTypedArray = true;
- var values = inferredDtype !== 'string' ?
- toTypedArray(x, inferredDtype, env().getBool('DEBUG')) :
- flatten(x, [], skipTypedArray);
- return ENGINE.makeTensor(values, inferredShape, inferredDtype);
- }
- function convertToTensorArray(arg, argName, functionName, parseAsDtype) {
- if (parseAsDtype === void 0) { parseAsDtype = 'numeric'; }
- if (!Array.isArray(arg)) {
- throw new Error("Argument " + argName + " passed to " + functionName + " must be a " +
- '`Tensor[]` or `TensorLike[]`');
- }
- var tensors = arg;
- return tensors.map(function (t, i) { return convertToTensor(t, argName + "[" + i + "]", functionName); }, parseAsDtype);
- }
-
- /**
- * @license
- * Copyright 2017 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Returns true if the axis specifies the inner most dimensions of the
- * array.
- */
- function axesAreInnerMostDims(axes, rank) {
- for (var i = 0; i < axes.length; ++i) {
- if (axes[axes.length - i - 1] !== rank - 1 - i) {
- return false;
- }
- }
- return true;
- }
- function combineLocations(outputLoc, reduceLoc, axes) {
- var rank = outputLoc.length + reduceLoc.length;
- var loc = [];
- var outIdx = 0;
- var reduceIdx = 0;
- for (var dim = 0; dim < rank; dim++) {
- if (axes.indexOf(dim) === -1) {
- loc.push(outputLoc[outIdx++]);
- }
- else {
- loc.push(reduceLoc[reduceIdx++]);
- }
- }
- return loc;
- }
- function computeOutAndReduceShapes(aShape, axes) {
- var outShape = [];
- var rank = aShape.length;
- for (var dim = 0; dim < rank; dim++) {
- if (axes.indexOf(dim) === -1) {
- outShape.push(aShape[dim]);
- }
- }
- var reduceShape = axes.map(function (dim) { return aShape[dim]; });
- return [outShape, reduceShape];
- }
- function expandShapeToKeepDim(shape, axes) {
- var reduceSubShape = axes.map(function (x) { return 1; });
- return combineLocations(shape, reduceSubShape, axes);
- }
- function assertAxesAreInnerMostDims(msg, axes, rank) {
- assert(axesAreInnerMostDims(axes, rank), function () { return msg + " supports only inner-most axes for now. " +
- ("Got axes " + axes + " and rank-" + rank + " input."); });
- }
- /**
- * Returns the axes permutation to be used with `tf.transpose`, if such
- * permutation is necessary. Otherwise it returns null. This method is used by
- * operations that operate only on inner-most axes.
- */
- function getAxesPermutation(axes, rank) {
- if (axesAreInnerMostDims(axes, rank)) {
- return null;
- }
- var result = [];
- for (var i = 0; i < rank; ++i) {
- if (axes.indexOf(i) === -1) {
- result.push(i);
- }
- }
- axes.forEach(function (axis) { return result.push(axis); });
- return result;
- }
- /** Returns the axes permutation that undoes the original permutation. */
- function getUndoAxesPermutation(axes) {
- return axes.map(function (axis, i) { return [i, axis]; })
- .sort(function (a, b) { return a[1] - b[1]; })
- .map(function (x) { return x[0]; });
- }
- function getInnerMostAxes(numAxes, rank) {
- var res = [];
- for (var i = rank - numAxes; i < rank; ++i) {
- res.push(i);
- }
- return res;
- }
-
- /**
- * @license
- * Copyright 2017 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function assertParamsConsistent(shapes, axis) {
- var rank = shapes[0].length;
- shapes.forEach(function (shape, i) {
- assert(shape.length === rank, function () {
- return "Error in concat" + rank + "D: rank of tensors[" + i + "] must be the same " +
- ("as the rank of the rest (" + rank + ")");
- });
- });
- assert(axis >= 0 && axis < rank, function () { return "Error in concat" + rank + "D: axis must be between 0 and " + (rank - 1) + "."; });
- var firstShape = shapes[0];
- shapes.forEach(function (shape, i) {
- for (var r = 0; r < rank; r++) {
- assert((r === axis) || (shape[r] === firstShape[r]), function () { return "Error in concat" + rank + "D: Shape of tensors[" + i + "] (" + shape + ") " +
- ("does not match the shape of the rest (" + firstShape + ") ") +
- ("along the non-concatenated axis " + i + "."); });
- }
- });
- }
- function computeOutShape(shapes, axis) {
- var outputShape = shapes[0].slice();
- for (var i = 1; i < shapes.length; i++) {
- outputShape[axis] += shapes[i][axis];
- }
- return outputShape;
- }
-
- /**
- * @license
- * Copyright 2018 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Used for wrapping functions that perform math operations on
- * Tensors. The function will be wrapped in a named scope that cleans all
- * memory usage after the function is done.
- */
- function op(f) {
- var keys = Object.keys(f);
- if (keys.length !== 1) {
- throw new Error("Please provide an object with a single key " +
- "(operation name) mapping to a function. Got an object with " +
- (keys.length + " keys."));
- }
- var opName = keys[0];
- var fn = f[opName];
- // Strip the underscore from the end of the function name.
- if (opName.endsWith('_')) {
- opName = opName.substring(0, opName.length - 1);
- }
- // tslint:disable-next-line:no-any
- var f2 = function () {
- var args = [];
- for (var _i = 0; _i < arguments.length; _i++) {
- args[_i] = arguments[_i];
- }
- ENGINE.startScope(opName);
- try {
- var result = fn.apply(void 0, args);
- if (result instanceof Promise) {
- console.error('Cannot return a Promise inside of tidy.');
- }
- ENGINE.endScope(result);
- return result;
- }
- catch (ex) {
- ENGINE.endScope(null);
- throw ex;
- }
- };
- Object.defineProperty(f2, 'name', { value: opName, configurable: true });
- // tslint:disable-next-line:no-any
- return f2;
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Converts two real numbers to a complex number.
- *
- * Given a tensor `real` representing the real part of a complex number, and a
- * tensor `imag` representing the imaginary part of a complex number, this
- * operation returns complex numbers elementwise of the form [r0, i0, r1, i1],
- * where r represents the real part and i represents the imag part.
- *
- * The input tensors real and imag must have the same shape.
- *
- * ```js
- * const real = tf.tensor1d([2.25, 3.25]);
- * const imag = tf.tensor1d([4.75, 5.75]);
- * const complex = tf.complex(real, imag);
- *
- * complex.print();
- * ```
- */
- /** @doc {heading: 'Tensors', subheading: 'Creation'} */
- function complex_(real, imag) {
- var $real = convertToTensor(real, 'real', 'complex');
- var $imag = convertToTensor(imag, 'imag', 'complex');
- assertShapesMatch($real.shape, $imag.shape, "real and imag shapes, " + $real.shape + " and " + $imag.shape + ", " +
- "must match in call to tf.complex().");
- return ENGINE.runKernelFunc(function (backend) { return backend.complex($real, $imag); }, { $real: $real, $imag: $imag });
- }
- /**
- * Returns the real part of a complex (or real) tensor.
- *
- * Given a tensor input, this operation returns a tensor of type float that is
- * the real part of each element in input considered as a complex number.
- *
- * If the input is real, it simply makes a clone.
- *
- * ```js
- * const x = tf.complex([-2.25, 3.25], [4.75, 5.75]);
- * tf.real(x).print();
- * ```
- */
- /** @doc {heading: 'Tensors', subheading: 'Creation'} */
- function real_(input) {
- var $input = convertToTensor(input, 'input', 'real');
- return ENGINE.runKernelFunc(function (backend) { return backend.real($input); }, { $input: $input });
- }
- /**
- * Returns the imaginary part of a complex (or real) tensor.
- *
- * Given a tensor input, this operation returns a tensor of type float that is
- * the imaginary part of each element in input considered as a complex number.
- * If input is real, a tensor of all zeros is returned.
- *
- * ```js
- * const x = tf.complex([-2.25, 3.25], [4.75, 5.75]);
- * tf.imag(x).print();
- * ```
- */
- /** @doc {heading: 'Tensors', subheading: 'Creation'} */
- function imag_(input) {
- var $input = convertToTensor(input, 'input', 'imag');
- return ENGINE.runKernelFunc(function (backend) { return backend.imag($input); }, { $input: $input });
- }
- var complex = op({ complex_: complex_ });
- var real = op({ real_: real_ });
- var imag = op({ imag_: imag_ });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Creates a `tf.Tensor` with the provided values, shape and dtype.
- *
- * ```js
- * // Pass an array of values to create a vector.
- * tf.tensor([1, 2, 3, 4]).print();
- * ```
- *
- * ```js
- * // Pass a nested array of values to make a matrix or a higher
- * // dimensional tensor.
- * tf.tensor([[1, 2], [3, 4]]).print();
- * ```
- *
- * ```js
- * // Pass a flat array and specify a shape yourself.
- * tf.tensor([1, 2, 3, 4], [2, 2]).print();
- * ```
- *
- * @param values The values of the tensor. Can be nested array of numbers,
- * or a flat array, or a `TypedArray`. If the values are strings,
- * they will be encoded as utf-8 and kept as `Uint8Array[]`.
- * @param shape The shape of the tensor. Optional. If not provided,
- * it is inferred from `values`.
- * @param dtype The data type.
- */
- /** @doc {heading: 'Tensors', subheading: 'Creation'} */
- function tensor(values, shape, dtype) {
- var inferredShape = inferShape(values, dtype);
- return makeTensor(values, shape, inferredShape, dtype);
- }
- /** This is shared code across all tensor creation methods. */
- function makeTensor(values, shape, inferredShape, dtype) {
- if (dtype == null) {
- dtype = inferDtype(values);
- }
- if (dtype === 'complex64') {
- throw new Error("Cannot construct a complex64 tensor directly. " +
- "Please use tf.complex(real, imag).");
- }
- if (!isTypedArray(values) && !Array.isArray(values) &&
- typeof values !== 'number' && typeof values !== 'boolean' &&
- typeof values !== 'string') {
- throw new Error('values passed to tensor(values) must be a number/boolean/string or ' +
- 'an array of numbers/booleans/strings, or a TypedArray');
- }
- if (shape != null) {
- assertNonNegativeIntegerDimensions(shape);
- var providedSize_1 = sizeFromShape(shape);
- var inferredSize_1 = sizeFromShape(inferredShape);
- assert(providedSize_1 === inferredSize_1, function () {
- return "Based on the provided shape, [" + shape + "], the tensor should have " +
- (providedSize_1 + " values but has " + inferredSize_1);
- });
- for (var i = 0; i < inferredShape.length; ++i) {
- var inferred = inferredShape[i];
- var flatDimsDontMatch = i === inferredShape.length - 1 ?
- inferred !== sizeFromShape(shape.slice(i)) :
- true;
- assert(inferredShape[i] === shape[i] || !flatDimsDontMatch, function () { return "Error creating a new Tensor. Inferred shape " +
- ("(" + inferredShape + ") does not match the provided ") +
- ("shape (" + shape + "). "); });
- }
- }
- if (!isTypedArray(values) && !Array.isArray(values)) {
- values = [values];
- }
- shape = shape || inferredShape;
- values = dtype !== 'string' ?
- toTypedArray(values, dtype, env().getBool('DEBUG')) :
- flatten(values, [], true);
- return ENGINE.makeTensor(values, shape, dtype);
- }
- /**
- * Creates rank-0 `tf.Tensor` (scalar) with the provided value and dtype.
- *
- * The same functionality can be achieved with `tf.tensor`, but in general
- * we recommend using `tf.scalar` as it makes the code more readable.
- *
- * ```js
- * tf.scalar(3.14).print();
- * ```
- *
- * @param value The value of the scalar.
- * @param dtype The data type.
- */
- /** @doc {heading: 'Tensors', subheading: 'Creation'} */
- function scalar(value, dtype) {
- if (((isTypedArray(value) && dtype !== 'string') || Array.isArray(value)) &&
- dtype !== 'complex64') {
- throw new Error('Error creating a new Scalar: value must be a primitive ' +
- '(number|boolean|string)');
- }
- if (dtype === 'string' && isTypedArray(value) &&
- !(value instanceof Uint8Array)) {
- throw new Error('When making a scalar from encoded string, ' +
- 'the value must be `Uint8Array`.');
- }
- var shape = [];
- var inferredShape = [];
- return makeTensor(value, shape, inferredShape, dtype);
- }
- /**
- * Creates rank-1 `tf.Tensor` with the provided values, shape and dtype.
- *
- * The same functionality can be achieved with `tf.tensor`, but in general
- * we recommend using `tf.tensor1d` as it makes the code more readable.
- *
- * ```js
- * tf.tensor1d([1, 2, 3]).print();
- * ```
- *
- * @param values The values of the tensor. Can be array of numbers,
- * or a `TypedArray`.
- * @param dtype The data type.
- */
- /** @doc {heading: 'Tensors', subheading: 'Creation'} */
- function tensor1d(values, dtype) {
- assertNonNull(values);
- var inferredShape = inferShape(values, dtype);
- if (inferredShape.length !== 1) {
- throw new Error('tensor1d() requires values to be a flat/TypedArray');
- }
- var shape = null;
- return makeTensor(values, shape, inferredShape, dtype);
- }
- /**
- * Creates rank-2 `tf.Tensor` with the provided values, shape and dtype.
- *
- * The same functionality can be achieved with `tf.tensor`, but in general
- * we recommend using `tf.tensor2d` as it makes the code more readable.
- *
- * ```js
- * // Pass a nested array.
- * tf.tensor2d([[1, 2], [3, 4]]).print();
- * ```
- * ```js
- * // Pass a flat array and specify a shape.
- * tf.tensor2d([1, 2, 3, 4], [2, 2]).print();
- * ```
- *
- * @param values The values of the tensor. Can be nested array of numbers,
- * or a flat array, or a `TypedArray`.
- * @param shape The shape of the tensor. If not provided, it is inferred from
- * `values`.
- * @param dtype The data type.
- */
- /** @doc {heading: 'Tensors', subheading: 'Creation'} */
- function tensor2d(values, shape, dtype) {
- assertNonNull(values);
- if (shape != null && shape.length !== 2) {
- throw new Error('tensor2d() requires shape to have two numbers');
- }
- var inferredShape = inferShape(values, dtype);
- if (inferredShape.length !== 2 && inferredShape.length !== 1) {
- throw new Error('tensor2d() requires values to be number[][] or flat/TypedArray');
- }
- if (inferredShape.length === 1 && shape == null) {
- throw new Error('tensor2d() requires shape to be provided when `values` ' +
- 'are a flat/TypedArray');
- }
- return makeTensor(values, shape, inferredShape, dtype);
- }
- /**
- * Creates rank-3 `tf.Tensor` with the provided values, shape and dtype.
- *
- * The same functionality can be achieved with `tf.tensor`, but in general
- * we recommend using `tf.tensor3d` as it makes the code more readable.
- *
- * ```js
- * // Pass a nested array.
- * tf.tensor3d([[[1], [2]], [[3], [4]]]).print();
- * ```
- * ```js
- * // Pass a flat array and specify a shape.
- * tf.tensor3d([1, 2, 3, 4], [2, 2, 1]).print();
- * ```
- *
- * @param values The values of the tensor. Can be nested array of numbers,
- * or a flat array, or a `TypedArray`.
- * @param shape The shape of the tensor. If not provided, it is inferred from
- * `values`.
- * @param dtype The data type.
- */
- /** @doc {heading: 'Tensors', subheading: 'Creation'} */
- function tensor3d(values, shape, dtype) {
- assertNonNull(values);
- if (shape != null && shape.length !== 3) {
- throw new Error('tensor3d() requires shape to have three numbers');
- }
- var inferredShape = inferShape(values, dtype);
- if (inferredShape.length !== 3 && inferredShape.length !== 1) {
- throw new Error('tensor3d() requires values to be number[][][] or flat/TypedArray');
- }
- if (inferredShape.length === 1 && shape == null) {
- throw new Error('tensor3d() requires shape to be provided when `values` ' +
- 'are a flat array');
- }
- return makeTensor(values, shape, inferredShape, dtype);
- }
- /**
- * Creates rank-4 `tf.Tensor` with the provided values, shape and dtype.
- *
- * The same functionality can be achieved with `tf.tensor`, but in general
- * we recommend using `tf.tensor4d` as it makes the code more readable.
- *
- * ```js
- * // Pass a nested array.
- * tf.tensor4d([[[[1], [2]], [[3], [4]]]]).print();
- * ```
- * ```js
- * // Pass a flat array and specify a shape.
- * tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]).print();
- * ```
- *
- * @param values The values of the tensor. Can be nested array of numbers,
- * or a flat array, or a `TypedArray`.
- * @param shape The shape of the tensor. Optional. If not provided,
- * it is inferred from `values`.
- * @param dtype The data type.
- */
- /** @doc {heading: 'Tensors', subheading: 'Creation'} */
- function tensor4d(values, shape, dtype) {
- assertNonNull(values);
- if (shape != null && shape.length !== 4) {
- throw new Error('tensor4d() requires shape to have four numbers');
- }
- var inferredShape = inferShape(values, dtype);
- if (inferredShape.length !== 4 && inferredShape.length !== 1) {
- throw new Error('tensor4d() requires values to be number[][][][] or flat/TypedArray');
- }
- if (inferredShape.length === 1 && shape == null) {
- throw new Error('tensor4d() requires shape to be provided when `values` ' +
- 'are a flat array');
- }
- return makeTensor(values, shape, inferredShape, dtype);
- }
- /**
- * Creates rank-5 `tf.Tensor` with the provided values, shape and dtype.
- *
- * The same functionality can be achieved with `tf.tensor`, but in general
- * we recommend using `tf.tensor5d` as it makes the code more readable.
- *
- * ```js
- * // Pass a nested array.
- * tf.tensor5d([[[[[1], [2]], [[3], [4]]]]]).print();
- * ```
- * ```js
- * // Pass a flat array and specify a shape.
- * tf.tensor5d([1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 2, 2, 1]).print();
- * ```
- *
- * @param values The values of the tensor. Can be nested array of numbers,
- * or a flat array, or a `TypedArray`.
- * @param shape The shape of the tensor. Optional. If not provided,
- * it is inferred from `values`.
- * @param dtype The data type.
- */
- /** @doc {heading: 'Tensors', subheading: 'Creation'} */
- function tensor5d(values, shape, dtype) {
- assertNonNull(values);
- if (shape != null && shape.length !== 5) {
- throw new Error('tensor5d() requires shape to have five numbers');
- }
- var inferredShape = inferShape(values, dtype);
- if (inferredShape.length !== 5 && inferredShape.length !== 1) {
- throw new Error('tensor5d() requires values to be ' +
- 'number[][][][][] or flat/TypedArray');
- }
- if (inferredShape.length === 1 && shape == null) {
- throw new Error('tensor5d() requires shape to be provided when `values` ' +
- 'are a flat array');
- }
- return makeTensor(values, shape, inferredShape, dtype);
- }
- /**
- * Creates rank-6 `tf.Tensor` with the provided values, shape and dtype.
- *
- * The same functionality can be achieved with `tf.tensor`, but in general
- * we recommend using `tf.tensor6d` as it makes the code more readable.
- *
- * ```js
- * // Pass a nested array.
- * tf.tensor6d([[[[[[1],[2]],[[3],[4]]],[[[5],[6]],[[7],[8]]]]]]).print();
- * ```
- * ```js
- * // Pass a flat array and specify a shape.
- * tf.tensor6d([1, 2, 3, 4, 5, 6, 7, 8], [1, 1, 2, 2, 2, 1]).print();
- * ```
- *
- * @param values The values of the tensor. Can be nested array of numbers,
- * or a flat array, or a `TypedArray`.
- * @param shape The shape of the tensor. Optional. If not provided,
- * it is inferred from `values`.
- * @param dtype The data type.
- */
- /** @doc {heading: 'Tensors', subheading: 'Creation'} */
- function tensor6d(values, shape, dtype) {
- assertNonNull(values);
- if (shape != null && shape.length !== 6) {
- throw new Error('tensor6d() requires shape to have six numbers');
- }
- var inferredShape = inferShape(values, dtype);
- if (inferredShape.length !== 6 && inferredShape.length !== 1) {
- throw new Error('tensor6d() requires values to be number[][][][][][] or ' +
- 'flat/TypedArray');
- }
- if (inferredShape.length === 1 && shape == null) {
- throw new Error('tensor6d() requires shape to be provided when `values` ' +
- 'are a flat array');
- }
- shape = shape ||
- inferredShape;
- return makeTensor(values, shape, inferredShape, dtype);
- }
- /**
- * Creates a new variable with the provided initial value.
- * ```js
- * const x = tf.variable(tf.tensor([1, 2, 3]));
- * x.assign(tf.tensor([4, 5, 6]));
- *
- * x.print();
- * ```
- *
- * @param initialValue Initial value for the tensor.
- * @param trainable If true, optimizers are allowed to update it.
- * @param name Name of the variable. Defaults to a unique id.
- * @param dtype If set, initialValue will be converted to the given type.
- */
- /** @doc {heading: 'Tensors', subheading: 'Creation'} */
- function variable(initialValue, trainable, name, dtype) {
- if (trainable === void 0) { trainable = true; }
- return ENGINE.makeVariable(initialValue, trainable, name, dtype);
- }
- /**
- * Creates a `tf.Tensor` with all elements set to 1.
- *
- * ```js
- * tf.ones([2, 2]).print();
- * ```
- *
- * @param shape An array of integers defining the output tensor shape.
- * @param dtype The type of an element in the resulting tensor. Defaults to
- * 'float'.
- */
- /** @doc {heading: 'Tensors', subheading: 'Creation'} */
- function ones$1(shape, dtype) {
- if (dtype === void 0) { dtype = 'float32'; }
- if (dtype === 'complex64') {
- var real_1 = ones$1(shape, 'float32');
- var imag_1 = zeros(shape, 'float32');
- return complex(real_1, imag_1);
- }
- var values = makeOnesTypedArray(sizeFromShape(shape), dtype);
- return ENGINE.makeTensor(values, shape, dtype);
- }
- /**
- * Creates a `tf.Tensor` with all elements set to 0.
- *
- * ```js
- * tf.zeros([2, 2]).print();
- * ```
- *
- * @param shape An array of integers defining the output tensor shape.
- * @param dtype The type of an element in the resulting tensor. Can
- * be 'float32', 'int32' or 'bool'. Defaults to 'float'.
- */
- /** @doc {heading: 'Tensors', subheading: 'Creation'} */
- function zeros(shape, dtype) {
- if (dtype === void 0) { dtype = 'float32'; }
- if (dtype === 'complex64') {
- var real_2 = zeros(shape, 'float32');
- var imag_2 = zeros(shape, 'float32');
- return complex(real_2, imag_2);
- }
- var values = makeZerosTypedArray(sizeFromShape(shape), dtype);
- return ENGINE.makeTensor(values, shape, dtype);
- }
- /**
- * Creates a `tf.Tensor` filled with a scalar value.
- *
- * ```js
- * tf.fill([2, 2], 4).print();
- * ```
- *
- * @param shape An array of integers defining the output tensor shape.
- * @param value The scalar value to fill the tensor with.
- * @param dtype The type of an element in the resulting tensor. Defaults to
- * 'float'.
- */
- /** @doc {heading: 'Tensors', subheading: 'Creation'} */
- function fill(shape, value, dtype) {
- return ENGINE.runKernelFunc(function (backend) { return backend.fill(shape, value, dtype); }, {});
- }
- /**
- * Creates a `tf.Tensor` with all elements set to 1 with the same shape as the
- * given tensor.
- *
- * ```js
- * const x = tf.tensor([1, 2]);
- * tf.onesLike(x).print();
- * ```
- * @param x A tensor.
- */
- /** @doc {heading: 'Tensors', subheading: 'Creation'} */
- function onesLike_(x) {
- var $x = convertToTensor(x, 'x', 'onesLike');
- if ($x.dtype === 'complex64') {
- var r = onesLike(real($x));
- var i = zerosLike(imag($x));
- return complex(r, i);
- }
- var der = function (dy, saved) { return ({ x: function () { return zerosLike(dy); } }); };
- return ENGINE.runKernelFunc(function (backend) { return backend.onesLike($x); }, { x: $x }, der, 'OnesLike');
- }
- /**
- * Creates a `tf.Tensor` with all elements set to 0 with the same shape as the
- * given tensor.
- *
- * ```js
- * const x = tf.tensor([1, 2]);
- * tf.zerosLike(x).print();
- * ```
- *
- * @param x The tensor of required shape.
- */
- /** @doc {heading: 'Tensors', subheading: 'Creation'} */
- function zerosLike_(x) {
- var $x = convertToTensor(x, 'x', 'zerosLike');
- var der = function (dy, saved) { return ({ x: function () { return zerosLike(dy); } }); };
- return ENGINE.runKernelFunc(function (backend) { return backend.zerosLike($x); }, { x: $x }, der, 'ZerosLike');
- }
- /**
- * Return an evenly spaced sequence of numbers over the given interval.
- *
- * ```js
- * tf.linspace(0, 9, 10).print();
- * ```
- * @param start The start value of the sequence.
- * @param stop The end value of the sequence.
- * @param num The number of values to generate.
- */
- /** @doc {heading: 'Tensors', subheading: 'Creation'} */
- function linspace(start, stop, num) {
- if (num <= 0) {
- throw new Error('The number of values should be positive.');
- }
- return ENGINE.runKernelFunc(function (backend) { return backend.linspace(start, stop, num); }, {});
- }
- /**
- * Creates a new `tf.Tensor1D` filled with the numbers in the range provided.
- *
- * The tensor is a is half-open interval meaning it includes start, but
- * excludes stop. Decrementing ranges and negative step values are also
- * supported.
- *
- * ```js
- * tf.range(0, 9, 2).print();
- * ```
- *
- * @param start An integer start value
- * @param stop An integer stop value
- * @param step An integer increment (will default to 1 or -1)
- * @param dtype The data type of the output tensor. Defaults to 'float32'.
- */
- /** @doc {heading: 'Tensors', subheading: 'Creation'} */
- function range(start, stop, step, dtype) {
- if (step === void 0) { step = 1; }
- if (dtype === void 0) { dtype = 'float32'; }
- if (step === 0) {
- throw new Error('Cannot have a step of zero');
- }
- var sameStartStop = start === stop;
- var increasingRangeNegativeStep = start < stop && step < 0;
- var decreasingRangePositiveStep = stop < start && step > 1;
- if (sameStartStop || increasingRangeNegativeStep ||
- decreasingRangePositiveStep) {
- return zeros([0], dtype);
- }
- var numElements = Math.abs(Math.ceil((stop - start) / step));
- var values = makeZerosTypedArray(numElements, dtype);
- if (stop < start && step === 1) {
- // Auto adjust the step's sign if it hasn't been set
- // (or was set to 1)
- step = -1;
- }
- values[0] = start;
- for (var i = 1; i < values.length; i++) {
- values[i] = values[i - 1] + step;
- }
- return tensor1d(values, dtype);
- }
- var onesLike = op({ onesLike_: onesLike_ });
- var zerosLike = op({ zerosLike_: zerosLike_ });
-
- /**
- * @license
- * Copyright 2018 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Concatenates a list of`tf.Tensor1D`s along an axis. See `concat` for details.
- *
- * For example, if:
- * A: shape(3) = |r1, g1, b1|
- * B: shape(2) = |r2, g2|
- * C = tf.concat1d([A, B]) == |r1, g1, b1, r2, g2|
- *
- * @param tensors A list of`tf.Tensor`s to concatenate.
- * @return The concatenated array.
- */
- function concat1d_(tensors) {
- return concat(tensors, 0 /* axis */);
- }
- /**
- * Concatenates a list of`tf.Tensor2D`s along an axis. See `concat` for details.
- *
- * For example, if:
- * A: shape(2, 3) = | r1, g1, b1 |
- * | r2, g2, b2 |
- *
- * B: shape(2, 3) = | r3, g3, b3 |
- * | r4, g4, b4 |
- *
- * C = tf.concat2d([A, B], axis)
- *
- * if axis = 0:
- * C: shape(4, 3) = | r1, g1, b1 |
- * | r2, g2, b2 |
- * | r3, g3, b3 |
- * | r4, g4, b4 |
- *
- * if axis = 1:
- * C = shape(2, 6) = | r1, g1, b1, r3, g3, b3 |
- * | r2, g2, b2, r4, g4, b4 |
- *
- *
- * @param tensors A list of `tf.Tensor`s to concatenate.
- * @param axis The axis to concatenate along.
- * @return The concatenated array.
- */
- function concat2d_(tensors, axis) {
- return concat(tensors, axis);
- }
- /**
- * Concatenates a list of `tf.Tensor3D`s along an axis.
- * See `concat` for details.
- *
- * For example, if:
- * A: shape(2, 1, 3) = | r1, g1, b1 |
- * | r2, g2, b2 |
- *
- * B: shape(2, 1, 3) = | r3, g3, b3 |
- * | r4, g4, b4 |
- *
- * C = tf.concat3d([A, B], axis)
- *
- * if axis = 0:
- * C: shape(4, 1, 3) = | r1, g1, b1 |
- * | r2, g2, b2 |
- * | r3, g3, b3 |
- * | r4, g4, b4 |
- *
- * if axis = 1:
- * C: shape(2, 2, 3) = | r1, g1, b1, r3, g3, b3 |
- * | r2, g2, b2, r4, g4, b4 |
- *
- * if axis = 2:
- * C = shape(2, 1, 6) = | r1, g1, b1, r3, g3, b3 |
- * | r2, g2, b2, r4, g4, b4 |
- *
- * @param tensors A list of`tf.Tensor`s to concatenate.
- * @param axis The axis to concate along.
- * @return The concatenated array.
- */
- function concat3d_(tensors, axis) {
- return concat(tensors, axis);
- }
- /**
- * Concatenates a list of `tf.Tensor4D`s along an axis.
- * See `concat` for details.
- *
- * @param tensors A list of `tf.Tensor`s to concatenate.
- * @param axis The axis to concate along.
- * @return The concatenated array.
- */
- function concat4d_(tensors, axis) {
- return concat(tensors, axis);
- }
- /**
- * Concatenates a list of `tf.Tensor`s along a given axis.
- *
- * The tensors ranks and types must match, and their sizes must match in all
- * dimensions except `axis`.
- *
- * Also available are stricter rank-specific methods that assert that
- * `tensors` are of the given rank:
- * - `tf.concat1d`
- * - `tf.concat2d`
- * - `tf.concat3d`
- * - `tf.concat4d`
- *
- * Except `tf.concat1d` (which does not have axis param), all methods have
- * same signature as this method.
- *
- * ```js
- * const a = tf.tensor1d([1, 2]);
- * const b = tf.tensor1d([3, 4]);
- * a.concat(b).print(); // or a.concat(b)
- * ```
- *
- * ```js
- * const a = tf.tensor1d([1, 2]);
- * const b = tf.tensor1d([3, 4]);
- * const c = tf.tensor1d([5, 6]);
- * tf.concat([a, b, c]).print();
- * ```
- *
- * ```js
- * const a = tf.tensor2d([[1, 2], [10, 20]]);
- * const b = tf.tensor2d([[3, 4], [30, 40]]);
- * const axis = 1;
- * tf.concat([a, b], axis).print();
- * ```
- * @param tensors A list of tensors to concatenate.
- * @param axis The axis to concate along. Defaults to 0 (the first dim).
- */
- /** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */
- function concat_(tensors, axis) {
- if (axis === void 0) { axis = 0; }
- assert(tensors.length >= 1, function () { return 'Pass at least one tensor to concat'; });
- var $tensors = convertToTensorArray(tensors, 'tensors', 'concat');
- if ($tensors[0].dtype === 'complex64') {
- $tensors.forEach(function (tensor) {
- if (tensor.dtype !== 'complex64') {
- throw new Error("Cannot concatenate complex64 tensors with a tensor\n with dtype " + tensor.dtype + ". ");
- }
- });
- }
- axis = parseAxisParam(axis, $tensors[0].shape)[0];
- var outShape = computeOutShape($tensors.map(function (t) { return t.shape; }), axis);
- if (sizeFromShape(outShape) === 0) {
- return tensor([], outShape);
- }
- // Keep only non-empty tensors (ignore tensors with 0 in their shape).
- $tensors = $tensors.filter(function (t) { return t.size > 0; });
- if ($tensors.length === 1) {
- return $tensors[0];
- }
- var shapes = $tensors.map(function (t) { return t.shape; });
- assertParamsConsistent(shapes, axis);
- var der = function (dy) {
- var sizeSplits = shapes.map(function (s) { return s[axis]; });
- var derTensors = split(dy, sizeSplits, axis);
- return derTensors.map(function (t) { return function () { return t; }; });
- };
- var inputs = $tensors;
- var attr = { axis: axis };
- return ENGINE.runKernelFunc(function (backend) { return backend.concat($tensors, axis); }, inputs, der, 'Concat', attr);
- }
- /**
- * Splits a `tf.Tensor` into sub tensors.
- *
- * If `numOrSizeSplits` is a number, splits `x` along dimension `axis`
- * into `numOrSizeSplits` smaller tensors.
- * Requires that `numOrSizeSplits` evenly divides `x.shape[axis]`.
- *
- * If `numOrSizeSplits` is a number array, splits `x` into
- * `numOrSizeSplits.length` pieces. The shape of the `i`-th piece has the
- * same size as `x` except along dimension `axis` where the size is
- * `numOrSizeSplits[i]`.
- *
- * ```js
- * const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]);
- * const [a, b] = tf.split(x, 2, 1);
- * a.print();
- * b.print();
- *
- * const [c, d, e] = tf.split(x, [1, 2, 1], 1);
- * c.print();
- * d.print();
- * e.print();
- * ```
- *
- * @param x The input tensor to split.
- * @param numOrSizeSplits Either an integer indicating the number of
- * splits along the axis or an array of integers containing the sizes of
- * each output tensor along the axis. If a number then it must evenly divide
- * `x.shape[axis]`; otherwise the sum of sizes must match `x.shape[axis]`.
- * @param axis The dimension along which to split. Defaults to 0 (the first
- * dim).
- */
- /** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */
- function split_(x, numOrSizeSplits, axis) {
- if (axis === void 0) { axis = 0; }
- var $x = convertToTensor(x, 'x', 'split');
- axis = parseAxisParam(axis, $x.shape)[0];
- var splitSizes;
- if (typeof (numOrSizeSplits) === 'number') {
- assert($x.shape[axis] % numOrSizeSplits === 0, function () { return 'Number of splits must evenly divide the axis.'; });
- splitSizes =
- new Array(numOrSizeSplits).fill($x.shape[axis] / numOrSizeSplits);
- }
- else {
- assert($x.shape[axis] === numOrSizeSplits.reduce(function (a, b) { return a + b; }), function () { return 'The sum of sizes must match the size of the axis dimension.'; });
- splitSizes = numOrSizeSplits;
- }
- var der = function (dy) { return ({ $x: function () { return concat(dy, axis); } }); };
- return ENGINE.runKernelFunc(function (backend) { return backend.split($x, splitSizes, axis); }, { $x: $x }, der);
- }
- var concat = op({ concat_: concat_ });
- var concat1d = op({ concat1d_: concat1d_ });
- var concat2d = op({ concat2d_: concat2d_ });
- var concat3d = op({ concat3d_: concat3d_ });
- var concat4d = op({ concat4d_: concat4d_ });
- var split = op({ split_: split_ });
-
- /**
- * @license
- * Copyright 2018 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Reshapes a `tf.Tensor` to a given shape.
- *
- * Given an input tensor, returns a new tensor with the same values as the
- * input tensor with shape `shape`.
- *
- * If one component of shape is the special value -1, the size of that
- * dimension is computed so that the total size remains constant. In
- * particular, a shape of [-1] flattens into 1-D. At most one component of
- * shape can be -1.
- *
- * If shape is 1-D or higher, then the operation returns a tensor with shape
- * shape filled with the values of tensor. In this case, the number of
- * elements implied by shape must be the same as the number of elements in
- * tensor.
- *
- * ```js
- * const x = tf.tensor1d([1, 2, 3, 4]);
- * x.reshape([2, 2]).print();
- * ```
- *
- * @param x The input tensor to be reshaped.
- * @param shape An array of integers defining the output tensor shape.
- */
- /** @doc {heading: 'Tensors', subheading: 'Transformations'} */
- function reshape_(x, shape) {
- var $x = convertToTensor(x, 'x', 'reshape', null);
- shape = inferFromImplicitShape(shape, $x.size);
- assert($x.size === sizeFromShape(shape), function () { return 'new shape and old shape must have the same number of elements.'; });
- var grad = function (dy) {
- return { x: function () { return dy.reshape($x.shape); } };
- };
- var attrs = { shape: shape };
- return ENGINE.runKernelFunc(function (backend) { return backend.reshape($x, shape); }, { x: $x }, grad, 'Reshape', attrs);
- }
- /**
- * Removes dimensions of size 1 from the shape of a `tf.Tensor`.
- *
- * ```js
- * const x = tf.tensor([1, 2, 3, 4], [1, 1, 4]);
- * x.squeeze().print();
- * ```
- *
- * @param x The input tensor to be squeezed.
- * @param axis An optional list of numbers. If specified, only
- * squeezes the dimensions listed. The dimension index starts at 0. It
- * is an error to squeeze a dimension that is not 1.
- */
- /** @doc {heading: 'Tensors', subheading: 'Transformations'} */
- function squeeze_(x, axis) {
- var $x = convertToTensor(x, 'x', 'squeeze');
- return reshape($x, squeezeShape($x.shape, axis).newShape);
- }
- /**
- * Casts a `tf.Tensor` to a new dtype.
- *
- * ```js
- * const x = tf.tensor1d([1.5, 2.5, 3]);
- * tf.cast(x, 'int32').print();
- * ```
- * @param x The input tensor to be casted.
- * @param dtype The dtype to cast the input tensor to.
- */
- /** @doc {heading: 'Tensors', subheading: 'Transformations'} */
- function cast_(x, dtype) {
- var $x = convertToTensor(x, 'x', 'cast');
- // Sanity checks.
- if (!isValidDtype(dtype)) {
- throw new Error("Failed to cast to unknown dtype " + dtype);
- }
- if (dtype === 'string' && $x.dtype !== 'string' ||
- dtype !== 'string' && $x.dtype === 'string') {
- throw new Error('Only strings can be casted to strings');
- }
- var grad = function (dy) {
- return { x: function () { return dy.clone(); } };
- };
- var attrs = { dtype: dtype };
- return ENGINE.runKernelFunc(function (backend) { return backend.cast($x, dtype); }, { x: $x }, grad, 'Cast', attrs);
- }
- /**
- * Stacks a list of rank-`R` `tf.Tensor`s into one rank-`(R+1)` `tf.Tensor`.
- *
- * ```js
- * const a = tf.tensor1d([1, 2]);
- * const b = tf.tensor1d([3, 4]);
- * const c = tf.tensor1d([5, 6]);
- * tf.stack([a, b, c]).print();
- * ```
- *
- * @param tensors A list of tensor objects with the same shape and dtype.
- * @param axis The axis to stack along. Defaults to 0 (the first dim).
- */
- /** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */
- function stack_(tensors, axis) {
- if (axis === void 0) { axis = 0; }
- var $tensors = convertToTensorArray(tensors, 'tensors', 'stack');
- assert($tensors.length >= 1, function () { return 'Pass at least one tensor to tf.stack'; });
- if ($tensors.length === 1) {
- return $tensors[0].expandDims(axis);
- }
- var rank = $tensors[0].rank;
- var shape = $tensors[0].shape;
- var dtype = $tensors[0].dtype;
- assert(axis <= rank, function () { return 'Axis must be <= rank of the tensor'; });
- $tensors.forEach(function (t) {
- assertShapesMatch(shape, t.shape, 'All tensors passed to stack must have matching shapes');
- });
- $tensors.forEach(function (t) {
- assert(dtype === t.dtype, function () { return 'All tensors passed to stack must have matching dtypes'; });
- });
- var expandedTensors = $tensors.map(function (t) { return t.expandDims(axis); });
- return concat(expandedTensors, axis);
- }
- /**
- * This operation reshapes the "batch" dimension 0 into `M + 1` dimensions of
- * shape `blockShape + [batch]`, interleaves these blocks back into the grid
- * defined by the spatial dimensions `[1, ..., M]`, to obtain a result with
- * the same rank as the input. The spatial dimensions of this intermediate
- * result are then optionally cropped according to `crops` to produce the
- * output. This is the reverse of `tf.spaceToBatchND`. See below for a precise
- * description.
- *
- * ```js
- * const x = tf.tensor4d([1, 2, 3, 4], [4, 1, 1, 1]);
- * const blockShape = [2, 2];
- * const crops = [[0, 0], [0, 0]];
- *
- * x.batchToSpaceND(blockShape, crops).print();
- * ```
- *
- * @param x A `tf.Tensor`. N-D with `x.shape` = `[batch] + spatialShape +
- * remainingShape`, where spatialShape has `M` dimensions.
- * @param blockShape A 1-D array. Must have shape `[M]`, all values must
- * be >= 1.
- * @param crops A 2-D array. Must have shape `[M, 2]`, all values must be >= 0.
- * `crops[i] = [cropStart, cropEnd]` specifies the amount to crop from input
- * dimension `i + 1`, which corresponds to spatial dimension `i`. It is required
- * that `cropStart[i] + cropEnd[i] <= blockShape[i] * inputShape[i + 1]`
- *
- * This operation is equivalent to the following steps:
- *
- * 1. Reshape `x` to `reshaped` of shape: `[blockShape[0], ...,
- * blockShape[M-1], batch / prod(blockShape), x.shape[1], ...,
- * x.shape[N-1]]`
- *
- * 2. Permute dimensions of `reshaped`to produce `permuted` of shape `[batch /
- * prod(blockShape),x.shape[1], blockShape[0], ..., x.shape[M],
- * blockShape[M-1],x.shape[M+1], ..., x.shape[N-1]]`
- *
- * 3. Reshape `permuted` to produce `reshapedPermuted` of shape `[batch /
- * prod(blockShape),x.shape[1] * blockShape[0], ..., x.shape[M] *
- * blockShape[M-1],x.shape[M+1], ..., x.shape[N-1]]`
- *
- * 4. Crop the start and end of dimensions `[1, ..., M]` of `reshapedPermuted`
- * according to `crops` to produce the output of shape: `[batch /
- * prod(blockShape),x.shape[1] * blockShape[0] - crops[0,0] - crops[0,1],
- * ..., x.shape[M] * blockShape[M-1] - crops[M-1,0] -
- * crops[M-1,1],x.shape[M+1], ..., x.shape[N-1]]`
- */
- /** @doc {heading: 'Tensors', subheading: 'Transformations'} */
- function batchToSpaceND_(x, blockShape, crops) {
- var $x = convertToTensor(x, 'x', 'batchToSpaceND');
- var prod = blockShape.reduce(function (a, b) { return a * b; });
- assert($x.rank >= 1 + blockShape.length, function () { return "input rank is " + $x.rank + " but should be > than blockShape.length " + blockShape.length; });
- assert(crops.length === blockShape.length, function () { return "crops.length is " + crops.length + " but should be equal to blockShape.length " + blockShape.length; });
- assert($x.shape[0] % prod === 0, function () { return "input tensor batch is " + $x.shape[0] + " but is not divisible by the product of " +
- ("the elements of blockShape " + blockShape.join(' * ') + " === " + prod); });
- var grad = function (dy) {
- return { $x: function () { return dy.spaceToBatchND(blockShape, crops); } };
- };
- return ENGINE.runKernelFunc(function (backend) { return backend.batchToSpaceND($x, blockShape, crops); }, { $x: $x }, grad);
- }
- /**
- * This operation divides "spatial" dimensions `[1, ..., M]` of the input into
- * a grid of blocks of shape `blockShape`, and interleaves these blocks with
- * the "batch" dimension (0) such that in the output, the spatial
- * dimensions `[1, ..., M]` correspond to the position within the grid,
- * and the batch dimension combines both the position within a spatial block
- * and the original batch position. Prior to division into blocks,
- * the spatial dimensions of the input are optionally zero padded
- * according to `paddings`. See below for a precise description.
- *
- * ```js
- * const x = tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]);
- * const blockShape = [2, 2];
- * const paddings = [[0, 0], [0, 0]];
- *
- * x.spaceToBatchND(blockShape, paddings).print();
- * ```
- *
- * @param x A `tf.Tensor`. N-D with `x.shape` = `[batch] + spatialShape +
- * remainingShape`, where spatialShape has `M` dimensions.
- * @param blockShape A 1-D array. Must have shape `[M]`, all values must
- * be >= 1.
- * @param paddings A 2-D array. Must have shape `[M, 2]`, all values must be >=
- * 0. `paddings[i] = [padStart, padEnd]` specifies the amount to zero-pad
- * from input dimension `i + 1`, which corresponds to spatial dimension `i`. It
- * is required that
- * `(inputShape[i + 1] + padStart + padEnd) % blockShape[i] === 0`
- *
- * This operation is equivalent to the following steps:
- *
- * 1. Zero-pad the start and end of dimensions `[1, ..., M]` of the input
- * according to `paddings` to produce `padded` of shape paddedShape.
- *
- * 2. Reshape `padded` to `reshapedPadded` of shape:
- * `[batch] + [paddedShape[1] / blockShape[0], blockShape[0], ...,
- * paddedShape[M] / blockShape[M-1], blockShape[M-1]] + remainingShape`
- *
- * 3. Permute dimensions of `reshapedPadded` to produce `permutedReshapedPadded`
- * of shape: `blockShape + [batch] + [paddedShape[1] / blockShape[0], ...,
- * paddedShape[M] / blockShape[M-1]] + remainingShape`
- *
- * 4. Reshape `permutedReshapedPadded` to flatten `blockShape` into the
- * batch dimension, producing an output tensor of shape:
- * `[batch * prod(blockShape)] + [paddedShape[1] / blockShape[0], ...,
- * paddedShape[M] / blockShape[M-1]] + remainingShape`
- */
- /** @doc {heading: 'Tensors', subheading: 'Transformations'} */
- function spaceToBatchND_(x, blockShape, paddings) {
- var $x = convertToTensor(x, 'x', 'spaceToBatchND');
- assert($x.rank >= 1 + blockShape.length, function () { return "input rank " + $x.rank + " should be > than [blockShape] " + blockShape.length; });
- assert(paddings.length === blockShape.length, function () { return "paddings.shape[0] " + paddings.length + " must be equal to [blockShape] " + blockShape.length; });
- assert($x.shape.reduce(function (a, b, i) {
- if (i > 0 && i <= blockShape.length) {
- return a &&
- ((b + paddings[i - 1][0] + paddings[i - 1][1]) %
- blockShape[i - 1] ===
- 0);
- }
- return a;
- }, true), function () { return "input spatial dimensions " + $x.shape.slice(1) + " with paddings " + paddings.toString() + " must be divisible by blockShapes " + blockShape.toString(); });
- var grad = function (dy) {
- return { $x: function () { return dy.batchToSpaceND(blockShape, paddings); } };
- };
- return ENGINE.runKernelFunc(function (backend) { return backend.spaceToBatchND($x, blockShape, paddings); }, { $x: $x }, grad);
- }
- /**
- * Unstacks a `tf.Tensor` of rank-`R` into a list of rank-`(R-1)` `tf.Tensor`s.
- *
- * ```js
- * const a = tf.tensor2d([1, 2, 3, 4], [2, 2]);
- *
- * tf.unstack(a).forEach(tensor => tensor.print());
- * ```
- *
- * @param x A tensor object.
- * @param axis The axis to unstack along. Defaults to 0 (the first dim).
- */
- /** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */
- function unstack_(x, axis) {
- if (axis === void 0) { axis = 0; }
- axis = axis || 0;
- var $x = convertToTensor(x, 'x', 'unstack');
- assert(axis >= -$x.shape.length && axis < $x.shape.length, function () {
- return "Axis = " + axis + " is not in [-" + $x.shape.length + ", " + $x.shape.length + ")";
- });
- if (axis < 0) {
- axis += $x.shape.length;
- }
- var grad = function (dy) {
- return { x: function () { return stack(dy, axis); } };
- };
- var attrs = { axis: axis };
- return ENGINE.runKernelFunc(function (backend) { return backend.unstack($x, axis); }, { x: $x }, grad, 'Unpack', attrs);
- }
- /**
- * Computes the cumulative sum of a `tf.Tensor` along `axis`.
- *
- * ```js
- * const x = tf.tensor([1, 2, 3, 4]);
- * x.cumsum().print();
- * ```
- * ```js
- * const x = tf.tensor([[1, 2], [3, 4]]);
- * x.cumsum().print();
- * ```
- *
- * @param x The input tensor to be summed.
- * @param axis The axis along which to sum. Optional. Defaults to 0.
- * @param exclusive Whether to perform exclusive cumulative sum. Optional.
- * Defaults to false. If set to true then the sum of each tensor entry
- * does not include its own value, but only the values previous to it
- * along the specified axis.
- * @param reverse Whether to sum in the opposite direction. Optional.
- * Defaults to false.
- */
- /** @doc {heading: 'Operations', subheading: 'Scan'} */
- function cumsum_(x, axis, exclusive, reverse) {
- if (axis === void 0) { axis = 0; }
- if (exclusive === void 0) { exclusive = false; }
- if (reverse === void 0) { reverse = false; }
- var $x = convertToTensor(x, 'x', 'cumsum');
- axis = axis | 0;
- var permutation = getAxesPermutation([axis], $x.rank);
- var permutedX = $x;
- if (permutation != null) {
- permutedX = $x.transpose(permutation);
- }
- var permutedAxis = getInnerMostAxes(1, $x.rank)[0];
- var grad = function (dy) {
- return { permutedX: function () { return dy.cumsum(axis, exclusive, !reverse); } };
- };
- var value = ENGINE.runKernelFunc(function (backend) { return backend.cumsum(permutedX, permutedAxis, exclusive, reverse); }, { permutedX: permutedX }, grad);
- if (permutation != null) {
- value = value.transpose(permutation);
- }
- return value;
- }
- /**
- * Returns a `tf.Tensor` that has expanded rank, by inserting a dimension
- * into the tensor's shape.
- *
- * ```js
- * const x = tf.tensor1d([1, 2, 3, 4]);
- * const axis = 1;
- * x.expandDims(axis).print();
- * ```
- *
- * @param x The input tensor whose dimensions to be expanded.
- * @param axis The dimension index at which to insert shape of `1`. Defaults
- * to 0 (the first dimension).
- */
- /** @doc {heading: 'Tensors', subheading: 'Transformations'} */
- function expandDims_(x, axis) {
- if (axis === void 0) { axis = 0; }
- var parseAs = null;
- var $x = convertToTensor(x, 'x', 'expandDims', parseAs);
- assert(axis <= $x.rank, function () { return 'Axis must be <= rank of the tensor'; });
- var newShape = $x.shape.slice();
- if (axis < 0) {
- // Negative value is counted from the tail of rank.
- assert(-($x.rank + 1) <= axis, function () { return "Axis must be in the interval [" + -($x.rank + 1) + ", " + $x.rank + "]"; });
- axis = $x.rank + axis + 1;
- }
- newShape.splice(axis, 0, 1);
- return reshape($x, newShape);
- }
- /**
- * Rearranges data from depth into blocks of spatial data. More specifically,
- * this op outputs a copy of the input tensor where values from the `depth`
- * dimension are moved in spatial blocks to the `height` and `width` dimensions.
- * The attr `blockSize` indicates the input block size and how the data is
- * moved.
- *
- * - Chunks of data of size `blockSize * blockSize` from depth are rearranged
- * into non-overlapping blocks of size `blockSize x blockSize`
- *
- * - The width the output tensor is `inputWidth * blockSize`, whereas the
- * height is `inputHeight * blockSize`
- *
- * - The Y, X coordinates within each block of the output image are determined
- * by the high order component of the input channel index
- *
- * - The depth of the input tensor must be divisible by `blockSize *
- * blockSize`
- *
- * The `dataFormat` attr specifies the layout of the input and output tensors
- * with the following options: "NHWC": [ `batch, height, width, channels` ]
- * "NCHW": [ `batch, channels, height, width` ]
- *
- * ```js
- * const x = tf.tensor4d([1, 2, 3, 4], [1, 1, 1, 4]);
- * const blockSize = 2;
- * const dataFormat = "NHWC";
- *
- * tf.depthToSpace(x, blockSize, dataFormat).print();
- * ```
- *
- * @param x The input tensor of rank 4
- * @param blockSIze An `int` that is `>= 2`. The size of the spatial block
- * @param dataFormat An optional string from: "NHWC", "NCHW". Defaults to "NHWC"
- */
- /** @doc {heading: 'Tensors', subheading: 'Transformations'} */
- function depthToSpace_(x, blockSize, dataFormat) {
- if (dataFormat === void 0) { dataFormat = 'NHWC'; }
- var $x = convertToTensor(x, 'x', 'depthToSpace');
- var inputHeight = (dataFormat === 'NHWC') ? $x.shape[1] : $x.shape[2];
- var inputWidth = (dataFormat === 'NHWC') ? $x.shape[2] : $x.shape[3];
- var inputDepth = (dataFormat === 'NHWC') ? $x.shape[3] : $x.shape[1];
- assert(inputHeight * blockSize >= 0, function () { return "Negative dimension size caused by overflow when multiplying\n " + inputHeight + " and " + blockSize + " for depthToSpace with input shape\n " + $x.shape; });
- assert(inputWidth * blockSize >= 0, function () { return "Negative dimension size caused by overflow when multiplying\n " + inputWidth + " and " + blockSize + " for depthToSpace with input shape\n " + $x.shape; });
- assert((inputDepth % (blockSize * blockSize) === 0), function () { return "Dimension size must be evenly divisible by " + blockSize * blockSize + " but is " + inputDepth + " for depthToSpace with input shape " + $x.shape; });
- return ENGINE.runKernelFunc(function (backend) { return backend.depthToSpace($x, blockSize, dataFormat); }, { $x: $x });
- }
- /**
- * Computes the difference between two lists of numbers.
- *
- * Given a Tensor `x` and a Tensor `y`, this operation returns a Tensor `out`
- * that represents all values that are in `x` but not in `y`. The returned
- * Tensor `out` is sorted in the same order that the numbers appear in `x`
- * (duplicates are preserved). This operation also returns a Tensor indices that
- * represents the position of each out element in `x`. In other words:
- *
- * `out[i] = x[idx[i]] for i in [0, 1, ..., out.length - 1]`
- *
- * ```js
- * const x = [1, 2, 3, 4, 5, 6];
- * const y = [1, 3, 5];
- *
- * const [out, indices] = await tf.setdiff1dAsync(x, y);
- * out.print(); // [2, 4, 6]
- * indices.print(); // [1, 3, 5]
- * ```
- *
- * @param x 1-D Tensor. Values to keep.
- * @param y 1-D Tensor. Must have the same type as x. Values to exclude in the
- * output.
- * @returns Promise of Tensor tuple [out, indices].
- * out: Tensor with the same type as x.
- * indices: A Tensor of type int32.
- */
- /** @doc {heading: 'Tensors', subheading: 'Transformations'} */
- function setdiff1dAsync_(x, y) {
- return __awaiter(this, void 0, void 0, function () {
- var $x, $y, xVals, yVals, ySet, outputSize, i, buffer, indices, i, p;
- return __generator(this, function (_a) {
- switch (_a.label) {
- case 0:
- $x = convertToTensor(x, 'x', 'setdiff1d');
- $y = convertToTensor(y, 'y', 'setdiff1d');
- assert($x.dtype === $y.dtype, function () { return "x and y should have the same dtype, but got x (" + $x.dtype + ") and y (" + $y.dtype + ")."; });
- assert($x.rank === 1, function () { return "x should be 1D tensor, but got x (" + $x.shape + ")."; });
- assert($y.rank === 1, function () { return "y should be 1D tensor, but got y (" + $y.shape + ")."; });
- return [4 /*yield*/, $x.data()];
- case 1:
- xVals = _a.sent();
- return [4 /*yield*/, $y.data()];
- case 2:
- yVals = _a.sent();
- ySet = new Set(yVals);
- outputSize = 0;
- for (i = 0; i < xVals.length; i++) {
- if (!ySet.has(xVals[i])) {
- outputSize++;
- }
- }
- buffer = new TensorBuffer([outputSize], $x.dtype);
- indices = new TensorBuffer([outputSize], 'int32');
- for (i = 0, p = 0; i < xVals.length; i++) {
- if (!ySet.has(xVals[i])) {
- buffer.values[p] = xVals[i];
- indices.values[p] = i;
- p++;
- }
- }
- return [2 /*return*/, [buffer.toTensor(), indices.toTensor()]];
- }
- });
- });
- }
- /**
- * Creates an empty `tf.TensorBuffer` with the specified `shape` and `dtype`.
- *
- * The values are stored in CPU as `TypedArray`. Fill the buffer using
- * `buffer.set()`, or by modifying directly `buffer.values`.
- *
- * When done, call `buffer.toTensor()` to get an immutable `tf.Tensor` with
- * those values.
- *
- * ```js
- * // Create a buffer and set values at particular indices.
- * const buffer = tf.buffer([2, 2]);
- * buffer.set(3, 0, 0);
- * buffer.set(5, 1, 0);
- *
- * // Convert the buffer back to a tensor.
- * buffer.toTensor().print();
- * ```
- *
- * @param shape An array of integers defining the output tensor shape.
- * @param dtype The dtype of the buffer. Defaults to 'float32'.
- * @param values The values of the buffer as `TypedArray`. Defaults to
- * zeros.
- */
- /** @doc {heading: 'Tensors', subheading: 'Creation'} */
- function buffer(shape, dtype, values) {
- if (dtype === void 0) { dtype = 'float32'; }
- dtype = dtype || 'float32';
- assertNonNegativeIntegerDimensions(shape);
- return new TensorBuffer(shape, dtype, values);
- }
- /**
- * Prints information about the `tf.Tensor` including its data.
- *
- * ```js
- * const verbose = true;
- * tf.tensor2d([1, 2, 3, 4], [2, 2]).print(verbose);
- * ```
- * @param x The tensor to be printed.
- * @param verbose Whether to print verbose information about the ` Tensor`,
- * including dtype and size.
- */
- /** @doc {heading: 'Tensors', subheading: 'Creation'} */
- function print(x, verbose) {
- if (verbose === void 0) { verbose = false; }
- console.log(x.toString(verbose));
- }
- var batchToSpaceND = op({ batchToSpaceND_: batchToSpaceND_ });
- var cast = op({ cast_: cast_ });
- var cumsum = op({ cumsum_: cumsum_ });
- var depthToSpace = op({ depthToSpace_: depthToSpace_ });
- var expandDims = op({ expandDims_: expandDims_ });
- var reshape = op({ reshape_: reshape_ });
- var spaceToBatchND = op({ spaceToBatchND_: spaceToBatchND_ });
- var squeeze = op({ squeeze_: squeeze_ });
- var stack = op({ stack_: stack_ });
- var unstack = op({ unstack_: unstack_ });
- var setdiff1dAsync = setdiff1dAsync_;
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Gets the new shape of the input Tensor after it's been reshaped
- * to:
- * [blockShape[0], ..., blockShape[M-1], batch / prod(blockShape),
- * inputShape[1], ..., inputShape[N-1]]
- *
- * See step 1: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd
- */
- function getReshaped(inputShape, blockShape, prod, batchToSpace) {
- if (batchToSpace === void 0) { batchToSpace = true; }
- var reshaped = [];
- if (batchToSpace) {
- reshaped = reshaped.concat(blockShape.slice(0));
- reshaped.push(inputShape[0] / prod);
- reshaped = reshaped.concat(inputShape.slice(1));
- }
- else {
- reshaped = reshaped.concat(inputShape[0]);
- var spatialLength = blockShape.length;
- for (var i = 0; i < spatialLength; ++i) {
- reshaped =
- reshaped.concat([inputShape[i + 1] / blockShape[i], blockShape[i]]);
- }
- reshaped = reshaped.concat(inputShape.slice(spatialLength + 1));
- }
- return reshaped;
- }
- /**
- * Gets the permutation that will transpose the dimensions of the
- * reshaped tensor to shape:
- *
- * [batch / prod(block_shape),inputShape[1], blockShape[0], ...,
- * inputShape[M], blockShape[M-1],inputShape[M+1], ..., inputShape[N-1]]
- *
- * see step 2: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd
- */
- function getPermuted(reshapedRank, blockShapeRank, batchToSpace) {
- if (batchToSpace === void 0) { batchToSpace = true; }
- var permuted = [];
- if (batchToSpace) {
- permuted.push(blockShapeRank);
- for (var i = blockShapeRank + 1; i < reshapedRank; ++i) {
- if (i <= 2 * blockShapeRank) {
- permuted.push(i);
- permuted.push(i - (blockShapeRank + 1));
- }
- else {
- permuted.push(i);
- }
- }
- }
- else {
- var permutedBeforeBatch = [];
- var permutedAfterBatch = [];
- for (var i = 1; i < reshapedRank; ++i) {
- if (i >= blockShapeRank * 2 + 1 || i % 2 === 1) {
- permutedAfterBatch.push(i);
- }
- else {
- permutedBeforeBatch.push(i);
- }
- }
- permuted.push.apply(permuted, permutedBeforeBatch);
- permuted.push(0);
- permuted.push.apply(permuted, permutedAfterBatch);
- }
- return permuted;
- }
- /**
- * Gets the shape of the reshaped and permuted input Tensor before any cropping
- * is applied. The new shape will be:
- *
- * [batch / prod(blockShape),inputShape[1] * blockShape[0], ...,
- * inputShape[M] * blockShape[M-1],inputShape[M+1], ..., inputShape[N-1]]
- *
- * See step 3: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd
- */
- function getReshapedPermuted(inputShape, blockShape, prod, batchToSpace) {
- if (batchToSpace === void 0) { batchToSpace = true; }
- var reshapedPermuted = [];
- if (batchToSpace) {
- reshapedPermuted.push(inputShape[0] / prod);
- }
- else {
- reshapedPermuted.push(inputShape[0] * prod);
- }
- for (var i = 1; i < inputShape.length; ++i) {
- if (i <= blockShape.length) {
- if (batchToSpace) {
- reshapedPermuted.push(blockShape[i - 1] * inputShape[i]);
- }
- else {
- reshapedPermuted.push(inputShape[i] / blockShape[i - 1]);
- }
- }
- else {
- reshapedPermuted.push(inputShape[i]);
- }
- }
- return reshapedPermuted;
- }
- /**
- * Converts the crops argument into the beginning coordinates of a slice
- * operation.
- */
- function getSliceBeginCoords(crops, blockShape) {
- var sliceBeginCoords = [0];
- for (var i = 0; i < blockShape; ++i) {
- sliceBeginCoords.push(crops[i][0]);
- }
- return sliceBeginCoords;
- }
- /**
- * Converts the crops argument into the size of a slice operation. When
- * combined with getSliceBeginCoords this function allows the reshaped and
- * permuted Tensor to be cropped to its final output shape of:
- *
- * inputShape[1] * blockShape[0] - crops[0,0] - crops[0,1], ...,
- * inputShape[M] * blockShape[M-1] -crops[M-1,0] -
- * crops[M-1,1],inputShape[M+1], ..., inputShape[N-1]]
- *
- * See step 4: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd
- */
- function getSliceSize(uncroppedShape, crops, blockShape) {
- var sliceSize = uncroppedShape.slice(0, 1);
- for (var i = 0; i < blockShape; ++i) {
- sliceSize.push(uncroppedShape[i + 1] - crops[i][0] - crops[i][1]);
- }
- return sliceSize;
- }
-
- var Add = 'Add';
- var AddN = 'AddN';
- var Div = 'Div';
- var FusedBatchNorm = 'FusedBatchNorm';
- var SquaredDifference = 'SquaredDifference';
- var Square = 'Square';
- var Transpose = 'Transpose';
- var NonMaxSuppressionV5 = 'NonMaxSuppressionV5';
- var BroadcastTo = 'BroadcastTo';
- var OneHot = 'OneHot';
- var Identity = 'Identity';
- var Tile = 'Tile';
- var PadV2 = 'PadV2';
- /**
- * TensorFlow.js-only kernels
- */
- var FromPixels = 'FromPixels';
- var MaxPoolWithArgmax = 'MaxPoolWithArgmax';
-
- /**
- * @license
- * Copyright 2020 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Adds two `tf.Tensor`s element-wise, A + B. Supports broadcasting.
- *
- * We also expose `tf.addStrict` which has the same signature as this op and
- * asserts that `a` and `b` are the same shape (does not broadcast).
- *
- * ```js
- * const a = tf.tensor1d([1, 2, 3, 4]);
- * const b = tf.tensor1d([10, 20, 30, 40]);
- *
- * a.add(b).print(); // or tf.add(a, b)
- * ```
- *
- * ```js
- * // Broadcast add a with b.
- * const a = tf.scalar(5);
- * const b = tf.tensor1d([10, 20, 30, 40]);
- *
- * a.add(b).print(); // or tf.add(a, b)
- * ```
- * @param a The first `tf.Tensor` to add.
- * @param b The second `tf.Tensor` to add. Must have the same type as `a`.
- */
- /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */
- function add_(a, b) {
- var _a;
- var $a = convertToTensor(a, 'a', 'add');
- var $b = convertToTensor(b, 'b', 'add');
- _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1];
- var forward = function (backend, save) {
- var res = backend.add($a, $b);
- save([$a, $b]);
- return res;
- };
- var inputs = { a: $a, b: $b };
- return ENGINE.runKernelFunc(forward, inputs, null /* gradient */, Add);
- }
- var add = op({ add_: add_ });
-
- /**
- * @license
- * Copyright 2017 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Returns the dimensions in the input shape that are broadcasted to
- * produce the provided output shape.
- *
- * The returned dimensions are 0-indexed and sorted. An example:
- * inShape = [4, 1, 3]
- * outShape = [5, 4, 3, 3]
- * result = [1]. Dimension 1 (2nd dimension of input) gets broadcasted 1 => 3.
- */
- function getBroadcastDims(inShape, outShape) {
- var inRank = inShape.length;
- var dims = [];
- for (var i = 0; i < inRank; i++) {
- var dim = inRank - 1 - i;
- var a = inShape[dim] || 1;
- var b = outShape[outShape.length - 1 - i] || 1;
- if (b > 1 && a === 1) {
- dims.unshift(dim);
- }
- }
- return dims;
- }
- /**
- * Returns the axes in the output space that should be reduced to produce
- * the input space.
- */
- function getReductionAxes(inShape, outShape) {
- var result = [];
- for (var i = 0; i < outShape.length; i++) {
- var inDim = inShape[inShape.length - i - 1];
- var outAxis = outShape.length - i - 1;
- var outDim = outShape[outAxis];
- if (inDim == null || (inDim === 1 && outDim > 1)) {
- result.unshift(outAxis);
- }
- }
- return result;
- }
- function assertAndGetBroadcastShape(shapeA, shapeB) {
- var result = [];
- var l = Math.max(shapeA.length, shapeB.length);
- for (var i = 0; i < l; i++) {
- var a = shapeA[shapeA.length - i - 1];
- if (a == null) {
- a = 1;
- }
- var b = shapeB[shapeB.length - i - 1];
- if (b == null) {
- b = 1;
- }
- if (a === 1) {
- result.unshift(b);
- }
- else if (b === 1) {
- result.unshift(a);
- }
- else if (a !== b) {
- var errMsg = "Operands could not be broadcast together with shapes " +
- (shapeA + " and " + shapeB + ".");
- throw Error(errMsg);
- }
- else {
- result.unshift(a);
- }
- }
- return result;
- }
-
- /**
- * @license
- * Copyright 2018 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes `-1 * x` element-wise.
- *
- * ```js
- * const x = tf.tensor2d([1, 2, -2, 0], [2, 2]);
- *
- * x.neg().print(); // or tf.neg(x)
- * ```
- *
- * @param x The input tensor.
- */
- /** @doc {heading: 'Operations', subheading: 'Basic math'} */
- function neg_(x) {
- var $x = convertToTensor(x, 'x', 'neg');
- var grad = function (dy) {
- return { x: function () { return dy.neg(); } };
- };
- var attrs = {};
- var inputsToSave = [$x];
- return ENGINE.runKernelFunc(function (backend) { return backend.neg($x); }, { x: $x }, grad, 'Neg', attrs, inputsToSave);
- }
- /**
- * Computes ceiling of input `tf.Tensor` element-wise: `ceil(x)`
- *
- * ```js
- * const x = tf.tensor1d([.6, 1.1, -3.3]);
- *
- * x.ceil().print(); // or tf.ceil(x)
- * ```
- * @param x The input Tensor.
- */
- /** @doc {heading: 'Operations', subheading: 'Basic math'} */
- function ceil_(x) {
- var $x = convertToTensor(x, 'x', 'ceil');
- // TODO(manrajgrover): Return null for gradients when backprop supports it.
- var grad = function (dy) {
- return { $x: function () { return zerosLike(dy); } };
- };
- return ENGINE.runKernelFunc(function (backend) { return backend.ceil($x); }, { $x: $x }, grad);
- }
- /**
- * Computes floor of input `tf.Tensor` element-wise: `floor(x)`.
- *
- * ```js
- * const x = tf.tensor1d([.6, 1.1, -3.3]);
- *
- * x.floor().print(); // or tf.floor(x)
- * ```
- * @param x The input tensor.
- */
- /** @doc {heading: 'Operations', subheading: 'Basic math'} */
- function floor_(x) {
- var $x = convertToTensor(x, 'x', 'floor');
- // TODO(nsthorat): Let gradients be null for cases where we want to stop
- // backpropgation.
- var grad = function (dy) {
- return { $x: function () { return zerosLike(dy); } };
- };
- return ENGINE.runKernelFunc(function (backend) { return backend.floor($x); }, { $x: $x }, grad);
- }
- /**
- * Returns an element-wise indication of the sign of a number.
- *
- * ```js
- * const x = tf.tensor1d([.6, 1.1, -3.3, NaN, 0]);
- *
- * x.sign().print(); // or tf.sign(x)
- * ```
- * @param x The input Tensor.
- */
- /** @doc {heading: 'Operations', subheading: 'Basic math'} */
- function sign_(x) {
- var $x = convertToTensor(x, 'x', 'sign');
- var grad = function (dy) {
- return { $x: function () { return zerosLike(dy); } };
- };
- return ENGINE.runKernelFunc(function (backend) { return backend.sign($x); }, { $x: $x }, grad);
- }
- /**
- * RReturns which elements of x are NaN.
- *
- * ```js
- * const x = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]);
- *
- * x.isNaN().print(); // or tf.isNaN(x)
- * ```
- * @param x The input Tensor.
- */
- /** @doc {heading: 'Operations', subheading: 'Basic math'} */
- function isNaN_(x) {
- var $x = convertToTensor(x, 'x', 'isNaN');
- // TODO(nsthorat): Let gradients be null for cases where we want to stop
- // backpropgation.
- var grad = function (dy) {
- return { $x: function () { return zerosLike(dy); } };
- };
- return ENGINE.runKernelFunc(function (backend) { return backend.isNaN($x); }, { $x: $x }, grad);
- }
- /**
- * Returns which elements of x are Infinity or -Infinity.
- *
- * ```js
- * const x = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]);
- *
- * x.isInf().print(); // or tf.isNaN(x)
- * ```
- * @param x The input Tensor.
- */
- /** @doc {heading: 'Operations', subheading: 'Basic math'} */
- function isInf_(x) {
- var $x = convertToTensor(x, 'x', 'isInf');
- // TODO(nsthorat): Let gradients be null for cases where we want to stop
- // backpropgation.
- var grad = function (dy) {
- return { $x: function () { return zerosLike(dy); } };
- };
- return ENGINE.runKernelFunc(function (backend) { return backend.isInf($x); }, { $x: $x }, grad);
- }
- /**
- * Returns which elements of x are finite.
- *
- * ```js
- * const x = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]);
- *
- * x.isFinite().print(); // or tf.isNaN(x)
- * ```
- * @param x The input Tensor.
- */
- /** @doc {heading: 'Operations', subheading: 'Basic math'} */
- function isFinite_(x) {
- var $x = convertToTensor(x, 'x', 'isFinite');
- // TODO(nsthorat): Let gradients be null for cases where we want to stop
- // backpropgation.
- var grad = function (dy) {
- return { $x: function () { return zerosLike(dy); } };
- };
- return ENGINE.runKernelFunc(function (backend) { return backend.isFinite($x); }, { $x: $x }, grad);
- }
- /**
- * Computes round of input `tf.Tensor` element-wise: `round(x)`.
- * It implements banker's rounding.
- *
- * ```js
- * const x = tf.tensor1d([.6, 1.1, -3.3]);
- *
- * x.round().print(); // or tf.round(x)
- * ```
- * @param x The input tensor.
- */
- /** @doc {heading: 'Operations', subheading: 'Basic math'} */
- function round_(x) {
- var $x = convertToTensor(x, 'x', 'round');
- // TODO(nsthorat): Let gradients be null for cases where we want to stop
- // backpropgation.
- var grad = function (dy) {
- return { $x: function () { return zerosLike(dy); } };
- };
- return ENGINE.runKernelFunc(function (backend) { return backend.round($x); }, { $x: $x }, grad);
- }
- /**
- * Computes exponential of the input `tf.Tensor` element-wise. `e ^ x`
- *
- * ```js
- * const x = tf.tensor1d([1, 2, -3]);
- *
- * x.exp().print(); // or tf.exp(x)
- * ```
- * @param x The input tensor.
- */
- /** @doc {heading: 'Operations', subheading: 'Basic math'} */
- function exp_(x) {
- var $x = convertToTensor(x, 'x', 'exp');
- var bck = function (dy, saved) {
- return { x: function () { return dy.mulStrict(saved[0]); } };
- };
- var attrs = {};
- var inputsToSave = [];
- var outputsToSave = [true];
- return ENGINE.runKernelFunc(function (backend, save) {
- var y = backend.exp($x);
- save([y]);
- return y;
- }, { x: $x }, bck, 'Exp', attrs, inputsToSave, outputsToSave);
- }
- /**
- * Computes exponential of the input `tf.Tensor` minus one element-wise.
- * `e ^ x - 1`
- *
- * ```js
- * const x = tf.tensor1d([1, 2, -3]);
- *
- * x.expm1().print(); // or tf.expm1(x)
- * ```
- * @param x The input tensor.
- */
- /** @doc {heading: 'Operations', subheading: 'Basic math'} */
- function expm1_(x) {
- var $x = convertToTensor(x, 'x', 'expm1');
- var grad = function (dy, saved) {
- var $x = saved[0];
- return { $x: function () { return dy.mul($x.exp()); } };
- };
- return ENGINE.runKernelFunc(function (backend, save) {
- var res = backend.expm1($x);
- save([$x]);
- return res;
- }, { $x: $x }, grad);
- }
- /**
- * Computes natural logarithm of the input `tf.Tensor` element-wise: `ln(x)`
- *
- * ```js
- * const x = tf.tensor1d([1, 2, Math.E]);
- *
- * x.log().print(); // or tf.log(x)
- * ```
- * @param x The input tensor.
- */
- /** @doc {heading: 'Operations', subheading: 'Basic math'} */
- function log_(x) {
- var $x = convertToTensor(x, 'x', 'log');
- var grad = function (dy, saved) {
- var $x = saved[0];
- return { x: function () { return dy.div($x.toFloat()); } };
- };
- var attrs = {};
- var inputsToSave = [$x];
- return ENGINE.runKernelFunc(function (backend, save) {
- var res = backend.log($x);
- save([$x]);
- return res;
- }, { x: $x }, grad, 'Log', attrs, inputsToSave);
- }
- /**
- * Computes natural logarithm of the input `tf.Tensor` plus one
- * element-wise: `ln(1 + x)`
- *
- * ```js
- * const x = tf.tensor1d([1, 2, Math.E - 1]);
- *
- * x.log1p().print(); // or tf.log1p(x)
- * ```
- * @param x The input tensor.
- */
- /** @doc {heading: 'Operations', subheading: 'Basic math'} */
- function log1p_(x) {
- var $x = convertToTensor(x, 'x', 'log1p');
- var grad = function (dy, saved) {
- var $x = saved[0];
- return { $x: function () { return dy.div($x.add(1)); } };
- };
- return ENGINE.runKernelFunc(function (backend, save) {
- var res = backend.log1p($x);
- save([$x]);
- return res;
- }, { $x: $x }, grad);
- }
- /**
- * Computes square root of the input `tf.Tensor` element-wise: `y = sqrt(x)`
- *
- * ```js
- * const x = tf.tensor1d([1, 2, 4, -1]);
- *
- * x.sqrt().print(); // or tf.sqrt(x)
- * ```
- * @param x The input tensor.
- */
- /** @doc {heading: 'Operations', subheading: 'Basic math'} */
- function sqrt_(x) {
- var $x = convertToTensor(x, 'x', 'sqrt');
- var grad = function (dy, saved) {
- var $x = saved[0];
- return { $x: function () { return dy.div($x.toFloat().sqrt().mul(2)); } };
- };
- return ENGINE.runKernelFunc(function (backend, save) {
- var res = backend.sqrt($x);
- save([$x]);
- return res;
- }, { $x: $x }, grad);
- }
- /**
- * Computes reciprocal of square root of the input `tf.Tensor` element-wise:
- * `y = 1 / sqrt(x)`
- *
- * ```js
- * const x = tf.tensor1d([1, 2, 4, -1]);
- *
- * x.rsqrt().print(); // or tf.rsqrt(x)
- * ```
- * @param x The input tensor.
- */
- /** @doc {heading: 'Operations', subheading: 'Basic math'} */
- function rsqrt_(x) {
- var $x = convertToTensor(x, 'x', 'rsqrt');
- var grad = function (dy, saved) {
- var $x = saved[0];
- return { x: function () { return dy.div($x.pow(1.5).mul(2)).neg(); } };
- };
- var inputsToSave = [$x];
- return ENGINE.runKernelFunc(function (backend, save) {
- var res = backend.rsqrt($x);
- save([$x]);
- return res;
- }, { x: $x }, grad, 'Rsqrt', {} /* attrs */, inputsToSave);
- }
- /**
- * Computes reciprocal of x element-wise: `1 / x`
- *
- * ```js
- * const x = tf.tensor1d([0, 1, 2]);
- *
- * x.reciprocal().print(); // or tf.reciprocal(x)
- * ```
- * @param x The input tensor.
- */
- /** @doc {heading: 'Operations', subheading: 'Basic math'} */
- function reciprocal_(x) {
- var $x = convertToTensor(x, 'x', 'reciprocal');
- var grad = function (dy, saved) {
- var $x = saved[0];
- return { $x: function () { return dy.div($x.square().neg()); } };
- };
- return ENGINE.runKernelFunc(function (backend, save) {
- var res = backend.reciprocal($x);
- save([$x]);
- return res;
- }, { $x: $x }, grad);
- }
- /**
- * Computes absolute value element-wise: `abs(x)`
- *
- * ```js
- * const x = tf.tensor1d([-1, 2, -3, 4]);
- *
- * x.abs().print(); // or tf.abs(x)
- * ```
- * @param x The input `tf.Tensor`.
- */
- /** @doc {heading: 'Operations', subheading: 'Basic math'} */
- function abs_(x) {
- var $x = convertToTensor(x, 'x', 'abs');
- if ($x.dtype === 'complex64') {
- return ENGINE.runKernelFunc(function (backend) { return backend.complexAbs($x); }, { $x: $x });
- }
- var grad = function (dy, saved) {
- var $x = saved[0];
- return { x: function () { return dy.mul($x.toFloat().step(-1)); } };
- };
- return ENGINE.runKernelFunc(function (backend, save) {
- var res = backend.abs($x);
- save([$x]);
- return res;
- }, { x: $x }, grad, 'Abs');
- }
- /**
- * Clips values element-wise. `max(min(x, clipValueMax), clipValueMin)`
- *
- * ```js
- * const x = tf.tensor1d([-1, 2, -3, 4]);
- *
- * x.clipByValue(-2, 3).print(); // or tf.clipByValue(x, -2, 3)
- * ```
- * @param x The input tensor.
- * @param clipValueMin Lower-bound of range to be clipped to.
- * @param clipValueMax Upper-bound of range to be clipped to.
- */
- /** @doc {heading: 'Operations', subheading: 'Basic math'} */
- function clipByValue_(x, clipValueMin, clipValueMax) {
- var $x = convertToTensor(x, 'x', 'clipByValue');
- assert((clipValueMin <= clipValueMax), function () { return "Error in clip: min (" + clipValueMin + ") must be " +
- ("less than or equal to max (" + clipValueMax + ")."); });
- var grad = function (dy, saved) {
- var $x = saved[0];
- return {
- x: function () { return dy.where($x.greaterEqual(clipValueMin)
- .logicalAnd($x.lessEqual(clipValueMax)), zerosLike(dy)); },
- };
- };
- var inputsToSave = [$x];
- var attr = { min: clipValueMin, max: clipValueMax };
- return ENGINE.runKernelFunc(function (backend, save) {
- var res = backend.clip($x, clipValueMin, clipValueMax);
- save([$x]);
- return res;
- }, { x: $x }, grad, 'ClipByValue', attr, inputsToSave);
- }
- /**
- * Computes sigmoid element-wise, `1 / (1 + exp(-x))`
- *
- * ```js
- * const x = tf.tensor1d([0, -1, 2, -3]);
- *
- * x.sigmoid().print(); // or tf.sigmoid(x)
- * ```
- * @param x The input tensor.
- */
- /** @doc {heading: 'Operations', subheading: 'Basic math'} */
- function sigmoid_(x) {
- var $x = convertToTensor(x, 'x', 'sigmoid');
- var grad = function (dy, saved) {
- var y = saved[0];
- return { x: function () { return dy.mul(y.mul(scalar(1).sub(y))); } };
- };
- return ENGINE.runKernelFunc(function (backend, save) {
- var y = backend.sigmoid($x);
- save([y]);
- return y;
- }, { x: $x }, grad, 'Sigmoid');
- }
- /**
- * Computes log sigmoid of the input `tf.Tensor` element-wise:
- * `logSigmoid(x)`. For numerical stability, we use `-tf.softplus(-x)`.
- *
- * ```js
- * const x = tf.tensor1d([0, 1, -1, .7]);
- *
- * x.logSigmoid().print(); // or tf.logSigmoid(x)
- * ```
- * @param x The input tensor.
- */
- /** @doc {heading: 'Operations', subheading: 'Basic math'} */
- function logSigmoid_(x) {
- var $x = convertToTensor(x, 'x', 'logSigmoid');
- var grad = function (dy, saved) {
- var $x = saved[0];
- return { $x: function () { return dy.mul($x.neg().sigmoid()); } };
- };
- return ENGINE.runKernelFunc(function (backend, save) {
- var res = backend.softplus($x.neg()).neg();
- save([$x]);
- return res;
- }, { $x: $x }, grad);
- }
- /**
- * Computes softplus of the input `tf.Tensor` element-wise: `log(exp(x) + 1)`
- *
- * ```js
- * const x = tf.tensor1d([0, 1, -1, .7]);
- *
- * x.softplus().print(); // or tf.softplus(x)
- * ```
- * @param x The input tensor.
- */
- /** @doc {heading: 'Operations', subheading: 'Basic math'} */
- function softplus_(x) {
- var $x = convertToTensor(x, 'x', 'softplus');
- var grad = function (dy, saved) {
- var $x = saved[0];
- return { $x: function () { return dy.mul($x.sigmoid()); } };
- };
- return ENGINE.runKernelFunc(function (backend, save) {
- var res = backend.softplus($x);
- save([$x]);
- return res;
- }, { $x: $x }, grad);
- }
- /**
- * Computes sin of the input Tensor element-wise: `sin(x)`
- *
- * ```js
- * const x = tf.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]);
- *
- * x.sin().print(); // or tf.sin(x)
- * ```
- * @param x The input tensor.
- */
- /** @doc {heading: 'Operations', subheading: 'Basic math'} */
- function sin_(x) {
- var $x = convertToTensor(x, 'x', 'sin');
- var grad = function (dy, saved) {
- var $x = saved[0];
- return { x: function () { return $x.toFloat().cos().mul(dy); } };
- };
- var inputsToSave = [$x];
- return ENGINE.runKernelFunc(function (backend, save) {
- var res = backend.sin($x);
- save([$x]);
- return res;
- }, { x: $x }, grad, 'Sin', {} /* attrs */, inputsToSave);
- }
- /**
- * Computes cos of the input `tf.Tensor` element-wise: `cos(x)`
- *
- * ```js
- * const x = tf.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]);
- *
- * x.cos().print(); // or tf.cos(x)
- * ```
- * @param x The input tensor.
- */
- /** @doc {heading: 'Operations', subheading: 'Basic math'} */
- function cos_(x) {
- var $x = convertToTensor(x, 'x', 'cos');
- var grad = function (dy, saved) {
- var $x = saved[0];
- return { x: function () { return $x.toFloat().sin().neg().mul(dy); } };
- };
- var inputsToSave = [$x];
- return ENGINE.runKernelFunc(function (backend, save) {
- var res = backend.cos($x);
- save([$x]);
- return res;
- }, { x: $x }, grad, 'Cos', {} /* attrs */, inputsToSave);
- }
- /**
- * Computes tan of the input `tf.Tensor` element-wise, `tan(x)`
- *
- * ```js
- * const x = tf.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]);
- *
- * x.tan().print(); // or tf.tan(x)
- * ```
- * @param x The input tensor.
- */
- /** @doc {heading: 'Operations', subheading: 'Basic math'} */
- function tan_(x) {
- var $x = convertToTensor(x, 'x', 'tan');
- var grad = function (dy, saved) {
- var $x = saved[0];
- return { $x: function () { return dy.div($x.cos().square()); } };
- };
- return ENGINE.runKernelFunc(function (backend, save) {
- var res = backend.tan($x);
- save([$x]);
- return res;
- }, { $x: $x }, grad);
- }
- /**
- * Computes asin of the input `tf.Tensor` element-wise: `asin(x)`
- *
- * ```js
- * const x = tf.tensor1d([0, 1, -1, .7]);
- *
- * x.asin().print(); // or tf.asin(x)
- * ```
- * @param x The input tensor.
- */
- /** @doc {heading: 'Operations', subheading: 'Basic math'} */
- function asin_(x) {
- var $x = convertToTensor(x, 'x', 'asin');
- var grad = function (dy, saved) {
- var $x = saved[0];
- return {
- $x: function () { return dy.divStrict(scalar(1).sub($x.toFloat().square()).sqrt()); }
- };
- };
- return ENGINE.runKernelFunc(function (backend, save) {
- var res = backend.asin($x);
- save([$x]);
- return res;
- }, { $x: $x }, grad);
- }
- /**
- * Computes acos of the input `tf.Tensor` element-wise: `acos(x)`
- *
- * ```js
- * const x = tf.tensor1d([0, 1, -1, .7]);
- *
- * x.acos().print(); // or tf.acos(x)
- * ```
- * @param x The input tensor.
- */
- /** @doc {heading: 'Operations', subheading: 'Basic math'} */
- function acos_(x) {
- var $x = convertToTensor(x, 'x', 'acos');
- var grad = function (dy, saved) {
- var $x = saved[0];
- return {
- $x: function () {
- return dy.divStrict(scalar(1).sub($x.toFloat().square()).sqrt()).neg();
- }
- };
- };
- return ENGINE.runKernelFunc(function (backend, save) {
- var res = backend.acos($x);
- save([$x]);
- return res;
- }, { $x: $x }, grad);
- }
- /**
- * Computes atan of the input `tf.Tensor` element-wise: `atan(x)`
- *
- * ```js
- * const x = tf.tensor1d([0, 1, -1, .7]);
- *
- * x.atan().print(); // or tf.atan(x)
- * ```
- * @param x The input tensor.
- */
- /** @doc {heading: 'Operations', subheading: 'Basic math'} */
- function atan_(x) {
- var $x = convertToTensor(x, 'x', 'atan');
- var grad = function (dy, saved) {
- var $x = saved[0];
- return { $x: function () { return dy.div($x.toFloat().square().add(1)); } };
- };
- return ENGINE.runKernelFunc(function (backend, save) {
- var res = backend.atan($x);
- save([$x]);
- return res;
- }, { $x: $x }, grad);
- }
- /**
- * Computes hyperbolic sin of the input `tf.Tensor` element-wise: `sinh(x)`
- *
- * ```js
- * const x = tf.tensor1d([0, 1, -1, .7]);
- *
- * x.sinh().print(); // or tf.sinh(x)
- * ```
- * @param x The input tensor.
- */
- /** @doc {heading: 'Operations', subheading: 'Basic math'} */
- function sinh_(x) {
- var $x = convertToTensor(x, 'x', 'sinh');
- var grad = function (dy, saved) {
- var $x = saved[0];
- return { $x: function () { return $x.toFloat().cosh().mulStrict(dy); } };
- };
- return ENGINE.runKernelFunc(function (backend, save) {
- var res = backend.sinh($x);
- save([$x]);
- return res;
- }, { $x: $x }, grad);
- }
- /**
- * Computes hyperbolic cos of the input `tf.Tensor` element-wise: `cosh(x)`
- *
- * ```js
- * const x = tf.tensor1d([0, 1, -1, .7]);
- *
- * x.cosh().print(); // or tf.cosh(x)
- * ```
- * @param x The input tensor.
- */
- /** @doc {heading: 'Operations', subheading: 'Basic math'} */
- function cosh_(x) {
- var $x = convertToTensor(x, 'x', 'cosh');
- var grad = function (dy, saved) {
- var $x = saved[0];
- return { $x: function () { return $x.toFloat().sinh().mulStrict(dy); } };
- };
- return ENGINE.runKernelFunc(function (backend, save) {
- var res = backend.cosh($x);
- save([$x]);
- return res;
- }, { $x: $x }, grad);
- }
- /**
- * Computes hyperbolic tangent of the input `tf.Tensor` element-wise: `tanh(x)`
- *
- * ```js
- * const x = tf.tensor1d([0, 1, -1, 70]);
- *
- * x.tanh().print(); // or tf.tanh(x)
- * ```
- * @param x The input tensor.
- */
- /** @doc {heading: 'Operations', subheading: 'Basic math'} */
- function tanh_(x) {
- var $x = convertToTensor(x, 'x', 'tanh');
- var grad = function (dy, saved) {
- var y = saved[0];
- return { x: function () { return scalar(1).sub(y.square()).mulStrict(dy); } };
- };
- var outputsToSave = [true];
- return ENGINE.runKernelFunc(function (backend, save) {
- var y = backend.tanh($x);
- save([y]);
- return y;
- }, { x: $x }, grad, 'Tanh', {} /* attrs */, null /* inputsToSave */, outputsToSave);
- }
- /**
- * Computes inverse hyperbolic sin of the input `tf.Tensor` element-wise:
- * `asinh(x)`
- *
- * ```js
- * const x = tf.tensor1d([0, 1, -1, .7]);
- *
- * x.asinh().print(); // or tf.asinh(x)
- * ```
- * @param x The input tensor.
- */
- /** @doc {heading: 'Operations', subheading: 'Basic math'} */
- function asinh_(x) {
- var $x = convertToTensor(x, 'x', 'asinh');
- var grad = function (dy, saved) {
- var $x = saved[0];
- return {
- $x: function () { return dy.divStrict(scalar(1).add($x.toFloat().square()).sqrt()); }
- };
- };
- return ENGINE.runKernelFunc(function (backend, save) {
- var res = backend.asinh($x);
- save([$x]);
- return res;
- }, { $x: $x }, grad);
- }
- /**
- * Computes the inverse hyperbolic cos of the input `tf.Tensor` element-wise:
- * `acosh(x)`
- *
- * ```js
- * const x = tf.tensor1d([10, 1, 3, 5.7]);
- *
- * x.acosh().print(); // or tf.acosh(x)
- * ```
- * @param x The input tensor.
- */
- /** @doc {heading: 'Operations', subheading: 'Basic math'} */
- function acosh_(x) {
- var $x = convertToTensor(x, 'x', 'acosh');
- var grad = function (dy, saved) {
- var $x = saved[0];
- return { $x: function () { return dy.divStrict($x.toFloat().square().sub(1).sqrt()); } };
- };
- return ENGINE.runKernelFunc(function (backend, save) {
- var res = backend.acosh($x);
- save([$x]);
- return res;
- }, { $x: $x }, grad);
- }
- /**
- * Computes inverse hyperbolic tan of the input `tf.Tensor` element-wise:
- * `atanh(x)`
- *
- * ```js
- * const x = tf.tensor1d([0, .1, -.1, .7]);
- *
- * x.atanh().print(); // or tf.atanh(x)
- * ```
- * @param x The input tensor.
- */
- /** @doc {heading: 'Operations', subheading: 'Basic math'} */
- function atanh_(x) {
- var $x = convertToTensor(x, 'x', 'atanh');
- var grad = function (dy, saved) {
- var $x = saved[0];
- return { $x: function () { return dy.div(scalar(1).sub($x.toFloat().square())); } };
- };
- return ENGINE.runKernelFunc(function (backend, save) {
- var res = backend.atanh($x);
- save([$x]);
- return res;
- }, { $x: $x }, grad);
- }
- /**
- * Computes gause error function of the input `tf.Tensor` element-wise:
- * `erf(x)`
- *
- * ```js
- * const x = tf.tensor1d([0, .1, -.1, .7]);
- *
- * x.erf().print(); // or tf.erf(x);
- * ```
- * @param x The input tensor.
- */
- /** @doc {heading: 'Operations', subheading: 'Basic math'} */
- function erf_(x) {
- var $x = convertToTensor(x, 'x', 'erf');
- assert($x.dtype === 'int32' || $x.dtype === 'float32', function () { return 'Input dtype must be `int32` or `float32`.'; });
- if ($x.dtype === 'int32') {
- $x = $x.toFloat();
- }
- var grad = function (dy, saved) {
- var $x = saved[0];
- return {
- $x: function () { return dy.mul($x.square().neg().exp().mul(2 / Math.sqrt(Math.PI))); }
- };
- };
- return ENGINE.runKernelFunc(function (backend, save) {
- var res = backend.erf($x);
- save([$x]);
- return res;
- }, { $x: $x }, grad);
- }
- /**
- * Computes step of the input `tf.Tensor` element-wise: `x > 0 ? 1 : alpha * x`
- *
- * ```js
- * const x = tf.tensor1d([0, 2, -1, -3]);
- *
- * x.step(.5).print(); // or tf.step(x, .5)
- * ```
- * @param x The input tensor.
- * @param alpha The gradient when input is negative.
- */
- /** @doc {heading: 'Operations', subheading: 'Basic math'} */
- function step_(x, alpha) {
- if (alpha === void 0) { alpha = 0.0; }
- var $x = convertToTensor(x, 'x', 'step');
- // TODO(manrajgrover): Return null for gradients when backprop supports
- // it.
- var grad = function (dy) {
- return { $x: function () { return zerosLike(dy); } };
- };
- return ENGINE.runKernelFunc(function (backend) { return backend.step($x, alpha); }, { $x: $x }, grad);
- }
- var abs = op({ abs_: abs_ });
- var acos = op({ acos_: acos_ });
- var acosh = op({ acosh_: acosh_ });
- var asin = op({ asin_: asin_ });
- var asinh = op({ asinh_: asinh_ });
- var atan = op({ atan_: atan_ });
- var atanh = op({ atanh_: atanh_ });
- var ceil = op({ ceil_: ceil_ });
- var clipByValue = op({ clipByValue_: clipByValue_ });
- var cos = op({ cos_: cos_ });
- var cosh = op({ cosh_: cosh_ });
- var erf = op({ erf_: erf_ });
- var exp = op({ exp_: exp_ });
- var expm1 = op({ expm1_: expm1_ });
- var floor = op({ floor_: floor_ });
- var log = op({ log_: log_ });
- var log1p = op({ log1p_: log1p_ });
- var logSigmoid = op({ logSigmoid_: logSigmoid_ });
- var neg = op({ neg_: neg_ });
- var reciprocal = op({ reciprocal_: reciprocal_ });
- var round = op({ round_: round_ });
- var rsqrt = op({ rsqrt_: rsqrt_ });
- var sigmoid = op({ sigmoid_: sigmoid_ });
- var sign = op({ sign_: sign_ });
- var isNaN$1 = op({ isNaN_: isNaN_ });
- var isInf = op({ isInf_: isInf_ });
- var isFinite$1 = op({ isFinite_: isFinite_ });
- var sin = op({ sin_: sin_ });
- var sinh = op({ sinh_: sinh_ });
- var softplus = op({ softplus_: softplus_ });
- var sqrt = op({ sqrt_: sqrt_ });
- var step = op({ step_: step_ });
- var tan = op({ tan_: tan_ });
- var tanh$1 = op({ tanh_: tanh_ });
-
- /**
- * @license
- * Copyright 2018 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Adds two `tf.Tensor`s element-wise, A + B.
- *
- * Inputs must be the same shape. For broadcasting support, use add() instead.
- *
- * @param a The first Tensor to add element-wise.
- * @param b The second Tensor to add element-wise.
- */
- function addStrict_(a, b) {
- var $a = convertToTensor(a, 'a', 'addStrict');
- var $b = convertToTensor(b, 'b', 'addStrict');
- assertShapesMatch($a.shape, $b.shape, 'Error in addStrict: ');
- return $a.add($b);
- }
- /**
- * Subtracts two `tf.Tensor`s element-wise, A - B. Supports broadcasting.
- *
- * We also expose `tf.subStrict` which has the same signature as this op and
- * asserts that `a` and `b` are the same shape (does not broadcast).
- *
- * ```js
- * const a = tf.tensor1d([10, 20, 30, 40]);
- * const b = tf.tensor1d([1, 2, 3, 4]);
- *
- * a.sub(b).print(); // or tf.sub(a, b)
- * ```
- *
- * ```js
- * // Broadcast subtract a with b.
- * const a = tf.tensor1d([10, 20, 30, 40]);
- * const b = tf.scalar(5);
- *
- * a.sub(b).print(); // or tf.sub(a, b)
- * ```
- * @param a The first `tf.Tensor` to subtract from.
- * @param b The second `tf.Tensor` to be subtracted. Must have the same dtype as
- * `a`.
- */
- /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */
- function sub_(a, b) {
- var _a;
- var $a = convertToTensor(a, 'a', 'sub');
- var $b = convertToTensor(b, 'b', 'sub');
- _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1];
- var outShape = assertAndGetBroadcastShape($a.shape, $b.shape);
- var der = function (dy) {
- var derA = function () {
- var res = dy;
- var reduceAxes = getReductionAxes($a.shape, outShape);
- if (reduceAxes.length > 0) {
- res = res.sum(reduceAxes);
- }
- return res.reshape($a.shape);
- };
- var derB = function () {
- var res = dy;
- var reduceAxes = getReductionAxes($b.shape, outShape);
- if (reduceAxes.length > 0) {
- res = res.sum(reduceAxes);
- }
- return res.neg().reshape($b.shape);
- };
- return { a: derA, b: derB };
- };
- return ENGINE.runKernelFunc(function (backend) { return backend.subtract($a, $b); }, { a: $a, b: $b }, der, 'Sub');
- }
- /**
- * Subtracts two `tf.Tensor`s element-wise, A - B. Inputs must
- * be the same shape.
- *
- * For broadcasting support, use `tf.sub` instead.
- *
- * @param a The first Tensor to subtract element-wise.
- * @param b The second Tensor to subtract element-wise.
- */
- function subStrict_(a, b) {
- var $a = convertToTensor(a, 'a', 'subStrict');
- var $b = convertToTensor(b, 'b', 'subStrict');
- assertShapesMatch($a.shape, $b.shape, 'Error in subStrict: ');
- return $a.sub($b);
- }
- /**
- * Computes the power of one `tf.Tensor` to another. Supports broadcasting.
- *
- * Given a `tf.Tensor` x and a `tf.Tensor` y, this operation computes x^y for
- * corresponding elements in x and y. The result's dtype will be the upcasted
- * type of the `base` and `exp` dtypes.
- *
- * ```js
- * const a = tf.tensor([[2, 3], [4, 5]])
- * const b = tf.tensor([[1, 2], [3, 0]]).toInt();
- *
- * a.pow(b).print(); // or tf.pow(a, b)
- * ```
- *
- * ```js
- * const a = tf.tensor([[1, 2], [3, 4]])
- * const b = tf.tensor(2).toInt();
- *
- * a.pow(b).print(); // or tf.pow(a, b)
- * ```
- * We also expose `powStrict` which has the same signature as this op and
- * asserts that `base` and `exp` are the same shape (does not broadcast).
- *
- * @param base The base `tf.Tensor` to pow element-wise.
- * @param exp The exponent `tf.Tensor` to pow element-wise.
- */
- /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */
- function pow_(base, exp) {
- var _a;
- var $base = convertToTensor(base, 'base', 'pow');
- var $exp = convertToTensor(exp, 'exp', 'pow');
- _a = makeTypesMatch($base, $exp), $base = _a[0], $exp = _a[1];
- var outShape = assertAndGetBroadcastShape($base.shape, $exp.shape);
- var grad = function (dy, saved) {
- var $base = saved[0], $exp = saved[1], y = saved[2];
- var derBase = function () {
- var expFloat = $exp.toFloat();
- var res = dy.mul(expFloat.mul($base.pow(expFloat.sub(scalar(1)))));
- var reduceAxes = getReductionAxes($base.shape, outShape);
- if (reduceAxes.length > 0) {
- res = res.sum(reduceAxes);
- }
- return res.reshape($base.shape);
- };
- var derExp = function () {
- var condition = $base.greater(0);
- var logBase = $base.log().where(condition, zerosLike($base));
- var res = dy.mul(y.mul(logBase));
- var reduceAxes = getReductionAxes($exp.shape, outShape);
- if (reduceAxes.length > 0) {
- res = res.sum(reduceAxes);
- }
- return res.reshape($exp.shape);
- };
- return { a: derBase, b: derExp };
- };
- var attrs = {};
- var inputsToSave = [$base, $exp];
- var outputsToSave = [true];
- return ENGINE.runKernelFunc(function (backend, save) {
- var y = backend.pow($base, $exp);
- save([$base, $exp, y]);
- return y;
- }, { a: $base, b: $exp }, grad, 'Pow', attrs, inputsToSave, outputsToSave);
- }
- /**
- * Computes the power of one `tf.Tensor` to another. Inputs must
- * be the same shape.
- *
- * For broadcasting support, use `tf.pow` instead.
- *
- * @param base The base tensor to pow element-wise.
- * @param exp The exponent tensor to pow element-wise.
- */
- function powStrict_(base, exp) {
- assertShapesMatch(base.shape, exp.shape, 'Error in powStrict: ');
- return base.pow(exp);
- }
- /**
- * Multiplies two `tf.Tensor`s element-wise, A * B. Supports broadcasting.
- *
- * We also expose `tf.mulStrict` which has the same signature as this op and
- * asserts that `a` and `b` are the same shape (does not broadcast).
- *
- * ```js
- * const a = tf.tensor1d([1, 2, 3, 4]);
- * const b = tf.tensor1d([2, 3, 4, 5]);
- *
- * a.mul(b).print(); // or tf.mul(a, b)
- * ```
- *
- * ```js
- * // Broadcast mul a with b.
- * const a = tf.tensor1d([1, 2, 3, 4]);
- * const b = tf.scalar(5);
- *
- * a.mul(b).print(); // or tf.mul(a, b)
- * ```
- * @param a The first tensor to multiply.
- * @param b The second tensor to multiply. Must have the same dtype as `a`.
- */
- /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */
- function mul_(a, b) {
- var _a;
- var $a = convertToTensor(a, 'a', 'mul');
- var $b = convertToTensor(b, 'b', 'mul');
- _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1];
- var outShape = assertAndGetBroadcastShape($a.shape, $b.shape);
- var der = function (dy, saved) {
- var $a = saved[0], $b = saved[1];
- var derA = function () {
- var res = dy.mul($b.toFloat());
- var reduceAxes = getReductionAxes($a.shape, outShape);
- if (reduceAxes.length > 0) {
- return res.sum(reduceAxes).reshape($a.shape);
- }
- return res;
- };
- var derB = function () {
- var res = dy.mul($a.toFloat());
- var reduceAxes = getReductionAxes($b.shape, outShape);
- if (reduceAxes.length > 0) {
- return res.sum(reduceAxes).reshape($b.shape);
- }
- return res;
- };
- return { a: derA, b: derB };
- };
- return ENGINE.runKernelFunc(function (backend, save) {
- var res = backend.multiply($a, $b);
- save([$a, $b]);
- return res;
- }, { a: $a, b: $b }, der, 'Mul');
- }
- /**
- * Multiplies two `tf.Tensor`s element-wise, A * B.
- *
- * Inputs must be the same shape. For broadcasting support, use `tf.mul`.
- *
- * @param a The first tensor to multiply.
- * @param b The first tensor to multiply. Must have the same
- * dtype as `a`.
- */
- function mulStrict_(a, b) {
- var $a = convertToTensor(a, 'a', 'mul');
- var $b = convertToTensor(b, 'b', 'mul');
- assertShapesMatch($a.shape, $b.shape, 'Error in multiplyStrict: ');
- return $a.mul($b);
- }
- /**
- * Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting.
- * The result is rounded with floor function.
- *
- *
- * ```js
- * const a = tf.tensor1d([1, 4, 9, 16]);
- * const b = tf.tensor1d([1, 2, 3, 4]);
- *
- * a.floorDiv(b).print(); // or tf.div(a, b)
- * ```
- *
- * ```js
- * // Broadcast div a with b.
- * const a = tf.tensor1d([2, 4, 6, 8]);
- * const b = tf.scalar(2);
- *
- * a.floorDiv(b).print(); // or tf.floorDiv(a, b)
- * ```
- *
- * @param a The first tensor as the numerator.
- * @param b The second tensor as the denominator. Must have the same dtype as
- * `a`.
- */
- /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */
- function floorDiv_(a, b) {
- var _a;
- var $a = convertToTensor(a, 'a', 'floorDiv');
- var $b = convertToTensor(b, 'b', 'floorDiv');
- _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1];
- var outShape = assertAndGetBroadcastShape($a.shape, $b.shape);
- var der = function (dy, saved) {
- var $a = saved[0], $b = saved[1];
- var derA = function () {
- var res = dy.div($b.toFloat());
- var reduceAxes = getReductionAxes($a.shape, outShape);
- if (reduceAxes.length > 0) {
- return res.sum(reduceAxes).reshape($a.shape);
- }
- return res;
- };
- var derB = function () {
- var res = dy.mul($a.toFloat());
- var reduceAxes = getReductionAxes($b.shape, outShape);
- if (reduceAxes.length > 0) {
- res = res.sum(reduceAxes).reshape($b.shape);
- }
- var tmp = $b.square();
- return res.div(tmp.toFloat()).neg();
- };
- return { a: derA, b: derB };
- };
- return ENGINE.runKernelFunc(function (backend, save) {
- var res = backend.floorDiv($a, $b);
- save([$a, $b]);
- return res;
- }, { a: $a, b: $b }, der, 'FloorDiv');
- }
- /**
- * Divides two `tf.Tensor`s element-wise, A / B. Inputs must
- * be the same shape.
- *
- * @param a The first tensor as the numerator for element-wise division.
- * @param b The second tensor as the denominator for element-wise division.
- */
- function divStrict_(a, b) {
- var $a = convertToTensor(a, 'a', 'div');
- var $b = convertToTensor(b, 'b', 'div');
- assertShapesMatch($a.shape, $b.shape, 'Error in divideStrict: ');
- return $a.div($b);
- }
- /**
- * Returns the mod of a and b element-wise.
- * `floor(x / y) * y + mod(x, y) = x`
- * Supports broadcasting.
- *
- * We also expose `tf.modStrict` which has the same signature as this op and
- * asserts that `a` and `b` are the same shape (does not broadcast).
- *
- * ```js
- * const a = tf.tensor1d([1, 4, 3, 16]);
- * const b = tf.tensor1d([1, 2, 9, 4]);
- *
- * a.mod(b).print(); // or tf.mod(a, b)
- * ```
- *
- * ```js
- * // Broadcast a mod b.
- * const a = tf.tensor1d([2, 4, 6, 8]);
- * const b = tf.scalar(5);
- *
- * a.mod(b).print(); // or tf.mod(a, b)
- * ```
- *
- * @param a The first tensor.
- * @param b The second tensor. Must have the same type as `a`.
- */
- /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */
- function mod_(a, b) {
- var _a;
- var $a = convertToTensor(a, 'a', 'mod');
- var $b = convertToTensor(b, 'b', 'mod');
- _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1];
- var outShape = assertAndGetBroadcastShape($a.shape, $b.shape);
- var der = function (dy, saved) {
- var $a = saved[0], $b = saved[1];
- var derA = function () {
- var reduceAxes = getReductionAxes($a.shape, outShape);
- if (reduceAxes.length > 0) {
- return dy.sum(reduceAxes).reshape($a.shape);
- }
- return dy;
- };
- var derB = function () {
- var res = dy.mul($a.div($b).floor().neg());
- var reduceAxes = getReductionAxes($b.shape, outShape);
- if (reduceAxes.length > 0) {
- return res.sum(reduceAxes).reshape($b.shape);
- }
- return res;
- };
- return { $a: derA, $b: derB };
- };
- return ENGINE.runKernelFunc(function (backend, save) {
- var res = backend.mod($a, $b);
- save([$a, $b]);
- return res;
- }, { $a: $a, $b: $b }, der);
- }
- /**
- * Returns the mod of a and b (`a < b ? a : b`) element-wise. Inputs must
- * be the same shape. For broadcasting support, use mod().
- *
- * @param a The first tensor.
- * @param b The second tensor. Must have the same dtype as `a`.
- */
- function modStrict_(a, b) {
- var $a = convertToTensor(a, 'a', 'modStrict');
- var $b = convertToTensor(b, 'b', 'modStrict');
- assertShapesMatch($a.shape, $b.shape, 'Error in modStrict: ');
- return $a.mod($b);
- }
- /**
- * Returns the min of a and b (`a < b ? a : b`) element-wise.
- * Supports broadcasting.
- *
- * We also expose `minimumStrict` which has the same signature as this op and
- * asserts that `a` and `b` are the same shape (does not broadcast).
- *
- * ```js
- * const a = tf.tensor1d([1, 4, 3, 16]);
- * const b = tf.tensor1d([1, 2, 9, 4]);
- *
- * a.minimum(b).print(); // or tf.minimum(a, b)
- * ```
- *
- * ```js
- * // Broadcast minimum a with b.
- * const a = tf.tensor1d([2, 4, 6, 8]);
- * const b = tf.scalar(5);
- *
- * a.minimum(b).print(); // or tf.minimum(a, b)
- * ```
- *
- * @param a The first tensor.
- * @param b The second tensor. Must have the same type as `a`.
- */
- /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */
- function minimum_(a, b) {
- var _a;
- var $a = convertToTensor(a, 'a', 'minimum');
- var $b = convertToTensor(b, 'b', 'minimum');
- _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1];
- if ($a.dtype === 'bool') {
- $a = $a.toInt();
- $b = $b.toInt();
- }
- assertAndGetBroadcastShape($a.shape, $b.shape);
- var der = function (dy, saved) {
- var $a = saved[0], $b = saved[1];
- var derA = function () { return dy.mul($a.lessEqual($b).toFloat()); };
- var derB = function () { return dy.mul($a.greater($b).toFloat()); };
- return { a: derA, b: derB };
- };
- return ENGINE.runKernelFunc(function (backend, save) {
- var res = backend.minimum($a, $b);
- save([$a, $b]);
- return res;
- }, { a: $a, b: $b }, der, 'Minimum');
- }
- /**
- * Returns the min of a and b (`a < b ? a : b`) element-wise. Inputs must
- * be the same shape. For broadcasting support, use minimum().
- *
- * @param a The first tensor.
- * @param b The second tensor. Must have the same dtype as `a`.
- */
- function minimumStrict_(a, b) {
- var $a = convertToTensor(a, 'a', 'minimumStrict');
- var $b = convertToTensor(b, 'b', 'minimumStrict');
- assertShapesMatch($a.shape, $b.shape, 'Error in minimumStrict: ');
- return $a.minimum($b);
- }
- /**
- * Returns the max of a and b (`a > b ? a : b`) element-wise.
- * Supports broadcasting.
- *
- * We also expose `tf.maximumStrict` which has the same signature as this op and
- * asserts that `a` and `b` are the same shape (does not broadcast).
- *
- * ```js
- * const a = tf.tensor1d([1, 4, 3, 16]);
- * const b = tf.tensor1d([1, 2, 9, 4]);
- *
- * a.maximum(b).print(); // or tf.maximum(a, b)
- * ```
- *
- * ```js
- * // Broadcast maximum a with b.
- * const a = tf.tensor1d([2, 4, 6, 8]);
- * const b = tf.scalar(5);
- *
- * a.maximum(b).print(); // or tf.maximum(a, b)
- * ```
- *
- * @param a The first tensor.
- * @param b The second tensor. Must have the same type as `a`.
- */
- /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */
- function maximum_(a, b) {
- var _a;
- var $a = convertToTensor(a, 'a', 'maximum');
- var $b = convertToTensor(b, 'b', 'maximum');
- _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1];
- if ($a.dtype === 'bool') {
- $a = $a.toInt();
- $b = $b.toInt();
- }
- assertAndGetBroadcastShape($a.shape, $b.shape);
- var der = function (dy, saved) {
- var $a = saved[0], $b = saved[1];
- var derA = function () { return dy.mul($a.greaterEqual($b).toFloat()); };
- var derB = function () { return dy.mul($a.less($b).toFloat()); };
- return { a: derA, b: derB };
- };
- return ENGINE.runKernelFunc(function (backend, save) {
- var res = backend.maximum($a, $b);
- save([$a, $b]);
- return res;
- }, { a: $a, b: $b }, der, 'Maximum');
- }
- /**
- * Returns the max of a and b (`a > b ? a : b`) element-wise. Inputs must
- * be the same shape. For broadcasting support, use maximum().
- *
- * @param a The first tensor.
- * @param b The second tensor. Must have the same dtype as `a`.
- */
- function maximumStrict_(a, b) {
- var $a = convertToTensor(a, 'a', 'maximumStrict');
- var $b = convertToTensor(b, 'b', 'maximumStrict');
- assertShapesMatch($a.shape, $b.shape, 'Error in maximumStrict: ');
- return $a.maximum($b);
- }
- /**
- * Returns (a - b) * (a - b) element-wise.
- *
- * Inputs must be the same shape. For broadcasting support, use
- * `tf.squaredDifference` instead.
- *
- * @param a The first tensor.
- * @param b The second tensor. Must have the same type as `a`.
- */
- function squaredDifferenceStrict_(a, b) {
- var $a = convertToTensor(a, 'a', 'squaredDifferenceStrict');
- var $b = convertToTensor(b, 'b', 'squaredDifferenceStrict');
- assertShapesMatch($a.shape, $b.shape, 'Error in squaredDifferenceStrict: ');
- return $a.squaredDifference($b);
- }
- /**
- * Computes arctangent of `tf.Tensor`s a / b element-wise: `atan2(a, b)`.
- * Supports broadcasting.
- *
- * ```js
- * const a = tf.tensor1d([1.0, 1.0, -1.0, .7]);
- * const b = tf.tensor1d([2.0, 13.0, 3.5, .21]);
- *
- * tf.atan2(a, b).print()
- * ```
- *
- * @param a The first tensor.
- * @param b The second tensor. Must have the same dtype as `a`.
- *
- */
- /** @doc {heading: 'Operations', subheading: 'Basic math'} */
- function atan2_(a, b) {
- var _a;
- var $a = convertToTensor(a, 'a', 'atan2');
- var $b = convertToTensor(b, 'b', 'atan2');
- _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1];
- var outShape = assertAndGetBroadcastShape($a.shape, $b.shape);
- var der = function (dy, saved) {
- var $a = saved[0], $b = saved[1];
- var derA = function () {
- var d = add($a.square(), $b.square());
- var res = dy.mul($b.div(d));
- var reduceAxes = getReductionAxes($a.shape, outShape);
- if (reduceAxes.length > 0) {
- res = res.sum(reduceAxes);
- }
- return res.reshape($a.shape);
- };
- var derB = function () {
- var d = add($a.square(), $b.square());
- var res = neg(dy.mul($a.div(d)));
- var reduceAxes = getReductionAxes($b.shape, outShape);
- if (reduceAxes.length > 0) {
- res = res.sum(reduceAxes);
- }
- return res.reshape($b.shape);
- };
- return { $a: derA, $b: derB };
- };
- return ENGINE.runKernelFunc(function (backend, save) {
- var res = backend.atan2($a, $b);
- save([$a, $b]);
- return res;
- }, { $a: $a, $b: $b }, der);
- }
- var addStrict = op({ addStrict_: addStrict_ });
- var atan2 = op({ atan2_: atan2_ });
- var divStrict = op({ divStrict_: divStrict_ });
- var floorDiv = op({ floorDiv_: floorDiv_ });
- var maximum = op({ maximum_: maximum_ });
- var maximumStrict = op({ maximumStrict_: maximumStrict_ });
- var minimum = op({ minimum_: minimum_ });
- var minimumStrict = op({ minimumStrict_: minimumStrict_ });
- var mod = op({ mod_: mod_ });
- var modStrict = op({ modStrict_: modStrict_ });
- var mul = op({ mul_: mul_ });
- var mulStrict = op({ mulStrict_: mulStrict_ });
- var pow = op({ pow_: pow_ });
- var powStrict = op({ powStrict_: powStrict_ });
- var squaredDifferenceStrict = op({ squaredDifferenceStrict_: squaredDifferenceStrict_ });
- var sub = op({ sub_: sub_ });
- var subStrict = op({ subStrict_: subStrict_ });
-
- /**
- * @license
- * Copyright 2020 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting.
- *
- * We also expose `tf.divStrict` which has the same signature as this op and
- * asserts that `a` and `b` are the same shape (does not broadcast).
- *
- * ```js
- * const a = tf.tensor1d([1, 4, 9, 16]);
- * const b = tf.tensor1d([1, 2, 3, 4]);
- *
- * a.div(b).print(); // or tf.div(a, b)
- * ```
- *
- * ```js
- * // Broadcast div a with b.
- * const a = tf.tensor1d([2, 4, 6, 8]);
- * const b = tf.scalar(2);
- *
- * a.div(b).print(); // or tf.div(a, b)
- * ```
- *
- * @param a The first tensor as the numerator.
- * @param b The second tensor as the denominator. Must have the same dtype as
- * `a`.
- */
- /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */
- function div_(a, b) {
- var _a;
- var $a = convertToTensor(a, 'a', 'div');
- var $b = convertToTensor(b, 'b', 'div');
- _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1];
- if ($a.dtype === 'int32' && $b.dtype === 'int32') {
- return floorDiv($a, $b);
- }
- var forward = function (backend, save) {
- var res = backend.realDivide($a, $b);
- save([$a, $b]);
- return res;
- };
- var inputs = { a: $a, b: $b };
- var attrs = {};
- return ENGINE.runKernelFunc(forward, inputs, null /* gradient */, Div, attrs);
- }
- var div = op({ div_: div_ });
-
- /**
- * Validate gather nd inputs.
- *
- * @param tensor The tensor contains the source values.
- * @param indices The tensor contains the indices to slice the source.
- *
- * @returns [resultShape, numUpdates, sliceSize, strides]
- */
- function prepareAndValidate(tensor, indices) {
- if (tensor.rank < 1) {
- throw new Error('tf.gatherND() expects the input to be rank 1 or higher,' +
- (" but the rank was " + tensor.rank + "."));
- }
- if (indices.rank < 1) {
- throw new Error('tf.gatherND() expects the indices to be rank 1 or higher,' +
- (" but the rank was " + indices.rank + "."));
- }
- if (indices.dtype !== 'int32') {
- throw new Error('tf.gatherND() expects the indices to be int32 type,' +
- (" but the dtype was " + indices.dtype + "."));
- }
- if (indices.shape[indices.rank - 1] > tensor.rank) {
- throw new Error('index innermost dimension length must be <= tensor rank; saw: ' +
- (indices.shape[indices.rank - 1] + " vs. " + tensor.rank));
- }
- if (tensor.size === 0) {
- throw new Error('Requested more than 0 entries, but input is empty.' +
- (" Input shape: " + tensor.shape + "."));
- }
- var indicesShape = indices.shape;
- var sliceRank = indicesShape[indicesShape.length - 1];
- // The result shape is
- // indices.shape[:-1] + params.shape[indices.shape[-1]:]
- var nResult = 1;
- for (var i = 0; i < indicesShape.length - 1; ++i) {
- nResult *= indicesShape[i];
- }
- var inputShape = tensor.shape;
- var resultShape = indicesShape.slice();
- resultShape.pop();
- var sliceSize = 1;
- for (var i = sliceRank; i < tensor.rank; ++i) {
- sliceSize *= inputShape[i];
- resultShape.push(inputShape[i]);
- }
- var strides = computeStrides(tensor.shape).map(function (stride) { return stride / sliceSize; }).concat([1]).slice(0, sliceRank);
- return [resultShape, nResult, sliceSize, strides];
- }
-
- var gather_nd_util = /*#__PURE__*/Object.freeze({
- prepareAndValidate: prepareAndValidate
- });
-
- /**
- * @license
- * Copyright 2017 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var PARALLELIZE_THRESHOLD = 30;
- function computeOptimalWindowSize(inSize) {
- if (inSize <= PARALLELIZE_THRESHOLD) {
- return inSize;
- }
- return nearestDivisor(inSize, Math.floor(Math.sqrt(inSize)));
- }
-
- /**
- * Check whether updates.shape = indices.shape[:batchDim] +
- * shape[sliceDim:]
- *
- * @param x The input tensor.
- */
- function validateUpdateShape(shape, indices, updates) {
- var sliceDim = (indices.rank > 1) ? indices.shape[indices.rank - 1] : 1;
- var batchDim = (indices.rank > 1) ? indices.rank - 1 : 1;
- var shapeError = 'Must have updates.shape = indices.shape[:batchDim] + ' +
- ("shape[sliceDim:], got updates.shape: " + updates.shape) +
- (", indices.shape: " + indices.shape + ", shape: " + shape) +
- (", sliceDim: " + sliceDim + ", and batchDim: " + batchDim + ".");
- if (updates.rank < batchDim) {
- throw new Error(shapeError + (" update.rank < " + batchDim + ". "));
- }
- if (shape.length < sliceDim + (updates.rank - batchDim)) {
- throw new Error(shapeError +
- (" Output shape length < " + (sliceDim + (updates.rank - batchDim))));
- }
- if (updates.rank !== batchDim + shape.length - sliceDim) {
- throw new Error(shapeError + (" update.rank != " + (batchDim + shape.length - sliceDim)));
- }
- for (var d = 0; d < batchDim; ++d) {
- if (updates.shape[d] !== indices.shape[d]) {
- throw new Error(shapeError +
- (" updates.shape[" + d + "] (" + updates.shape[d] + ") != indices.shape[" + d + "] (" + indices.shape[d] + ")."));
- }
- }
- for (var d = 0; d < updates.rank - batchDim; ++d) {
- if (updates.shape[d + batchDim] !== shape[d + sliceDim]) {
- throw new Error(shapeError +
- (" updates.shape[" + (d + batchDim) + "] (" + updates.shape[d + batchDim] + ") != shape[" + (d + batchDim) + "] (" + shape[d + batchDim] + ")"));
- }
- }
- }
- /**
- * Validate scatter nd inputs.
- *
- * @param update The tensor contains the update values.
- * @param indices The tensor contains the indices for the update values.
- * @param shape The shape of the output tensor.
- */
- function validateInput(updates, indices, shape) {
- if (indices.rank < 1) {
- throw new Error('tf.scatterND() expects the indices to be rank 1 or higher,' +
- (" but the rank was " + indices.rank + "."));
- }
- if (updates.rank < 1) {
- throw new Error('tf.scatterND() expects the updates to be rank 1 or higher,' +
- (" but the rank was " + updates.rank + "."));
- }
- if (indices.dtype !== 'int32') {
- throw new Error("The dtype of 'indices' should be int32, but got dtype: " + indices.dtype);
- }
- if (shape.length < 1) {
- throw new Error("Output rank must be greater or equal to 1, but got shape: " + shape);
- }
- if (shape.length === 0) {
- if (indices.size === 0) {
- throw new Error("Indices specified for empty output. indices shape: " + indices.shape);
- }
- if (updates.size === 0) {
- throw new Error("Updates specified for empty output. updates shape: " + updates.shape);
- }
- }
- validateUpdateShape(shape, indices, updates);
- }
- /**
- * Calculate the shape information for the output.
- *
- * @param update The tensor contains the update values.
- * @param indices The tensor contains the indices for the update values.
- * @param shape The shape of the output tensor.
- *
- * @returns ScatterShapeInfo
- */
- function calculateShapes(updates, indices, shape) {
- // Calculate the number of dimensions in indices
- var indicesRank = indices.shape.length;
- var sliceRank = (indicesRank > 1) ? indices.shape[indicesRank - 1] : 1;
- // Calculate the number of elements that make up each slice of our updated
- // tensor. This allows us to work with flattened tensors and copy over whole
- // slices at a time.
- var totalNd = shape.length;
- var sliceSize = 1;
- for (var i = sliceRank; i < totalNd; ++i) {
- sliceSize *= shape[i];
- }
- var safeSliceDim = (sliceRank < 1) ? 1 : sliceRank;
- var numUpdates = sizeFromShape(indices.shape) / safeSliceDim;
- var strides = computeStrides(shape.slice(0, sliceRank)).concat([1]);
- var outputSize = sizeFromShape(shape);
- return { sliceRank: sliceRank, numUpdates: numUpdates, sliceSize: sliceSize, strides: strides, outputSize: outputSize };
- }
-
- var scatter_nd_util = /*#__PURE__*/Object.freeze({
- validateUpdateShape: validateUpdateShape,
- validateInput: validateInput,
- calculateShapes: calculateShapes
- });
-
- /**
- * @license
- * Copyright 2018 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function segOpComputeOptimalWindowSize(inSize, numSegments) {
- var done = false;
- var res;
- if (inSize <= PARALLELIZE_THRESHOLD) {
- res = inSize;
- done = true;
- }
- else {
- res = nearestDivisor(inSize, Math.floor(Math.sqrt(inSize)));
- }
- while (!done) {
- if (res > numSegments || res === inSize) {
- done = true;
- }
- else {
- res = nearestDivisor(inSize, res + 1);
- }
- }
- return res;
- }
- function computeOutShape$1(aShape, axis, numSegments) {
- var outShape = [];
- var rank = aShape.length;
- for (var dim = 0; dim < rank; dim++) {
- if (dim !== axis) {
- outShape.push(aShape[dim]);
- }
- else {
- outShape.push(numSegments);
- }
- }
- return outShape;
- }
- function collectGatherOpShapeInfo(x, indices, axis) {
- var dimSize = x.shape[axis];
- var outputShape = [];
- var batchSize = 1;
- var sliceSize = 1;
- for (var i = 0; i < axis; i++) {
- outputShape.push(x.shape[i]);
- batchSize *= x.shape[i];
- }
- for (var i = 0; i < indices.rank; i++) {
- outputShape.push(indices.shape[i]);
- }
- for (var i = axis + 1; i < x.rank; i++) {
- outputShape.push(x.shape[i]);
- sliceSize *= x.shape[i];
- }
- return { batchSize: batchSize, sliceSize: sliceSize, dimSize: dimSize, outputShape: outputShape };
- }
-
- /**
- * @license
- * Copyright 2017 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function assertParamsValid(input, begin, size) {
- assert(input.rank === begin.length, function () { return "Error in slice" + input.rank + "D: Length of begin " + begin + " must " +
- ("match the rank of the array (" + input.rank + ")."); });
- assert(input.rank === size.length, function () { return "Error in slice" + input.rank + "D: Length of size " + size + " must " +
- ("match the rank of the array (" + input.rank + ")."); });
- var _loop_1 = function (i) {
- assert(begin[i] + size[i] <= input.shape[i], function () { return "Error in slice" + input.rank + "D: begin[" + i + "] + size[" + i + "] " +
- ("(" + (begin[i] + size[i]) + ") would overflow input.shape[" + i + "] (" + input.shape[i] + ")"); });
- };
- for (var i = 0; i < input.rank; ++i) {
- _loop_1(i);
- }
- }
- /** Converts a binary mask to an array of axes. Used in stridedSlice(). */
- function maskToAxes(mask) {
- var axes = [];
- var axis = 0;
- while (mask > 0) {
- if (mask & 1) {
- axes.push(axis);
- }
- mask /= 2;
- axis++;
- }
- return axes;
- }
- /** Computes the output shape given the strided slice params. */
- function computeOutShape$2(begin, end, strides) {
- var size = [];
- for (var axis = 0; axis < begin.length; axis++) {
- size[axis] = Math.ceil((end[axis] - begin[axis]) / strides[axis]);
- }
- return size;
- }
- function startForAxis(beginMask, startIndices, strides, inputShape, axis) {
- // Begin with the specified index
- var start = startIndices[axis];
- var stride = strides[axis] || 1;
- // Check the axis bit from right of beginMask or the begin index is not set
- // for the axis.
- if (beginMask & 1 << axis || start == null) {
- if (stride > 0) {
- // Forward iteration - use the first element. These values will get
- // clamped below (Note: We could have set them to 0 and axis_size-1, but
- // use lowest() and max() to maintain symmetry with StopForAxis())
- start = Number.MIN_SAFE_INTEGER;
- }
- else {
- // Backward iteration - use the last element.
- start = Number.MAX_SAFE_INTEGER;
- }
- }
- // Handle negative indices
- var axisSize = inputShape[axis];
- if (start < 0) {
- start += axisSize;
- }
- // Clamping
- start = clamp(0, start, axisSize - 1);
- return start;
- }
- function stopForAxis(endMask, stopIndices, strides, inputShape, axis) {
- // Begin with the specified index
- var stop = stopIndices[axis];
- var stride = strides[axis] || 1;
- // Check the axis bit from right of endMask or if the stop index is not set
- // for this axis.
- if (endMask & (1 << axis) || stop == null) {
- if (stride > 0) {
- // Forward iteration - use the last element. These values will get
- // clamped below
- stop = Number.MAX_SAFE_INTEGER;
- }
- else {
- // Backward iteration - use the first element.
- stop = Number.MIN_SAFE_INTEGER;
- }
- }
- // Handle negative indices
- var axisSize = inputShape[axis];
- if (stop < 0) {
- stop += axisSize;
- }
- // Clamping
- // Because the end index points one past the last element, we need slightly
- // different clamping ranges depending on the direction.
- if (stride > 0) {
- // Forward iteration
- stop = clamp(0, stop, axisSize);
- }
- else {
- // Backward iteration
- stop = clamp(-1, stop, axisSize - 1);
- }
- return stop;
- }
- /**
- * Returns true if the slice occupies a continous set of elements in the
- * 'flat' space.
- */
- function isSliceContinous(shape, begin, size) {
- // Index of the first axis that has size > 1.
- var firstNonOneAxis = size.length;
- for (var i = 0; i < size.length; i++) {
- if (size[i] > 1) {
- firstNonOneAxis = i;
- break;
- }
- }
- for (var i = firstNonOneAxis + 1; i < size.length; i++) {
- if (begin[i] > 0 || size[i] !== shape[i]) {
- return false;
- }
- }
- return true;
- }
- function computeFlatOffset(begin, strides) {
- var flatOffset = begin.length > 0 ? begin[begin.length - 1] : 1;
- for (var i = 0; i < begin.length - 1; i++) {
- flatOffset += begin[i] * strides[i];
- }
- return flatOffset;
- }
-
- var slice_util = /*#__PURE__*/Object.freeze({
- assertParamsValid: assertParamsValid,
- maskToAxes: maskToAxes,
- computeOutShape: computeOutShape$2,
- startForAxis: startForAxis,
- stopForAxis: stopForAxis,
- isSliceContinous: isSliceContinous,
- computeFlatOffset: computeFlatOffset
- });
-
- /**
- * @license
- * Copyright 2018 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Provided `f(x)`, returns another function `g(x, dy?)`, which gives the
- * gradient of `f(x)` with respect to `x`.
- *
- * If `dy` is provided, the gradient of `f(x).mul(dy).sum()` with respect to
- * `x` is computed instead. `f(x)` must take a single tensor `x` and return a
- * single tensor `y`. If `f()` takes multiple inputs, use `tf.grads` instead.
- *
- * ```js
- * // f(x) = x ^ 2
- * const f = x => x.square();
- * // f'(x) = 2x
- * const g = tf.grad(f);
- *
- * const x = tf.tensor1d([2, 3]);
- * g(x).print();
- * ```
- *
- * ```js
- * // f(x) = x ^ 3
- * const f = x => x.pow(tf.scalar(3, 'int32'));
- * // f'(x) = 3x ^ 2
- * const g = tf.grad(f);
- * // f''(x) = 6x
- * const gg = tf.grad(g);
- *
- * const x = tf.tensor1d([2, 3]);
- * gg(x).print();
- * ```
- *
- * @param f The function f(x), to compute gradient for.
- */
- /** @doc {heading: 'Training', subheading: 'Gradients'} */
- function grad(f) {
- assert(isFunction(f), function () { return 'The f passed in grad(f) must be a function'; });
- return function (x, dy) {
- // x can be of any dtype, thus null as the last argument.
- var $x = convertToTensor(x, 'x', 'tf.grad', null);
- var $dy = (dy != null) ? convertToTensor(dy, 'dy', 'tf.grad') : null;
- return ENGINE.tidy(function () {
- var _a = ENGINE.gradients(function () { return f($x); }, [$x], $dy), value = _a.value, grads = _a.grads;
- if ($dy != null) {
- assertShapesMatch(value.shape, $dy.shape, 'The shape of dy passed in grad(f)(x, dy) must match the shape ' +
- 'returned by f(x)');
- }
- checkGrads(grads);
- return grads[0];
- });
- };
- }
- /**
- * Provided `f(x1, x2,...)`, returns another function `g([x1, x2,...], dy?)`,
- * which gives an array of gradients of `f()` with respect to each input
- * [`x1`,`x2`,...].
- *
- * If `dy` is passed when calling `g()`, the gradient of
- * `f(x1,...).mul(dy).sum()` with respect to each input is computed instead.
- * The provided `f` must take one or more tensors and return a single tensor
- * `y`. If `f()` takes a single input, we recommend using `tf.grad` instead.
- *
- * ```js
- * // f(a, b) = a * b
- * const f = (a, b) => a.mul(b);
- * // df / da = b, df / db = a
- * const g = tf.grads(f);
- *
- * const a = tf.tensor1d([2, 3]);
- * const b = tf.tensor1d([-2, -3]);
- * const [da, db] = g([a, b]);
- * console.log('da');
- * da.print();
- * console.log('db');
- * db.print();
- * ```
- *
- * @param f The function `f(x1, x2,...)` to compute gradients for.
- */
- /** @doc {heading: 'Training', subheading: 'Gradients'} */
- function grads(f) {
- assert(isFunction(f), function () { return 'The f passed in grads(f) must be a function'; });
- return function (args, dy) {
- assert(Array.isArray(args), function () { return 'The args passed in grads(f)(args) must be an array ' +
- 'of `Tensor`s or `TensorLike`s'; });
- // args can be of any dtype, thus null as the last argument.
- var $args = convertToTensorArray(args, 'args', 'tf.grads', null);
- var $dy = (dy != null) ? convertToTensor(dy, 'dy', 'tf.grads') : null;
- return ENGINE.tidy(function () {
- var _a = ENGINE.gradients(function () { return f.apply(void 0, $args); }, $args, $dy), value = _a.value, grads = _a.grads;
- if ($dy != null) {
- assertShapesMatch(value.shape, $dy.shape, 'The shape of dy passed in grads(f)([x1,...], dy) must ' +
- 'match the shape returned by f([x1,...])');
- }
- checkGrads(grads);
- return grads;
- });
- };
- }
- /**
- * Like `tf.grad`, but also returns the value of `f()`. Useful when `f()`
- * returns a metric you want to show.
- *
- * The result is a rich object with the following properties:
- * - grad: The gradient of `f(x)` w.r.t `x` (result of `tf.grad`).
- * - value: The value returned by `f(x)`.
- *
- * ```js
- * // f(x) = x ^ 2
- * const f = x => x.square();
- * // f'(x) = 2x
- * const g = tf.valueAndGrad(f);
- *
- * const x = tf.tensor1d([2, 3]);
- * const {value, grad} = g(x);
- *
- * console.log('value');
- * value.print();
- * console.log('grad');
- * grad.print();
- * ```
- */
- /** @doc {heading: 'Training', subheading: 'Gradients'} */
- function valueAndGrad(f) {
- assert(isFunction(f), function () { return 'The f passed in valueAndGrad(f) must be a function'; });
- return function (x, dy) {
- assert(x instanceof Tensor, function () { return 'The x passed in valueAndGrad(f)(x) must be a tensor'; });
- assert(dy == null || dy instanceof Tensor, function () { return 'The dy passed in valueAndGrad(f)(x, dy) must be a tensor'; });
- var _a = ENGINE.gradients(function () { return f(x); }, [x], dy), grads = _a.grads, value = _a.value;
- checkGrads(grads);
- return { grad: grads[0], value: value };
- };
- }
- /**
- * Like `tf.grads`, but returns also the value of `f()`. Useful when `f()`
- * returns a metric you want to show.
- *
- * The result is a rich object with the following properties:
- * - grads: The gradients of `f()` w.r.t each input (result of `tf.grads`).
- * - value: The value returned by `f(x)`.
- *
- * ```js
- * // f(a, b) = a * b
- * const f = (a, b) => a.mul(b);
- * // df/da = b, df/db = a
- * const g = tf.valueAndGrads(f);
- *
- * const a = tf.tensor1d([2, 3]);
- * const b = tf.tensor1d([-2, -3]);
- * const {value, grads} = g([a, b]);
- *
- * const [da, db] = grads;
- *
- * console.log('value');
- * value.print();
- *
- * console.log('da');
- * da.print();
- * console.log('db');
- * db.print();
- * ```
- */
- /** @doc {heading: 'Training', subheading: 'Gradients'} */
- function valueAndGrads(f) {
- assert(isFunction(f), function () { return 'The f passed in valueAndGrads(f) must be a function'; });
- return function (args, dy) {
- assert(Array.isArray(args) && args.every(function (arg) { return arg instanceof Tensor; }), function () { return 'The args passed in valueAndGrads(f)(args) must be array of ' +
- 'tensors'; });
- assert(dy == null || dy instanceof Tensor, function () { return 'The dy passed in valueAndGrads(f)(args, dy) must be a tensor'; });
- var res = ENGINE.gradients(function () { return f.apply(void 0, args); }, args, dy);
- if (dy != null) {
- assertShapesMatch(res.value.shape, dy.shape, 'The shape of dy passed in valueAndGrads(f)([x1,...], dy) must ' +
- 'match the shape returned by f([x1,...])');
- }
- checkGrads(res.grads);
- return res;
- };
- }
- /**
- * Computes and returns the gradient of f(x) with respect to the list of
- * trainable variables provided by `varList`. If no list is provided, it
- * defaults to all trainable variables.
- *
- * ```js
- * const a = tf.variable(tf.tensor1d([3, 4]));
- * const b = tf.variable(tf.tensor1d([5, 6]));
- * const x = tf.tensor1d([1, 2]);
- *
- * // f(a, b) = a * x ^ 2 + b * x
- * const f = () => a.mul(x.square()).add(b.mul(x)).sum();
- * // df/da = x ^ 2, df/db = x
- * const {value, grads} = tf.variableGrads(f);
- *
- * Object.keys(grads).forEach(varName => grads[varName].print());
- * ```
- *
- * @param f The function to execute. f() should return a scalar.
- * @param varList The list of variables to compute the gradients with respect
- * to. Defaults to all trainable variables.
- * @returns An object with the following keys and values:
- * - `value`: The value of the function `f`.
- * - `grads`: A map from the names of the variables to the gradients.
- * If the `varList` argument is provided explicitly and contains a subset of
- * non-trainable variables, this map in the return value will contain keys
- * that map the names of the non-trainable variables to `null`.
- */
- /** @doc {heading: 'Training', subheading: 'Gradients'} */
- function variableGrads(f, varList) {
- assert(isFunction(f), function () { return 'The f passed in variableGrads(f) must be a function'; });
- assert(varList == null ||
- Array.isArray(varList) && varList.every(function (v) { return v instanceof Variable; }), function () {
- return 'The varList passed in variableGrads(f, varList) must be an array ' +
- 'of variables';
- });
- var specifiedVarList = varList != null;
- if (!specifiedVarList) {
- // Get all of the trainable variables.
- varList = [];
- for (var varName in ENGINE.registeredVariables) {
- varList.push(ENGINE.registeredVariables[varName]);
- }
- }
- var specifiedNonTrainable = specifiedVarList ? varList.filter(function (variable) { return !variable.trainable; }) : null;
- // Prune non-trainable variables.
- var originalVarCount = varList.length;
- varList = varList.filter(function (variable) { return variable.trainable; });
- assert(varList.length > 0, function () { return "variableGrads() expects at least one of the input variables to " +
- ("be trainable, but none of the " + originalVarCount + " variables is ") +
- "trainable."; });
- var allowNoGradients = true;
- var _a = ENGINE.gradients(f, varList, null, allowNoGradients), value = _a.value, grads = _a.grads;
- assert(grads.some(function (g) { return g != null; }), function () { return 'Cannot find a connection between any variable and the result of ' +
- 'the loss function y=f(x). Please make sure the operations that ' +
- 'use variables are inside the function f passed to minimize().'; });
- assert(value.rank === 0, function () { return "The f passed in variableGrads(f) must return a scalar, but it " +
- ("returned a rank-" + value.rank + " tensor"); });
- var namedGrads = {};
- varList.forEach(function (v, i) {
- if (grads[i] != null) {
- namedGrads[v.name] = grads[i];
- }
- });
- if (specifiedNonTrainable != null) {
- // If varList is explicitly provided and contains non-trainable values,
- // add them to the returned gradients with `null` values.
- specifiedNonTrainable.forEach(function (v) { return namedGrads[v.name] = null; });
- }
- return { value: value, grads: namedGrads };
- }
- /**
- * Overrides the gradient computation of a function `f`.
- *
- * Takes a function
- * `f(...inputs, save) => {value: Tensor, gradFunc: (dy, saved) => Tensor[]}`
- * and returns another function `g(...inputs)` which takes the same inputs as
- * `f`. When called, `g` returns `f().value`. In backward mode, custom gradients
- * with respect to each input of `f` are computed using `f().gradFunc`.
- *
- * The `save` function passsed to `f` should be used for saving tensors needed
- * in the gradient. And the `saved` passed to the `gradFunc` is a
- * `NamedTensorMap`, which contains those saved tensor.
- *
- * ```js
- * const customOp = tf.customGrad((x, save) => {
- * // Save x to make sure it's available later for the gradient.
- * save([x]);
- * // Override gradient of our custom x ^ 2 op to be dy * abs(x);
- * return {
- * value: x.square(),
- * // Note `saved.x` which points to the `x` we saved earlier.
- * gradFunc: (dy, saved) => [dy.mul(saved[0].abs())]
- * };
- * });
- *
- * const x = tf.tensor1d([-1, -2, 3]);
- * const dx = tf.grad(x => customOp(x));
- *
- * console.log(`f(x):`);
- * customOp(x).print();
- * console.log(`f'(x):`);
- * dx(x).print();
- * ```
- *
- * @param f The function to evaluate in forward mode, which should return
- * `{value: Tensor, gradFunc: (dy, saved) => Tensor[]}`, where `gradFunc`
- * returns the custom gradients of `f` with respect to its inputs.
- */
- /** @doc {heading: 'Training', subheading: 'Gradients'} */
- function customGrad(f) {
- return ENGINE.customGrad(f);
- }
- function checkGrads(grads) {
- var numNullGradients = grads.filter(function (g) { return g == null; }).length;
- if (numNullGradients > 0) {
- throw new Error("Cannot compute gradient of y=f(x) with respect to x. Make sure that\n the f you passed encloses all operations that lead from x to y.");
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Computes the softmax normalized vector given the logits.
- *
- * ```js
- * const a = tf.tensor1d([1, 2, 3]);
- *
- * a.softmax().print(); // or tf.softmax(a)
- * ```
- *
- * ```js
- * const a = tf.tensor2d([2, 4, 6, 1, 2, 3], [2, 3]);
- *
- * a.softmax().print(); // or tf.softmax(a)
- * ```
- *
- * @param logits The logits array.
- * @param dim The dimension softmax would be performed on. Defaults to `-1`
- * which indicates the last dimension.
- */
- /** @doc {heading: 'Operations', subheading: 'Normalization'} */
- function softmax_(logits, dim) {
- if (dim === void 0) { dim = -1; }
- var $logits = convertToTensor(logits, 'logits', 'softmax', 'float32');
- if (dim === -1) {
- dim = $logits.rank - 1;
- }
- if (dim !== $logits.rank - 1) {
- throw Error('Softmax along a non-last dimension is not yet supported. ' +
- ("Logits was rank " + $logits.rank + " and dim was " + dim));
- }
- var inputsToSave = [];
- var outputsToSave = [true];
- return ENGINE.runKernelFunc(function (backend, save) {
- var y = backend.softmax($logits, dim);
- save([y]);
- return y;
- }, { logits: $logits }, function (dy, saved) {
- var y = saved[0];
- var dyTimesY = dy.mul(y);
- var keepDims = true;
- return {
- logits: function () { return dyTimesY.sub(dyTimesY.sum([dim], keepDims).mul(y)); }
- };
- }, 'Softmax', { dim: dim }, inputsToSave, outputsToSave);
- }
- /**
- * Computes the log softmax.
- *
- * ```js
- * const a = tf.tensor1d([1, 2, 3]);
- *
- * a.logSoftmax().print(); // or tf.logSoftmax(a)
- * ```
- *
- * ```js
- * const a = tf.tensor2d([2, 4, 6, 1, 2, 3], [2, 3]);
- *
- * a.logSoftmax().print(); // or tf.logSoftmax(a)
- * ```
- *
- * @param logits The logits array.
- * @param axis The dimension softmax would be performed on. Defaults to `-1`
- * which indicates the last dimension.
- */
- /** @doc {heading: 'Operations', subheading: 'Normalization'} */
- function logSoftmax_(logits, axis) {
- if (axis === void 0) { axis = -1; }
- var $logits = convertToTensor(logits, 'logits', 'logSoftmax');
- if (axis === -1) {
- axis = $logits.rank - 1;
- }
- if (axis !== $logits.rank - 1) {
- throw Error('Log Softmax along a non-last dimension is not yet supported. ' +
- ("Logits was rank " + $logits.rank + " and axis was " + axis));
- }
- var customOp = customGrad(function (logits, save) {
- var keepDims = true;
- var xMax = logits.max(axis, true);
- var shifted = logits.sub(xMax);
- var value = shifted.toFloat().sub(shifted.exp().sum(axis, keepDims).log());
- save([value]);
- var gradFunc = function (dy, saved) {
- var value = saved[0];
- var softmax = value.exp();
- return dy.sub(dy.sum(axis, keepDims).mul(softmax));
- };
- return { value: value, gradFunc: gradFunc };
- });
- return customOp($logits);
- }
- var softmax = op({ softmax_: softmax_ });
- var logSoftmax = op({ logSoftmax_: logSoftmax_ });
-
- /**
- * @license
- * Copyright 2018 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Transposes the `tf.Tensor`. Permutes the dimensions according to `perm`.
- *
- * The returned `tf.Tensor`'s dimension `i` will correspond to the input
- * dimension `perm[i]`. If `perm` is not given, it is set to `[n-1...0]`,
- * where `n` is the rank of the input `tf.Tensor`. Hence by default, this
- * operation performs a regular matrix transpose on 2-D input `tf.Tensor`s.
- *
- * ```js
- * const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
- *
- * a.transpose().print(); // or tf.transpose(a)
- * ```
- *
- * @param x The tensor to transpose.
- * @param perm The permutation of the dimensions of a.
- */
- /** @doc {heading: 'Operations', subheading: 'Matrices'} */
- function transpose_(x, perm) {
- var $x = convertToTensor(x, 'x', 'transpose');
- if (perm == null) {
- perm = $x.shape.map(function (s, i) { return i; }).reverse();
- }
- assert($x.rank === perm.length, function () { return "Error in transpose: rank of input " + $x.rank + " " +
- ("must match length of perm " + perm + "."); });
- perm.forEach(function (axis) {
- assert(axis >= 0 && axis < $x.rank, function () { return "All entries in 'perm' must be between 0 and " + ($x.rank - 1) +
- (" but got " + perm); });
- });
- if ($x.rank <= 1) {
- return $x.clone();
- }
- var attrs = { perm: perm };
- return ENGINE.runKernelFunc(function (backend) { return backend.transpose($x, perm); }, { x: $x }, null /* gradient */, 'Transpose', attrs);
- }
- var transpose = op({ transpose_: transpose_ });
-
- /**
- * @license
- * Copyright 2018 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var EPSILON_FLOAT32 = 1e-7;
- var EPSILON_FLOAT16 = 1e-4;
- /** Convenient class for storing tensor-related data. */
- var DataStorage = /** @class */ (function () {
- function DataStorage(backend, dataMover) {
- this.backend = backend;
- this.dataMover = dataMover;
- this.data = new WeakMap();
- this.dataIdsCount = 0;
- }
- DataStorage.prototype.get = function (dataId) {
- if (!this.data.has(dataId)) {
- this.dataMover.moveData(this.backend, dataId);
- }
- return this.data.get(dataId);
- };
- DataStorage.prototype.set = function (dataId, value) {
- this.dataIdsCount++;
- this.data.set(dataId, value);
- };
- DataStorage.prototype.has = function (dataId) {
- return this.data.has(dataId);
- };
- DataStorage.prototype.delete = function (dataId) {
- this.dataIdsCount--;
- return this.data.delete(dataId);
- };
- DataStorage.prototype.numDataIds = function () {
- return this.dataIdsCount;
- };
- return DataStorage;
- }());
- /**
- * The interface that defines the kernels that should be implemented when
- * adding a new backend. New backends don't need to implement every one of the
- * methods, this can be done gradually (throw an error for unimplemented
- * methods).
- */
- var KernelBackend = /** @class */ (function () {
- function KernelBackend() {
- }
- KernelBackend.prototype.time = function (f) {
- return notYetImplemented('time');
- };
- KernelBackend.prototype.read = function (dataId) {
- return notYetImplemented('read');
- };
- KernelBackend.prototype.readSync = function (dataId) {
- return notYetImplemented('readSync');
- };
- KernelBackend.prototype.numDataIds = function () {
- return notYetImplemented('numDataIds');
- };
- KernelBackend.prototype.disposeData = function (dataId) {
- return notYetImplemented('disposeData');
- };
- KernelBackend.prototype.write = function (values, shape, dtype) {
- return notYetImplemented('write');
- };
- KernelBackend.prototype.move = function (dataId, values, shape, dtype) {
- return notYetImplemented('move');
- };
- KernelBackend.prototype.memory = function () {
- return notYetImplemented('memory');
- };
- /** Returns the highest precision for floats in bits (e.g. 16 or 32) */
- KernelBackend.prototype.floatPrecision = function () {
- return notYetImplemented('floatPrecision');
- };
- /** Returns the smallest representable number. */
- KernelBackend.prototype.epsilon = function () {
- return this.floatPrecision() === 32 ? EPSILON_FLOAT32 : EPSILON_FLOAT16;
- };
- KernelBackend.prototype.batchMatMul = function (a, b, transposeA, transposeB) {
- return notYetImplemented('batchMatMul');
- };
- KernelBackend.prototype.fusedBatchMatMul = function (_a) {
- var a = _a.a, b = _a.b, transposeA = _a.transposeA, transposeB = _a.transposeB, bias = _a.bias, activation = _a.activation, preluActivationWeights = _a.preluActivationWeights;
- return notYetImplemented('fusedBatchMatMul');
- };
- KernelBackend.prototype.slice = function (x, begin, size) {
- return notYetImplemented('slice');
- };
- KernelBackend.prototype.stridedSlice = function (x, begin, end, strides) {
- return notYetImplemented('stridedSlice');
- };
- KernelBackend.prototype.unstack = function (x, axis) {
- return notYetImplemented('unstack');
- };
- KernelBackend.prototype.reverse = function (a, axis) {
- return notYetImplemented('reverse');
- };
- KernelBackend.prototype.concat = function (tensors, axis) {
- return notYetImplemented('concat');
- };
- KernelBackend.prototype.neg = function (a) {
- return notYetImplemented('neg');
- };
- KernelBackend.prototype.add = function (a, b) {
- return notYetImplemented('add');
- };
- KernelBackend.prototype.addN = function (tensors) {
- return notYetImplemented('addN');
- };
- KernelBackend.prototype.subtract = function (a, b) {
- return notYetImplemented('subtract');
- };
- KernelBackend.prototype.multiply = function (a, b) {
- return notYetImplemented('multiply');
- };
- KernelBackend.prototype.realDivide = function (a, b) {
- return notYetImplemented('realDivide');
- };
- KernelBackend.prototype.floorDiv = function (a, b) {
- return notYetImplemented('floorDiv');
- };
- KernelBackend.prototype.sum = function (x, axes) {
- return notYetImplemented('sum');
- };
- KernelBackend.prototype.prod = function (x, axes) {
- return notYetImplemented('prod');
- };
- KernelBackend.prototype.unsortedSegmentSum = function (x, segmentIds, numSegments) {
- return notYetImplemented('unsortedSegmentSum');
- };
- KernelBackend.prototype.argMin = function (x, axis) {
- return notYetImplemented('argMin');
- };
- KernelBackend.prototype.argMax = function (x, axis) {
- return notYetImplemented('argMax');
- };
- KernelBackend.prototype.equal = function (a, b) {
- return notYetImplemented('equal');
- };
- KernelBackend.prototype.notEqual = function (a, b) {
- return notYetImplemented('notEqual');
- };
- KernelBackend.prototype.less = function (a, b) {
- return notYetImplemented('less');
- };
- KernelBackend.prototype.lessEqual = function (a, b) {
- return notYetImplemented('lessEqual');
- };
- KernelBackend.prototype.greater = function (a, b) {
- return notYetImplemented('greater');
- };
- KernelBackend.prototype.greaterEqual = function (a, b) {
- return notYetImplemented('greaterEqual');
- };
- KernelBackend.prototype.logicalNot = function (a) {
- return notYetImplemented('logicalNot');
- };
- KernelBackend.prototype.logicalAnd = function (a, b) {
- return notYetImplemented('logicalAnd');
- };
- KernelBackend.prototype.logicalOr = function (a, b) {
- return notYetImplemented('logicalOr');
- };
- KernelBackend.prototype.where = function (condition) {
- return notYetImplemented('where');
- };
- KernelBackend.prototype.select = function (condition, a, b) {
- return notYetImplemented('select');
- };
- KernelBackend.prototype.topk = function (x, k, sorted) {
- return notYetImplemented('topk');
- };
- KernelBackend.prototype.min = function (x, axes) {
- return notYetImplemented('min');
- };
- KernelBackend.prototype.minimum = function (a, b) {
- return notYetImplemented('minimum');
- };
- KernelBackend.prototype.mod = function (a, b) {
- return notYetImplemented('mod');
- };
- KernelBackend.prototype.max = function (x, axes) {
- return notYetImplemented('max');
- };
- KernelBackend.prototype.maximum = function (a, b) {
- return notYetImplemented('maximum');
- };
- KernelBackend.prototype.all = function (x, axes) {
- return notYetImplemented('all');
- };
- KernelBackend.prototype.any = function (x, axes) {
- return notYetImplemented('any');
- };
- KernelBackend.prototype.squaredDifference = function (a, b) {
- return notYetImplemented('squaredDifference');
- };
- KernelBackend.prototype.ceil = function (x) {
- return notYetImplemented('ceil');
- };
- KernelBackend.prototype.floor = function (x) {
- return notYetImplemented('floor');
- };
- KernelBackend.prototype.round = function (x) {
- return notYetImplemented('round');
- };
- KernelBackend.prototype.sign = function (x) {
- return notYetImplemented('sign');
- };
- KernelBackend.prototype.isNaN = function (x) {
- return notYetImplemented('isNaN');
- };
- KernelBackend.prototype.isInf = function (x) {
- return notYetImplemented('isInf');
- };
- KernelBackend.prototype.isFinite = function (x) {
- return notYetImplemented('isFinite');
- };
- KernelBackend.prototype.pow = function (a, b) {
- return notYetImplemented('pow');
- };
- KernelBackend.prototype.exp = function (x) {
- return notYetImplemented('exp');
- };
- KernelBackend.prototype.expm1 = function (x) {
- return notYetImplemented('expm1');
- };
- KernelBackend.prototype.softmax = function (x, dim) {
- return notYetImplemented('softmax');
- };
- KernelBackend.prototype.log = function (x) {
- return notYetImplemented('log');
- };
- KernelBackend.prototype.log1p = function (x) {
- return notYetImplemented('log1p');
- };
- KernelBackend.prototype.sqrt = function (x) {
- return notYetImplemented('sqrt');
- };
- KernelBackend.prototype.rsqrt = function (x) {
- return notYetImplemented('rsqrt');
- };
- KernelBackend.prototype.square = function (x) {
- return notYetImplemented('square');
- };
- KernelBackend.prototype.reciprocal = function (x) {
- return notYetImplemented('reciprocal');
- };
- KernelBackend.prototype.relu = function (x) {
- return notYetImplemented('relu');
- };
- KernelBackend.prototype.relu6 = function (x) {
- return notYetImplemented('relu6');
- };
- KernelBackend.prototype.prelu = function (x, a) {
- return notYetImplemented('prelu');
- };
- KernelBackend.prototype.elu = function (x) {
- return notYetImplemented('elu');
- };
- KernelBackend.prototype.eluDer = function (dy, y) {
- return notYetImplemented('eluDer');
- };
- KernelBackend.prototype.selu = function (x) {
- return notYetImplemented('selu');
- };
- KernelBackend.prototype.int = function (x) {
- return notYetImplemented('int');
- };
- KernelBackend.prototype.clip = function (x, min, max) {
- return notYetImplemented('clip');
- };
- KernelBackend.prototype.abs = function (x) {
- return notYetImplemented('abs');
- };
- KernelBackend.prototype.complexAbs = function (x) {
- return notYetImplemented('complexAbs');
- };
- KernelBackend.prototype.sigmoid = function (x) {
- return notYetImplemented('sigmoid');
- };
- KernelBackend.prototype.softplus = function (x) {
- return notYetImplemented('softplus');
- };
- KernelBackend.prototype.sin = function (x) {
- return notYetImplemented('sin');
- };
- KernelBackend.prototype.cos = function (x) {
- return notYetImplemented('cos');
- };
- KernelBackend.prototype.tan = function (x) {
- return notYetImplemented('tan');
- };
- KernelBackend.prototype.asin = function (x) {
- return notYetImplemented('asin');
- };
- KernelBackend.prototype.acos = function (x) {
- return notYetImplemented('acos');
- };
- KernelBackend.prototype.atan = function (x) {
- return notYetImplemented('atan');
- };
- KernelBackend.prototype.atan2 = function (a, b) {
- return notYetImplemented('atan2');
- };
- KernelBackend.prototype.sinh = function (x) {
- return notYetImplemented('sinh');
- };
- KernelBackend.prototype.cosh = function (x) {
- return notYetImplemented('cosh');
- };
- KernelBackend.prototype.tanh = function (x) {
- return notYetImplemented('tanh');
- };
- KernelBackend.prototype.asinh = function (x) {
- return notYetImplemented('asinh');
- };
- KernelBackend.prototype.acosh = function (x) {
- return notYetImplemented('acosh');
- };
- KernelBackend.prototype.atanh = function (x) {
- return notYetImplemented('atanh');
- };
- KernelBackend.prototype.erf = function (x) {
- return notYetImplemented('erf');
- };
- KernelBackend.prototype.step = function (x, alpha) {
- return notYetImplemented('step');
- };
- KernelBackend.prototype.fusedConv2d = function (_a) {
- var input = _a.input, filter = _a.filter, convInfo = _a.convInfo, bias = _a.bias, activation = _a.activation, preluActivationWeights = _a.preluActivationWeights;
- return notYetImplemented('fusedConv2d');
- };
- KernelBackend.prototype.conv2d = function (x, filter, convInfo) {
- return notYetImplemented('conv2d');
- };
- KernelBackend.prototype.conv2dDerInput = function (dy, filter, convInfo) {
- return notYetImplemented('conv2dDerInput');
- };
- KernelBackend.prototype.conv2dDerFilter = function (x, dY, convInfo) {
- return notYetImplemented('conv2dDerFilter');
- };
- KernelBackend.prototype.fusedDepthwiseConv2D = function (_a) {
- var input = _a.input, filter = _a.filter, convInfo = _a.convInfo, bias = _a.bias, activation = _a.activation, preluActivationWeights = _a.preluActivationWeights;
- return notYetImplemented('fusedDepthwiseConv2D');
- };
- KernelBackend.prototype.depthwiseConv2D = function (input, filter, convInfo) {
- return notYetImplemented('depthwiseConv2D');
- };
- KernelBackend.prototype.depthwiseConv2DDerInput = function (dy, filter, convInfo) {
- return notYetImplemented('depthwiseConv2DDerInput');
- };
- KernelBackend.prototype.depthwiseConv2DDerFilter = function (x, dY, convInfo) {
- return notYetImplemented('depthwiseConv2DDerFilter');
- };
- KernelBackend.prototype.conv3d = function (x, filter, convInfo) {
- return notYetImplemented('conv3d');
- };
- KernelBackend.prototype.conv3dDerInput = function (dy, filter, convInfo) {
- return notYetImplemented('conv3dDerInput');
- };
- KernelBackend.prototype.conv3dDerFilter = function (x, dY, convInfo) {
- return notYetImplemented('conv3dDerFilter');
- };
- KernelBackend.prototype.maxPool = function (x, convInfo) {
- return notYetImplemented('maxPool');
- };
- KernelBackend.prototype.maxPoolBackprop = function (dy, x, y, convInfo) {
- return notYetImplemented('maxPoolBackprop');
- };
- KernelBackend.prototype.avgPool = function (x, convInfo) {
- return notYetImplemented('avgPool');
- };
- KernelBackend.prototype.avgPoolBackprop = function (dy, x, convInfo) {
- return notYetImplemented('avgPoolBackprop');
- };
- KernelBackend.prototype.avgPool3d = function (x, convInfo) {
- return notYetImplemented('avgPool3d');
- };
- KernelBackend.prototype.avgPool3dBackprop = function (dy, x, convInfo) {
- return notYetImplemented('avgPool3dBackprop');
- };
- KernelBackend.prototype.maxPool3d = function (x, convInfo) {
- return notYetImplemented('maxPool3d');
- };
- KernelBackend.prototype.maxPool3dBackprop = function (dy, x, y, convInfo) {
- return notYetImplemented('maxPool3dBackprop');
- };
- KernelBackend.prototype.reshape = function (x, shape) {
- return notYetImplemented('reshape');
- };
- KernelBackend.prototype.cast = function (x, dtype) {
- return notYetImplemented('cast');
- };
- KernelBackend.prototype.tile = function (x, reps) {
- return notYetImplemented('tile');
- };
- KernelBackend.prototype.pad = function (x, paddings, constantValue) {
- return notYetImplemented('pad');
- };
- KernelBackend.prototype.transpose = function (x, perm) {
- return notYetImplemented('transpose');
- };
- KernelBackend.prototype.gather = function (x, indices, axis) {
- return notYetImplemented('gather');
- };
- KernelBackend.prototype.gatherND = function (x, indices) {
- return notYetImplemented('gatherND');
- };
- KernelBackend.prototype.scatterND = function (indices, updates, shape) {
- return notYetImplemented('scatterND');
- };
- KernelBackend.prototype.batchToSpaceND = function (x, blockShape, crops) {
- return notYetImplemented('batchToSpaceND');
- };
- KernelBackend.prototype.spaceToBatchND = function (x, blockShape, paddings) {
- return notYetImplemented('spaceToBatchND');
- };
- KernelBackend.prototype.resizeBilinear = function (x, newHeight, newWidth, alignCorners) {
- return notYetImplemented('resizeBilinear');
- };
- KernelBackend.prototype.resizeBilinearBackprop = function (dy, x, alignCorners) {
- return notYetImplemented('resizeBilinearBackprop');
- };
- KernelBackend.prototype.resizeNearestNeighbor = function (x, newHEight, newWidth, alignCorners) {
- return notYetImplemented('resizeNearestNeighbor');
- };
- KernelBackend.prototype.resizeNearestNeighborBackprop = function (dy, x, alignCorners) {
- return notYetImplemented('resizeNearestNeighborBackprop');
- };
- KernelBackend.prototype.batchNormalization = function (x, mean, variance, varianceEpsilon, scale, offset) {
- return notYetImplemented('batchNormalization');
- };
- KernelBackend.prototype.localResponseNormalization4D = function (x, radius, bias, alpha, beta) {
- return notYetImplemented('localResponseNormalization4D');
- };
- KernelBackend.prototype.LRNGrad = function (dy, inputImage, outputImage, radius, bias, alpha, beta) {
- return notYetImplemented('LRNGrad');
- };
- KernelBackend.prototype.multinomial = function (logits, normalized, numSamples, seed) {
- return notYetImplemented('multinomial');
- };
- KernelBackend.prototype.oneHot = function (indices, depth, onValue, offValue) {
- return notYetImplemented('oneHot');
- };
- KernelBackend.prototype.cumsum = function (x, axis, exclusive, reverse) {
- return notYetImplemented('cumsum');
- };
- KernelBackend.prototype.nonMaxSuppression = function (boxes, scores, maxOutputSize, iouThreshold, scoreThreshold) {
- return notYetImplemented('nonMaxSuppression');
- };
- KernelBackend.prototype.fft = function (x) {
- return notYetImplemented('fft');
- };
- KernelBackend.prototype.ifft = function (x) {
- return notYetImplemented('ifft');
- };
- KernelBackend.prototype.complex = function (real, imag) {
- return notYetImplemented('complex');
- };
- KernelBackend.prototype.real = function (input) {
- return notYetImplemented('real');
- };
- KernelBackend.prototype.imag = function (input) {
- return notYetImplemented('imag');
- };
- KernelBackend.prototype.cropAndResize = function (image, boxes, boxIndex, cropSize, method, extrapolationValue) {
- return notYetImplemented('cropAndResize');
- };
- KernelBackend.prototype.depthToSpace = function (x, blockSize, dataFormat) {
- return notYetImplemented('depthToSpace');
- };
- // Aligns with the "SplitV" kernel in TensorFlow.
- KernelBackend.prototype.split = function (value, sizeSplits, axis) {
- return notYetImplemented('split');
- };
- KernelBackend.prototype.sparseToDense = function (sparseIndices, sparseValues, outputShape, defaultValue) {
- return notYetImplemented('sparseToDense');
- };
- KernelBackend.prototype.diag = function (x) {
- return notYetImplemented('diag');
- };
- KernelBackend.prototype.fill = function (shape, value, dtype) {
- return notYetImplemented('fill');
- };
- KernelBackend.prototype.onesLike = function (x) {
- return notYetImplemented('onesLike');
- };
- KernelBackend.prototype.zerosLike = function (x) {
- return notYetImplemented('zerosLike');
- };
- KernelBackend.prototype.linspace = function (start, stop, num) {
- return notYetImplemented('linspace');
- };
- KernelBackend.prototype.dispose = function () {
- return notYetImplemented('dispose');
- };
- return KernelBackend;
- }());
- function notYetImplemented(kernelName) {
- throw new Error("'" + kernelName + "' not yet implemented or not found in the registry. " +
- "Did you forget to import the kernel?");
- }
-
- /**
- * @license
- * Copyright 2017 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function computePool2DInfo(inShape, filterSize, strides, dilations, pad, roundingMode, dataFormat) {
- if (dataFormat === void 0) { dataFormat = 'channelsLast'; }
- var _a = parseTupleParam(filterSize), filterHeight = _a[0], filterWidth = _a[1];
- var filterShape;
- if (dataFormat === 'channelsLast') {
- filterShape = [filterHeight, filterWidth, inShape[3], inShape[3]];
- }
- else if (dataFormat === 'channelsFirst') {
- filterShape = [filterHeight, filterWidth, inShape[1], inShape[1]];
- }
- else {
- throw new Error("Unknown dataFormat " + dataFormat);
- }
- return computeConv2DInfo(inShape, filterShape, strides, dilations, pad, roundingMode, false, dataFormat);
- }
- /**
- * Computes the information for a forward pass of a pooling3D operation.
- */
- function computePool3DInfo(inShape, filterSize, strides, dilations, pad, roundingMode, dataFormat) {
- if (dataFormat === void 0) { dataFormat = 'NDHWC'; }
- var _a = parse3TupleParam(filterSize), filterDepth = _a[0], filterHeight = _a[1], filterWidth = _a[2];
- var filterShape;
- var $dataFormat;
- if (dataFormat === 'NDHWC') {
- $dataFormat = 'channelsLast';
- filterShape =
- [filterDepth, filterHeight, filterWidth, inShape[4], inShape[4]];
- }
- else if (dataFormat === 'NCDHW') {
- $dataFormat = 'channelsFirst';
- filterShape =
- [filterDepth, filterHeight, filterWidth, inShape[1], inShape[1]];
- }
- else {
- throw new Error("Unknown dataFormat " + dataFormat);
- }
- return computeConv3DInfo(inShape, filterShape, strides, dilations, pad, false, $dataFormat, roundingMode);
- }
- /**
- * Computes the information for a forward pass of a convolution/pooling
- * operation.
- */
- function computeConv2DInfo(inShape, filterShape, strides, dilations, pad, roundingMode, depthwise, dataFormat) {
- if (depthwise === void 0) { depthwise = false; }
- if (dataFormat === void 0) { dataFormat = 'channelsLast'; }
- var _a = [-1, -1, -1, -1], batchSize = _a[0], inHeight = _a[1], inWidth = _a[2], inChannels = _a[3];
- if (dataFormat === 'channelsLast') {
- batchSize = inShape[0], inHeight = inShape[1], inWidth = inShape[2], inChannels = inShape[3];
- }
- else if (dataFormat === 'channelsFirst') {
- batchSize = inShape[0], inChannels = inShape[1], inHeight = inShape[2], inWidth = inShape[3];
- }
- else {
- throw new Error("Unknown dataFormat " + dataFormat);
- }
- var filterHeight = filterShape[0], filterWidth = filterShape[1], filterChannels = filterShape[3];
- var _b = parseTupleParam(strides), strideHeight = _b[0], strideWidth = _b[1];
- var _c = parseTupleParam(dilations), dilationHeight = _c[0], dilationWidth = _c[1];
- var effectiveFilterHeight = getEffectiveFilterSize(filterHeight, dilationHeight);
- var effectiveFilterWidth = getEffectiveFilterSize(filterWidth, dilationWidth);
- var _d = getPadAndOutInfo(pad, inHeight, inWidth, strideHeight, strideWidth, effectiveFilterHeight, effectiveFilterWidth, roundingMode), padInfo = _d.padInfo, outHeight = _d.outHeight, outWidth = _d.outWidth;
- var outChannels = depthwise ? filterChannels * inChannels : filterChannels;
- var outShape;
- if (dataFormat === 'channelsFirst') {
- outShape = [batchSize, outChannels, outHeight, outWidth];
- }
- else if (dataFormat === 'channelsLast') {
- outShape = [batchSize, outHeight, outWidth, outChannels];
- }
- return {
- batchSize: batchSize,
- dataFormat: dataFormat,
- inHeight: inHeight,
- inWidth: inWidth,
- inChannels: inChannels,
- outHeight: outHeight,
- outWidth: outWidth,
- outChannels: outChannels,
- padInfo: padInfo,
- strideHeight: strideHeight,
- strideWidth: strideWidth,
- filterHeight: filterHeight,
- filterWidth: filterWidth,
- effectiveFilterHeight: effectiveFilterHeight,
- effectiveFilterWidth: effectiveFilterWidth,
- dilationHeight: dilationHeight,
- dilationWidth: dilationWidth,
- inShape: inShape,
- outShape: outShape,
- filterShape: filterShape
- };
- }
- /**
- * Computes the information for a forward pass of a 3D convolution/pooling
- * operation.
- */
- function computeConv3DInfo(inShape, filterShape, strides, dilations, pad, depthwise, dataFormat, roundingMode) {
- if (depthwise === void 0) { depthwise = false; }
- if (dataFormat === void 0) { dataFormat = 'channelsLast'; }
- var _a = [-1, -1, -1, -1, -1], batchSize = _a[0], inDepth = _a[1], inHeight = _a[2], inWidth = _a[3], inChannels = _a[4];
- if (dataFormat === 'channelsLast') {
- batchSize = inShape[0], inDepth = inShape[1], inHeight = inShape[2], inWidth = inShape[3], inChannels = inShape[4];
- }
- else if (dataFormat === 'channelsFirst') {
- batchSize = inShape[0], inChannels = inShape[1], inDepth = inShape[2], inHeight = inShape[3], inWidth = inShape[4];
- }
- else {
- throw new Error("Unknown dataFormat " + dataFormat);
- }
- var filterDepth = filterShape[0], filterHeight = filterShape[1], filterWidth = filterShape[2], filterChannels = filterShape[4];
- var _b = parse3TupleParam(strides), strideDepth = _b[0], strideHeight = _b[1], strideWidth = _b[2];
- var _c = parse3TupleParam(dilations), dilationDepth = _c[0], dilationHeight = _c[1], dilationWidth = _c[2];
- var effectiveFilterDepth = getEffectiveFilterSize(filterDepth, dilationDepth);
- var effectiveFilterHeight = getEffectiveFilterSize(filterHeight, dilationHeight);
- var effectiveFilterWidth = getEffectiveFilterSize(filterWidth, dilationWidth);
- var _d = get3DPadAndOutInfo(pad, inDepth, inHeight, inWidth, strideDepth, strideHeight, strideWidth, effectiveFilterDepth, effectiveFilterHeight, effectiveFilterWidth, roundingMode), padInfo = _d.padInfo, outDepth = _d.outDepth, outHeight = _d.outHeight, outWidth = _d.outWidth;
- var outChannels = depthwise ? filterChannels * inChannels : filterChannels;
- var outShape;
- if (dataFormat === 'channelsFirst') {
- outShape = [batchSize, outChannels, outDepth, outHeight, outWidth];
- }
- else if (dataFormat === 'channelsLast') {
- outShape = [batchSize, outDepth, outHeight, outWidth, outChannels];
- }
- return {
- batchSize: batchSize,
- dataFormat: dataFormat,
- inDepth: inDepth,
- inHeight: inHeight,
- inWidth: inWidth,
- inChannels: inChannels,
- outDepth: outDepth,
- outHeight: outHeight,
- outWidth: outWidth,
- outChannels: outChannels,
- padInfo: padInfo,
- strideDepth: strideDepth,
- strideHeight: strideHeight,
- strideWidth: strideWidth,
- filterDepth: filterDepth,
- filterHeight: filterHeight,
- filterWidth: filterWidth,
- effectiveFilterDepth: effectiveFilterDepth,
- effectiveFilterHeight: effectiveFilterHeight,
- effectiveFilterWidth: effectiveFilterWidth,
- dilationDepth: dilationDepth,
- dilationHeight: dilationHeight,
- dilationWidth: dilationWidth,
- inShape: inShape,
- outShape: outShape,
- filterShape: filterShape
- };
- }
- function computeOutputShape2D(inShape, fieldSize, stride, zeroPad, roundingMode) {
- if (zeroPad == null) {
- zeroPad = computeDefaultPad(inShape, fieldSize, stride);
- }
- var inputRows = inShape[0];
- var inputCols = inShape[1];
- var outputRows = conditionalRound((inputRows - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
- assert(isInt(outputRows), function () { return "The output # of rows (" + outputRows + ") must be an integer. " +
- "Change the stride and/or zero pad parameters"; });
- var outputCols = conditionalRound((inputCols - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
- assert(isInt(outputCols), function () { return "The output # of columns (" + outputCols + ") must be an integer. " +
- "Change the stride and/or zero pad parameters"; });
- return [outputRows, outputCols];
- }
- function computeOutputShape4D(inShape, fieldSize, outChannels, stride, zeroPad, roundingMode) {
- if (zeroPad == null) {
- zeroPad = computeDefaultPad(inShape, fieldSize, stride);
- }
- var inputDepth = inShape[0];
- var inputRows = inShape[1];
- var inputCols = inShape[2];
- var outputDepths = conditionalRound((inputDepth - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
- assert(isInt(outputDepths), function () { return "The output # of depths (" + outputDepths + ") must be an integer. " +
- "Change the stride and/or zero pad parameters"; });
- var outputRows = conditionalRound((inputRows - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
- assert(isInt(outputRows), function () { return "The output # of rows (" + outputRows + ") must be an integer. " +
- "Change the stride and/or zero pad parameters"; });
- var outputCols = conditionalRound((inputCols - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
- assert(isInt(outputCols), function () { return "The output # of columns (" + outputCols + ") must be an integer. " +
- "Change the stride and/or zero pad parameters"; });
- return [outputDepths, outputRows, outputCols, outChannels];
- }
- function computeDefaultPad(inputShape, fieldSize, stride, dilation) {
- if (dilation === void 0) { dilation = 1; }
- var effectiveFieldSize = getEffectiveFilterSize(fieldSize, dilation);
- return Math.floor((inputShape[0] * (stride - 1) - stride + effectiveFieldSize) / 2);
- }
- function parseTupleParam(param) {
- if (typeof param === 'number') {
- return [param, param, param];
- }
- if (param.length === 2) {
- return [param[0], param[1], 1];
- }
- return param;
- }
- function parse3TupleParam(param) {
- return typeof param === 'number' ? [param, param, param] : param;
- }
- /* See https://www.tensorflow.org/api_docs/python/tf/nn/atrous_conv2d
- * Atrous convolution is equivalent to standard convolution with upsampled
- * filters with effective_filter_height =
- * filter_height + (filter_height - 1) * (dilation - 1)
- * and effective_filter_width =
- * filter_width + (filter_width - 1) * (dilation - 1),
- * produced by inserting dilation - 1 zeros along consecutive elements across
- * the filters' spatial dimensions.
- * When there is a dilation, this converts a filter dimension to the
- * effective filter dimension, so it can be used in a standard convolution.
- */
- function getEffectiveFilterSize(filterSize, dilation) {
- if (dilation <= 1) {
- return filterSize;
- }
- return filterSize + (filterSize - 1) * (dilation - 1);
- }
- function getPadAndOutInfo(pad, inHeight, inWidth, strideHeight, strideWidth, filterHeight, filterWidth, roundingMode) {
- var padInfo;
- var outHeight;
- var outWidth;
- if (typeof pad === 'number') {
- var padType = (pad === 0) ? 'VALID' : 'NUMBER';
- padInfo = { top: pad, bottom: pad, left: pad, right: pad, type: padType };
- var outShape = computeOutputShape2D([inHeight, inWidth], filterHeight, strideHeight, pad, roundingMode);
- outHeight = outShape[0];
- outWidth = outShape[1];
- }
- else if (pad === 'same') {
- outHeight = Math.ceil(inHeight / strideHeight);
- outWidth = Math.ceil(inWidth / strideWidth);
- var padAlongHeight = Math.max(0, (outHeight - 1) * strideHeight + filterHeight - inHeight);
- var padAlongWidth = Math.max(0, (outWidth - 1) * strideWidth + filterWidth - inWidth);
- var top_1 = Math.floor(padAlongHeight / 2);
- var bottom = padAlongHeight - top_1;
- var left = Math.floor(padAlongWidth / 2);
- var right = padAlongWidth - left;
- padInfo = { top: top_1, bottom: bottom, left: left, right: right, type: 'SAME' };
- }
- else if (pad === 'valid') {
- padInfo = { top: 0, bottom: 0, left: 0, right: 0, type: 'VALID' };
- outHeight = Math.ceil((inHeight - filterHeight + 1) / strideHeight);
- outWidth = Math.ceil((inWidth - filterWidth + 1) / strideWidth);
- }
- else {
- throw Error("Unknown padding parameter: " + pad);
- }
- return { padInfo: padInfo, outHeight: outHeight, outWidth: outWidth };
- }
- function get3DPadAndOutInfo(pad, inDepth, inHeight, inWidth, strideDepth, strideHeight, strideWidth, filterDepth, filterHeight, filterWidth, roundingMode) {
- var padInfo;
- var outDepth;
- var outHeight;
- var outWidth;
- if (typeof pad === 'number') {
- var padType = (pad === 0) ? 'VALID' : 'NUMBER';
- padInfo = {
- top: pad,
- bottom: pad,
- left: pad,
- right: pad,
- front: pad,
- back: pad,
- type: padType
- };
- var outShape = computeOutputShape4D([inDepth, inHeight, inWidth, 1], filterDepth, 1, strideDepth, pad, roundingMode);
- outDepth = outShape[0];
- outHeight = outShape[1];
- outWidth = outShape[2];
- }
- else if (pad === 'same') {
- outDepth = Math.ceil(inDepth / strideDepth);
- outHeight = Math.ceil(inHeight / strideHeight);
- outWidth = Math.ceil(inWidth / strideWidth);
- var padAlongDepth = (outDepth - 1) * strideDepth + filterDepth - inDepth;
- var padAlongHeight = (outHeight - 1) * strideHeight + filterHeight - inHeight;
- var padAlongWidth = (outWidth - 1) * strideWidth + filterWidth - inWidth;
- var front = Math.floor(padAlongDepth / 2);
- var back = padAlongDepth - front;
- var top_2 = Math.floor(padAlongHeight / 2);
- var bottom = padAlongHeight - top_2;
- var left = Math.floor(padAlongWidth / 2);
- var right = padAlongWidth - left;
- padInfo = { top: top_2, bottom: bottom, left: left, right: right, front: front, back: back, type: 'SAME' };
- }
- else if (pad === 'valid') {
- padInfo = {
- top: 0,
- bottom: 0,
- left: 0,
- right: 0,
- front: 0,
- back: 0,
- type: 'VALID'
- };
- outDepth = Math.ceil((inDepth - filterDepth + 1) / strideDepth);
- outHeight = Math.ceil((inHeight - filterHeight + 1) / strideHeight);
- outWidth = Math.ceil((inWidth - filterWidth + 1) / strideWidth);
- }
- else {
- throw Error("Unknown padding parameter: " + pad);
- }
- return { padInfo: padInfo, outDepth: outDepth, outHeight: outHeight, outWidth: outWidth };
- }
- /**
- * Rounds a value depending on the rounding mode
- * @param value
- * @param roundingMode
- */
- function conditionalRound(value, roundingMode) {
- if (!roundingMode) {
- return value;
- }
- switch (roundingMode) {
- case 'round':
- // used for Caffe Conv
- return Math.round(value);
- case 'ceil':
- // used for Caffe Pool
- return Math.ceil(value);
- case 'floor':
- return Math.floor(value);
- default:
- throw new Error("Unknown roundingMode " + roundingMode);
- }
- }
- function tupleValuesAreOne(param) {
- var _a = parseTupleParam(param), dimA = _a[0], dimB = _a[1], dimC = _a[2];
- return dimA === 1 && dimB === 1 && dimC === 1;
- }
- function eitherStridesOrDilationsAreOne(strides, dilations) {
- return tupleValuesAreOne(strides) || tupleValuesAreOne(dilations);
- }
- /**
- * Convert Conv2D dataFormat from 'NHWC'|'NCHW' to
- * 'channelsLast'|'channelsFirst'
- * @param dataFormat in 'NHWC'|'NCHW' mode
- * @return dataFormat in 'channelsLast'|'channelsFirst' mode
- * @throws unknown dataFormat
- */
- function convertConv2DDataFormat(dataFormat) {
- if (dataFormat === 'NHWC') {
- return 'channelsLast';
- }
- else if (dataFormat === 'NCHW') {
- return 'channelsFirst';
- }
- else {
- throw new Error("Unknown dataFormat " + dataFormat);
- }
- }
-
- /**
- * @license
- * Copyright 2018 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function castTensor(x, dtype, backend) {
- if (dtype === 'complex64') {
- if (x.dtype === 'complex64') {
- return x.clone();
- }
- var zerosTensor = zeros(x.shape);
- var floatX = x.toFloat();
- var result = backend.complex(floatX, zerosTensor);
- zerosTensor.dispose();
- floatX.dispose();
- return result;
- }
- if (!hasEncodingLoss(x.dtype, dtype)) {
- // We don't change the underlying data, since we cast to higher
- // precision.
- return ENGINE.makeTensorFromDataId(x.dataId, x.shape, dtype);
- }
- if (x.dtype === 'complex64') {
- var real = backend.real(x);
- var result = real.cast(dtype);
- real.dispose();
- return result;
- }
- if (dtype === 'int32') {
- return backend.int(x);
- }
- else if (dtype === 'bool') {
- var zero = scalar(0, x.dtype);
- var result = backend.notEqual(x, zero);
- zero.dispose();
- return result;
- }
- else {
- throw new Error("Error in Cast: failed to cast " + x.dtype + " to " + dtype);
- }
- }
- function reshapeTensor(x, shape) {
- return ENGINE.makeTensorFromDataId(x.dataId, shape, x.dtype);
- }
- function linspaceImpl(start, stop, num) {
- var step = (stop - start) / (num - 1);
- var values = makeZerosTypedArray(num, 'float32');
- values[0] = start;
- for (var i = 1; i < values.length; i++) {
- values[i] = values[i - 1] + step;
- }
- return tensor1d(values, 'float32');
- }
-
- var backend_util = /*#__PURE__*/Object.freeze({
- castTensor: castTensor,
- reshapeTensor: reshapeTensor,
- linspaceImpl: linspaceImpl,
- upcastType: upcastType,
- axesAreInnerMostDims: axesAreInnerMostDims,
- combineLocations: combineLocations,
- computeOutAndReduceShapes: computeOutAndReduceShapes,
- expandShapeToKeepDim: expandShapeToKeepDim,
- assertAxesAreInnerMostDims: assertAxesAreInnerMostDims,
- getAxesPermutation: getAxesPermutation,
- getUndoAxesPermutation: getUndoAxesPermutation,
- getInnerMostAxes: getInnerMostAxes,
- getBroadcastDims: getBroadcastDims,
- getReductionAxes: getReductionAxes,
- assertAndGetBroadcastShape: assertAndGetBroadcastShape,
- assertParamsConsistent: assertParamsConsistent,
- computeOutShape: computeOutShape,
- computePool2DInfo: computePool2DInfo,
- computePool3DInfo: computePool3DInfo,
- computeConv2DInfo: computeConv2DInfo,
- computeConv3DInfo: computeConv3DInfo,
- computeDefaultPad: computeDefaultPad,
- tupleValuesAreOne: tupleValuesAreOne,
- eitherStridesOrDilationsAreOne: eitherStridesOrDilationsAreOne,
- convertConv2DDataFormat: convertConv2DDataFormat,
- PARALLELIZE_THRESHOLD: PARALLELIZE_THRESHOLD,
- computeOptimalWindowSize: computeOptimalWindowSize
- });
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Merges real and imaginary Float32Arrays into a single complex Float32Array.
- *
- * The memory layout is interleaved as follows:
- * real: [r0, r1, r2]
- * imag: [i0, i1, i2]
- * complex: [r0, i0, r1, i1, r2, i2]
- *
- * This is the inverse of splitRealAndImagArrays.
- *
- * @param real The real values of the complex tensor values.
- * @param imag The imag values of the complex tensor values.
- * @returns A complex tensor as a Float32Array with merged values.
- */
- function mergeRealAndImagArrays(real, imag) {
- if (real.length !== imag.length) {
- throw new Error("Cannot merge real and imag arrays of different lengths. real:" +
- (real.length + ", imag: " + imag.length + "."));
- }
- var result = new Float32Array(real.length * 2);
- for (var i = 0; i < result.length; i += 2) {
- result[i] = real[i / 2];
- result[i + 1] = imag[i / 2];
- }
- return result;
- }
- /**
- * Splits a complex Float32Array into real and imag parts.
- *
- * The memory layout is interleaved as follows:
- * complex: [r0, i0, r1, i1, r2, i2]
- * real: [r0, r1, r2]
- * imag: [i0, i1, i2]
- *
- * This is the inverse of mergeRealAndImagArrays.
- *
- * @param complex The complex tensor values.
- * @returns An object with real and imag Float32Array components of the complex
- * tensor.
- */
- function splitRealAndImagArrays(complex) {
- var real = new Float32Array(complex.length / 2);
- var imag = new Float32Array(complex.length / 2);
- for (var i = 0; i < complex.length; i += 2) {
- real[i / 2] = complex[i];
- imag[i / 2] = complex[i + 1];
- }
- return { real: real, imag: imag };
- }
- /**
- * Extracts even indexed complex values in the given array.
- * @param complex The complex tensor values
- */
- function complexWithEvenIndex(complex) {
- var len = Math.ceil(complex.length / 4);
- var real = new Float32Array(len);
- var imag = new Float32Array(len);
- for (var i = 0; i < complex.length; i += 4) {
- real[Math.floor(i / 4)] = complex[i];
- imag[Math.floor(i / 4)] = complex[i + 1];
- }
- return { real: real, imag: imag };
- }
- /**
- * Extracts odd indexed comple values in the given array.
- * @param complex The complex tensor values
- */
- function complexWithOddIndex(complex) {
- var len = Math.floor(complex.length / 4);
- var real = new Float32Array(len);
- var imag = new Float32Array(len);
- for (var i = 2; i < complex.length; i += 4) {
- real[Math.floor(i / 4)] = complex[i];
- imag[Math.floor(i / 4)] = complex[i + 1];
- }
- return { real: real, imag: imag };
- }
- /**
- * Get the map representing a complex value in the given array.
- * @param complex The complex tensor values.
- * @param index An index of the target complex value.
- */
- function getComplexWithIndex(complex, index) {
- var real = complex[index * 2];
- var imag = complex[index * 2 + 1];
- return { real: real, imag: imag };
- }
- /**
- * Insert a given complex value into the TypedArray.
- * @param data The array in which the complex value is inserted.
- * @param c The complex value to be inserted.
- * @param index An index of the target complex value.
- */
- function assignToTypedArray(data, real, imag, index) {
- data[index * 2] = real;
- data[index * 2 + 1] = imag;
- }
- /**
- * Make the list of exponent terms used by FFT.
- */
- function exponents(n, inverse) {
- var real = new Float32Array(n / 2);
- var imag = new Float32Array(n / 2);
- for (var i = 0; i < Math.ceil(n / 2); i++) {
- var x = (inverse ? 2 : -2) * Math.PI * (i / n);
- real[i] = Math.cos(x);
- imag[i] = Math.sin(x);
- }
- return { real: real, imag: imag };
- }
- /**
- * Make the exponent term used by FFT.
- */
- function exponent(k, n, inverse) {
- var x = (inverse ? 2 : -2) * Math.PI * (k / n);
- var real = Math.cos(x);
- var imag = Math.sin(x);
- return { real: real, imag: imag };
- }
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Inserts a value into a sorted array. This method allows duplicate, meaning it
- * allows inserting duplicate value, in which case, the element will be inserted
- * at the lowest index of the value.
- * @param arr The array to modify.
- * @param element The element to insert.
- * @param comparator Optional. If no comparator is specified, elements are
- * compared using array_util.defaultComparator, which is suitable for Strings
- * and Numbers in ascending arrays. If the array contains multiple instances of
- * the target value, the left-most instance will be returned. To provide a
- * comparator, it should take 2 arguments to compare and return a negative,
- * zero, or a positive number.
- */
- function binaryInsert(arr, element, comparator) {
- var index = binarySearch(arr, element, comparator);
- var insertionPoint = index < 0 ? -(index + 1) : index;
- arr.splice(insertionPoint, 0, element);
- }
- /**
- * Searches the array for the target using binary search, returns the index
- * of the found element, or position to insert if element not found. If no
- * comparator is specified, elements are compared using array_
- * util.defaultComparator, which is suitable for Strings and Numbers in
- * ascending arrays. If the array contains multiple instances of the target
- * value, the left-most instance will be returned.
- * @param arr The array to be searched in.
- * @param target The target to be searched for.
- * @param comparator Should take 2 arguments to compare and return a negative,
- * zero, or a positive number.
- * @return Lowest index of the target value if found, otherwise the insertion
- * point where the target should be inserted, in the form of
- * (-insertionPoint - 1).
- */
- function binarySearch(arr, target, comparator) {
- return binarySearch_(arr, target, comparator || defaultComparator);
- }
- /**
- * Compares its two arguments for order.
- * @param a The first element to be compared.
- * @param b The second element to be compared.
- * @return A negative number, zero, or a positive number as the first
- * argument is less than, equal to, or greater than the second.
- */
- function defaultComparator(a, b) {
- return a > b ? 1 : a < b ? -1 : 0;
- }
- function binarySearch_(arr, target, comparator) {
- var left = 0;
- var right = arr.length;
- var middle = 0;
- var found = false;
- while (left < right) {
- middle = left + ((right - left) >>> 1);
- var compareResult = comparator(target, arr[middle]);
- if (compareResult > 0) {
- left = middle + 1;
- }
- else {
- right = middle;
- // If compareResult is 0, the value is found. We record it is found,
- // and then keep looking because there may be duplicate.
- found = !compareResult;
- }
- }
- return found ? left : -left - 1;
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function nonMaxSuppressionV3(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold) {
- var dummySoftNmsSigma = 0.0;
- return nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, dummySoftNmsSigma)
- .selectedIndices;
- }
- function nonMaxSuppressionV5(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma) {
- // For NonMaxSuppressionV5Op, we always return a second output holding
- // corresponding scores.
- var returnScoresTensor = true;
- var result = nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma, returnScoresTensor);
- result.numValidOutputs.dispose();
- return {
- selectedIndices: result.selectedIndices,
- selectedScores: result.selectedScores
- };
- }
- function nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma, returnScoresTensor, padToMaxOutputSize) {
- if (returnScoresTensor === void 0) { returnScoresTensor = false; }
- if (padToMaxOutputSize === void 0) { padToMaxOutputSize = false; }
- // The list is sorted in ascending order, so that we can always pop the
- // candidate with the largest score in O(1) time.
- var candidates = Array.from(scores)
- .map(function (score, boxIndex) { return ({ score: score, boxIndex: boxIndex, suppressBeginIndex: 0 }); })
- .filter(function (c) { return c.score > scoreThreshold; })
- .sort(ascendingComparator);
- // If softNmsSigma is 0, the outcome of this algorithm is exactly same as
- // before.
- var scale = softNmsSigma > 0 ? (-0.5 / softNmsSigma) : 0.0;
- var selectedIndices = [];
- var selectedScores = [];
- while (selectedIndices.length < maxOutputSize && candidates.length > 0) {
- var candidate = candidates.pop();
- var originalScore = candidate.score, boxIndex = candidate.boxIndex, suppressBeginIndex = candidate.suppressBeginIndex;
- if (originalScore < scoreThreshold) {
- break;
- }
- // Overlapping boxes are likely to have similar scores, therefore we
- // iterate through the previously selected boxes backwards in order to
- // see if candidate's score should be suppressed. We use
- // suppressBeginIndex to track and ensure a candidate can be suppressed
- // by a selected box no more than once. Also, if the overlap exceeds
- // iouThreshold, we simply ignore the candidate.
- var ignoreCandidate = false;
- for (var j = selectedIndices.length - 1; j >= suppressBeginIndex; --j) {
- var iou = intersectionOverUnion(boxes, boxIndex, selectedIndices[j]);
- if (iou >= iouThreshold) {
- ignoreCandidate = true;
- break;
- }
- candidate.score =
- candidate.score * suppressWeight(iouThreshold, scale, iou);
- if (candidate.score <= scoreThreshold) {
- break;
- }
- }
- // At this point, if `candidate.score` has not dropped below
- // `scoreThreshold`, then we know that we went through all of the
- // previous selections and can safely update `suppressBeginIndex` to the
- // end of the selected array. Then we can re-insert the candidate with
- // the updated score and suppressBeginIndex back in the candidate list.
- // If on the other hand, `candidate.score` has dropped below the score
- // threshold, we will not add it back to the candidates list.
- candidate.suppressBeginIndex = selectedIndices.length;
- if (!ignoreCandidate) {
- // Candidate has passed all the tests, and is not suppressed, so
- // select the candidate.
- if (candidate.score === originalScore) {
- selectedIndices.push(boxIndex);
- selectedScores.push(candidate.score);
- }
- else if (candidate.score > scoreThreshold) {
- // Candidate's score is suppressed but is still high enough to be
- // considered, so add back to the candidates list.
- binaryInsert(candidates, candidate, ascendingComparator);
- }
- }
- }
- // NonMaxSuppressionV4 feature: padding output to maxOutputSize.
- var numValidOutputs = selectedIndices.length;
- if (padToMaxOutputSize) {
- selectedIndices.fill(0, numValidOutputs);
- selectedScores.fill(0.0, numValidOutputs);
- }
- return {
- selectedIndices: tensor1d(selectedIndices, 'int32'),
- selectedScores: tensor1d(selectedScores, 'float32'),
- numValidOutputs: scalar(numValidOutputs, 'int32')
- };
- }
- function intersectionOverUnion(boxes, i, j) {
- var iCoord = boxes.subarray(i * 4, i * 4 + 4);
- var jCoord = boxes.subarray(j * 4, j * 4 + 4);
- var yminI = Math.min(iCoord[0], iCoord[2]);
- var xminI = Math.min(iCoord[1], iCoord[3]);
- var ymaxI = Math.max(iCoord[0], iCoord[2]);
- var xmaxI = Math.max(iCoord[1], iCoord[3]);
- var yminJ = Math.min(jCoord[0], jCoord[2]);
- var xminJ = Math.min(jCoord[1], jCoord[3]);
- var ymaxJ = Math.max(jCoord[0], jCoord[2]);
- var xmaxJ = Math.max(jCoord[1], jCoord[3]);
- var areaI = (ymaxI - yminI) * (xmaxI - xminI);
- var areaJ = (ymaxJ - yminJ) * (xmaxJ - xminJ);
- if (areaI <= 0 || areaJ <= 0) {
- return 0.0;
- }
- var intersectionYmin = Math.max(yminI, yminJ);
- var intersectionXmin = Math.max(xminI, xminJ);
- var intersectionYmax = Math.min(ymaxI, ymaxJ);
- var intersectionXmax = Math.min(xmaxI, xmaxJ);
- var intersectionArea = Math.max(intersectionYmax - intersectionYmin, 0.0) *
- Math.max(intersectionXmax - intersectionXmin, 0.0);
- return intersectionArea / (areaI + areaJ - intersectionArea);
- }
- // A Gaussian penalty function, this method always returns values in [0, 1].
- // The weight is a function of similarity, the more overlap two boxes are, the
- // smaller the weight is, meaning highly overlapping boxe will be significantly
- // penalized. On the other hand, a non-overlapping box will not be penalized.
- function suppressWeight(iouThreshold, scale, iou) {
- var weight = Math.exp(scale * iou * iou);
- return iou <= iouThreshold ? weight : 0.0;
- }
- function ascendingComparator(c1, c2) {
- // For objects with same scores, we make the object with the larger index go
- // first. In an array that pops from the end, this means that the object with
- // the smaller index will be popped first. This ensures the same output as
- // the TensorFlow python version.
- return (c1.score - c2.score) ||
- ((c1.score === c2.score) && (c2.boxIndex - c1.boxIndex));
- }
-
- /**
- * @license
- * Copyright 2018 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /** Shared implementation of the split kernel across WebGL and CPU. */
- function split$1(x, sizeSplits, axis) {
- var begin = new Array(x.rank).fill(0);
- var size = x.shape.slice();
- return sizeSplits.map(function (s) {
- size[axis] = s;
- var slice = x.slice(begin, size);
- begin[axis] += s;
- return slice;
- });
- }
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function tile(xBuf, reps) {
- var newShape = new Array(xBuf.rank);
- for (var i = 0; i < newShape.length; i++) {
- newShape[i] = xBuf.shape[i] * reps[i];
- }
- var result = buffer(newShape, xBuf.dtype);
- for (var i = 0; i < result.values.length; ++i) {
- var newLoc = result.indexToLoc(i);
- var originalLoc = new Array(xBuf.rank);
- for (var j = 0; j < originalLoc.length; j++) {
- originalLoc[j] = newLoc[j] % xBuf.shape[j];
- }
- var originalIndex = xBuf.locToIndex(originalLoc);
- result.values[i] = xBuf.values[originalIndex];
- }
- return result.toTensor();
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function topkImpl(x, xShape, xDtype, k, sorted) {
- // Reshape into a 2d tensor [batch, lastDim] and compute topk along lastDim.
- var lastDim = xShape[xShape.length - 1];
- var _a = [x.length / lastDim, lastDim], batch = _a[0], size = _a[1];
- var allTopKVals = getTypedArrayFromDType(xDtype, batch * k);
- var allTopKIndices = getTypedArrayFromDType('int32', batch * k);
- for (var b = 0; b < batch; b++) {
- var offset = b * size;
- var vals = x.subarray(offset, offset + size);
- var valAndInd = [];
- for (var i = 0; i < vals.length; i++) {
- valAndInd.push({ value: vals[i], index: i });
- }
- valAndInd.sort(function (a, b) { return b.value - a.value; });
- var outOffset = b * k;
- var topKVals = allTopKVals.subarray(outOffset, outOffset + k);
- var topKIndices = allTopKIndices.subarray(outOffset, outOffset + k);
- for (var i = 0; i < k; i++) {
- topKVals[i] = valAndInd[i].value;
- topKIndices[i] = valAndInd[i].index;
- }
- }
- // Reshape back to the original input shape, except that the last
- // dimension is k.
- var outputShape = xShape.slice();
- outputShape[outputShape.length - 1] = k;
- return [
- tensor(allTopKVals, outputShape, xDtype),
- tensor(allTopKIndices, outputShape, 'int32')
- ];
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function whereImpl(condShape, condVals) {
- var indices = [];
- for (var i = 0; i < condVals.length; i++) {
- if (condVals[i]) {
- indices.push(i);
- }
- }
- var inBuffer = buffer(condShape, 'int32');
- var out = buffer([indices.length, condShape.length], 'int32');
- for (var i = 0; i < indices.length; i++) {
- var loc = inBuffer.indexToLoc(indices[i]);
- var offset = i * condShape.length;
- out.values.set(loc, offset);
- }
- return out.toTensor();
- }
-
- /**
- * @license
- * Copyright 2019 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var AddNProgram = /** @class */ (function () {
- function AddNProgram(outputShape, shapes) {
- this.outputShape = [];
- this.outputShape = outputShape;
- this.variableNames = shapes.map(function (_, i) { return "T" + i; });
- var snippets = [];
- // Get target elements from every input tensor.
- this.variableNames.forEach(function (variable) {
- snippets.push("float v" + variable + " = get" + variable + "AtOutCoords();");
- });
- // Calculate the sum of all elements.
- var operation = this.variableNames
- .map(function (variable) {
- return "v" + variable;
- })
- .join(' + ');
- this.userCode = "\n void main() {\n " + snippets.join('\n ') + "\n\n float result = " + operation + ";\n setOutput(result);\n }\n ";
- }
- return AddNProgram;
- }());
-
- /**
- * @license
- * Copyright 2019 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var AddNPackedProgram = /** @class */ (function () {
- function AddNPackedProgram(outputShape, shapes) {
- this.outputShape = [];
- this.packedInputs = true;
- this.packedOutput = true;
- this.outputShape = outputShape;
- this.variableNames = shapes.map(function (_, i) { return "T" + i; });
- var snippets = [];
- // Get target elements from every input tensor.
- this.variableNames.forEach(function (variable) {
- snippets.push("vec4 v" + variable + " = get" + variable + "AtOutCoords();");
- });
- // Calculate the sum of all elements.
- var operation = this.variableNames
- .map(function (variable) {
- return "v" + variable;
- })
- .join(' + ');
- this.userCode = "\n void main() {\n " + snippets.join('\n ') + "\n\n vec4 result = " + operation + ";\n setOutput(result);\n }\n ";
- }
- return AddNPackedProgram;
- }());
-
- /**
- * @license
- * Copyright 2017 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var ArgMinMaxProgram = /** @class */ (function () {
- function ArgMinMaxProgram(reduceInfo, op, firstPass) {
- this.variableNames = ['A'];
- var windowSize = reduceInfo.windowSize;
- var batchSize = reduceInfo.batchSize;
- var inSize = reduceInfo.inSize;
- var outSize = Math.ceil(inSize / windowSize);
- if (!firstPass) {
- this.variableNames.push('bestIndicesA');
- }
- this.outputShape = [batchSize, outSize];
- var compOp = (op === 'max') ? '>' : '<';
- var indexSnippet = firstPass ?
- 'inOffset + i;' :
- 'round(getBestIndicesA(batch, inOffset + i));';
- this.userCode = "\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n int outIdx = coords[1];\n int inOffset = outIdx * " + windowSize + ";\n\n int bestIndex = inOffset;\n float bestValue = getA(batch, bestIndex);\n\n for (int i = 0; i < " + windowSize + "; i++) {\n int inIdx = " + indexSnippet + ";\n float candidate = getA(batch, inIdx);\n if (candidate " + compOp + " bestValue) {\n bestValue = candidate;\n bestIndex = inIdx;\n }\n }\n setOutput(float(bestIndex));\n }\n ";
- }
- return ArgMinMaxProgram;
- }());
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function getVecChannels(name, rank) {
- return ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, rank).map(function (d) { return name + "." + d; });
- }
- function getChannels(name, rank) {
- if (rank === 1) {
- return [name];
- }
- return getVecChannels(name, rank);
- }
- function getSourceCoords(rank, dims) {
- if (rank === 1) {
- return 'rc';
- }
- var coords = '';
- for (var i = 0; i < rank; i++) {
- coords += dims[i];
- if (i < rank - 1) {
- coords += ',';
- }
- }
- return coords;
- }
-
- /**
- * @license
- * Copyright 2018 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function getGlslDifferences() {
- var version;
- var attribute;
- var varyingVs;
- var varyingFs;
- var texture2D;
- var output;
- var defineOutput;
- var defineSpecialNaN;
- var defineSpecialInf;
- var defineRound;
- if (env().getNumber('WEBGL_VERSION') === 2) {
- version = '#version 300 es';
- attribute = 'in';
- varyingVs = 'out';
- varyingFs = 'in';
- texture2D = 'texture';
- output = 'outputColor';
- defineOutput = 'out vec4 outputColor;';
- // Use custom isnan definition to work across differences between
- // implementations on various platforms. While this should happen in ANGLE
- // we still see differences between android and windows (on chrome) when
- // using isnan directly.
- defineSpecialNaN = "\n bool isnan_custom(float val) {\n return (val > 0.0 || val < 0.0) ? false : val != 0.0;\n }\n\n bvec4 isnan_custom(vec4 val) {\n return bvec4(isnan_custom(val.x),\n isnan_custom(val.y), isnan_custom(val.z), isnan_custom(val.w));\n }\n\n #define isnan(value) isnan_custom(value)\n ";
- // In webgl 2 we do not need to specify a custom isinf so there is no
- // need for a special INFINITY constant.
- defineSpecialInf = "";
- defineRound = "\n #define round(value) newRound(value)\n int newRound(float value) {\n return int(floor(value + 0.5));\n }\n\n ivec4 newRound(vec4 value) {\n return ivec4(floor(value + vec4(0.5)));\n }\n ";
- }
- else {
- version = '';
- attribute = 'attribute';
- varyingVs = 'varying';
- varyingFs = 'varying';
- texture2D = 'texture2D';
- output = 'gl_FragColor';
- defineOutput = '';
- // WebGL1 has no built in isnan so we define one here.
- defineSpecialNaN = "\n #define isnan(value) isnan_custom(value)\n bool isnan_custom(float val) {\n return (val > 0. || val < 1. || val == 0.) ? false : true;\n }\n bvec4 isnan_custom(vec4 val) {\n return bvec4(isnan(val.x), isnan(val.y), isnan(val.z), isnan(val.w));\n }\n ";
- defineSpecialInf = "\n uniform float INFINITY;\n\n bool isinf(float val) {\n return abs(val) == INFINITY;\n }\n bvec4 isinf(vec4 val) {\n return equal(abs(val), vec4(INFINITY));\n }\n ";
- defineRound = "\n int round(float value) {\n return int(floor(value + 0.5));\n }\n\n ivec4 round(vec4 value) {\n return ivec4(floor(value + vec4(0.5)));\n }\n ";
- }
- return {
- version: version,
- attribute: attribute,
- varyingVs: varyingVs,
- varyingFs: varyingFs,
- texture2D: texture2D,
- output: output,
- defineOutput: defineOutput,
- defineSpecialNaN: defineSpecialNaN,
- defineSpecialInf: defineSpecialInf,
- defineRound: defineRound
- };
- }
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /**
- * Produces GLSL code that derives logical coordinates from a flat
- * index. The code performs integer division with each stride and decrements
- * the index until the index equals the final dimension coordinate.
- */
- function getLogicalCoordinatesFromFlatIndex(coords, shape, index) {
- if (index === void 0) { index = 'index'; }
- var strides = computeStrides(shape);
- return strides
- .map(function (stride, i) {
- var line1 = "int " + coords[i] + " = " + index + " / " + stride;
- var line2 = i === strides.length - 1 ?
- "int " + coords[i + 1] + " = " + index + " - " + coords[i] + " * " + stride :
- "index -= " + coords[i] + " * " + stride;
- return line1 + "; " + line2 + ";";
- })
- .join('');
- }
- /**
- * Produces GLSL that computes the flat index from 3D coordinates.
- */
- function getFlatIndexFrom3D(shape) {
- var strides = computeStrides(shape).map(function (d) { return d.toString(); });
- return "\n int getFlatIndex(ivec3 coords) {\n return coords.x * " + strides[0] + " + coords.y * " + strides[1] + " + coords.z;\n }\n";
- }
- var ENCODE_FLOAT_SNIPPET = "\n const float FLOAT_MAX = 1.70141184e38;\n const float FLOAT_MIN = 1.17549435e-38;\n\n lowp vec4 encode_float(highp float v) {\n if (isnan(v)) {\n return vec4(255, 255, 255, 255);\n }\n\n highp float av = abs(v);\n\n if(av < FLOAT_MIN) {\n return vec4(0.0, 0.0, 0.0, 0.0);\n } else if(v > FLOAT_MAX) {\n return vec4(0.0, 0.0, 128.0, 127.0) / 255.0;\n } else if(v < -FLOAT_MAX) {\n return vec4(0.0, 0.0, 128.0, 255.0) / 255.0;\n }\n\n highp vec4 c = vec4(0,0,0,0);\n\n highp float e = floor(log2(av));\n highp float m = exp2(fract(log2(av))) - 1.0;\n\n c[2] = floor(128.0 * m);\n m -= c[2] / 128.0;\n c[1] = floor(32768.0 * m);\n m -= c[1] / 32768.0;\n c[0] = floor(8388608.0 * m);\n\n highp float ebias = e + 127.0;\n c[3] = floor(ebias / 2.0);\n ebias -= c[3] * 2.0;\n c[2] += floor(ebias) * 128.0;\n\n c[3] += 128.0 * step(0.0, -v);\n\n return c / 255.0;\n }\n";
-
- /**
- * @license
- * Copyright 2017 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function makeShader(inputsInfo, outputShape, userCode, usesPackedTextures) {
- var prefixSnippets = [];
- inputsInfo.forEach(function (x) {
- var size = sizeFromShape(x.shapeInfo.logicalShape);
- // Snippet when we decided to upload the values as uniform.
- if (x.shapeInfo.isUniform) {
- prefixSnippets.push("uniform float " + x.name + (size > 1 ? "[" + size + "]" : '') + ";");
- }
- else {
- prefixSnippets.push("uniform sampler2D " + x.name + ";");
- prefixSnippets.push("uniform int offset" + x.name + ";");
- }
- });
- var inputPrefixSnippet = prefixSnippets.join('\n');
- var inputSamplingSnippet = inputsInfo
- .map(function (x) { return getInputSamplingSnippet(x, outputShape, usesPackedTextures); })
- .join('\n');
- var outTexShape = outputShape.texShape;
- var glsl = getGlslDifferences();
- var floatTextureSampleSnippet = getFloatTextureSampleSnippet(glsl);
- var outputSamplingSnippet;
- var floatTextureSetOutputSnippet;
- var shaderPrefix = getShaderPrefix(glsl);
- if (outputShape.isPacked) {
- outputSamplingSnippet =
- getPackedOutputSamplingSnippet(outputShape.logicalShape, outTexShape);
- floatTextureSetOutputSnippet = getFloatTextureSetRGBASnippet(glsl);
- }
- else {
- outputSamplingSnippet =
- getOutputSamplingSnippet(outputShape.logicalShape, outTexShape);
- floatTextureSetOutputSnippet = getFloatTextureSetRSnippet(glsl);
- }
- if (usesPackedTextures) {
- shaderPrefix += SHADER_PACKED_PREFIX;
- }
- var source = [
- shaderPrefix, floatTextureSampleSnippet, floatTextureSetOutputSnippet,
- inputPrefixSnippet, outputSamplingSnippet, inputSamplingSnippet, userCode
- ].join('\n');
- return source;
- }
- function getSamplerFromInInfo(inInfo) {
- var shape = inInfo.shapeInfo.logicalShape;
- switch (shape.length) {
- case 0:
- return getSamplerScalar(inInfo);
- case 1:
- return getSampler1D(inInfo);
- case 2:
- return getSampler2D(inInfo);
- case 3:
- return getSampler3D(inInfo);
- case 4:
- return getSampler4D(inInfo);
- case 5:
- return getSampler5D(inInfo);
- case 6:
- return getSampler6D(inInfo);
- default:
- throw new Error(shape.length + "-D input sampling" +
- " is not yet supported");
- }
- }
- function getPackedSamplerFromInInfo(inInfo) {
- var shape = inInfo.shapeInfo.logicalShape;
- switch (shape.length) {
- case 0:
- return getPackedSamplerScalar(inInfo);
- case 1:
- return getPackedSampler1D(inInfo);
- case 2:
- return getPackedSampler2D(inInfo);
- case 3:
- return getPackedSampler3D(inInfo);
- default:
- return getPackedSamplerND(inInfo);
- }
- }
- function getInputSamplingSnippet(inInfo, outShapeInfo, usesPackedTextures) {
- if (usesPackedTextures === void 0) { usesPackedTextures = false; }
- var res = '';
- if (usesPackedTextures) {
- res += getPackedSamplerFromInInfo(inInfo);
- }
- else {
- res += getSamplerFromInInfo(inInfo);
- }
- var inShape = inInfo.shapeInfo.logicalShape;
- var outShape = outShapeInfo.logicalShape;
- if (inShape.length <= outShape.length) {
- if (usesPackedTextures) {
- res += getPackedSamplerAtOutputCoords(inInfo, outShapeInfo);
- }
- else {
- res += getSamplerAtOutputCoords(inInfo, outShapeInfo);
- }
- }
- return res;
- }
- function getPackedOutputSamplingSnippet(outShape, outTexShape) {
- switch (outShape.length) {
- case 0:
- return getOutputScalarCoords();
- case 1:
- return getOutputPacked1DCoords(outShape, outTexShape);
- case 2:
- return getOutputPacked2DCoords(outShape, outTexShape);
- case 3:
- return getOutputPacked3DCoords(outShape, outTexShape);
- default:
- return getOutputPackedNDCoords(outShape, outTexShape);
- }
- }
- function getOutputSamplingSnippet(outShape, outTexShape) {
- switch (outShape.length) {
- case 0:
- return getOutputScalarCoords();
- case 1:
- return getOutput1DCoords(outShape, outTexShape);
- case 2:
- return getOutput2DCoords(outShape, outTexShape);
- case 3:
- return getOutput3DCoords(outShape, outTexShape);
- case 4:
- return getOutput4DCoords(outShape, outTexShape);
- case 5:
- return getOutput5DCoords(outShape, outTexShape);
- case 6:
- return getOutput6DCoords(outShape, outTexShape);
- default:
- throw new Error(outShape.length + "-D output sampling is not yet supported");
- }
- }
- function getFloatTextureSampleSnippet(glsl) {
- return "\n float sampleTexture(sampler2D textureSampler, vec2 uv) {\n return " + glsl.texture2D + "(textureSampler, uv).r;\n }\n ";
- }
- function getFloatTextureSetRSnippet(glsl) {
- return "\n void setOutput(float val) {\n " + glsl.output + " = vec4(val, 0, 0, 0);\n }\n ";
- }
- function getFloatTextureSetRGBASnippet(glsl) {
- return "\n void setOutput(vec4 val) {\n " + glsl.output + " = val;\n }\n ";
- }
- function getShaderPrefix(glsl) {
- var SHADER_PREFIX = glsl.version + "\n precision highp float;\n precision highp int;\n precision highp sampler2D;\n " + glsl.varyingFs + " vec2 resultUV;\n " + glsl.defineOutput + "\n const vec2 halfCR = vec2(0.5, 0.5);\n\n struct ivec5\n {\n int x;\n int y;\n int z;\n int w;\n int u;\n };\n\n struct ivec6\n {\n int x;\n int y;\n int z;\n int w;\n int u;\n int v;\n };\n\n uniform float NAN;\n " + glsl.defineSpecialNaN + "\n " + glsl.defineSpecialInf + "\n " + glsl.defineRound + "\n\n int imod(int x, int y) {\n return x - y * (x / y);\n }\n\n int idiv(int a, int b, float sign) {\n int res = a / b;\n int mod = imod(a, b);\n if (sign < 0. && mod != 0) {\n res -= 1;\n }\n return res;\n }\n\n //Based on the work of Dave Hoskins\n //https://www.shadertoy.com/view/4djSRW\n #define HASHSCALE1 443.8975\n float random(float seed){\n vec2 p = resultUV * seed;\n vec3 p3 = fract(vec3(p.xyx) * HASHSCALE1);\n p3 += dot(p3, p3.yzx + 19.19);\n return fract((p3.x + p3.y) * p3.z);\n }\n\n " + SAMPLE_1D_SNIPPET + "\n " + SAMPLE_2D_SNIPPET + "\n " + SAMPLE_3D_SNIPPET + "\n ";
- return SHADER_PREFIX;
- }
- var SAMPLE_1D_SNIPPET = "\nvec2 uvFromFlat(int texNumR, int texNumC, int index) {\n int texR = index / texNumC;\n int texC = index - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\nvec2 packedUVfrom1D(int texNumR, int texNumC, int index) {\n int texelIndex = index / 2;\n int texR = texelIndex / texNumC;\n int texC = texelIndex - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\n";
- var SAMPLE_2D_SNIPPET = "\nvec2 packedUVfrom2D(int texelsInLogicalRow, int texNumR,\n int texNumC, int row, int col) {\n int texelIndex = (row / 2) * texelsInLogicalRow + (col / 2);\n int texR = texelIndex / texNumC;\n int texC = texelIndex - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\n";
- var SAMPLE_3D_SNIPPET = "\nvec2 packedUVfrom3D(int texNumR, int texNumC,\n int texelsInBatch, int texelsInLogicalRow, int b,\n int row, int col) {\n int index = b * texelsInBatch + (row / 2) * texelsInLogicalRow + (col / 2);\n int texR = index / texNumC;\n int texC = index - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\n";
- var SHADER_PACKED_PREFIX = "\n float getChannel(vec4 frag, vec2 innerDims) {\n vec2 modCoord = mod(innerDims, 2.);\n return modCoord.x == 0. ?\n (modCoord.y == 0. ? frag.r : frag.g) :\n (modCoord.y == 0. ? frag.b : frag.a);\n }\n float getChannel(vec4 frag, int dim) {\n float modCoord = mod(float(dim), 2.);\n return modCoord == 0. ? frag.r : frag.g;\n }\n";
- function getOutputScalarCoords() {
- return "\n int getOutputCoords() {\n return 0;\n }\n ";
- }
- function getOutputPacked1DCoords(shape, texShape) {
- var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
- if (packedTexShape[0] === 1) {
- return "\n int getOutputCoords() {\n return 2 * int(resultUV.x * " + packedTexShape[1] + ".0);\n }\n ";
- }
- if (packedTexShape[1] === 1) {
- return "\n int getOutputCoords() {\n return 2 * int(resultUV.y * " + packedTexShape[0] + ".0);\n }\n ";
- }
- return "\n int getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + packedTexShape[0] + ", " + packedTexShape[1] + "));\n return 2 * (resTexRC.x * " + packedTexShape[1] + " + resTexRC.y);\n }\n ";
- }
- function getOutput1DCoords(shape, texShape) {
- if (texShape[0] === 1) {
- return "\n int getOutputCoords() {\n return int(resultUV.x * " + texShape[1] + ".0);\n }\n ";
- }
- if (texShape[1] === 1) {
- return "\n int getOutputCoords() {\n return int(resultUV.y * " + texShape[0] + ".0);\n }\n ";
- }
- return "\n int getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n return resTexRC.x * " + texShape[1] + " + resTexRC.y;\n }\n ";
- }
- function getOutputPacked3DCoords(shape, texShape) {
- var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
- var texelsInLogicalRow = Math.ceil(shape[2] / 2);
- var texelsInBatch = texelsInLogicalRow * Math.ceil(shape[1] / 2);
- return "\n ivec3 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + packedTexShape[0] + ", " + packedTexShape[1] + "));\n int index = resTexRC.x * " + packedTexShape[1] + " + resTexRC.y;\n\n int b = index / " + texelsInBatch + ";\n index -= b * " + texelsInBatch + ";\n\n int r = 2 * (index / " + texelsInLogicalRow + ");\n int c = imod(index, " + texelsInLogicalRow + ") * 2;\n\n return ivec3(b, r, c);\n }\n ";
- }
- function getOutput3DCoords(shape, texShape) {
- var coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], shape);
- return "\n ivec3 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n " + coordsFromIndexSnippet + "\n return ivec3(r, c, d);\n }\n ";
- }
- function getOutputPackedNDCoords(shape, texShape) {
- var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
- var texelsInLogicalRow = Math.ceil(shape[shape.length - 1] / 2);
- var texelsInBatch = texelsInLogicalRow * Math.ceil(shape[shape.length - 2] / 2);
- var texelsInBatchN = texelsInBatch;
- var batches = "";
- var coords = 'b, r, c';
- for (var b = 2; b < shape.length - 1; b++) {
- texelsInBatchN *= shape[shape.length - b - 1];
- batches = "\n int b" + b + " = index / " + texelsInBatchN + ";\n index -= b" + b + " * " + texelsInBatchN + ";\n " + batches;
- coords = "b" + b + ", " + coords;
- }
- return "\n ivec" + shape.length + " getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + packedTexShape[0] + ", " + packedTexShape[1] + "));\n int index = resTexRC.x * " + packedTexShape[1] + " + resTexRC.y;\n\n " + batches + "\n\n int b = index / " + texelsInBatch + ";\n index -= b * " + texelsInBatch + ";\n\n int r = 2 * (index / " + texelsInLogicalRow + ");\n int c = imod(index, " + texelsInLogicalRow + ") * 2;\n\n return ivec" + shape.length + "(" + coords + ");\n }\n ";
- }
- function getOutput4DCoords(shape, texShape) {
- var coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd', 'd2'], shape);
- return "\n ivec4 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n " + coordsFromIndexSnippet + "\n return ivec4(r, c, d, d2);\n }\n ";
- }
- function getOutput5DCoords(shape, texShape) {
- var coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd', 'd2', 'd3'], shape);
- return "\n ivec5 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx * vec2(" + texShape[0] + ",\n " + texShape[1] + "));\n\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n\n " + coordsFromIndexSnippet + "\n\n ivec5 outShape = ivec5(r, c, d, d2, d3);\n return outShape;\n }\n ";
- }
- function getOutput6DCoords(shape, texShape) {
- var coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd', 'd2', 'd3', 'd4'], shape);
- return "\n ivec6 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n\n " + coordsFromIndexSnippet + "\n\n ivec6 result = ivec6(r, c, d, d2, d3, d4);\n return result;\n }\n ";
- }
- function getOutputPacked2DCoords(shape, texShape) {
- var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
- if (arraysEqual(shape, texShape)) {
- return "\n ivec2 getOutputCoords() {\n return 2 * ivec2(resultUV.yx * vec2(" + packedTexShape[0] + ", " + packedTexShape[1] + "));\n }\n ";
- }
- // texels needed to accommodate a logical row
- var texelsInLogicalRow = Math.ceil(shape[1] / 2);
- /**
- * getOutputCoords
- *
- * resTexRC: The rows and columns of the texels. If you move over one
- * texel to the right in the packed texture, you are moving over one column
- * (not two).
- *
- * index: The texel index
- */
- return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + packedTexShape[0] + ", " + packedTexShape[1] + "));\n\n int index = resTexRC.x * " + packedTexShape[1] + " + resTexRC.y;\n int r = 2 * (index / " + texelsInLogicalRow + ");\n int c = imod(index, " + texelsInLogicalRow + ") * 2;\n\n return ivec2(r, c);\n }\n ";
- }
- function getOutput2DCoords(shape, texShape) {
- if (arraysEqual(shape, texShape)) {
- return "\n ivec2 getOutputCoords() {\n return ivec2(resultUV.yx * vec2(" + texShape[0] + ", " + texShape[1] + "));\n }\n ";
- }
- if (shape[1] === 1) {
- return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n return ivec2(index, 0);\n }\n ";
- }
- if (shape[0] === 1) {
- return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n return ivec2(0, index);\n }\n ";
- }
- return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n int r = index / " + shape[1] + ";\n int c = index - r * " + shape[1] + ";\n return ivec2(r, c);\n }\n ";
- }
- function getFlatOffsetUniformName(texName) {
- return "offset" + texName;
- }
- function getPackedSamplerScalar(inputInfo) {
- var texName = inputInfo.name;
- var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
- var glsl = getGlslDifferences();
- return "\n vec4 " + funcName + "() {\n return " + glsl.texture2D + "(" + texName + ", halfCR);\n }\n ";
- }
- function getSamplerScalar(inputInfo) {
- var texName = inputInfo.name;
- var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
- if (inputInfo.shapeInfo.isUniform) {
- return "float " + funcName + "() {return " + texName + ";}";
- }
- var _a = inputInfo.shapeInfo.texShape, texNumR = _a[0], texNumC = _a[1];
- if (texNumR === 1 && texNumC === 1) {
- return "\n float " + funcName + "() {\n return sampleTexture(" + texName + ", halfCR);\n }\n ";
- }
- var _b = inputInfo.shapeInfo.texShape, tNumR = _b[0], tNumC = _b[1];
- var offset = getFlatOffsetUniformName(texName);
- return "\n float " + funcName + "() {\n vec2 uv = uvFromFlat(" + tNumR + ", " + tNumC + ", " + offset + ");\n return sampleTexture(" + texName + ", uv);\n }\n ";
- }
- function getPackedSampler1D(inputInfo) {
- var texName = inputInfo.name;
- var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
- var texShape = inputInfo.shapeInfo.texShape;
- var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
- var glsl = getGlslDifferences();
- return "\n vec4 " + funcName + "(int index) {\n vec2 uv = packedUVfrom1D(\n " + packedTexShape[0] + ", " + packedTexShape[1] + ", index);\n return " + glsl.texture2D + "(" + texName + ", uv);\n }\n ";
- }
- function getSampler1D(inputInfo) {
- var texName = inputInfo.name;
- var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
- if (inputInfo.shapeInfo.isUniform) {
- // Uniform arrays will be less than 65505 (no risk of float16 overflow).
- return "\n float " + funcName + "(int index) {\n " + getUniformSampler(inputInfo) + "\n }\n ";
- }
- var texShape = inputInfo.shapeInfo.texShape;
- var tNumR = texShape[0];
- var tNumC = texShape[1];
- if (tNumC === 1 && tNumR === 1) {
- return "\n float " + funcName + "(int index) {\n return sampleTexture(" + texName + ", halfCR);\n }\n ";
- }
- var offset = getFlatOffsetUniformName(texName);
- if (tNumC === 1) {
- return "\n float " + funcName + "(int index) {\n vec2 uv = vec2(0.5, (float(index + " + offset + ") + 0.5) / " + tNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
- }
- if (tNumR === 1) {
- return "\n float " + funcName + "(int index) {\n vec2 uv = vec2((float(index + " + offset + ") + 0.5) / " + tNumC + ".0, 0.5);\n return sampleTexture(" + texName + ", uv);\n }\n ";
- }
- return "\n float " + funcName + "(int index) {\n vec2 uv = uvFromFlat(" + tNumR + ", " + tNumC + ", index + " + offset + ");\n return sampleTexture(" + texName + ", uv);\n }\n ";
- }
- function getPackedSampler2D(inputInfo) {
- var shape = inputInfo.shapeInfo.logicalShape;
- var texName = inputInfo.name;
- var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
- var texShape = inputInfo.shapeInfo.texShape;
- var texNumR = texShape[0];
- var texNumC = texShape[1];
- var glsl = getGlslDifferences();
- if (texShape != null && arraysEqual(shape, texShape)) {
- return "\n vec4 " + funcName + "(int row, int col) {\n vec2 uv = (vec2(col, row) + halfCR) / vec2(" + texNumC + ".0, " + texNumR + ".0);\n\n return " + glsl.texture2D + "(" + texName + ", uv);\n }\n ";
- }
- var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
- var valuesPerRow = Math.ceil(shape[1] / 2);
- return "\n vec4 " + funcName + "(int row, int col) {\n vec2 uv = packedUVfrom2D(" + valuesPerRow + ", " + packedTexShape[0] + ", " + packedTexShape[1] + ", row, col);\n return " + glsl.texture2D + "(" + texName + ", uv);\n }\n ";
- }
- function getSampler2D(inputInfo) {
- var shape = inputInfo.shapeInfo.logicalShape;
- var texName = inputInfo.name;
- var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
- var texShape = inputInfo.shapeInfo.texShape;
- if (texShape != null && arraysEqual(shape, texShape)) {
- var texNumR_1 = texShape[0];
- var texNumC_1 = texShape[1];
- return "\n float " + funcName + "(int row, int col) {\n vec2 uv = (vec2(col, row) + halfCR) / vec2(" + texNumC_1 + ".0, " + texNumR_1 + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
- }
- var _a = squeezeShape(shape), newShape = _a.newShape, keptDims = _a.keptDims;
- var squeezedShape = newShape;
- if (squeezedShape.length < shape.length) {
- var newInputInfo = squeezeInputInfo(inputInfo, squeezedShape);
- var params = ['row', 'col'];
- return "\n " + getSamplerFromInInfo(newInputInfo) + "\n float " + funcName + "(int row, int col) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n ";
- }
- if (inputInfo.shapeInfo.isUniform) {
- // Uniform arrays will be less than 65505 (no risk of float16 overflow).
- return "\n float " + funcName + "(int row, int col) {\n int index = round(dot(vec2(row, col), vec2(" + shape[1] + ", 1)));\n " + getUniformSampler(inputInfo) + "\n }\n ";
- }
- var texNumR = texShape[0];
- var texNumC = texShape[1];
- var offset = getFlatOffsetUniformName(texName);
- if (texNumC === 1) {
- // index is used directly as physical (no risk of float16 overflow).
- return "\n float " + funcName + "(int row, int col) {\n float index = dot(vec3(row, col, " + offset + "), vec3(" + shape[1] + ", 1, 1));\n vec2 uv = vec2(0.5, (index + 0.5) / " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
- }
- if (texNumR === 1) {
- // index is used directly as physical (no risk of float16 overflow).
- return "\n float " + funcName + "(int row, int col) {\n float index = dot(vec3(row, col, " + offset + "), vec3(" + shape[1] + ", 1, 1));\n vec2 uv = vec2((index + 0.5) / " + texNumC + ".0, 0.5);\n return sampleTexture(" + texName + ", uv);\n }\n ";
- }
- return "\n float " + funcName + "(int row, int col) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * " + shape[1] + " + col + " + offset + ";\n vec2 uv = uvFromFlat(" + texNumR + ", " + texNumC + ", index);\n return sampleTexture(" + texName + ", uv);\n }\n";
- }
- function getPackedSampler3D(inputInfo) {
- var shape = inputInfo.shapeInfo.logicalShape;
- var texName = inputInfo.name;
- var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
- var texShape = inputInfo.shapeInfo.texShape;
- var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
- if (shape[0] === 1) {
- var squeezedShape = shape.slice(1);
- var keptDims = [1, 2];
- var newInputInfo = squeezeInputInfo(inputInfo, squeezedShape);
- var params = ['b', 'row', 'col'];
- return "\n " + getPackedSamplerFromInInfo(newInputInfo) + "\n vec4 " + funcName + "(int b, int row, int col) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n ";
- }
- var texNumR = packedTexShape[0];
- var texNumC = packedTexShape[1];
- var valuesPerRow = Math.ceil(shape[2] / 2);
- var texelsInBatch = valuesPerRow * Math.ceil(shape[1] / 2);
- var glsl = getGlslDifferences();
- return "\n vec4 " + funcName + "(int b, int row, int col) {\n vec2 uv = packedUVfrom3D(\n " + texNumR + ", " + texNumC + ", " + texelsInBatch + ", " + valuesPerRow + ", b, row, col);\n return " + glsl.texture2D + "(" + texName + ", uv);\n }\n ";
- }
- function getSampler3D(inputInfo) {
- var shape = inputInfo.shapeInfo.logicalShape;
- var texName = inputInfo.name;
- var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
- var stride0 = shape[1] * shape[2];
- var stride1 = shape[2];
- var _a = squeezeShape(shape), newShape = _a.newShape, keptDims = _a.keptDims;
- var squeezedShape = newShape;
- if (squeezedShape.length < shape.length) {
- var newInputInfo = squeezeInputInfo(inputInfo, squeezedShape);
- var params = ['row', 'col', 'depth'];
- return "\n " + getSamplerFromInInfo(newInputInfo) + "\n float " + funcName + "(int row, int col, int depth) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n ";
- }
- if (inputInfo.shapeInfo.isUniform) {
- // Uniform arrays will be less than 65505 (no risk of float16 overflow).
- return "\n float " + funcName + "(int row, int col, int depth) {\n int index = round(dot(vec3(row, col, depth),\n vec3(" + stride0 + ", " + stride1 + ", 1)));\n " + getUniformSampler(inputInfo) + "\n }\n ";
- }
- var texShape = inputInfo.shapeInfo.texShape;
- var texNumR = texShape[0];
- var texNumC = texShape[1];
- var flatOffset = inputInfo.shapeInfo.flatOffset;
- if (texNumC === stride0 && flatOffset == null) {
- // texC is used directly as physical (no risk of float16 overflow).
- return "\n float " + funcName + "(int row, int col, int depth) {\n float texR = float(row);\n float texC = dot(vec2(col, depth), vec2(" + stride1 + ", 1));\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
- }
- if (texNumC === stride1 && flatOffset == null) {
- // texR is used directly as physical (no risk of float16 overflow).
- return "\n float " + funcName + "(int row, int col, int depth) {\n float texR = dot(vec2(row, col), vec2(" + shape[1] + ", 1));\n float texC = float(depth);\n vec2 uv = (vec2(texC, texR) + halfCR) / vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
- }
- var offset = getFlatOffsetUniformName(texName);
- return "\n float " + funcName + "(int row, int col, int depth) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * " + stride0 + " + col * " + stride1 + " + depth + " + offset + ";\n vec2 uv = uvFromFlat(" + texNumR + ", " + texNumC + ", index);\n return sampleTexture(" + texName + ", uv);\n }\n ";
- }
- function getPackedSamplerND(inputInfo) {
- var shape = inputInfo.shapeInfo.logicalShape;
- var rank = shape.length;
- var texName = inputInfo.name;
- var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
- var texShape = inputInfo.shapeInfo.texShape;
- var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
- var texNumR = packedTexShape[0];
- var texNumC = packedTexShape[1];
- var valuesPerRow = Math.ceil(shape[rank - 1] / 2);
- var texelsInBatch = valuesPerRow * Math.ceil(shape[rank - 2] / 2);
- var params = "int b, int row, int col";
- var index = "b * " + texelsInBatch + " + (row / 2) * " + valuesPerRow + " + (col / 2)";
- for (var b = 2; b < rank - 1; b++) {
- params = "int b" + b + ", " + params;
- texelsInBatch *= shape[rank - b - 1];
- index = "b" + b + " * " + texelsInBatch + " + " + index;
- }
- var glsl = getGlslDifferences();
- return "\n vec4 " + funcName + "(" + params + ") {\n int index = " + index + ";\n int texR = index / " + texNumC + ";\n int texC = index - texR * " + texNumC + ";\n vec2 uv = (vec2(texC, texR) + halfCR) / vec2(" + texNumC + ", " + texNumR + ");\n return " + glsl.texture2D + "(" + texName + ", uv);\n }\n ";
- }
- function getSampler4D(inputInfo) {
- var shape = inputInfo.shapeInfo.logicalShape;
- var texName = inputInfo.name;
- var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
- var stride2 = shape[3];
- var stride1 = shape[2] * stride2;
- var stride0 = shape[1] * stride1;
- var _a = squeezeShape(shape), newShape = _a.newShape, keptDims = _a.keptDims;
- if (newShape.length < shape.length) {
- var newInputInfo = squeezeInputInfo(inputInfo, newShape);
- var params = ['row', 'col', 'depth', 'depth2'];
- return "\n " + getSamplerFromInInfo(newInputInfo) + "\n float " + funcName + "(int row, int col, int depth, int depth2) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n ";
- }
- if (inputInfo.shapeInfo.isUniform) {
- // Uniform arrays will be less than 65505 (no risk of float16 overflow).
- return "\n float " + funcName + "(int row, int col, int depth, int depth2) {\n int index = round(dot(vec4(row, col, depth, depth2),\n vec4(" + stride0 + ", " + stride1 + ", " + stride2 + ", 1)));\n " + getUniformSampler(inputInfo) + "\n }\n ";
- }
- var flatOffset = inputInfo.shapeInfo.flatOffset;
- var texShape = inputInfo.shapeInfo.texShape;
- var texNumR = texShape[0];
- var texNumC = texShape[1];
- if (texNumC === stride0 && flatOffset == null) {
- // texC is used directly as physical (no risk of float16 overflow).
- return "\n float " + funcName + "(int row, int col, int depth, int depth2) {\n float texR = float(row);\n float texC =\n dot(vec3(col, depth, depth2),\n vec3(" + stride1 + ", " + stride2 + ", 1));\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
- }
- if (texNumC === stride2 && flatOffset == null) {
- // texR is used directly as physical (no risk of float16 overflow).
- return "\n float " + funcName + "(int row, int col, int depth, int depth2) {\n float texR = dot(vec3(row, col, depth),\n vec3(" + shape[1] * shape[2] + ", " + shape[2] + ", 1));\n float texC = float(depth2);\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
- }
- var offset = getFlatOffsetUniformName(texName);
- return "\n float " + funcName + "(int row, int col, int depth, int depth2) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * " + stride0 + " + col * " + stride1 + " +\n depth * " + stride2 + " + depth2;\n vec2 uv = uvFromFlat(" + texNumR + ", " + texNumC + ", index + " + offset + ");\n return sampleTexture(" + texName + ", uv);\n }\n ";
- }
- function getSampler5D(inputInfo) {
- var shape = inputInfo.shapeInfo.logicalShape;
- var texName = inputInfo.name;
- var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
- var stride3 = shape[4];
- var stride2 = shape[3] * stride3;
- var stride1 = shape[2] * stride2;
- var stride0 = shape[1] * stride1;
- var _a = squeezeShape(shape), newShape = _a.newShape, keptDims = _a.keptDims;
- if (newShape.length < shape.length) {
- var newInputInfo = squeezeInputInfo(inputInfo, newShape);
- var params = ['row', 'col', 'depth', 'depth2', 'depth3'];
- return "\n " + getSamplerFromInInfo(newInputInfo) + "\n float " + funcName + "(int row, int col, int depth, int depth2, int depth3) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n ";
- }
- if (inputInfo.shapeInfo.isUniform) {
- // Uniform arrays will be less than 65505 (no risk of float16 overflow).
- return "\n float " + funcName + "(int row, int col, int depth, int depth2, int depth3) {\n float index = dot(\n vec4(row, col, depth, depth2),\n vec4(" + stride0 + ", " + stride1 + ", " + stride2 + ", " + stride3 + ")) +\n depth3;\n " + getUniformSampler(inputInfo) + "\n }\n ";
- }
- var flatOffset = inputInfo.shapeInfo.flatOffset;
- var texShape = inputInfo.shapeInfo.texShape;
- var texNumR = texShape[0];
- var texNumC = texShape[1];
- if (texNumC === stride0 && flatOffset == null) {
- // texC is used directly as physical (no risk of float16 overflow).
- return "\n float " + funcName + "(int row, int col, int depth, int depth2, int depth3) {\n int texR = row;\n float texC = dot(vec4(col, depth, depth2, depth3),\n vec4(" + stride1 + ", " + stride2 + ", " + stride3 + ", 1));\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
- }
- if (texNumC === stride3 && flatOffset == null) {
- // texR is used directly as physical (no risk of float16 overflow).
- return "\n float " + funcName + "(int row, int col, int depth, int depth2, int depth3) {\n float texR = dot(\n vec4(row, col, depth, depth2),\n vec4(" + shape[1] * shape[2] * shape[3] + ",\n " + shape[2] * shape[3] + ", " + shape[3] + ", 1));\n int texC = depth3;\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
- }
- var offset = getFlatOffsetUniformName(texName);
- return "\n float " + funcName + "(int row, int col, int depth, int depth2, int depth3) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * " + stride0 + " + col * " + stride1 + " + depth * " + stride2 + " +\n depth2 * " + stride3 + " + depth3 + " + offset + ";\n vec2 uv = uvFromFlat(" + texNumR + ", " + texNumC + ", index);\n return sampleTexture(" + texName + ", uv);\n }\n ";
- }
- function getSampler6D(inputInfo) {
- var shape = inputInfo.shapeInfo.logicalShape;
- var texName = inputInfo.name;
- var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
- var _a = squeezeShape(shape), newShape = _a.newShape, keptDims = _a.keptDims;
- if (newShape.length < shape.length) {
- var newInputInfo = squeezeInputInfo(inputInfo, newShape);
- var params = ['row', 'col', 'depth', 'depth2', 'depth3', 'depth4'];
- return "\n " + getSamplerFromInInfo(newInputInfo) + "\n float " + funcName + "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n ";
- }
- var stride4 = shape[5];
- var stride3 = shape[4] * stride4;
- var stride2 = shape[3] * stride3;
- var stride1 = shape[2] * stride2;
- var stride0 = shape[1] * stride1;
- if (inputInfo.shapeInfo.isUniform) {
- // Uniform arrays will be less than 65505 (no risk of float16 overflow).
- return "\n float " + funcName + "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n int index = round(dot(\n vec4(row, col, depth, depth2),\n vec4(" + stride0 + ", " + stride1 + ", " + stride2 + ", " + stride3 + ")) +\n dot(\n vec2(depth3, depth4),\n vec2(" + stride4 + ", 1)));\n " + getUniformSampler(inputInfo) + "\n }\n ";
- }
- var flatOffset = inputInfo.shapeInfo.flatOffset;
- var texShape = inputInfo.shapeInfo.texShape;
- var texNumR = texShape[0];
- var texNumC = texShape[1];
- if (texNumC === stride0 && flatOffset == null) {
- // texC is used directly as physical (no risk of float16 overflow).
- return "\n float " + funcName + "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n int texR = row;\n float texC = dot(vec4(col, depth, depth2, depth3),\n vec4(" + stride1 + ", " + stride2 + ", " + stride3 + ", " + stride4 + ")) +\n float(depth4);\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
- }
- if (texNumC === stride4 && flatOffset == null) {
- // texR is used directly as physical (no risk of float16 overflow).
- return "\n float " + funcName + "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n float texR = dot(vec4(row, col, depth, depth2),\n vec4(" + shape[1] * shape[2] * shape[3] * shape[4] + ",\n " + shape[2] * shape[3] * shape[4] + ",\n " + shape[3] * shape[4] + ",\n " + shape[4] + ")) + float(depth3);\n int texC = depth4;\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
- }
- var offset = getFlatOffsetUniformName(texName);
- return "\n float " + funcName + "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * " + stride0 + " + col * " + stride1 + " + depth * " + stride2 + " +\n depth2 * " + stride3 + " + depth3 * " + stride4 + " + depth4 + " + offset + ";\n vec2 uv = uvFromFlat(" + texNumR + ", " + texNumC + ", index);\n return sampleTexture(" + texName + ", uv);\n }\n ";
- }
- function getUniformSampler(inputInfo) {
- var texName = inputInfo.name;
- var inSize = sizeFromShape(inputInfo.shapeInfo.logicalShape);
- if (inSize < 2) {
- return "return " + texName + ";";
- }
- return "\n for (int i = 0; i < " + inSize + "; i++) {\n if (i == index) {\n return " + texName + "[i];\n }\n }\n ";
- }
- function getPackedSamplerAtOutputCoords(inputInfo, outShapeInfo) {
- var texName = inputInfo.name;
- var texFuncSnippet = texName.charAt(0).toUpperCase() + texName.slice(1);
- var funcName = 'get' + texFuncSnippet + 'AtOutCoords';
- var inRank = inputInfo.shapeInfo.logicalShape.length;
- var outRank = outShapeInfo.logicalShape.length;
- var broadcastDims = getBroadcastDims(inputInfo.shapeInfo.logicalShape, outShapeInfo.logicalShape);
- var type = getCoordsDataType(outRank);
- var rankDiff = outRank - inRank;
- var coordsSnippet;
- var fields = ['x', 'y', 'z', 'w', 'u', 'v'];
- if (inRank === 0) {
- coordsSnippet = '';
- }
- else if (outRank < 2 && broadcastDims.length >= 1) {
- coordsSnippet = 'coords = 0;';
- }
- else {
- coordsSnippet =
- broadcastDims.map(function (d) { return "coords." + fields[d + rankDiff] + " = 0;"; })
- .join('\n');
- }
- var unpackedCoordsSnippet = '';
- if (outRank < 2 && inRank > 0) {
- unpackedCoordsSnippet = 'coords';
- }
- else {
- unpackedCoordsSnippet = inputInfo.shapeInfo.logicalShape
- .map(function (s, i) { return "coords." + fields[i + rankDiff]; })
- .join(', ');
- }
- var output = "return outputValue;";
- var inSize = sizeFromShape(inputInfo.shapeInfo.logicalShape);
- var isInputScalar = inSize === 1;
- var outSize = sizeFromShape(outShapeInfo.logicalShape);
- var isOutputScalar = outSize === 1;
- if (inRank === 1 && !isInputScalar && !isOutputScalar) {
- output = "\n return vec4(outputValue.xy, outputValue.xy);\n ";
- }
- else if (isInputScalar && !isOutputScalar) {
- if (outRank === 1) {
- output = "\n return vec4(outputValue.x, outputValue.x, 0., 0.);\n ";
- }
- else {
- output = "\n return vec4(outputValue.x);\n ";
- }
- }
- else if (broadcastDims.length) {
- var rows = inRank - 2;
- var cols = inRank - 1;
- if (broadcastDims.indexOf(rows) > -1 && broadcastDims.indexOf(cols) > -1) {
- output = "return vec4(outputValue.x);";
- }
- else if (broadcastDims.indexOf(rows) > -1) {
- output = "return vec4(outputValue.x, outputValue.y, " +
- "outputValue.x, outputValue.y);";
- }
- else if (broadcastDims.indexOf(cols) > -1) {
- output = "return vec4(outputValue.xx, outputValue.zz);";
- }
- }
- return "\n vec4 " + funcName + "() {\n " + type + " coords = getOutputCoords();\n " + coordsSnippet + "\n vec4 outputValue = get" + texFuncSnippet + "(" + unpackedCoordsSnippet + ");\n " + output + "\n }\n ";
- }
- function getSamplerAtOutputCoords(inputInfo, outShapeInfo) {
- var texName = inputInfo.name;
- var texFuncSnippet = texName.charAt(0).toUpperCase() + texName.slice(1);
- var funcName = 'get' + texFuncSnippet + 'AtOutCoords';
- var outTexShape = outShapeInfo.texShape;
- var inTexShape = inputInfo.shapeInfo.texShape;
- var inRank = inputInfo.shapeInfo.logicalShape.length;
- var outRank = outShapeInfo.logicalShape.length;
- if (!inputInfo.shapeInfo.isUniform && inRank === outRank &&
- inputInfo.shapeInfo.flatOffset == null &&
- arraysEqual(inTexShape, outTexShape)) {
- return "\n float " + funcName + "() {\n return sampleTexture(" + texName + ", resultUV);\n }\n ";
- }
- var type = getCoordsDataType(outRank);
- var broadcastDims = getBroadcastDims(inputInfo.shapeInfo.logicalShape, outShapeInfo.logicalShape);
- var rankDiff = outRank - inRank;
- var coordsSnippet;
- var fields = ['x', 'y', 'z', 'w', 'u', 'v'];
- if (inRank === 0) {
- coordsSnippet = '';
- }
- else if (outRank < 2 && broadcastDims.length >= 1) {
- coordsSnippet = 'coords = 0;';
- }
- else {
- coordsSnippet =
- broadcastDims.map(function (d) { return "coords." + fields[d + rankDiff] + " = 0;"; })
- .join('\n');
- }
- var unpackedCoordsSnippet = '';
- if (outRank < 2 && inRank > 0) {
- unpackedCoordsSnippet = 'coords';
- }
- else {
- unpackedCoordsSnippet = inputInfo.shapeInfo.logicalShape
- .map(function (s, i) { return "coords." + fields[i + rankDiff]; })
- .join(', ');
- }
- return "\n float " + funcName + "() {\n " + type + " coords = getOutputCoords();\n " + coordsSnippet + "\n return get" + texFuncSnippet + "(" + unpackedCoordsSnippet + ");\n }\n ";
- }
- function getCoordsDataType(rank) {
- if (rank <= 1) {
- return 'int';
- }
- else if (rank === 2) {
- return 'ivec2';
- }
- else if (rank === 3) {
- return 'ivec3';
- }
- else if (rank === 4) {
- return 'ivec4';
- }
- else if (rank === 5) {
- return 'ivec5';
- }
- else if (rank === 6) {
- return 'ivec6';
- }
- else {
- throw Error("GPU for rank " + rank + " is not yet supported");
- }
- }
- /** Returns a new input info (a copy) that has a squeezed logical shape. */
- function squeezeInputInfo(inInfo, squeezedShape) {
- // Deep copy.
- var newInputInfo = JSON.parse(JSON.stringify(inInfo));
- newInputInfo.shapeInfo.logicalShape = squeezedShape;
- return newInputInfo;
- }
- function getSqueezedParams(params, keptDims) {
- return keptDims.map(function (d) { return params[d]; }).join(', ');
- }
-
- /**
- * @license
- * Copyright 2019 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var ArgMinMaxPackedProgram = /** @class */ (function () {
- function ArgMinMaxPackedProgram(shape, windowSize, op, firstPass) {
- this.variableNames = ['A'];
- this.packedInputs = true;
- this.packedOutput = true;
- assert(shape.length > 2, function () { return "Packed arg" + (op.charAt(0).toUpperCase() +
- op.slice(1)) + " supports only inputs with rank above 2."; });
- var inSize = shape[shape.length - 1];
- var outSize = Math.ceil(inSize / windowSize);
- this.outputShape = shape.slice(0, -1);
- if (outSize > 1) {
- this.outputShape.push(outSize);
- }
- if (!firstPass) {
- this.variableNames.push('bestIndicesA');
- }
- var outShape = this.outputShape;
- var rank = outShape.length;
- var dtype = getCoordsDataType(rank);
- var coords = getChannels('coords', rank);
- var sourceLocSetup;
- var sourceRank;
- if (outSize === 1) {
- sourceRank = rank + 1;
- var sourceLocDType = getCoordsDataType(sourceRank);
- sourceLocSetup = "\n " + sourceLocDType + " sourceLocR = " + sourceLocDType + "(" + coords.join() + ", 0);\n ++" + coords[rank - 1] + ";\n " + sourceLocDType + " sourceLocG = " + sourceLocDType + "(" + coords.join() + ", 0);\n ++" + coords[rank - 2] + ";\n " + sourceLocDType + " sourceLocA = " + sourceLocDType + "(" + coords.join() + ", 0);\n --" + coords[rank - 1] + ";\n " + sourceLocDType + " sourceLocB = " + sourceLocDType + "(" + coords.join() + ", 0);\n --" + coords[rank - 2] + ";";
- }
- else {
- sourceRank = rank;
- sourceLocSetup = "\n " + dtype + " sourceLocR = coords;\n ++" + coords[rank - 1] + ";\n " + dtype + " sourceLocG = coords;\n ++" + coords[rank - 2] + ";\n " + dtype + " sourceLocA = coords;\n --" + coords[rank - 1] + ";\n " + dtype + " sourceLocB = coords;\n --" + coords[rank - 2] + ";";
- }
- var channels = ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, sourceRank);
- var inChannel = '.' + channels[sourceRank - 1]; // e.g. ".b" for rank 3.
- var intChannels = channels.map(function (x) { return 'int ' + x; });
- var srcRCoords = getChannels('sourceLocR', sourceRank - 1).concat('inIdx.r');
- var srcGCoords = getChannels('sourceLocG', sourceRank - 1).concat('inIdx.g');
- var srcBCoords = getChannels('sourceLocB', sourceRank - 1).concat('inIdx.b');
- var srcACoords = getChannels('sourceLocA', sourceRank - 1).concat('inIdx.a');
- var compOp = (op === 'max') ? 'greaterThan' : 'lessThan';
- var fetchCandidateIdx = firstPass ? '' : "\n inIdx = round(vec4(getBestIndicesAChannel(" + srcRCoords.join() + "),\n getBestIndicesAChannel(" + srcGCoords.join() + "),\n getBestIndicesAChannel(" + srcBCoords.join() + "),\n getBestIndicesAChannel(" + srcACoords.join() + ")));";
- var fetchValue = "vec4(\n getAChannel(" + srcRCoords.join() + "),\n hasNextCol ? getAChannel(" + srcGCoords.join() + ") : 0.,\n hasNextRow ? getAChannel(" + srcBCoords.join() + ") : 0.,\n hasNextRow && hasNextCol ? getAChannel(" + srcACoords.join() + ") : 0.)";
- var getBestIndicesAChannelSnippet = firstPass ? '' : "\n float getBestIndicesAChannel(" + intChannels.join() + ") {\n return getChannel(getBestIndicesA(" + channels.join() + "),\n vec2(" + channels.slice(-2).join() + "));\n }";
- this.userCode = "\n float getAChannel(" + intChannels.join() + ") {\n return getChannel(getA(" + channels.join() + "),\n vec2(" + channels.slice(-2).join() + "));\n }\n " + getBestIndicesAChannelSnippet + "\n void main() {\n " + dtype + " coords = getOutputCoords();\n bool hasNextCol = " + coords[rank - 1] + " < " + (outShape[rank - 1] - 1) + ";\n bool hasNextRow = " + coords[rank - 2] + " < " + (outShape[rank - 2] - 1) + ";\n " + sourceLocSetup + "\n ivec4 srcIdx = ivec4(sourceLocR" + inChannel + ", sourceLocG" + inChannel + ",\n sourceLocB" + inChannel + ", sourceLocA" + inChannel + ") * " + windowSize + ";\n ivec4 inIdx = srcIdx;\n vec4 bestIndex = vec4(inIdx);\n vec4 bestValue = " + fetchValue + ";\n\n for (int i = 0; i < " + windowSize + "; i++) {\n inIdx = srcIdx;\n " + fetchCandidateIdx + "\n vec4 candidate = " + fetchValue + ";\n bvec4 nan = isnan(candidate);\n bvec4 replace = bvec4(\n vec4(" + compOp + "(candidate, bestValue)) * (vec4(1.0) - vec4(nan)));\n\n bestValue = vec4(replace.x ? candidate.x : bestValue.x,\n replace.y ? candidate.y : bestValue.y,\n replace.z ? candidate.z : bestValue.z,\n replace.w ? candidate.w : bestValue.w);\n bestIndex = mix(bestIndex, vec4(inIdx), vec4(replace));\n srcIdx++;\n }\n setOutput(bestIndex);\n }\n ";
- }
- return ArgMinMaxPackedProgram;
- }());
-
- /**
- * @license
- * Copyright 2017 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var AvgPool2DBackpropProgram = /** @class */ (function () {
- function AvgPool2DBackpropProgram(convInfo) {
- this.variableNames = ['dy'];
- this.outputShape = convInfo.inShape;
- var filterHeight = convInfo.filterHeight;
- var filterWidth = convInfo.filterWidth;
- var strideHeight = convInfo.strideHeight;
- var strideWidth = convInfo.strideWidth;
- var dilationHeight = convInfo.dilationHeight;
- var dilationWidth = convInfo.dilationWidth;
- var effectiveFilterHeight = convInfo.effectiveFilterHeight;
- var effectiveFilterWidth = convInfo.effectiveFilterWidth;
- var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
- var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
- var avgMultiplier = 1 / (filterHeight * filterWidth);
- this.userCode = "\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n const float avgMultiplier = float(" + avgMultiplier + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n\n ivec2 dyRCCorner = coords.yz - pads;\n int dyRCorner = dyRCCorner.x;\n int dyCCorner = dyRCCorner.y;\n\n // Convolve dy(?, ?, d) with pos mask(:, :, d) to get dx(xR, xC, d).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 || fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n for (int wC = 0; wC < " + effectiveFilterWidth + ";\n wC+= " + dilationWidth + ") {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n float dyValue = getDy(b, idyR, idyC, d);\n\n dotProd += dyValue * avgMultiplier;\n }\n }\n setOutput(dotProd);\n }\n ";
- }
- return AvgPool2DBackpropProgram;
- }());
- var AvgPool3DBackpropProgram = /** @class */ (function () {
- function AvgPool3DBackpropProgram(convInfo) {
- this.variableNames = ['dy'];
- this.outputShape = convInfo.inShape;
- var filterDepth = convInfo.filterDepth;
- var filterHeight = convInfo.filterHeight;
- var filterWidth = convInfo.filterWidth;
- var strideDepth = convInfo.strideDepth;
- var strideHeight = convInfo.strideHeight;
- var strideWidth = convInfo.strideWidth;
- var dilationDepth = convInfo.dilationDepth;
- var dilationHeight = convInfo.dilationHeight;
- var dilationWidth = convInfo.dilationWidth;
- var effectiveFilterDepth = convInfo.effectiveFilterDepth;
- var effectiveFilterHeight = convInfo.effectiveFilterHeight;
- var effectiveFilterWidth = convInfo.effectiveFilterWidth;
- var padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
- var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
- var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
- var avgMultiplier = 1 / (filterDepth * filterHeight * filterWidth);
- this.userCode = "\n const ivec3 pads = ivec3(" + padFront + ", " + padTop + ", " + padLeft + ");\n const float avgMultiplier = float(" + avgMultiplier + ");\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int ch = coords.u;\n\n ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;\n int dyDCorner = dyCorner.x;\n int dyRCorner = dyCorner.y;\n int dyCCorner = dyCorner.z;\n\n // Convolve dy(?, ?, ?, d) with pos mask(:, :, :, ch) to get\n // dx(xD, xR, xC, ch).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n\n for (int wD = 0; wD < " + effectiveFilterDepth + ";\n wD += " + dilationDepth + ") {\n float dyD = float(dyDCorner + wD) / " + strideDepth + ".0;\n\n if (dyD < 0.0 || dyD >= " + convInfo.outDepth + ".0 || fract(dyD) > 0.0) {\n continue;\n }\n int idyD = int(dyD);\n\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 ||\n fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n for (int wC = 0; wC < " + effectiveFilterWidth + ";\n wC += " + dilationWidth + ") {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n float dyValue = getDy(batch, idyD, idyR, idyC, ch);\n\n dotProd += dyValue * avgMultiplier;\n }\n }\n }\n setOutput(dotProd);\n }\n ";
- }
- return AvgPool3DBackpropProgram;
- }());
-
- /**
- * @license
- * Copyright 2017 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var BatchNormProgram = /** @class */ (function () {
- function BatchNormProgram(xShape, meanShape, varianceShape, offsetShape, scaleShape, varianceEpsilon) {
- this.outputShape = [];
- this.variableNames = ['x', 'mean', 'variance'];
- assertAndGetBroadcastShape(xShape, meanShape);
- assertAndGetBroadcastShape(xShape, varianceShape);
- var offsetSnippet = '0.0';
- if (offsetShape != null) {
- assertAndGetBroadcastShape(xShape, offsetShape);
- this.variableNames.push('offset');
- offsetSnippet = 'getOffsetAtOutCoords()';
- }
- var scaleSnippet = '1.0';
- if (scaleShape != null) {
- assertAndGetBroadcastShape(xShape, scaleShape);
- this.variableNames.push('scale');
- scaleSnippet = 'getScaleAtOutCoords()';
- }
- this.outputShape = xShape;
- this.userCode = "\n void main() {\n float x = getXAtOutCoords();\n float mean = getMeanAtOutCoords();\n float variance = getVarianceAtOutCoords();\n float offset = " + offsetSnippet + ";\n float scale = " + scaleSnippet + ";\n float inv = scale * inversesqrt(variance + float(" + varianceEpsilon + "));\n setOutput(dot(vec3(x, -mean, offset), vec3(inv, inv, 1)));\n }\n ";
- }
- return BatchNormProgram;
- }());
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var BatchNormPackedProgram = /** @class */ (function () {
- function BatchNormPackedProgram(xShape, meanShape, varianceShape, offsetShape, scaleShape, varianceEpsilon) {
- this.packedInputs = true;
- this.packedOutput = true;
- this.variableNames = ['x', 'mean', 'variance'];
- assertAndGetBroadcastShape(xShape, meanShape);
- assertAndGetBroadcastShape(xShape, varianceShape);
- var offsetSnippet = 'vec4(0.0)';
- if (offsetShape != null) {
- assertAndGetBroadcastShape(xShape, offsetShape);
- this.variableNames.push('offset');
- offsetSnippet = 'getOffsetAtOutCoords()';
- }
- var scaleSnippet = 'vec4(1.0)';
- if (scaleShape != null) {
- assertAndGetBroadcastShape(xShape, scaleShape);
- this.variableNames.push('scale');
- scaleSnippet = 'getScaleAtOutCoords()';
- }
- this.outputShape = xShape;
- this.userCode = "\n void main() {\n vec4 offset = " + offsetSnippet + ";\n vec4 scale = " + scaleSnippet + ";\n\n vec4 x = getXAtOutCoords();\n vec4 mean = getMeanAtOutCoords();\n vec4 variance = getVarianceAtOutCoords();\n\n vec4 inv = scale * inversesqrt(variance + vec4(" + varianceEpsilon + "));\n\n setOutput((x - mean) * inv + offset);\n }\n ";
- }
- return BatchNormPackedProgram;
- }());
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- // (Ar + Ai)(Br + Bi) =
- // ArBr + ArBi + AiBr + AiBi = ArBr - AB + ArBi + AiBr
- // Yr = ArBr - AB
- // Yi = ArBi + AiBr
- var COMPLEX_MULTIPLY = {
- REAL: 'return areal * breal - aimag * bimag;',
- IMAG: 'return areal * bimag + aimag * breal;'
- };
- var BinaryOpComplexProgram = /** @class */ (function () {
- function BinaryOpComplexProgram(op, aShape, bShape) {
- this.variableNames = ['AReal', 'AImag', 'BReal', 'BImag'];
- this.outputShape =
- assertAndGetBroadcastShape(aShape, bShape);
- this.userCode = "\n float binaryOpComplex(\n float areal, float aimag, float breal, float bimag) {\n " + op + "\n }\n\n void main() {\n float areal = getARealAtOutCoords();\n float aimag = getAImagAtOutCoords();\n float breal = getBRealAtOutCoords();\n float bimag = getBImagAtOutCoords();\n setOutput(binaryOpComplex(areal, aimag, breal, bimag));\n }\n ";
- }
- return BinaryOpComplexProgram;
- }());
-
- /**
- * @license
- * Copyright 2017 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var CHECK_NAN_SNIPPET = "\n if (isnan(a)) return a;\n if (isnan(b)) return b;\n";
- var ADD = 'return a + b;';
- var SUB = 'return a - b;';
- var MUL = 'return a * b;';
- // Without the equality check div produces 0.9999 for a = b, which when
- // floored can cause errors.
- var DIV = "\nif (a == b) {\n return 1.0;\n};\nreturn a / b;";
- // We use native integer division to deal with floating point imprecision. Since
- // we implement floor division and glsl implements truncated division, we
- // correct for this by subtracting 1 from result when the result is negative and
- // there is a remainder.
- var INT_DIV = "\n float s = sign(a) * sign(b);\n int ia = round(a);\n int ib = round(b);\n if (ib != 0) {\n // Windows (D3D) wants guaranteed non-zero int division at compile-time.\n return float(idiv(ia, ib, s));\n } else {\n return NAN;\n }\n";
- var POW = "\nif(a < 0.0 && floor(b) < b){\n return NAN;\n}\nif (b == 0.0) {\n return 1.0;\n}\nreturn (round(mod(b, 2.0)) != 1) ?\n pow(abs(a), b) : sign(a) * pow(abs(a), b);\n";
- var EQUAL = "return float(a == b);";
- var NOT_EQUAL = "return float(a != b);";
- var LESS = "return float(a < b);";
- var LESS_EQUAL = "return float(a <= b);";
- var GREATER = "return float(a > b);";
- var GREATER_EQUAL = "return float(a >= b);";
- var LOGICAL_AND = "return float(a >= 1.0 && b >= 1.0);";
- var LOGICAL_OR = "return float(a >= 1.0 || b >= 1.0);";
- var MAX = CHECK_NAN_SNIPPET + "\n return max(a, b);\n";
- var MIN = CHECK_NAN_SNIPPET + "\n return min(a, b);\n";
- var MOD = "if (b == 0.0) return NAN;\n return mod(a, b);";
- var ATAN2 = CHECK_NAN_SNIPPET + "\n return atan(a, b);\n";
- var ELU_DER = "return (b >= 1.0) ? a : a * (b + 1.0);";
- var PRELU = "return (a < 0.) ? b * a : a;";
- var BinaryOpProgram = /** @class */ (function () {
- function BinaryOpProgram(op, aShape, bShape) {
- this.variableNames = ['A', 'B'];
- this.outputShape =
- assertAndGetBroadcastShape(aShape, bShape);
- this.userCode = "\n float binaryOperation(float a, float b) {\n " + op + "\n }\n\n void main() {\n float a = getAAtOutCoords();\n float b = getBAtOutCoords();\n setOutput(binaryOperation(a, b));\n }\n ";
- }
- return BinaryOpProgram;
- }());
-
- /**
- * @license
- * Copyright 2018 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var CHECK_NAN_SNIPPET$1 = "\n result.r = isNaN.r > 0. ? NAN : result.r;\n result.g = isNaN.g > 0. ? NAN : result.g;\n result.b = isNaN.b > 0. ? NAN : result.b;\n result.a = isNaN.a > 0. ? NAN : result.a;\n";
- // We do the same as in ./binaryop_gpu, with vec4 and ivec4.
- // On Linux, the vectorized implementation produces NaNs when a and b are 0.
- var DIV$1 = "\n // vec4 one = vec4(equal(a, b));\n // return one + (vec4(1.0) - one) * a / b;\n vec4 result = a / b;\n if(a.x == b.x) {\n result.x = 1.;\n }\n if(a.y == b.y) {\n result.y = 1.;\n }\n if(a.z == b.z) {\n result.z = 1.;\n }\n if(a.w == b.w) {\n result.w = 1.;\n }\n\n return result;\n";
- var INT_DIV$1 = "\n ivec4 ia = round(a);\n ivec4 ib = round(b);\n bvec4 cond = notEqual(ib, ivec4(0));\n ivec4 result = ivec4(0);\n vec4 s = sign(a) * sign(b);\n\n // Windows (D3D) wants guaranteed non-zero int division at compile-time.\n if (cond[0]) {\n result[0] = idiv(ia[0], ib[0], s[0]);\n }\n if (cond[1]) {\n result[1] = idiv(ia[1], ib[1], s[1]);\n }\n if (cond[2]) {\n result[2] = idiv(ia[2], ib[2], s[2]);\n }\n if (cond[3]) {\n result[3] = idiv(ia[3], ib[3], s[3]);\n }\n return vec4(result);\n";
- var POW$1 = "\n // isModRound1 has 1 for components with round(mod(b, 2.0)) == 1, 0 otherwise.\n vec4 isModRound1 = vec4(equal(round(mod(b, 2.0)), ivec4(1)));\n vec4 multiplier = sign(a) * isModRound1 + (vec4(1.0) - isModRound1);\n vec4 result = multiplier * pow(abs(a), b);\n\n // Ensure that a^0 = 1, including 0^0 = 1 as this correspond to TF and JS\n bvec4 isExpZero = equal(b, vec4(0.0));\n result.r = isExpZero.r ? 1.0 : result.r;\n result.g = isExpZero.g ? 1.0 : result.g;\n result.b = isExpZero.b ? 1.0 : result.b;\n result.a = isExpZero.a ? 1.0 : result.a;\n\n vec4 isNaN = vec4(lessThan(a, vec4(0.0))) * vec4(lessThan(floor(b), b));\n " +
- CHECK_NAN_SNIPPET$1 + "\n return result;\n";
- var PRELU$1 = "\n vec4 aLessThanZero = vec4(lessThan(a, vec4(0.)));\n return (aLessThanZero * (b * a)) + ((vec4(1.0) - aLessThanZero) * a);\n";
- var ELU_DER$1 = "\n vec4 bGTEZero = vec4(greaterThanEqual(b, vec4(0.)));\n return (bGTEZero * a) + ((vec4(1.0) - bGTEZero) * (a * (b + vec4(1.0))));\n";
- var ATAN2$1 = "\n vec4 result = atan(a, b);\n vec4 isNaN = min(vec4(isnan(a)) + vec4(isnan(b)), vec4(1.0));\n " +
- CHECK_NAN_SNIPPET$1 + "\n return result;\n";
- var EQUAL$1 = "\n return vec4(equal(a, b));\n";
- var NOT_EQUAL$1 = "\n return vec4(notEqual(a, b));\n";
- var LESS$1 = "\n return vec4(lessThan(a, b));\n";
- var LESS_EQUAL$1 = "\n return vec4(lessThanEqual(a, b));\n";
- var GREATER$1 = "\n return vec4(greaterThan(a, b));\n";
- var GREATER_EQUAL$1 = "\n return vec4(greaterThanEqual(a, b));\n";
- var LOGICAL_AND$1 = "\n return vec4(\n vec4(greaterThanEqual(a, vec4(1.0))) *\n vec4(greaterThanEqual(b, vec4(1.0))));\n";
- var LOGICAL_OR$1 = "\n return min(\n vec4(greaterThanEqual(a, vec4(1.0))) +\n vec4(greaterThanEqual(b, vec4(1.0))),\n vec4(1.0));\n";
- var MAX$1 = "\n vec4 result = vec4(max(a, b));\n vec4 isNaN = min(vec4(isnan(a)) + vec4(isnan(b)), vec4(1.0));\n " +
- CHECK_NAN_SNIPPET$1 + "\n return result;\n";
- var MIN$1 = "\n vec4 result = vec4(min(a, b));\n vec4 isNaN = min(vec4(isnan(a)) + vec4(isnan(b)), vec4(1.0));\n " +
- CHECK_NAN_SNIPPET$1 + "\n return result;\n";
- var MOD$1 = "\n vec4 result = mod(a, b);\n vec4 isNaN = vec4(equal(b, vec4(0.0)));\n " +
- CHECK_NAN_SNIPPET$1 + "\n return result;\n";
- var BinaryOpPackedProgram = /** @class */ (function () {
- function BinaryOpPackedProgram(op, aShape, bShape, checkOutOfBounds) {
- if (checkOutOfBounds === void 0) { checkOutOfBounds = false; }
- this.variableNames = ['A', 'B'];
- this.supportsBroadcasting = true;
- this.packedInputs = true;
- this.packedOutput = true;
- this.outputShape =
- assertAndGetBroadcastShape(aShape, bShape);
- var rank = this.outputShape.length;
- var checkOutOfBoundsString = '';
- if (checkOutOfBounds) {
- if (rank === 0 || sizeFromShape(this.outputShape) === 1) {
- checkOutOfBoundsString = "\n result.y = 0.;\n result.z = 0.;\n result.w = 0.;\n ";
- }
- else {
- var dtype = getCoordsDataType(rank);
- checkOutOfBoundsString = "\n " + dtype + " coords = getOutputCoords();\n ";
- if (rank === 1) {
- checkOutOfBoundsString += "\n result.y = (coords + 1) >= " + this.outputShape[0] + " ? 0. : result.y;\n result.z = 0.;\n result.w = 0.;\n ";
- }
- else {
- var channels = getChannels('coords', rank);
- checkOutOfBoundsString += "\n bool nextRowOutOfBounds =\n (" + channels[rank - 2] + " + 1) >= " + this.outputShape[rank - 2] + ";\n bool nextColOutOfBounds =\n (" + channels[rank - 1] + " + 1) >= " + this.outputShape[rank - 1] + ";\n result.y = nextColOutOfBounds ? 0. : result.y;\n result.z = nextRowOutOfBounds ? 0. : result.z;\n result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w;\n ";
- }
- }
- }
- this.userCode = "\n vec4 binaryOperation(vec4 a, vec4 b) {\n " + op + "\n }\n\n void main() {\n vec4 a = getAAtOutCoords();\n vec4 b = getBAtOutCoords();\n\n vec4 result = binaryOperation(a, b);\n " + checkOutOfBoundsString + "\n\n setOutput(result);\n }\n ";
- }
- return BinaryOpPackedProgram;
- }());
-
- /**
- * @license
- * Copyright 2017 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var ClipProgram = /** @class */ (function () {
- function ClipProgram(aShape) {
- this.variableNames = ['A'];
- this.outputShape = aShape;
- this.userCode = "\n uniform float minVal;\n uniform float maxVal;\n\n void main() {\n float value = getAAtOutCoords();\n if (isnan(value)) {\n setOutput(value);\n return;\n }\n\n setOutput(clamp(value, minVal, maxVal));\n }\n ";
- }
- ClipProgram.prototype.getCustomSetupFunc = function (min, max) {
- var _this = this;
- return function (gpgpu, webGLProgram) {
- if (_this.minLoc == null) {
- _this.minLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'minVal');
- _this.maxLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'maxVal');
- }
- gpgpu.gl.uniform1f(_this.minLoc, min);
- gpgpu.gl.uniform1f(_this.maxLoc, max);
- };
- };
- return ClipProgram;
- }());
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var ClipPackedProgram = /** @class */ (function () {
- function ClipPackedProgram(aShape) {
- this.variableNames = ['A'];
- this.packedInputs = true;
- this.packedOutput = true;
- this.outputShape = aShape;
- this.userCode = "\n uniform float minVal;\n uniform float maxVal;\n\n void main() {\n vec4 value = getAAtOutCoords();\n\n if (any(isnan(value))) {\n setOutput(value);\n return;\n }\n\n setOutput(clamp(value, vec4(minVal), vec4(maxVal)));\n }\n ";
- }
- ClipPackedProgram.prototype.getCustomSetupFunc = function (min, max) {
- var _this = this;
- return function (gpgpu, webGLProgram) {
- if (_this.minLoc == null) {
- _this.minLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'minVal');
- _this.maxLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'maxVal');
- }
- gpgpu.gl.uniform1f(_this.minLoc, min);
- gpgpu.gl.uniform1f(_this.maxLoc, max);
- };
- };
- return ClipPackedProgram;
- }());
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var ComplexAbsProgram = /** @class */ (function () {
- function ComplexAbsProgram(shape) {
- this.variableNames = ['real', 'imag'];
- this.outputShape = shape;
- this.userCode = "\n void main() {\n float re = abs(getRealAtOutCoords());\n float im = abs(getImagAtOutCoords());\n float mx = max(re, im);\n\n // sadly the length function in glsl is not underflow-safe\n // (at least not on Intel GPUs). So the safe solution is\n // to ensure underflow-safety in all cases.\n setOutput(\n mx == 0.0 ? 0.0 : mx * length(vec2(1, min(re, im)/mx))\n );\n }\n ";
- }
- return ComplexAbsProgram;
- }());
-
- /**
- * @license
- * Copyright 2017 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var ConcatProgram = /** @class */ (function () {
- // Concats 2d tensors along axis=1. See comments in MathBackendWebGL.concat().
- function ConcatProgram(shapes) {
- this.outputShape = [];
- this.outputShape = computeOutShape(shapes, 1 /* axis */);
- this.variableNames = shapes.map(function (_, i) { return "T" + i; });
- var offsets = new Array(shapes.length - 1);
- offsets[0] = shapes[0][1];
- for (var i = 1; i < offsets.length; i++) {
- offsets[i] = offsets[i - 1] + shapes[i][1];
- }
- var snippets = ["if (yC < " + offsets[0] + ") setOutput(getT0(yR, yC));"];
- for (var i = 1; i < offsets.length; i++) {
- var shift = offsets[i - 1];
- snippets.push("else if (yC < " + offsets[i] + ") " +
- ("setOutput(getT" + i + "(yR, yC-" + shift + "));"));
- }
- var lastIndex = offsets.length;
- var lastShift = offsets[offsets.length - 1];
- snippets.push("else setOutput(getT" + lastIndex + "(yR, yC-" + lastShift + "));");
- this.userCode = "\n void main() {\n ivec2 coords = getOutputCoords();\n int yR = coords.x;\n int yC = coords.y;\n\n " + snippets.join('\n ') + "\n }\n ";
- }
- return ConcatProgram;
- }());
-
- /**
- * @license
- * Copyright 2019 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var ConcatPackedProgram = /** @class */ (function () {
- function ConcatPackedProgram(shapes, axis) {
- this.packedInputs = true;
- this.packedOutput = true;
- this.outputShape = [];
- this.outputShape = computeOutShape(shapes, axis);
- var shape = this.outputShape;
- var rank = shape.length;
- var dtype = getCoordsDataType(rank);
- var coords = getChannels('coords', rank);
- var channels = ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, rank);
- this.variableNames = shapes.map(function (_, i) { return "T" + i; });
- var offsets = new Array(shapes.length - 1);
- offsets[0] = shapes[0][axis];
- for (var i = 1; i < offsets.length; i++) {
- offsets[i] = offsets[i - 1] + shapes[i][axis];
- }
- var channel = channels[axis];
- var lastChannels = channels.slice(-2);
- var allChannels = channels.join();
- var getValueSnippet = "if (" + channel + " < " + offsets[0] + ") {\n return getChannel(\n getT0(" + allChannels + "), vec2(" + lastChannels.join() + "));\n }";
- for (var i = 1; i < offsets.length; i++) {
- var shift_1 = offsets[i - 1];
- // Note: the >= comparison below may seem unnecessary given the check
- // above but is needed to workaround branch execution issues on some
- // devices. It makes all the conditions exclusive without relying on
- // execution order.
- getValueSnippet += "\n if (" + channel + " < " + offsets[i] + " && " + channel + " >= " + offsets[i - 1] + ") {\n return getChannel(\n getT" + i + "(" + shiftedChannels(channels, channel, shift_1) + "),\n vec2(" + shiftedChannels(lastChannels, channel, shift_1) + "));\n }";
- }
- var lastIndex = offsets.length;
- var shift = offsets[offsets.length - 1];
- getValueSnippet += "\n return getChannel(\n getT" + lastIndex + "(" + shiftedChannels(channels, channel, shift) + "),\n vec2(" + shiftedChannels(lastChannels, channel, shift) + "));";
- this.userCode = "\n float getValue(" + channels.map(function (x) { return 'int ' + x; }) + ") {\n " + getValueSnippet + "\n }\n\n void main() {\n " + dtype + " coords = getOutputCoords();\n vec4 result = vec4(getValue(" + coords + "), 0., 0., 0.);\n\n " + coords[rank - 1] + " = " + coords[rank - 1] + " + 1;\n if (" + coords[rank - 1] + " < " + shape[rank - 1] + ") {\n result.g = getValue(" + coords + ");\n }\n\n " + coords[rank - 2] + " = " + coords[rank - 2] + " + 1;\n if (" + coords[rank - 2] + " < " + shape[rank - 2] + ") {\n result.a = getValue(" + coords + ");\n }\n\n " + coords[rank - 1] + " = " + coords[rank - 1] + " - 1;\n if (" + coords[rank - 2] + " < " + shape[rank - 2] + " &&\n " + coords[rank - 1] + " < " + shape[rank - 1] + ") {\n result.b = getValue(" + coords + ");\n }\n setOutput(result);\n }\n ";
- }
- return ConcatPackedProgram;
- }());
- /**
- * Return an expression for coordinates into a vector where a given channel
- * will be offset by [shift].
- *
- * @param channels the channels to consider
- * @param channel the channel we want shifted
- * @param shift the amount to subtract from the channel.
- *
- * @returns a string of the form 'x, y-[shift], z' where any one channel can
- * have the shift applied.
- */
- function shiftedChannels(channels, channel, shift) {
- var channelIdx = channels.indexOf(channel);
- var res = channels.map(function (c, idx) {
- if (idx === channelIdx) {
- return c + " - " + shift;
- }
- else {
- return c;
- }
- });
- return res.join();
- }
-
- /**
- * @license
- * Copyright 2017 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var Conv2DDerFilterProgram = /** @class */ (function () {
- function Conv2DDerFilterProgram(convInfo) {
- this.variableNames = ['x', 'dy'];
- this.outputShape = convInfo.filterShape;
- var strideHeight = convInfo.strideHeight;
- var strideWidth = convInfo.strideWidth;
- var padTop = convInfo.padInfo.top;
- var padLeft = convInfo.padInfo.left;
- var isChannelsLast = convInfo.dataFormat === 'channelsLast';
- this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int wR = coords.x;\n int wC = coords.y;\n int d1 = coords.z;\n int d2 = coords.w;\n\n // Convolve x(?, ?, d1) with dy(:, :, d2) to get dw(wR, wC, d1, d2).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n\n for (int b = 0; b < " + convInfo.batchSize + "; b++) {\n for (int yR = 0; yR < " + convInfo.outHeight + "; yR++) {\n int xR = wR + yR * " + strideHeight + " - " + padTop + ";\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int yC = 0; yC < " + convInfo.outWidth + "; yC++) {\n int xC = wC + yC * " + strideWidth + " - " + padLeft + ";\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n if (" + isChannelsLast + ") {\n float dyValue = getDy(b, yR, yC, d2);\n float xValue = getX(b, xR, xC, d1);\n dotProd += (xValue * dyValue);\n } else {\n float dyValue = getDy(b, d2, yR, yC);\n float xValue = getX(b, d1, xR, xC);\n dotProd += (xValue * dyValue);\n }\n\n }\n }\n }\n setOutput(dotProd);\n }\n ";
- }
- return Conv2DDerFilterProgram;
- }());
- var Conv2DDerInputProgram = /** @class */ (function () {
- function Conv2DDerInputProgram(convInfo) {
- this.variableNames = ['dy', 'W'];
- this.outputShape = convInfo.inShape;
- var filterHeight = convInfo.filterHeight;
- var filterWidth = convInfo.filterWidth;
- var strideHeight = convInfo.strideHeight;
- var strideWidth = convInfo.strideWidth;
- var isChannelsLast = convInfo.dataFormat === 'channelsLast';
- var padTop = filterHeight - 1 - convInfo.padInfo.top;
- var padLeft = filterWidth - 1 - convInfo.padInfo.left;
- var rowDim = isChannelsLast ? 1 : 2;
- var colDim = isChannelsLast ? 2 : 3;
- var channelDim = isChannelsLast ? 3 : 1;
- this.userCode = "\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d1 = coords[" + channelDim + "];\n\n ivec2 dyCorner = ivec2(coords[" + rowDim + "], coords[" + colDim + "]) - pads;\n int dyRCorner = dyCorner.x;\n int dyCCorner = dyCorner.y;\n\n // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n for (int wR = 0; wR < " + filterHeight + "; wR++) {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 || fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n int wRPerm = " + filterHeight + " - 1 - wR;\n\n for (int wC = 0; wC < " + filterWidth + "; wC++) {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n int wCPerm = " + filterWidth + " - 1 - wC;\n\n for (int d2 = 0; d2 < " + convInfo.outChannels + "; d2++) {\n\n if (" + isChannelsLast + ") {\n float xValue = getDy(batch, idyR, idyC, d2);\n float wValue = getW(wRPerm, wCPerm, d1, d2);\n dotProd += xValue * wValue;\n } else {\n float xValue = getDy(batch, d2, idyR, idyC);\n float wValue = getW(wRPerm, wCPerm, d1, d2);\n dotProd += xValue * wValue;\n }\n\n }\n }\n }\n setOutput(dotProd);\n }\n ";
- }
- return Conv2DDerInputProgram;
- }());
- var Conv3DDerFilterProgram = /** @class */ (function () {
- function Conv3DDerFilterProgram(convInfo) {
- this.variableNames = ['x', 'dy'];
- this.outputShape = convInfo.filterShape;
- var strideDepth = convInfo.strideDepth;
- var strideHeight = convInfo.strideHeight;
- var strideWidth = convInfo.strideWidth;
- var padFront = convInfo.padInfo.front;
- var padTop = convInfo.padInfo.top;
- var padLeft = convInfo.padInfo.left;
- this.userCode = "\n void main() {\n ivec5 coords = getOutputCoords();\n int wF = coords.x;\n int wR = coords.y;\n int wC = coords.z;\n int d1 = coords.w;\n int d2 = coords.u;\n\n float dotProd = 0.0;\n\n for (int b = 0; b < " + convInfo.batchSize + "; b++) {\n for (int yF = 0; yF < " + convInfo.outDepth + "; yF++) {\n int xF = wF + yF * " + strideDepth + " - " + padFront + ";\n\n if (xF < 0 || xF >= " + convInfo.inDepth + ") {\n continue;\n }\n\n for (int yR = 0; yR < " + convInfo.outHeight + "; yR++) {\n int xR = wR + yR * " + strideHeight + " - " + padTop + ";\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int yC = 0; yC < " + convInfo.outWidth + "; yC++) {\n int xC = wC + yC * " + strideWidth + " - " + padLeft + ";\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n float dyValue = getDy(b, yF, yR, yC, d2);\n float xValue = getX(b, xF, xR, xC, d1);\n dotProd += (xValue * dyValue);\n }\n }\n }\n }\n setOutput(dotProd);\n }\n ";
- }
- return Conv3DDerFilterProgram;
- }());
- var Conv3DDerInputProgram = /** @class */ (function () {
- function Conv3DDerInputProgram(convInfo) {
- this.variableNames = ['dy', 'W'];
- this.outputShape = convInfo.inShape;
- var filterDepth = convInfo.filterDepth;
- var filterHeight = convInfo.filterHeight;
- var filterWidth = convInfo.filterWidth;
- var strideDepth = convInfo.strideDepth;
- var strideHeight = convInfo.strideHeight;
- var strideWidth = convInfo.strideWidth;
- var padFront = filterDepth - 1 - convInfo.padInfo.front;
- var padTop = filterHeight - 1 - convInfo.padInfo.top;
- var padLeft = filterWidth - 1 - convInfo.padInfo.left;
- this.userCode = "\n const ivec3 pads = ivec3(" + padFront + ", " + padTop + ", " + padLeft + ");\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int d1 = coords.u;\n\n\n ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;\n int dyFCorner = dyCorner.x;\n int dyRCorner = dyCorner.y;\n int dyCCorner = dyCorner.z;\n\n float dotProd = 0.0;\n for (int wF = 0; wF < " + filterDepth + "; wF++) {\n float dyF = float(dyFCorner + wF) / " + strideDepth + ".0;\n\n if (dyF < 0.0 || dyF >= " + convInfo.outDepth + ".0 || fract(dyF) > 0.0) {\n continue;\n }\n int idyF = int(dyF);\n\n int wFPerm = " + filterDepth + " - 1 - wF;\n\n for (int wR = 0; wR < " + filterHeight + "; wR++) {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 ||\n fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n int wRPerm = " + filterHeight + " - 1 - wR;\n\n for (int wC = 0; wC < " + filterWidth + "; wC++) {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n int wCPerm = " + filterWidth + " - 1 - wC;\n\n for (int d2 = 0; d2 < " + convInfo.outChannels + "; d2++) {\n float xValue = getDy(batch, idyF, idyR, idyC, d2);\n float wValue = getW(wFPerm, wRPerm, wCPerm, d1, d2);\n dotProd += xValue * wValue;\n }\n }\n }\n }\n setOutput(dotProd);\n }\n ";
- }
- return Conv3DDerInputProgram;
- }());
-
- /**
- * @license
- * Copyright 2018 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var DepthwiseConv2DDerFilterProgram = /** @class */ (function () {
- function DepthwiseConv2DDerFilterProgram(convInfo) {
- this.variableNames = ['x', 'dy'];
- this.outputShape = convInfo.filterShape;
- var strideHeight = convInfo.strideHeight;
- var strideWidth = convInfo.strideWidth;
- var padTop = convInfo.padInfo.top;
- var padLeft = convInfo.padInfo.left;
- var channelMul = convInfo.outChannels / convInfo.inChannels;
- this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int wR = coords.x;\n int wC = coords.y;\n int d1 = coords.z;\n int dm = coords.w;\n int d2 = d1 * " + channelMul + " + dm;\n\n float dotProd = 0.0;\n\n // TO DO: Vec4 over the batch size\n for (int b = 0; b < " + convInfo.batchSize + "; b++) {\n for (int yR = 0; yR < " + convInfo.outHeight + "; yR++) {\n int xR = wR + yR * " + strideHeight + " - " + padTop + ";\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int yC = 0; yC < " + convInfo.outWidth + "; yC++) {\n int xC = wC + yC * " + strideWidth + " - " + padLeft + ";\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n float dyValue = getDy(b, yR, yC, d2);\n float xValue = getX(b, xR, xC, d1);\n dotProd += (xValue * dyValue);\n }\n }\n }\n setOutput(dotProd);\n }\n ";
- }
- return DepthwiseConv2DDerFilterProgram;
- }());
- var DepthwiseConv2DDerInputProgram = /** @class */ (function () {
- function DepthwiseConv2DDerInputProgram(convInfo) {
- this.variableNames = ['dy', 'W'];
- this.outputShape = convInfo.inShape;
- var filterHeight = convInfo.filterHeight;
- var filterWidth = convInfo.filterWidth;
- var strideHeight = convInfo.strideHeight;
- var strideWidth = convInfo.strideWidth;
- var padTop = filterHeight - 1 - convInfo.padInfo.top;
- var padLeft = filterWidth - 1 - convInfo.padInfo.left;
- var channelMul = convInfo.outChannels / convInfo.inChannels;
- this.userCode = "\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d1 = coords[3];\n ivec2 dyCorner = coords.yz - pads;\n int dyRCorner = dyCorner.x;\n int dyCCorner = dyCorner.y;\n\n float dotProd = 0.0;\n\n for (int wR = 0; wR < " + filterHeight + "; wR++) {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 || fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n int wRPerm = " + filterHeight + " - 1 - wR;\n\n for (int wC = 0; wC < " + filterWidth + "; wC++) {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n int wCPerm = " + filterWidth + " - 1 - wC;\n\n // TO DO: Vec4 over the channelMul\n for (int dm = 0; dm < " + channelMul + "; dm++) {\n int d2 = d1 * " + channelMul + " + dm;\n float xValue = getDy(batch, idyR, idyC, d2);\n float wValue = getW(wRPerm, wCPerm, d1, dm);\n dotProd += xValue * wValue;\n }\n }\n }\n setOutput(dotProd);\n }\n ";
- }
- return DepthwiseConv2DDerInputProgram;
- }());
-
- /**
- * @license
- * Copyright 2017 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var Conv2DProgram = /** @class */ (function () {
- function Conv2DProgram(convInfo, addBias, activation, hasPreluActivationWeights) {
- if (addBias === void 0) { addBias = false; }
- if (activation === void 0) { activation = null; }
- if (hasPreluActivationWeights === void 0) { hasPreluActivationWeights = false; }
- this.variableNames = ['x', 'W'];
- this.outputShape = convInfo.outShape;
- var padTop = convInfo.padInfo.top;
- var padLeft = convInfo.padInfo.left;
- var strideHeight = convInfo.strideHeight;
- var strideWidth = convInfo.strideWidth;
- var dilationHeight = convInfo.dilationHeight;
- var dilationWidth = convInfo.dilationWidth;
- var filterHeight = convInfo.filterHeight;
- var filterWidth = convInfo.filterWidth;
- var inputDepthNearestVec4 = Math.floor(convInfo.inChannels / 4) * 4;
- var inputDepthVec4Remainder = convInfo.inChannels % 4;
- var isChannelsLast = convInfo.dataFormat === 'channelsLast';
- var rowDim = isChannelsLast ? 1 : 2;
- var colDim = isChannelsLast ? 2 : 3;
- var channelDim = isChannelsLast ? 3 : 1;
- var activationSnippet = '', applyActivationSnippet = '';
- if (activation) {
- if (hasPreluActivationWeights) {
- activationSnippet = "float activation(float a) {\n float b = getPreluActivationWeightsAtOutCoords();\n " + activation + "\n }";
- }
- else {
- activationSnippet = "\n float activation(float x) {\n " + activation + "\n }\n ";
- }
- applyActivationSnippet = "result = activation(result);";
- }
- var addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
- if (addBias) {
- this.variableNames.push('bias');
- }
- if (hasPreluActivationWeights) {
- this.variableNames.push('preluActivationWeights');
- }
- this.userCode = "\n " + activationSnippet + "\n\n const ivec2 strides = ivec2(" + strideHeight + ", " + strideWidth + ");\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d2 = coords[" + channelDim + "];\n\n ivec2 xRCCorner =\n ivec2(coords[" + rowDim + "], coords[" + colDim + "]) * strides - pads;\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n // Convolve x(?, ?, d1) with w(:, :, d1, d2) to get y(yR, yC, d2).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n for (int wR = 0; wR < " + filterHeight + "; wR++) {\n int xR = xRCorner + wR * " + dilationHeight + ";\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + filterWidth + "; wC++) {\n int xC = xCCorner + wC * " + dilationWidth + ";\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n for (int d1 = 0; d1 < " + inputDepthNearestVec4 + "; d1 += 4) {\n vec4 wValues = vec4(\n getW(wR, wC, d1, d2),\n getW(wR, wC, d1 + 1, d2),\n getW(wR, wC, d1 + 2, d2),\n getW(wR, wC, d1 + 3, d2)\n );\n\n if (" + isChannelsLast + ") {\n vec4 xValues = vec4(\n getX(batch, xR, xC, d1),\n getX(batch, xR, xC, d1 + 1),\n getX(batch, xR, xC, d1 + 2),\n getX(batch, xR, xC, d1 + 3)\n );\n dotProd += dot(xValues, wValues);\n } else {\n vec4 xValues = vec4(\n getX(batch, d1, xR, xC),\n getX(batch, d1 + 1, xR, xC),\n getX(batch, d1 + 2, xR, xC),\n getX(batch, d1 + 3, xR, xC)\n );\n dotProd += dot(xValues, wValues);\n }\n }\n\n if (" + (inputDepthVec4Remainder === 1) + ") {\n\n if (" + isChannelsLast + ") {\n dotProd +=\n getX(batch, xR, xC, " + inputDepthNearestVec4 + ") *\n getW(wR, wC, " + inputDepthNearestVec4 + ", d2);\n } else {\n dotProd +=\n getX(batch, " + inputDepthNearestVec4 + ", xR, xC) *\n getW(wR, wC, " + inputDepthNearestVec4 + ", d2);\n }\n\n } else if (" + (inputDepthVec4Remainder === 2) + ") {\n vec2 wValues = vec2(\n getW(wR, wC, " + inputDepthNearestVec4 + ", d2),\n getW(wR, wC, " + inputDepthNearestVec4 + " + 1, d2)\n );\n\n if (" + isChannelsLast + ") {\n vec2 xValues = vec2(\n getX(batch, xR, xC, " + inputDepthNearestVec4 + "),\n getX(batch, xR, xC, " + inputDepthNearestVec4 + " + 1)\n );\n dotProd += dot(xValues, wValues);\n } else {\n vec2 xValues = vec2(\n getX(batch, " + inputDepthNearestVec4 + ", xR, xC),\n getX(batch, " + inputDepthNearestVec4 + " + 1, xR, xC)\n );\n dotProd += dot(xValues, wValues);\n }\n\n } else if (" + (inputDepthVec4Remainder === 3) + ") {\n vec3 wValues = vec3(\n getW(wR, wC, " + inputDepthNearestVec4 + ", d2),\n getW(wR, wC, " + inputDepthNearestVec4 + " + 1, d2),\n getW(wR, wC, " + inputDepthNearestVec4 + " + 2, d2)\n );\n\n if (" + isChannelsLast + ") {\n vec3 xValues = vec3(\n getX(batch, xR, xC, " + inputDepthNearestVec4 + "),\n getX(batch, xR, xC, " + inputDepthNearestVec4 + " + 1),\n getX(batch, xR, xC, " + inputDepthNearestVec4 + " + 2)\n );\n dotProd += dot(xValues, wValues);\n } else {\n vec3 xValues = vec3(\n getX(batch, " + inputDepthNearestVec4 + ", xR, xC),\n getX(batch, " + inputDepthNearestVec4 + " + 1, xR, xC),\n getX(batch, " + inputDepthNearestVec4 + " + 2, xR, xC)\n );\n dotProd += dot(xValues, wValues);\n }\n\n }\n }\n }\n\n float result = dotProd;\n " + addBiasSnippet + "\n " + applyActivationSnippet + "\n setOutput(result);\n }\n ";
- }
- return Conv2DProgram;
- }());
- var Conv3DProgram = /** @class */ (function () {
- function Conv3DProgram(convInfo) {
- this.variableNames = ['x', 'W'];
- this.outputShape = convInfo.outShape;
- var padFront = convInfo.padInfo.front;
- var padTop = convInfo.padInfo.top;
- var padLeft = convInfo.padInfo.left;
- var strideDepth = convInfo.strideDepth;
- var strideHeight = convInfo.strideHeight;
- var strideWidth = convInfo.strideWidth;
- var dilationDepth = convInfo.dilationDepth;
- var dilationHeight = convInfo.dilationHeight;
- var dilationWidth = convInfo.dilationWidth;
- var filterDepth = convInfo.filterDepth;
- var filterHeight = convInfo.filterHeight;
- var filterWidth = convInfo.filterWidth;
- var inputDepthNearestVec4 = Math.floor(convInfo.inChannels / 4) * 4;
- var inputDepthVec4Remainder = convInfo.inChannels % 4;
- this.userCode = "\n const ivec3 strides = ivec3(" + strideDepth + ", " + strideHeight + ", " + strideWidth + ");\n const ivec3 pads = ivec3(" + padFront + ", " + padTop + ", " + padLeft + ");\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int d2 = coords.u;\n\n ivec3 xFRCCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;\n int xFCorner = xFRCCorner.x;\n int xRCorner = xFRCCorner.y;\n int xCCorner = xFRCCorner.z;\n\n // Convolve x(?, ?, ?, d1) with w(:, :, :, d1, d2) to get\n // y(yF, yR, yC, d2). ? = to be determined. : = across all\n // values in that axis.\n float dotProd = 0.0;\n for (int wF = 0; wF < " + filterDepth + "; wF++) {\n int xF = xFCorner + wF * " + dilationDepth + ";\n\n if (xF < 0 || xF >= " + convInfo.inDepth + ") {\n continue;\n }\n\n for (int wR = 0; wR < " + filterHeight + "; wR++) {\n int xR = xRCorner + wR * " + dilationHeight + ";\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + filterWidth + "; wC++) {\n int xC = xCCorner + wC * " + dilationWidth + ";\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n for (int d1 = 0; d1 < " + inputDepthNearestVec4 + "; d1 += 4) {\n vec4 xValues = vec4(\n getX(batch, xF, xR, xC, d1),\n getX(batch, xF, xR, xC, d1 + 1),\n getX(batch, xF, xR, xC, d1 + 2),\n getX(batch, xF, xR, xC, d1 + 3)\n );\n vec4 wValues = vec4(\n getW(wF, wR, wC, d1, d2),\n getW(wF, wR, wC, d1 + 1, d2),\n getW(wF, wR, wC, d1 + 2, d2),\n getW(wF, wR, wC, d1 + 3, d2)\n );\n\n dotProd += dot(xValues, wValues);\n }\n\n if (" + (inputDepthVec4Remainder === 1) + ") {\n dotProd +=\n getX(batch, xF, xR, xC, " + inputDepthNearestVec4 + ") *\n getW(wF, wR, wC, " + inputDepthNearestVec4 + ", d2);\n } else if (" + (inputDepthVec4Remainder === 2) + ") {\n vec2 xValues = vec2(\n getX(batch, xF, xR, xC, " + inputDepthNearestVec4 + "),\n getX(batch, xF, xR, xC, " + inputDepthNearestVec4 + " + 1)\n );\n vec2 wValues = vec2(\n getW(wF, wR, wC, " + inputDepthNearestVec4 + ", d2),\n getW(wF, wR, wC, " + inputDepthNearestVec4 + " + 1, d2)\n );\n dotProd += dot(xValues, wValues);\n } else if (" + (inputDepthVec4Remainder === 3) + ") {\n vec3 xValues = vec3(\n getX(batch, xF, xR, xC, " + inputDepthNearestVec4 + "),\n getX(batch, xF, xR, xC, " + inputDepthNearestVec4 + " + 1),\n getX(batch, xF, xR, xC, " + inputDepthNearestVec4 + " + 2)\n );\n vec3 wValues = vec3(\n getW(wF, wR, wC, " + inputDepthNearestVec4 + ", d2),\n getW(wF, wR, wC, " + inputDepthNearestVec4 + " + 1, d2),\n getW(wF, wR, wC, " + inputDepthNearestVec4 + " + 2, d2)\n );\n dotProd += dot(xValues, wValues);\n }\n }\n }\n }\n setOutput(dotProd);\n }\n ";
- }
- return Conv3DProgram;
- }());
-
- /**
- * @license
- * Copyright 2017 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var DepthwiseConv2DProgram = /** @class */ (function () {
- function DepthwiseConv2DProgram(convInfo, addBias, activation, hasPreluActivation) {
- if (addBias === void 0) { addBias = false; }
- if (activation === void 0) { activation = null; }
- if (hasPreluActivation === void 0) { hasPreluActivation = false; }
- this.variableNames = ['x', 'W'];
- this.outputShape = convInfo.outShape;
- var xNumRows = convInfo.inHeight;
- var xNumCols = convInfo.inWidth;
- var padTop = convInfo.padInfo.top;
- var padLeft = convInfo.padInfo.left;
- var strideHeight = convInfo.strideHeight;
- var strideWidth = convInfo.strideWidth;
- var dilationHeight = convInfo.dilationHeight;
- var dilationWidth = convInfo.dilationWidth;
- var filterHeight = convInfo.filterHeight;
- var filterWidth = convInfo.filterWidth;
- var channelMul = convInfo.outChannels / convInfo.inChannels;
- var activationSnippet = '', applyActivationSnippet = '';
- if (activation) {
- if (hasPreluActivation) {
- activationSnippet = "float activation(float a) {\n float b = getPreluActivationWeightsAtOutCoords();\n " + activation + "\n }";
- }
- else {
- activationSnippet = "\n float activation(float x) {\n " + activation + "\n }\n ";
- }
- applyActivationSnippet = "result = activation(result);";
- }
- var addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
- if (addBias) {
- this.variableNames.push('bias');
- }
- if (hasPreluActivation) {
- this.variableNames.push('preluActivationWeights');
- }
- this.userCode = "\n " + activationSnippet + "\n\n const ivec2 strides = ivec2(" + strideHeight + ", " + strideWidth + ");\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords.x;\n ivec2 xRCCorner = coords.yz * strides - pads;\n int d2 = coords.w;\n int d1 = d2 / " + channelMul + ";\n int q = d2 - d1 * " + channelMul + ";\n\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n // Convolve x(?, ?, d1) with w(:, :, d1, q) to get y(yR, yC, d2).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n // TO DO(dsmilkov): Flatten the two for loops and vec4 the operations.\n for (int wR = 0; wR < " + filterHeight + "; wR++) {\n int xR = xRCorner + wR * " + dilationHeight + ";\n\n if (xR < 0 || xR >= " + xNumRows + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + filterWidth + "; wC++) {\n int xC = xCCorner + wC * " + dilationWidth + ";\n\n if (xC < 0 || xC >= " + xNumCols + ") {\n continue;\n }\n\n float xVal = getX(batch, xR, xC, d1);\n float wVal = getW(wR, wC, d1, q);\n dotProd += xVal * wVal;\n }\n }\n\n float result = dotProd;\n " + addBiasSnippet + "\n " + applyActivationSnippet + "\n setOutput(result);\n }\n ";
- }
- return DepthwiseConv2DProgram;
- }());
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var DepthwiseConvPacked2DProgram = /** @class */ (function () {
- function DepthwiseConvPacked2DProgram(convInfo, addBias, activation, hasPreluActivation) {
- if (addBias === void 0) { addBias = false; }
- if (activation === void 0) { activation = null; }
- if (hasPreluActivation === void 0) { hasPreluActivation = false; }
- this.variableNames = ['x', 'W'];
- this.packedInputs = true;
- this.packedOutput = true;
- this.outputShape = convInfo.outShape;
- var xNumRows = convInfo.inHeight;
- var xNumCols = convInfo.inWidth;
- var padTop = convInfo.padInfo.top;
- var padLeft = convInfo.padInfo.left;
- var strideHeight = convInfo.strideHeight;
- var strideWidth = convInfo.strideWidth;
- var dilationHeight = convInfo.dilationHeight;
- var dilationWidth = convInfo.dilationWidth;
- var filterHeight = convInfo.filterHeight;
- var filterWidth = convInfo.filterWidth;
- var texelsAcross = filterWidth;
- var mainLoop = "int xR; int xC; int xCOffset;";
- for (var r = 0; r < filterHeight; r++) {
- for (var c = 0; c < filterWidth; c++) {
- mainLoop += "\n vec4 xTexelR" + r + "C" + c * 2 + " = vec4(0.);\n vec4 wR" + r + "C" + c + " = vec4(0.);\n vec4 xR" + r + "C" + c + " = vec4(0.);";
- }
- }
- /**
- * This vectorized implementation works by gathering the values needed for
- * each output channel's dot product into vec4's and then multiplying them
- * all together (this happens in the final double for-loop below). Most of
- * the main loop consists of constructing these vec4's with the minimum
- * number of texture2D calls, which means making use of all four returned
- * values from a texture2D call at once.
- */
- for (var r = 0; r < filterHeight; r++) {
- for (var texelC = 0; texelC < texelsAcross; texelC++) {
- var c = texelC * 2;
- mainLoop += "\n xR = xRCorner + " + r * dilationHeight + ";\n xC = xCCorner + " + c * dilationWidth + ";\n ";
- if (strideWidth === 1) {
- if (c < filterWidth) {
- // If padding is odd, the outer texels have to be composed.
- if (padLeft % 2 === 1) {
- // TODO: Ensure vec4 previous does not result in redundant sample,
- // and avoid setting xTexelRC's that exceed the boundary in the
- // first place rather than resetting them to vec4(0)).
- // To compute xCOffset:
- // - If padding is odd, we must add 1 to ensure we ask for an
- // even-numbered row.
- // - We subtract 2 to access the previous texel.
- mainLoop += "\n xCOffset = xC + 1;\n if(xR >= 0 && xR < " + xNumRows + " && xCOffset >= 0 && xCOffset < " + xNumCols + ") {\n xTexelR" + r + "C" + c + " = getX(batch, xR, xCOffset, d1);\n\n // Need to manually clear unused channels in case\n // we're reading from recycled texture.\n if(xCOffset + 1 >= " + xNumCols + ") {\n xTexelR" + r + "C" + c + ".zw = vec2(0.);\n }\n } else {\n xTexelR" + r + "C" + c + " = vec4(0.);\n }\n\n xCOffset = xC + 1 - 2;\n if(xR >= 0 && xR < " + xNumRows + " && xCOffset >= 0 && xCOffset < " + xNumCols + ") {\n vec4 previous = getX(batch, xR, xCOffset, d1);\n\n // Need to manually clear unused channels in case\n // we're reading from recycled texture.\n if(xCOffset + 1 >= " + xNumCols + ") {\n previous.zw = vec2(0.);\n }\n\n xR" + r + "C" + c + " = vec4(previous.zw, xTexelR" + r + "C" + c + ".xy);\n } else {\n xR" + r + "C" + c + " = vec4(0, 0, xTexelR" + r + "C" + c + ".xy);\n }\n ";
- }
- else {
- // Padding is even, so xRC corresponds to a single texel.
- mainLoop += "\n if(xR >= 0 && xR < " + xNumRows + " && xC >= 0 && xC < " + xNumCols + ") {\n xTexelR" + r + "C" + c + " = getX(batch, xR, xC, d1);\n } else {\n xTexelR" + r + "C" + c + " = vec4(0.);\n }\n\n xR" + r + "C" + c + " = xTexelR" + r + "C" + c + ";\n ";
- }
- if (c + 1 < filterWidth) {
- // If dilation is even, the second entry should match the first
- // (either both are composed or both are single samples). But if
- // dilation is odd, then the second entry should be the opposite
- // of the first (if the first is composed, the second is a single
- // sample, and vice versa.)
- var nextTexelOffset = padLeft % 2 === 0 ?
- nearestLargerEven(dilationWidth) :
- dilationWidth;
- if ((dilationWidth % 2 === 0 && padLeft % 2 === 1) ||
- (dilationWidth % 2 !== 0 && padLeft % 2 !== 1)) {
- mainLoop += "\n xCOffset = xC + " + padLeft % 2 + " + " + nextTexelOffset + ";\n\n if(xR >= 0 && xR < " + xNumRows + " &&\n xCOffset >= 0 && xCOffset < " + xNumCols + ") {\n xTexelR" + r + "C" + (c + 2) + " = getX(batch, xR, xCOffset, d1);\n }\n ";
- // If dilation > 1 then the xRC's will not be able to share any
- // values, so each xRC will require two unique calls to getX.
- if (dilationWidth > 1) {
- mainLoop += "\n xCOffset -= 2;\n if(xR >= 0 && xR < " + xNumRows + " &&\n xCOffset >= 0 && xCOffset < " + xNumCols + ") {\n xTexelR" + r + "C" + c + " = getX(batch, xR, xCOffset, d1);\n } else {\n xTexelR" + r + "C" + c + " = vec4(0.);\n }\n ";
- }
- mainLoop += "\n xR" + r + "C" + (c + 1) + " = vec4(\n xTexelR" + r + "C" + c + ".zw, xTexelR" + r + "C" + (c + 2) + ".xy);\n ";
- }
- else {
- mainLoop += "\n xCOffset = xC + " + nextTexelOffset + ";\n\n if(xR >= 0 && xR < " + xNumRows + " &&\n xCOffset >= 0 && xCOffset < " + xNumCols + ") {\n xTexelR" + r + "C" + (c + 2) + " = getX(batch, xR, xCOffset, d1);\n }\n\n xR" + r + "C" + (c + 1) + " = xTexelR" + r + "C" + (c + 2) + ";\n ";
- }
- }
- }
- }
- else { // stride > 1
- if (c < filterWidth) {
- mainLoop += "\n if(xR >= 0 && xR < " + xNumRows + ") {\n ";
- // Depending on whether padLeft is even or odd, we want either the
- // xy or zw channels from X texels for xR${r}C${c}. If padLeft is
- // even, xR${r}C${c + 1} is simply the zw channels of texels we've
- // already sampled. But if padLeft is odd, xR${r}C{$c + 1}.zw will
- // need to come from the xy channels of a new texel, hence the `vec4
- // final` initialized below.
- if (padLeft % 2 === 1) {
- mainLoop += "\n xCOffset = xC + 1 - " + strideWidth + ";\n if(xCOffset >= 0 && xCOffset < " + xNumCols + ") {\n xTexelR" + r + "C" + c + " = getX(batch, xR, xCOffset, d1);\n } else {\n xTexelR" + r + "C" + c + " = vec4(0.);\n }\n\n if(xC + 1 >= 0 && xC + 1 < " + xNumCols + ") {\n xTexelR" + r + "C" + (c + 2) + " = getX(batch, xR, xC + 1, d1);\n } else {\n xTexelR" + r + "C" + (c + 2) + " = vec4(0.);\n }\n\n xR" + r + "C" + c + " = vec4(\n xTexelR" + r + "C" + c + ".zw, xTexelR" + r + "C" + (c + 2) + ".zw);\n ";
- if (c + 1 < filterWidth) {
- mainLoop += "\n vec4 final = vec4(0.);\n xCOffset = xC + 1 + " + strideWidth + ";\n if(xCOffset >= 0 && xCOffset < " + xNumCols + ") {\n final = getX(batch, xR, xCOffset, d1);\n }\n xR" + r + "C" + (c + 1) + " = vec4(xTexelR" + r + "C" + (c + 2) + ".xy, final.xy);\n ";
- }
- }
- else {
- mainLoop += "\n if(xC >= 0 && xC < " + xNumCols + ") {\n xTexelR" + r + "C" + c + " = getX(batch, xR, xC, d1);\n } else {\n xTexelR" + r + "C" + c + " = vec4(0.);\n }\n\n xCOffset = xC + " + strideWidth + ";\n if(xCOffset >= 0 && xCOffset < " + xNumCols + ") {\n xTexelR" + r + "C" + (c + 2) + " = getX(batch, xR, xCOffset, d1);\n } else {\n xTexelR" + r + "C" + (c + 2) + " = vec4(0.);\n }\n\n xR" + r + "C" + c + " = vec4(\n xTexelR" + r + "C" + c + ".xy, xTexelR" + r + "C" + (c + 2) + ".xy);\n ";
- if (c + 1 < filterWidth) {
- mainLoop += "\n xR" + r + "C" + (c + 1) + " = vec4(\n xTexelR" + r + "C" + c + ".zw, xTexelR" + r + "C" + (c + 2) + ".zw);\n ";
- }
- }
- mainLoop += "}";
- }
- }
- if (c < filterWidth) {
- mainLoop += "\n vec4 wTexelR" + r + "C" + c + " = getW(" + r + ", " + c + ", d1, q);\n wR" + r + "C" + c + " = vec4(wTexelR" + r + "C" + c + ".xz, wTexelR" + r + "C" + c + ".xz);\n ";
- if (c + 1 < filterWidth) {
- mainLoop += "\n vec4 wTexelR" + r + "C" + (c + 1) + " = getW(" + r + ", " + (c + 1) + ", d1, q);\n wR" + r + "C" + (c + 1) + " =\n vec4(wTexelR" + r + "C" + (c + 1) + ".xz, wTexelR" + r + "C" + (c + 1) + ".xz);";
- }
- }
- }
- }
- for (var r = 0; r < filterHeight; r++) {
- for (var c = 0; c < filterWidth; c++) {
- mainLoop += "dotProd += xR" + r + "C" + c + " * wR" + r + "C" + c + ";";
- }
- }
- var activationSnippet = '', applyActivationSnippet = '';
- if (activation) {
- if (hasPreluActivation) {
- activationSnippet = "vec4 activation(vec4 a) {\n vec4 b = getPreluActivationWeightsAtOutCoords();\n " + activation + "\n }";
- }
- else {
- activationSnippet = "vec4 activation(vec4 x) {\n " + activation + "\n }";
- }
- applyActivationSnippet = "result = activation(result);";
- }
- var addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
- if (addBias) {
- this.variableNames.push('bias');
- }
- if (hasPreluActivation) {
- this.variableNames.push('preluActivationWeights');
- }
- this.userCode = "\n " + activationSnippet + "\n\n const ivec2 strides = ivec2(" + strideHeight + ", " + strideWidth + ");\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n\n ivec4 coords = getOutputCoords();\n int batch = coords.x;\n ivec2 xRCCorner = coords.yz * strides - pads;\n int d2 = coords.w;\n int d1 = d2;\n int q = 0;\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n vec4 dotProd = vec4(0.);\n\n " + mainLoop + "\n\n vec4 result = dotProd;\n " + addBiasSnippet + "\n " + applyActivationSnippet + "\n setOutput(result);\n }\n ";
- }
- return DepthwiseConvPacked2DProgram;
- }());
-
- /**
- * @license
- * Copyright 2017 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var CropAndResizeProgram = /** @class */ (function () {
- function CropAndResizeProgram(imageShape, boxShape, cropSize, method, extrapolationValue) {
- this.variableNames = ['Image', 'Boxes', 'BoxInd'];
- this.outputShape = [];
- var batch = imageShape[0], imageHeight = imageShape[1], imageWidth = imageShape[2], depth = imageShape[3];
- var numBoxes = boxShape[0];
- var cropHeight = cropSize[0], cropWidth = cropSize[1];
- this.outputShape = [numBoxes, cropHeight, cropWidth, depth];
- var methodId = method === 'bilinear' ? 1 : 0;
- var _a = [imageHeight - 1 + ".0", imageWidth - 1 + ".0"], inputHeightFloat = _a[0], inputWidthFloat = _a[1];
- var _b = cropHeight > 1 ?
- [
- "" + (imageHeight - 1) / (cropHeight - 1),
- '(y2-y1) * height_ratio',
- "y1*" + inputHeightFloat + " + float(y)*(height_scale)",
- ] :
- [
- '0.0',
- '0.0',
- "0.5 * (y1+y2) * " + inputHeightFloat,
- ], heightRatio = _b[0], heightScale = _b[1], inY = _b[2];
- var _c = cropWidth > 1 ?
- [
- "" + (imageWidth - 1) / (cropWidth - 1),
- '(x2-x1) * width_ratio',
- "x1*" + inputWidthFloat + " + float(x)*(width_scale)",
- ] :
- [
- '0.0',
- '0.0',
- "0.5 * (x1+x2) * " + inputWidthFloat,
- ], widthRatio = _c[0], widthScale = _c[1], inX = _c[2];
- // Reference implementation
- // tslint:disable-next-line:max-line-length
- // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc
- this.userCode = "\n const float height_ratio = float(" + heightRatio + ");\n const float width_ratio = float(" + widthRatio + ");\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int y = coords[1];\n int x = coords[2];\n int d = coords[3];\n\n // get box vals\n float y1 = getBoxes(b,0);\n float x1 = getBoxes(b,1);\n float y2 = getBoxes(b,2);\n float x2 = getBoxes(b,3);\n\n // get image in batch index\n int bInd = round(getBoxInd(b));\n if(bInd < 0 || bInd >= " + batch + ") {\n return;\n }\n\n float height_scale = " + heightScale + ";\n float width_scale = " + widthScale + ";\n\n float in_y = " + inY + ";\n if( in_y < 0.0 || in_y > " + inputHeightFloat + " ) {\n setOutput(float(" + extrapolationValue + "));\n return;\n }\n float in_x = " + inX + ";\n if( in_x < 0.0 || in_x > " + inputWidthFloat + " ) {\n setOutput(float(" + extrapolationValue + "));\n return;\n }\n\n vec2 sourceFracIndexCR = vec2(in_x,in_y);\n if(" + methodId + " == 1) {\n // Compute the four integer indices.\n ivec2 sourceFloorCR = ivec2(sourceFracIndexCR);\n ivec2 sourceCeilCR = ivec2(ceil(sourceFracIndexCR));\n\n float topLeft = getImage(b, sourceFloorCR.y, sourceFloorCR.x, d);\n float bottomLeft = getImage(b, sourceCeilCR.y, sourceFloorCR.x, d);\n float topRight = getImage(b, sourceFloorCR.y, sourceCeilCR.x, d);\n float bottomRight = getImage(b, sourceCeilCR.y, sourceCeilCR.x, d);\n\n vec2 fracCR = sourceFracIndexCR - vec2(sourceFloorCR);\n\n float top = topLeft + (topRight - topLeft) * fracCR.x;\n float bottom = bottomLeft + (bottomRight - bottomLeft) * fracCR.x;\n float newValue = top + (bottom - top) * fracCR.y;\n setOutput(newValue);\n } else {\n // Compute the coordinators of nearest neighbor point.\n ivec2 sourceNearestCR = ivec2(floor(\n sourceFracIndexCR + vec2(0.5,0.5)));\n float newValue = getImage(b, sourceNearestCR.y, sourceNearestCR.x, d);\n setOutput(newValue);\n }\n }\n ";
- }
- return CropAndResizeProgram;
- }());
-
- /**
- * @license
- * Copyright 2018 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var CumSumProgram = /** @class */ (function () {
- function CumSumProgram(shape, exclusive, reverse) {
- this.variableNames = ['x'];
- this.outputShape = shape;
- var rank = shape.length;
- var finalDim = shape[shape.length - 1];
- var comparator = reverse ? '<' : '>';
- this.userCode = "\n int getIndex(int i) {\n " + (reverse ? "return " + finalDim + " -i - 1;" : 'return i;') + "\n }\n\n void main() {\n " + getCoordsDataType(rank) + " coords = getOutputCoords();\n int end = " + getFinalCoord(rank, 'coords') + ";\n float val = 0.0;\n for (int i = " + finalDim + " - 1; i >= 0; i -= 1) {\n int idx = getIndex(i);\n if (idx " + comparator + " end) {\n continue;\n }\n if (idx == end && " + exclusive + ") {\n continue;\n }\n " + getFinalCoord(rank, 'coords') + " = idx;\n val += getX(" + getCoords(rank, 'coords') + ");\n }\n setOutput(val);\n }\n ";
- }
- return CumSumProgram;
- }());
- function getCoords(rank, name) {
- if (rank === 1) {
- return "" + name;
- }
- else if (rank === 2) {
- return name + ".x, " + name + ".y";
- }
- else if (rank === 3) {
- return name + ".x, " + name + ".y, " + name + ".z";
- }
- else if (rank === 4) {
- return name + ".x, " + name + ".y, " + name + ".z, " + name + ".w";
- }
- else {
- throw Error("Cumulative sum for rank " + rank + " is not yet supported");
- }
- }
- function getFinalCoord(rank, name) {
- if (rank === 1) {
- return "" + name;
- }
- else if (rank === 2) {
- return name + ".y";
- }
- else if (rank === 3) {
- return name + ".z";
- }
- else if (rank === 4) {
- return name + ".w";
- }
- else {
- throw Error("Cumulative sum for rank " + rank + " is not yet supported");
- }
- }
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var DecodeMatrixProgram = /** @class */ (function () {
- function DecodeMatrixProgram(outputShape) {
- this.variableNames = ['A'];
- this.packedInputs = false;
- this.packedOutput = true;
- this.outPackingScheme = PackingScheme.DENSE;
- var texShape = getDenseTexShape(outputShape);
- var glsl = getGlslDifferences();
- this.outputShape = outputShape;
- this.userCode = "\n ivec3 outCoordsFromFlatIndex(int index) {\n " + getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], outputShape) + "\n return ivec3(r, c, d);\n }\n\n void main() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = 4 * (resTexRC.x * " + texShape[1] + " + resTexRC.y);\n\n vec4 result = vec4(0.);\n\n for (int i=0; i<4; i++) {\n int flatIndex = index + i;\n ivec3 rc = outCoordsFromFlatIndex(flatIndex);\n result[i] = getA(rc.x, rc.y, rc.z);\n }\n\n " + glsl.output + " = result;\n }\n ";
- }
- return DecodeMatrixProgram;
- }());
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var DecodeMatrixPackedProgram = /** @class */ (function () {
- function DecodeMatrixPackedProgram(outputShape) {
- this.variableNames = ['A'];
- this.packedInputs = true;
- this.packedOutput = true;
- this.outPackingScheme = PackingScheme.DENSE;
- var texShape = getDenseTexShape(outputShape);
- var glsl = getGlslDifferences();
- this.outputShape = outputShape;
- this.userCode = "\n ivec3 outCoordsFromFlatIndex(int index) {\n " + getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], outputShape) + "\n return ivec3(r, c, d);\n }\n\n void main() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = 4 * (resTexRC.x * " + texShape[1] + " + resTexRC.y);\n\n vec4 result = vec4(0.);\n\n for (int i=0; i<4; i++) {\n int flatIndex = index + i;\n ivec3 rc = outCoordsFromFlatIndex(flatIndex);\n result[i] = getChannel(getA(rc.x, rc.y, rc.z), vec2(rc.y, rc.z));\n }\n\n " + glsl.output + " = result;\n }\n ";
- }
- return DecodeMatrixPackedProgram;
- }());
-
- /**
- * @license
- * Copyright 2018 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var DepthToSpaceProgram = /** @class */ (function () {
- function DepthToSpaceProgram(outputShape, blockSize, dataFormat) {
- this.variableNames = ['x'];
- this.outputShape = [];
- this.outputShape = outputShape;
- this.blockSize = blockSize;
- this.dataFormat = dataFormat;
- this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int h = " + this.getHeightCoordString() + ";\n int w = " + this.getWidthCoordString() + ";\n int d = " + this.getDepthCoordString() + ";\n\n int in_h = h / " + blockSize + ";\n int offset_h = imod(h, " + blockSize + ");\n int in_w = w / " + blockSize + ";\n int offset_w = imod(w, " + blockSize + ");\n int offset_d = (offset_h * " + blockSize + " + offset_w) *\n " + this.getOutputDepthSize() + ";\n int in_d = d + offset_d;\n\n float result = " + this.getInputSamplingString() + ";\n setOutput(result);\n }\n ";
- }
- DepthToSpaceProgram.prototype.getHeightCoordString = function () {
- if (this.dataFormat === 'NHWC') {
- return "coords[1]";
- }
- else {
- return "coords[2]";
- }
- };
- DepthToSpaceProgram.prototype.getWidthCoordString = function () {
- if (this.dataFormat === 'NHWC') {
- return "coords[2]";
- }
- else {
- return "coords[3]";
- }
- };
- DepthToSpaceProgram.prototype.getDepthCoordString = function () {
- if (this.dataFormat === 'NHWC') {
- return "coords[3]";
- }
- else {
- return "coords[1]";
- }
- };
- DepthToSpaceProgram.prototype.getOutputDepthSize = function () {
- if (this.dataFormat === 'NHWC') {
- return this.outputShape[3];
- }
- else {
- return this.outputShape[1];
- }
- };
- DepthToSpaceProgram.prototype.getInputSamplingString = function () {
- if (this.dataFormat === 'NHWC') {
- return "getX(b, in_h, in_w, in_d)";
- }
- else {
- return "getX(b, in_d, in_h, in_w)";
- }
- };
- return DepthToSpaceProgram;
- }());
-
- /**
- * @license
- * Copyright 2019 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var DiagProgram = /** @class */ (function () {
- function DiagProgram(size) {
- this.variableNames = ['X'];
- this.outputShape = [size, size];
- this.userCode = "\n void main() {\n ivec2 coords = getOutputCoords();\n float val = coords[0] == coords[1] ? getX(coords[0]) : 0.0;\n setOutput(val);\n }\n ";
- }
- return DiagProgram;
- }());
-
- /**
- * @license
- * Copyright 2018 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var EncodeFloatProgram = /** @class */ (function () {
- function EncodeFloatProgram(outputShape) {
- this.variableNames = ['A'];
- this.outTexUsage = TextureUsage.DOWNLOAD;
- var glsl = getGlslDifferences();
- this.outputShape = outputShape;
- this.userCode = "\n " + ENCODE_FLOAT_SNIPPET + "\n\n void main() {\n float x = getAAtOutCoords();\n " + glsl.output + " = encode_float(x);\n }\n ";
- }
- return EncodeFloatProgram;
- }());
-
- /**
- * @license
- * Copyright 2018 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var EncodeFloatPackedProgram = /** @class */ (function () {
- function EncodeFloatPackedProgram(outputShape) {
- this.variableNames = ['A'];
- this.packedInputs = true;
- this.packedOutput = false;
- this.outTexUsage = TextureUsage.DOWNLOAD;
- var glsl = getGlslDifferences();
- this.outputShape = outputShape;
- this.userCode = "\n " + ENCODE_FLOAT_SNIPPET + "\n\n void main() {\n ivec3 coords = getOutputCoords();\n float x = getChannel(getAAtOutCoords(), vec2(coords.y, coords.z));\n " + glsl.output + " = encode_float(x);\n }\n ";
- }
- return EncodeFloatPackedProgram;
- }());
-
- /**
- * @license
- * Copyright 2018 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var EncodeMatrixProgram = /** @class */ (function () {
- function EncodeMatrixProgram(outputShape, texShape, inputIsUnsignedByte) {
- if (inputIsUnsignedByte === void 0) { inputIsUnsignedByte = false; }
- this.variableNames = ['A'];
- var glsl = getGlslDifferences();
- var height = texShape[0], width = texShape[1];
- this.outputShape = outputShape;
- var output = "result";
- if (inputIsUnsignedByte) {
- output = "floor(result * 255. + 0.5)";
- }
- this.userCode = "\n " + getFlatIndexFrom3D(outputShape) + "\n\n void main() {\n ivec3 coords = getOutputCoords();\n\n int flatIndex = getFlatIndex(coords);\n int offset = imod(flatIndex, 4);\n\n flatIndex = idiv(flatIndex, 4, 1.);\n \n int r = flatIndex / " + width + ";\n int c = imod(flatIndex, " + width + ");\n vec2 uv = (vec2(c, r) + halfCR) / vec2(" + width + ".0, " + height + ".0);\n vec4 values = " + glsl.texture2D + "(A, uv);\n\n float result;\n\n if(offset == 0) {\n result = values[0];\n } else if(offset == 1) {\n result = values[1];\n } else if(offset == 2) {\n result = values[2];\n } else {\n result = values[3];\n }\n\n " + glsl.output + " = vec4(" + output + ", 0., 0., 0.);\n }\n ";
- }
- return EncodeMatrixProgram;
- }());
-
- /**
- * @license
- * Copyright 2018 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- /*
- This is how the shader encodes a tensor with shape = [2, 3, 5]
- (indices are [batch, row, col]).
-
- 000|001 002|003 004|xxx 020|021 022|023 024|xxx
- ------- ------- ------- ------- ------- -------
- 010|011 012|013 014|xxx xxx|xxx xxx|xxx xxx|xxx
-
- 100|101 102|103 104|xxx 120|121 122|123 124|xxx
- ------- ------- ------- ------- ------- -------
- 110|111 112|113 114|xxx xxx|xxx xxx|xxx xxx|xxx
-
- Single texels contain only values from the same batch, and from adjacent rows
- and columns.
- */
- var EncodeMatrixPackedProgram = /** @class */ (function () {
- function EncodeMatrixPackedProgram(outputShape, texShape, inputIsUnsignedByte) {
- if (inputIsUnsignedByte === void 0) { inputIsUnsignedByte = false; }
- this.variableNames = ['A'];
- this.packedInputs = false;
- this.packedOutput = true;
- var glsl = getGlslDifferences();
- var height = texShape[0], width = texShape[1];
- this.outputShape = outputShape;
- var mainLoop = '';
- var output = 'result';
- if (inputIsUnsignedByte) {
- output = 'floor(result * 255. + 0.5)';
- }
- for (var row = 0; row <= 1; row++) {
- for (var col = 0; col <= 1; col++) {
- var channel = row * 2 + col;
- mainLoop += "\n localCoords = coords;\n if(localCoords[2] + " + col + " < " + outputShape[2] + ") {\n localCoords[2] += " + col + ";\n if(localCoords[1] + " + row + " < " + outputShape[1] + ") {\n localCoords[1] += " + row + ";\n\n flatIndex = getFlatIndex(localCoords);\n offset = imod(flatIndex, 4);\n\n flatIndex = idiv(flatIndex, 4, 1.);\n\n r = flatIndex / " + width + ";\n c = imod(flatIndex, " + width + ");\n uv = (vec2(c, r) + halfCR) / vec2(" + width + ".0, " + height + ".0);\n values = " + glsl.texture2D + "(A, uv);\n\n if(offset == 0) {\n result[" + channel + "] = values[0];\n } else if(offset == 1) {\n result[" + channel + "] = values[1];\n } else if(offset == 2) {\n result[" + channel + "] = values[2];\n } else {\n result[" + channel + "] = values[3];\n }\n }\n }\n ";
- }
- }
- this.userCode = "\n " + getFlatIndexFrom3D(outputShape) + "\n\n void main() {\n ivec3 coords = getOutputCoords();\n\n vec4 result = vec4(0.);\n int flatIndex, r, c, offset;\n ivec3 localCoords;\n vec2 uv;\n vec4 values;\n\n " + mainLoop + "\n\n " + glsl.output + " = " + output + ";\n }\n ";
- }
- return EncodeMatrixPackedProgram;
- }());
-
- /**
- * @license
- * Copyright 2018 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var COMPLEX_FFT = {
- REAL: 'return real * expR - imag * expI;',
- IMAG: 'return real * expI + imag * expR;'
- };
- var FFTProgram = /** @class */ (function () {
- function FFTProgram(op, inputShape, inverse) {
- this.variableNames = ['real', 'imag'];
- var innerDim = inputShape[1];
- this.outputShape = inputShape;
- var exponentMultiplierSnippet = inverse ? "2.0 * " + Math.PI : "-2.0 * " + Math.PI;
- var resultDenominator = inverse ? innerDim + ".0" : '1.0';
- this.userCode = "\n const float exponentMultiplier = " + exponentMultiplierSnippet + ";\n\n float unaryOpComplex(float real, float expR, float imag, float expI) {\n " + op + "\n }\n\n float mulMatDFT(int batch, int index) {\n float indexRatio = float(index) / float(" + innerDim + ");\n float exponentMultiplierTimesIndexRatio =\n exponentMultiplier * indexRatio;\n\n float result = 0.0;\n\n for (int i = 0; i < " + innerDim + "; i++) {\n // x = (-2|2 * PI / N) * index * i;\n float x = exponentMultiplierTimesIndexRatio * float(i);\n float expR = cos(x);\n float expI = sin(x);\n float real = getReal(batch, i);\n float imag = getImag(batch, i);\n\n result +=\n unaryOpComplex(real, expR, imag, expI) / " + resultDenominator + ";\n }\n\n return result;\n }\n\n void main() {\n ivec2 coords = getOutputCoords();\n setOutput(mulMatDFT(coords[0], coords[1]));\n }\n ";
- }
- return FFTProgram;
- }());
-
- /**
- * @license
- * Copyright 2019 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var FillProgram = /** @class */ (function () {
- function FillProgram(shape, value) {
- this.outputShape = [];
- this.variableNames = ['x'];
- this.outputShape = shape;
- this.userCode = "\n uniform float value;\n void main() {\n // Input can be obtained from uniform value.\n setOutput(value);\n }\n ";
- }
- FillProgram.prototype.getCustomSetupFunc = function (value) {
- var _this = this;
- return function (gpgpu, webGLProgram) {
- if (_this.valueLoc == null) {
- _this.valueLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'value');
- }
- gpgpu.gl.uniform1f(_this.valueLoc, value);
- };
- };
- return FillProgram;
- }());
-
- /**
- * @license
- * Copyright 2017 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var GatherProgram = /** @class */ (function () {
- function GatherProgram(aShape, indicesLength, axis) {
- this.variableNames = ['A', 'indices'];
- var outputShape = aShape.slice();
- outputShape[axis] = indicesLength;
- this.outputShape = outputShape;
- this.rank = outputShape.length;
- var dtype = getCoordsDataType(this.rank);
- var sourceCoords = getSourceCoords$1(aShape, axis);
- this.userCode = "\n void main() {\n " + dtype + " resRC = getOutputCoords();\n setOutput(getA(" + sourceCoords + "));\n }\n ";
- }
- return GatherProgram;
- }());
- function getSourceCoords$1(aShape, axis) {
- var rank = aShape.length;
- if (rank > 4) {
- throw Error("Gather for rank " + rank + " is not yet supported");
- }
- if (rank === 1) {
- return "int(getIndices(resRC))";
- }
- var currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w'];
- var sourceCoords = [];
- for (var i = 0; i < aShape.length; i++) {
- if (i === axis) {
- sourceCoords.push("int(getIndices(" + currentCoords[i] + "))");
- }
- else {
- sourceCoords.push("" + currentCoords[i]);
- }
- }
- return sourceCoords.join();
- }
-
- var GatherNDProgram = /** @class */ (function () {
- function GatherNDProgram(sliceDim, strides, shape) {
- this.sliceDim = sliceDim;
- this.strides = strides;
- this.variableNames = ['x', 'indices'];
- this.outputShape = shape;
- var stridesType = getCoordsDataType(strides.length);
- var dtype = getCoordsDataType(shape.length);
- var strideString = this.sliceDim > 1 ? 'strides[j]' : 'strides';
- this.userCode = "\n " + stridesType + " strides = " + stridesType + "(" + this.strides + ");\n void main() {\n " + dtype + " coords = getOutputCoords();\n int flattenIndex = 0;\n for (int j = 0; j < " + this.sliceDim + "; j++) {\n int index = round(getIndices(coords[0], j));\n flattenIndex += index * " + strideString + ";\n }\n setOutput(getX(flattenIndex, coords[1]));\n }\n ";
- }
- return GatherNDProgram;
- }());
-
- /**
- * @license
- * Copyright 2017 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function createVertexShader$1(gl, debug) {
- var glsl = getGlslDifferences();
- var vertexShaderSource = glsl.version + "\n precision highp float;\n " + glsl.attribute + " vec3 clipSpacePos;\n " + glsl.attribute + " vec2 uv;\n " + glsl.varyingVs + " vec2 resultUV;\n\n void main() {\n gl_Position = vec4(clipSpacePos, 1);\n resultUV = uv;\n }";
- return createVertexShader(gl, debug, vertexShaderSource);
- }
- function createVertexBuffer(gl, debug) {
- // [x y z u v] * [upper-left, lower-left, upper-right, lower-right]
- var vertexArray = new Float32Array([-1, 1, 0, 0, 1, -1, -1, 0, 0, 0, 1, 1, 0, 1, 1, 1, -1, 0, 1, 0]);
- return createStaticVertexBuffer(gl, debug, vertexArray);
- }
- function createIndexBuffer(gl, debug) {
- // OpenGL (and WebGL) have "CCW == front" winding
- var triangleVertexIndices = new Uint16Array([0, 1, 2, 2, 1, 3]);
- return createStaticIndexBuffer(gl, debug, triangleVertexIndices);
- }
- function createAndConfigureTexture(gl, debug, width, height, internalFormat, textureFormat, textureType) {
- validateTextureSize(width, height);
- var texture = createTexture(gl, debug);
- var tex2d = gl.TEXTURE_2D;
- callAndCheck(gl, debug, function () { return gl.bindTexture(tex2d, texture); });
- callAndCheck(gl, debug, function () { return gl.texParameteri(tex2d, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE); });
- callAndCheck(gl, debug, function () { return gl.texParameteri(tex2d, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE); });
- callAndCheck(gl, debug, function () { return gl.texParameteri(tex2d, gl.TEXTURE_MIN_FILTER, gl.NEAREST); });
- callAndCheck(gl, debug, function () { return gl.texParameteri(tex2d, gl.TEXTURE_MAG_FILTER, gl.NEAREST); });
- callAndCheck(gl, debug, function () { return gl.texImage2D(tex2d, 0, internalFormat, width, height, 0, textureFormat, textureType, null); });
- callAndCheck(gl, debug, function () { return gl.bindTexture(gl.TEXTURE_2D, null); });
- return texture;
- }
- function createFloat32MatrixTexture(gl, debug, rows, columns, textureConfig) {
- var _a = getUnpackedMatrixTextureShapeWidthHeight(rows, columns), width = _a[0], height = _a[1];
- return createAndConfigureTexture(gl, debug, width, height, textureConfig.internalFormatFloat, textureConfig.textureFormatFloat, gl.FLOAT);
- }
- function createFloat16MatrixTexture(gl, debug, rows, columns, textureConfig) {
- var _a = getUnpackedMatrixTextureShapeWidthHeight(rows, columns), width = _a[0], height = _a[1];
- return createAndConfigureTexture(gl, debug, width, height, textureConfig.internalFormatHalfFloat, textureConfig.textureFormatFloat, textureConfig.textureTypeHalfFloat);
- }
- function createUnsignedBytesMatrixTexture(gl, debug, rows, columns, textureConfig) {
- var _a = getUnpackedMatrixTextureShapeWidthHeight(rows, columns), width = _a[0], height = _a[1];
- return createAndConfigureTexture(gl, debug, width, height, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE);
- }
- function createPackedMatrixTexture(gl, debug, rows, columns, textureConfig) {
- var _a = getPackedMatrixTextureShapeWidthHeight(rows, columns), width = _a[0], height = _a[1];
- return createAndConfigureTexture(gl, debug, width, height, textureConfig.internalFormatPackedFloat, gl.RGBA, gl.FLOAT);
- }
- function createFloat16PackedMatrixTexture(gl, debug, rows, columns, textureConfig) {
- var _a = getPackedMatrixTextureShapeWidthHeight(rows, columns), width = _a[0], height = _a[1];
- return createAndConfigureTexture(gl, debug, width, height, textureConfig.internalFormatPackedHalfFloat, gl.RGBA, textureConfig.textureTypeHalfFloat);
- }
- function bindVertexProgramAttributeStreams(gl, debug, program, vertexBuffer) {
- var posOffset = 0; // x is the first buffer element
- var uvOffset = 3 * 4; // uv comes after [x y z]
- var stride = (3 * 4) + (2 * 4); // xyz + uv, each entry is 4-byte float.
- callAndCheck(gl, debug, function () { return gl.bindBuffer(gl.ARRAY_BUFFER, vertexBuffer); });
- var success = bindVertexBufferToProgramAttribute(gl, debug, program, 'clipSpacePos', vertexBuffer, 3, stride, posOffset);
- return success &&
- bindVertexBufferToProgramAttribute(gl, debug, program, 'uv', vertexBuffer, 2, stride, uvOffset);
- }
- function uploadDenseMatrixToTexture(gl, debug, texture, width, height, data, textureConfig) {
- callAndCheck(gl, debug, function () { return gl.bindTexture(gl.TEXTURE_2D, texture); });
- var dataForUpload, texelDataType, internalFormat;
- if (data instanceof Uint8Array) {
- dataForUpload = new Uint8Array(width * height * 4);
- texelDataType = gl.UNSIGNED_BYTE;
- internalFormat = gl.RGBA;
- }
- else {
- dataForUpload = new Float32Array(width * height * 4);
- texelDataType = gl.FLOAT;
- internalFormat = textureConfig.internalFormatPackedFloat;
- }
- dataForUpload.set(data);
- callAndCheck(gl, debug, function () { return gl.texImage2D(gl.TEXTURE_2D, 0, internalFormat, width, height, 0, gl.RGBA, texelDataType, dataForUpload); });
- callAndCheck(gl, debug, function () { return gl.bindTexture(gl.TEXTURE_2D, null); });
- }
- function uploadPixelDataToTexture(gl, debug, texture, pixels) {
- callAndCheck(gl, debug, function () { return gl.bindTexture(gl.TEXTURE_2D, texture); });
- if (pixels.data instanceof Uint8Array) {
- callAndCheck(gl, debug, function () { return gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA, pixels.width, pixels.height, 0, gl.RGBA, gl.UNSIGNED_BYTE, pixels.data); });
- }
- else {
- callAndCheck(gl, debug, function () { return gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, pixels); });
- }
- callAndCheck(gl, debug, function () { return gl.bindTexture(gl.TEXTURE_2D, null); });
- }
- function createBufferFromOutputTexture(gl2, debug, rows, columns, textureConfig) {
- // Create and bind the buffer.
- var buffer = gl2.createBuffer();
- callAndCheck(gl2, debug, function () { return gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer); });
- // Initialize the buffer to the size of the texture in bytes.
- var bytesPerFloat = 4;
- var valuesPerTexel = 4;
- var bufferSizeBytes = bytesPerFloat * valuesPerTexel * rows * columns;
- callAndCheck(gl2, debug, function () { return gl2.bufferData(gl2.PIXEL_PACK_BUFFER, bufferSizeBytes, gl2.STREAM_READ); });
- // Enqueue a command on the GPU command queue to copy of texture into the
- // buffer.
- callAndCheck(gl2, debug, function () { return gl2.readPixels(0, 0, columns, rows, gl2.RGBA, gl2.FLOAT, 0); });
- callAndCheck(gl2, debug, function () { return gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null); });
- return buffer;
- }
- function downloadFloat32MatrixFromBuffer(gl, buffer, size) {
- var gl2 = gl;
- var downloadTarget = new Float32Array(size);
- gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer);
- gl2.getBufferSubData(gl2.PIXEL_PACK_BUFFER, 0, downloadTarget);
- gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null);
- return downloadTarget;
- }
- function downloadByteEncodedFloatMatrixFromOutputTexture(gl, debug, rows, columns, textureConfig) {
- var _a = getUnpackedMatrixTextureShapeWidthHeight(rows, columns), w = _a[0], h = _a[1];
- var numChannels = 4;
- var downloadTarget = new Uint8Array(getUnpackedArraySizeFromMatrixSize(rows * columns, numChannels));
- callAndCheck(gl, debug, function () { return gl.readPixels(0, 0, w, h, textureConfig.downloadTextureFormat, gl.UNSIGNED_BYTE, downloadTarget); });
- // By wrapping the buffer in a Float32Array, we use native browser IEEE 754
- // decoding of the 4 bytes that back each 32 bit float.
- return new Float32Array(downloadTarget.buffer);
- }
- function downloadPackedMatrixFromBuffer(gl, buffer, batch, rows, cols, physicalRows, physicalCols, textureConfig) {
- var gl2 = gl;
- var downloadTarget = new Float32Array(getPackedRGBAArraySizeFromMatrixShape(physicalRows, physicalCols));
- gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer);
- gl2.getBufferSubData(gl2.PIXEL_PACK_BUFFER, 0, downloadTarget);
- gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null);
- return downloadTarget;
- }
- function downloadMatrixFromPackedOutputTexture(gl, debug, physicalRows, physicalCols) {
- var packedRGBA = new Float32Array(physicalRows * physicalCols * 4);
- callAndCheck(gl, debug, function () { return gl.readPixels(0, 0, physicalCols, physicalRows, gl.RGBA, gl.FLOAT, packedRGBA); });
- return packedRGBA;
- }
-
- var gpgpu_util = /*#__PURE__*/Object.freeze({
- createVertexShader: createVertexShader$1,
- createVertexBuffer: createVertexBuffer,
- createIndexBuffer: createIndexBuffer,
- createFloat32MatrixTexture: createFloat32MatrixTexture,
- createFloat16MatrixTexture: createFloat16MatrixTexture,
- createUnsignedBytesMatrixTexture: createUnsignedBytesMatrixTexture,
- createPackedMatrixTexture: createPackedMatrixTexture,
- createFloat16PackedMatrixTexture: createFloat16PackedMatrixTexture,
- bindVertexProgramAttributeStreams: bindVertexProgramAttributeStreams,
- uploadDenseMatrixToTexture: uploadDenseMatrixToTexture,
- uploadPixelDataToTexture: uploadPixelDataToTexture,
- createBufferFromOutputTexture: createBufferFromOutputTexture,
- downloadFloat32MatrixFromBuffer: downloadFloat32MatrixFromBuffer,
- downloadByteEncodedFloatMatrixFromOutputTexture: downloadByteEncodedFloatMatrixFromOutputTexture,
- downloadPackedMatrixFromBuffer: downloadPackedMatrixFromBuffer,
- downloadMatrixFromPackedOutputTexture: downloadMatrixFromPackedOutputTexture
- });
-
- /**
- * @license
- * Copyright 2017 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var GPGPUContext = /** @class */ (function () {
- function GPGPUContext(gl) {
- this.outputTexture = null;
- this.program = null;
- this.disposed = false;
- this.vertexAttrsAreBound = false;
- this.itemsToPoll = [];
- var glVersion = env().getNumber('WEBGL_VERSION');
- if (gl != null) {
- this.gl = gl;
- setWebGLContext(glVersion, gl);
- }
- else {
- this.gl = getWebGLContext(glVersion);
- }
- // WebGL 2.0 enables texture floats without an extension.
- var COLOR_BUFFER_FLOAT = 'WEBGL_color_buffer_float';
- var COLOR_BUFFER_HALF_FLOAT = 'EXT_color_buffer_half_float';
- if (env().getNumber('WEBGL_VERSION') === 1) {
- var TEXTURE_FLOAT = 'OES_texture_float';
- var TEXTURE_HALF_FLOAT = 'OES_texture_half_float';
- this.textureFloatExtension =
- getExtensionOrThrow(this.gl, this.debug, TEXTURE_FLOAT);
- if (hasExtension(this.gl, TEXTURE_HALF_FLOAT)) {
- this.textureHalfFloatExtension = getExtensionOrThrow(this.gl, this.debug, TEXTURE_HALF_FLOAT);
- }
- else if (env().get('WEBGL_FORCE_F16_TEXTURES')) {
- throw new Error('GL context does not support half float textures, yet the ' +
- 'environment flag WEBGL_FORCE_F16_TEXTURES is set to true.');
- }
- this.colorBufferFloatExtension = this.gl.getExtension(COLOR_BUFFER_FLOAT);
- if (hasExtension(this.gl, COLOR_BUFFER_HALF_FLOAT)) {
- this.colorBufferHalfFloatExtension = getExtensionOrThrow(this.gl, this.debug, COLOR_BUFFER_HALF_FLOAT);
- }
- else if (env().get('WEBGL_FORCE_F16_TEXTURES')) {
- throw new Error('GL context does not support color renderable half floats, yet ' +
- 'the environment flag WEBGL_FORCE_F16_TEXTURES is set to true.');
- }
- }
- else {
- COLOR_BUFFER_FLOAT = 'EXT_color_buffer_float';
- if (hasExtension(this.gl, COLOR_BUFFER_FLOAT)) {
- this.colorBufferFloatExtension =
- this.gl.getExtension(COLOR_BUFFER_FLOAT);
- }
- else if (hasExtension(this.gl, COLOR_BUFFER_HALF_FLOAT)) {
- this.colorBufferHalfFloatExtension =
- this.gl.getExtension(COLOR_BUFFER_HALF_FLOAT);
- }
- else {
- throw new Error('GL context does not support color renderable floats');
- }
- }
- this.vertexBuffer = createVertexBuffer(this.gl, this.debug);
- this.indexBuffer = createIndexBuffer(this.gl, this.debug);
- this.framebuffer = createFramebuffer(this.gl, this.debug);
- this.textureConfig =
- getTextureConfig(this.gl, this.textureHalfFloatExtension);
- }
- Object.defineProperty(GPGPUContext.prototype, "debug", {
- get: function () {
- return env().getBool('DEBUG');
- },
- enumerable: true,
- configurable: true
- });
- GPGPUContext.prototype.dispose = function () {
- var _this = this;
- if (this.disposed) {
- return;
- }
- if (this.program != null) {
- console.warn('Disposing a GPGPUContext that still has a bound WebGLProgram.' +
- ' This is probably a resource leak, delete the program with ' +
- 'GPGPUContext.deleteProgram before disposing.');
- }
- if (this.outputTexture != null) {
- console.warn('Disposing a GPGPUContext that still has a bound output matrix ' +
- 'texture. This is probably a resource leak, delete the output ' +
- 'matrix texture with GPGPUContext.deleteMatrixTexture before ' +
- 'disposing.');
- }
- var gl = this.gl;
- callAndCheck(gl, this.debug, function () { return gl.finish(); });
- callAndCheck(gl, this.debug, function () { return gl.bindFramebuffer(gl.FRAMEBUFFER, null); });
- callAndCheck(gl, this.debug, function () { return gl.deleteFramebuffer(_this.framebuffer); });
- callAndCheck(gl, this.debug, function () { return gl.bindBuffer(gl.ARRAY_BUFFER, null); });
- callAndCheck(gl, this.debug, function () { return gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, null); });
- callAndCheck(gl, this.debug, function () { return gl.deleteBuffer(_this.indexBuffer); });
- this.disposed = true;
- };
- GPGPUContext.prototype.createFloat32MatrixTexture = function (rows, columns) {
- this.throwIfDisposed();
- return createFloat32MatrixTexture(this.gl, this.debug, rows, columns, this.textureConfig);
- };
- GPGPUContext.prototype.createFloat16MatrixTexture = function (rows, columns) {
- this.throwIfDisposed();
- return createFloat16MatrixTexture(this.gl, this.debug, rows, columns, this.textureConfig);
- };
- GPGPUContext.prototype.createUnsignedBytesMatrixTexture = function (rows, columns) {
- this.throwIfDisposed();
- return createUnsignedBytesMatrixTexture(this.gl, this.debug, rows, columns, this.textureConfig);
- };
- GPGPUContext.prototype.uploadPixelDataToTexture = function (texture, pixels) {
- this.throwIfDisposed();
- uploadPixelDataToTexture(this.gl, this.debug, texture, pixels);
- };
- GPGPUContext.prototype.uploadDenseMatrixToTexture = function (texture, width, height, data) {
- this.throwIfDisposed();
- uploadDenseMatrixToTexture(this.gl, this.debug, texture, width, height, data, this.textureConfig);
- };
- GPGPUContext.prototype.createFloat16PackedMatrixTexture = function (rows, columns) {
- this.throwIfDisposed();
- return createFloat16PackedMatrixTexture(this.gl, this.debug, rows, columns, this.textureConfig);
- };
- GPGPUContext.prototype.createPackedMatrixTexture = function (rows, columns) {
- this.throwIfDisposed();
- return createPackedMatrixTexture(this.gl, this.debug, rows, columns, this.textureConfig);
- };
- GPGPUContext.prototype.deleteMatrixTexture = function (texture) {
- var _this = this;
- this.throwIfDisposed();
- if (this.outputTexture === texture) {
- unbindColorTextureFromFramebuffer(this.gl, this.debug, this.framebuffer);
- this.outputTexture = null;
- }
- callAndCheck(this.gl, this.debug, function () { return _this.gl.deleteTexture(texture); });
- };
- GPGPUContext.prototype.downloadByteEncodedFloatMatrixFromOutputTexture = function (texture, rows, columns) {
- var _this = this;
- return this.downloadMatrixDriver(texture, function () { return downloadByteEncodedFloatMatrixFromOutputTexture(_this.gl, _this.debug, rows, columns, _this.textureConfig); });
- };
- GPGPUContext.prototype.downloadPackedMatrixFromBuffer = function (buffer, batch, rows, columns, physicalRows, physicalCols) {
- return downloadPackedMatrixFromBuffer(this.gl, buffer, batch, rows, columns, physicalRows, physicalCols, this.textureConfig);
- };
- GPGPUContext.prototype.downloadFloat32MatrixFromBuffer = function (buffer, size) {
- return downloadFloat32MatrixFromBuffer(this.gl, buffer, size);
- };
- GPGPUContext.prototype.createBufferFromTexture = function (texture, rows, columns) {
- this.bindTextureToFrameBuffer(texture);
- var result = createBufferFromOutputTexture(this.gl, this.debug, rows, columns, this.textureConfig);
- this.unbindTextureToFrameBuffer();
- return result;
- };
- GPGPUContext.prototype.createAndWaitForFence = function () {
- var fenceContext = this.createFence(this.gl);
- return this.pollFence(fenceContext);
- };
- GPGPUContext.prototype.createFence = function (gl) {
- var _this = this;
- var query;
- var isFencePassed;
- if (env().getBool('WEBGL_FENCE_API_ENABLED')) {
- var gl2_1 = gl;
- var sync_1 = gl2_1.fenceSync(gl2_1.SYNC_GPU_COMMANDS_COMPLETE, 0);
- gl.flush();
- isFencePassed = function () {
- var status = gl2_1.clientWaitSync(sync_1, 0, 0);
- return status === gl2_1.ALREADY_SIGNALED ||
- status === gl2_1.CONDITION_SATISFIED;
- };
- query = sync_1;
- }
- else if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) {
- query = this.beginQuery();
- this.endQuery();
- isFencePassed = function () { return _this.isQueryAvailable(query, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION')); };
- }
- else {
- // If we have no way to fence, return true immediately. This will fire in
- // WebGL 1.0 when there is no disjoint query timer. In this case, because
- // the fence passes immediately, we'll immediately ask for a download of
- // the texture, which will cause the UI thread to hang.
- isFencePassed = function () { return true; };
- }
- return { query: query, isFencePassed: isFencePassed };
- };
- GPGPUContext.prototype.downloadMatrixFromPackedTexture = function (texture, physicalRows, physicalCols) {
- var _this = this;
- return this.downloadMatrixDriver(texture, function () { return downloadMatrixFromPackedOutputTexture(_this.gl, _this.debug, physicalRows, physicalCols); });
- };
- GPGPUContext.prototype.createProgram = function (fragmentShaderSource) {
- this.throwIfDisposed();
- var gl = this.gl;
- var fragmentShader = createFragmentShader(gl, this.debug, fragmentShaderSource);
- var vertexShader = createVertexShader$1(gl, this.debug);
- var program = createProgram(gl, this.debug);
- callAndCheck(gl, this.debug, function () { return gl.attachShader(program, vertexShader); });
- callAndCheck(gl, this.debug, function () { return gl.attachShader(program, fragmentShader); });
- linkProgram(gl, this.debug, program);
- if (this.debug) {
- validateProgram(gl, this.debug, program);
- }
- if (!this.vertexAttrsAreBound) {
- this.setProgram(program);
- this.vertexAttrsAreBound = bindVertexProgramAttributeStreams(gl, this.debug, this.program, this.vertexBuffer);
- }
- return program;
- };
- GPGPUContext.prototype.deleteProgram = function (program) {
- var _this = this;
- this.throwIfDisposed();
- if (program === this.program) {
- this.program = null;
- }
- if (program != null) {
- callAndCheck(this.gl, this.debug, function () { return _this.gl.deleteProgram(program); });
- }
- };
- GPGPUContext.prototype.setProgram = function (program) {
- var _this = this;
- this.throwIfDisposed();
- this.program = program;
- if ((this.program != null) && this.debug) {
- validateProgram(this.gl, this.debug, this.program);
- }
- callAndCheck(this.gl, this.debug, function () { return _this.gl.useProgram(program); });
- };
- GPGPUContext.prototype.getUniformLocation = function (program, uniformName, shouldThrow) {
- if (shouldThrow === void 0) { shouldThrow = true; }
- this.throwIfDisposed();
- if (shouldThrow) {
- return getProgramUniformLocationOrThrow(this.gl, this.debug, program, uniformName);
- }
- else {
- return getProgramUniformLocation(this.gl, program, uniformName);
- }
- };
- GPGPUContext.prototype.getAttributeLocation = function (program, attribute) {
- var _this = this;
- this.throwIfDisposed();
- return callAndCheck(this.gl, this.debug, function () { return _this.gl.getAttribLocation(program, attribute); });
- };
- GPGPUContext.prototype.getUniformLocationNoThrow = function (program, uniformName) {
- this.throwIfDisposed();
- return this.gl.getUniformLocation(program, uniformName);
- };
- GPGPUContext.prototype.setInputMatrixTexture = function (inputMatrixTexture, uniformLocation, textureUnit) {
- this.throwIfDisposed();
- this.throwIfNoProgram();
- bindTextureToProgramUniformSampler(this.gl, this.debug, this.program, inputMatrixTexture, uniformLocation, textureUnit);
- };
- GPGPUContext.prototype.setOutputMatrixTexture = function (outputMatrixTexture, rows, columns) {
- this.setOutputMatrixTextureDriver(outputMatrixTexture, columns, rows);
- };
- GPGPUContext.prototype.setOutputPackedMatrixTexture = function (outputPackedMatrixTexture, rows, columns) {
- this.throwIfDisposed();
- var _a = getPackedMatrixTextureShapeWidthHeight(rows, columns), width = _a[0], height = _a[1];
- this.setOutputMatrixTextureDriver(outputPackedMatrixTexture, width, height);
- };
- GPGPUContext.prototype.setOutputMatrixWriteRegion = function (startRow, numRows, startColumn, numColumns) {
- this.setOutputMatrixWriteRegionDriver(startColumn, startRow, numColumns, numRows);
- };
- GPGPUContext.prototype.setOutputPackedMatrixWriteRegion = function (startRow, numRows, startColumn, numColumns) {
- throw new Error('setOutputPackedMatrixWriteRegion not implemented.');
- };
- GPGPUContext.prototype.debugValidate = function () {
- if (this.program != null) {
- validateProgram(this.gl, this.debug, this.program);
- }
- validateFramebuffer(this.gl);
- };
- GPGPUContext.prototype.executeProgram = function () {
- this.throwIfDisposed();
- this.throwIfNoProgram();
- var gl = this.gl;
- if (this.debug) {
- this.debugValidate();
- }
- callAndCheck(gl, this.debug, function () { return gl.drawElements(gl.TRIANGLES, 6, gl.UNSIGNED_SHORT, 0); });
- };
- GPGPUContext.prototype.blockUntilAllProgramsCompleted = function () {
- var _this = this;
- this.throwIfDisposed();
- callAndCheck(this.gl, this.debug, function () { return _this.gl.finish(); });
- };
- GPGPUContext.prototype.getQueryTimerExtension = function () {
- if (this.disjointQueryTimerExtension == null) {
- this.disjointQueryTimerExtension =
- getExtensionOrThrow(this.gl, this.debug, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2 ?
- 'EXT_disjoint_timer_query_webgl2' :
- 'EXT_disjoint_timer_query');
- }
- return this.disjointQueryTimerExtension;
- };
- GPGPUContext.prototype.getQueryTimerExtensionWebGL2 = function () {
- return this.getQueryTimerExtension();
- };
- GPGPUContext.prototype.getQueryTimerExtensionWebGL1 = function () {
- return this.getQueryTimerExtension();
- };
- GPGPUContext.prototype.beginQuery = function () {
- if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2) {
- var gl2 = this.gl;
- var ext_1 = this.getQueryTimerExtensionWebGL2();
- var query_1 = gl2.createQuery();
- gl2.beginQuery(ext_1.TIME_ELAPSED_EXT, query_1);
- return query_1;
- }
- var ext = this.getQueryTimerExtensionWebGL1();
- var query = ext.createQueryEXT();
- ext.beginQueryEXT(ext.TIME_ELAPSED_EXT, query);
- return query;
- };
- GPGPUContext.prototype.endQuery = function () {
- if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2) {
- var gl2 = this.gl;
- var ext_2 = this.getQueryTimerExtensionWebGL2();
- gl2.endQuery(ext_2.TIME_ELAPSED_EXT);
- return;
- }
- var ext = this.getQueryTimerExtensionWebGL1();
- ext.endQueryEXT(ext.TIME_ELAPSED_EXT);
- };
- GPGPUContext.prototype.waitForQueryAndGetTime = function (query) {
- return __awaiter(this, void 0, void 0, function () {
- var _this = this;
- return __generator(this, function (_a) {
- switch (_a.label) {
- case 0: return [4 /*yield*/, repeatedTry(function () { return _this.disposed || // while testing contexts are created / disposed
- // in rapid succession, so without this check we
- // may poll for the query timer indefinitely
- _this.isQueryAvailable(query, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION')); })];
- case 1:
- _a.sent();
- return [2 /*return*/, this.getQueryTime(query, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION'))];
- }
- });
- });
- };
- GPGPUContext.prototype.getQueryTime = function (query, queryTimerVersion) {
- if (queryTimerVersion === 0) {
- return null;
- }
- if (queryTimerVersion === 2) {
- var gl2 = this.gl;
- var timeElapsedNanos = gl2.getQueryParameter(query, gl2.QUERY_RESULT);
- // Return milliseconds.
- return timeElapsedNanos / 1000000;
- }
- else {
- var ext = this.getQueryTimerExtensionWebGL1();
- var timeElapsedNanos = ext.getQueryObjectEXT(query, ext.QUERY_RESULT_EXT);
- // Return milliseconds.
- return timeElapsedNanos / 1000000;
- }
- };
- GPGPUContext.prototype.isQueryAvailable = function (query, queryTimerVersion) {
- if (queryTimerVersion === 0) {
- return true;
- }
- if (queryTimerVersion === 2) {
- var gl2 = this.gl;
- var ext = this.getQueryTimerExtensionWebGL2();
- var available = gl2.getQueryParameter(query, gl2.QUERY_RESULT_AVAILABLE);
- if (this.disjoint == null) {
- this.disjoint = this.gl.getParameter(ext.GPU_DISJOINT_EXT);
- }
- return available && !this.disjoint;
- }
- else {
- var ext = this.getQueryTimerExtensionWebGL1();
- var available = ext.getQueryObjectEXT(query, ext.QUERY_RESULT_AVAILABLE_EXT);
- if (this.disjoint == null) {
- this.disjoint = this.gl.getParameter(ext.GPU_DISJOINT_EXT);
- }
- return available && !this.disjoint;
- }
- };
- GPGPUContext.prototype.pollFence = function (fenceContext) {
- var _this = this;
- return new Promise(function (resolve) {
- _this.addItemToPoll(function () { return fenceContext.isFencePassed(); }, function () { return resolve(); });
- });
- };
- GPGPUContext.prototype.pollItems = function () {
- // Find the last query that has finished.
- var index = linearSearchLastTrue(this.itemsToPoll.map(function (x) { return x.isDoneFn; }));
- for (var i = 0; i <= index; ++i) {
- var resolveFn = this.itemsToPoll[i].resolveFn;
- resolveFn();
- }
- this.itemsToPoll = this.itemsToPoll.slice(index + 1);
- };
- GPGPUContext.prototype.addItemToPoll = function (isDoneFn, resolveFn) {
- var _this = this;
- this.itemsToPoll.push({ isDoneFn: isDoneFn, resolveFn: resolveFn });
- if (this.itemsToPoll.length > 1) {
- // We already have a running loop that polls.
- return;
- }
- // Start a new loop that polls.
- repeatedTry(function () {
- _this.pollItems();
- // End the loop if no more items to poll.
- return _this.itemsToPoll.length === 0;
- });
- };
- GPGPUContext.prototype.bindTextureToFrameBuffer = function (texture) {
- this.throwIfDisposed();
- bindColorTextureToFramebuffer(this.gl, this.debug, texture, this.framebuffer);
- if (this.debug) {
- validateFramebuffer(this.gl);
- }
- };
- GPGPUContext.prototype.unbindTextureToFrameBuffer = function () {
- if (this.outputTexture != null) {
- bindColorTextureToFramebuffer(this.gl, this.debug, this.outputTexture, this.framebuffer);
- if (this.debug) {
- validateFramebuffer(this.gl);
- }
- }
- else {
- unbindColorTextureFromFramebuffer(this.gl, this.debug, this.framebuffer);
- }
- };
- GPGPUContext.prototype.downloadMatrixDriver = function (texture, downloadAndDecode) {
- this.bindTextureToFrameBuffer(texture);
- var result = downloadAndDecode();
- this.unbindTextureToFrameBuffer();
- return result;
- };
- GPGPUContext.prototype.setOutputMatrixTextureDriver = function (outputMatrixTextureMaybePacked, width, height) {
- this.throwIfDisposed();
- var gl = this.gl;
- bindColorTextureToFramebuffer(gl, this.debug, outputMatrixTextureMaybePacked, this.framebuffer);
- if (this.debug) {
- validateFramebuffer(gl);
- }
- this.outputTexture = outputMatrixTextureMaybePacked;
- callAndCheck(gl, this.debug, function () { return gl.viewport(0, 0, width, height); });
- callAndCheck(gl, this.debug, function () { return gl.scissor(0, 0, width, height); });
- };
- GPGPUContext.prototype.setOutputMatrixWriteRegionDriver = function (x, y, width, height) {
- var _this = this;
- this.throwIfDisposed();
- callAndCheck(this.gl, this.debug, function () { return _this.gl.scissor(x, y, width, height); });
- };
- GPGPUContext.prototype.throwIfDisposed = function () {
- if (this.disposed) {
- throw new Error('Attempted to use disposed GPGPUContext.');
- }
- };
- GPGPUContext.prototype.throwIfNoProgram = function () {
- if (this.program == null) {
- throw new Error('No GPU program is currently set.');
- }
- };
- return GPGPUContext;
- }());
- /**
- * Finds the index of the last true element using linear search.
- * Note: We can't do binary search because Chrome expects us to explicitly
- * test all fences before download:
- * https://github.com/tensorflow/tfjs/issues/1145
- */
- function linearSearchLastTrue(arr) {
- var i = 0;
- for (; i < arr.length; ++i) {
- var isDone = arr[i]();
- if (!isDone) {
- break;
- }
- }
- return i - 1;
- }
-
- /**
- * @license
- * Copyright 2017 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- function compileProgram(gpgpu, program, inputs, output) {
- var userCode = program.userCode;
- var inputInfos = inputs.map(function (input, i) {
- var shapeInfo = {
- logicalShape: input.shape,
- texShape: input.isUniform ? null : input.texData.texShape,
- isUniform: input.isUniform,
- isPacked: input.isUniform ? false : input.texData.isPacked,
- flatOffset: null
- };
- if (input.texData != null && input.texData.slice != null &&
- input.texData.slice.flatOffset > 0) {
- shapeInfo.flatOffset = input.texData.slice.flatOffset;
- }
- return { name: program.variableNames[i], shapeInfo: shapeInfo };
- });
- var inShapeInfos = inputInfos.map(function (x) { return x.shapeInfo; });
- var outShapeInfo = {
- logicalShape: output.shape,
- texShape: output.texData.texShape,
- isUniform: false,
- isPacked: output.texData.isPacked,
- flatOffset: null
- };
- var source = makeShader(inputInfos, outShapeInfo, userCode, program.packedInputs);
- var webGLProgram = gpgpu.createProgram(source);
- // Add special uniforms (NAN, INFINITY)
- var infLoc = null;
- var nanLoc = gpgpu.getUniformLocation(webGLProgram, 'NAN', false);
- if (env().getNumber('WEBGL_VERSION') === 1) {
- infLoc = gpgpu.getUniformLocation(webGLProgram, 'INFINITY', false);
- }
- // Add user-defined uniforms
- var uniformLocations = {};
- for (var i = 0; i < program.variableNames.length; i++) {
- var varName = program.variableNames[i];
- var shouldThrow = false;
- uniformLocations[varName] =
- gpgpu.getUniformLocation(webGLProgram, varName, shouldThrow);
- uniformLocations["offset" + varName] =
- gpgpu.getUniformLocation(webGLProgram, "offset" + varName, shouldThrow);
- }
- return {
- program: program,
- source: source,
- webGLProgram: webGLProgram,
- uniformLocations: uniformLocations,
- inShapeInfos: inShapeInfos,
- outShapeInfo: outShapeInfo,
- infLoc: infLoc,
- nanLoc: nanLoc,
- };
- }
- function validateBinaryAndProgram(shapeInfos, inputs) {
- if (shapeInfos.length !== inputs.length) {
- throw Error("Binary was compiled with " + shapeInfos.length + " inputs, but " +
- ("was executed with " + inputs.length + " inputs"));
- }
- shapeInfos.forEach(function (s, i) {
- var shapeA = s.logicalShape;
- var input = inputs[i];
- var shapeB = input.shape;
- if (!arraysEqual(shapeA, shapeB)) {
- throw Error("Binary was compiled with different shapes than " +
- ("the current args. Shapes " + shapeA + " and " + shapeB + " must match"));
- }
- // The input is uploaded as uniform.
- if (s.isUniform && input.isUniform) {
- return;
- }
- var texShapeA = s.texShape;
- var texShapeB = input.isUniform ? null : input.texData.texShape;
- if (!arraysEqual(texShapeA, texShapeB)) {
- throw Error("Binary was compiled with different texture shapes than the" +
- (" current args. Shape " + texShapeA + " and " + texShapeB + " must match"));
- }
- });
- }
- function runProgram(gpgpu, binary, inputs, output, customSetup) {
- validateBinaryAndProgram(binary.inShapeInfos, inputs);
- validateBinaryAndProgram([binary.outShapeInfo], [output]);
- var outTex = output.texData.texture;
- var outTexShape = output.texData.texShape;
- if (output.texData.isPacked) {
- gpgpu.setOutputPackedMatrixTexture(outTex, outTexShape[0], outTexShape[1]);
- }
- else {
- gpgpu.setOutputMatrixTexture(outTex, outTexShape[0], outTexShape[1]);
- }
- gpgpu.setProgram(binary.webGLProgram);
- // Set special uniforms (NAN, INFINITY)
- if (env().getNumber('WEBGL_VERSION') === 1) {
- if (binary.infLoc !== null) {
- gpgpu.gl.uniform1f(binary.infLoc, Infinity);
- }
- }
- if (binary.nanLoc !== null) {
- gpgpu.gl.uniform1f(binary.nanLoc, NaN);
- }
- // Set user-defined inputs
- inputs.forEach(function (input, i) {
- var varName = binary.program.variableNames[i];
- var varLoc = binary.uniformLocations[varName];
- var varOffsetLoc = binary.uniformLocations["offset" + varName];
- if (varLoc == null) {
- // The compiler inferred that this variable is not used in this shader.
- return;
- }
- if (input.isUniform) {
- // Upload the values of the tensor as uniform.
- if (sizeFromShape(input.shape) < 2) {
- gpgpu.gl.uniform1f(varLoc, input.uniformValues[0]);
- }
- else {
- var vals = input.uniformValues;
- if (!(vals instanceof Float32Array)) {
- vals = new Float32Array(vals);
- }
- gpgpu.gl.uniform1fv(varLoc, vals);
- }
- return;
- }
- // If the input was sliced, upload the flat offset index.
- if (input.texData.slice != null && varOffsetLoc != null) {
- gpgpu.gl.uniform1i(varOffsetLoc, input.texData.slice.flatOffset);
- }
- gpgpu.setInputMatrixTexture(input.texData.texture, varLoc, i);
- });
- if (customSetup != null) {
- customSetup(gpgpu, binary.webGLProgram);
- }
- gpgpu.executeProgram();
- }
- function makeShaderKey(program, inputs, output) {
- var keyInputs = '';
- inputs.concat(output).forEach(function (x) {
- var hasOffset = x.texData != null && x.texData.slice != null &&
- x.texData.slice.flatOffset > 0;
- var texShape = x.isUniform ? 'uniform' : x.texData.texShape;
- keyInputs += x.shape + "_" + texShape + "_" + hasOffset;
- });
- var keyUserCode = program.userCode;
- var key = program.constructor.name;
- // Fast string concat. See https://jsperf.com/string-concatenation/14.
- key += '_' + keyInputs + '_' + keyUserCode;
- return key;
- }
-
- /**
- * @license
- * Copyright 2019 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var Im2ColPackedProgram = /** @class */ (function () {
- function Im2ColPackedProgram(outputShape, inputShape, convInfo) {
- this.variableNames = ['A'];
- this.packedInputs = true;
- this.packedOutput = true;
- this.outputShape = outputShape;
- var filterWidth = convInfo.filterWidth, inChannels = convInfo.inChannels, strideWidth = convInfo.strideWidth, strideHeight = convInfo.strideHeight, padInfo = convInfo.padInfo, outWidth = convInfo.outWidth, dilationWidth = convInfo.dilationWidth, dilationHeight = convInfo.dilationHeight, dataFormat = convInfo.dataFormat;
- var left = padInfo.left, top = padInfo.top;
- var itemsPerBlockRow = inChannels * filterWidth;
- var glsl = getGlslDifferences();
- var isChannelsLast = dataFormat === 'channelsLast';
- var rowDim = isChannelsLast ? 0 : 1;
- var colDim = isChannelsLast ? 1 : 2;
- var unrolled = "";
- for (var row = 0; row <= 1; row++) {
- for (var col = 0; col <= 1; col++) {
- unrolled += "\n blockIndex = rc.y + " + col + ";\n pos = rc.x + " + row + ";\n\n if(blockIndex < " + outputShape[1] + " && pos < " + outputShape[0] + ") {\n offsetY = int(blockIndex / (" + outWidth + ")) * " + strideHeight + " - " + top + ";\n d0 = offsetY + " + dilationHeight + " * (pos / " + itemsPerBlockRow + ");\n\n if(d0 < " + inputShape[rowDim] + " && d0 >= 0) {\n\n offsetX = int(mod(float(blockIndex), " + outWidth + ".) * " + strideWidth + ". - " + left + ".);\n d1 = offsetX + " + dilationWidth + " * (int(mod(float(pos), " + itemsPerBlockRow + ".) / " + inChannels + ".));\n\n if(d1 < " + inputShape[colDim] + " && d1 >= 0) {\n\n ch = int(mod(float(pos), " + inChannels + ".));\n\n if (" + isChannelsLast + ") {\n innerDims = vec2(d1, ch);\n result[" + (row * 2 + col) + "] = getChannel(\n getA(d0, int(innerDims.x),\n int(innerDims.y)), innerDims);\n } else {\n innerDims = vec2(d0, d1);\n result[" + (row * 2 + col) + "] = getChannel(\n getA(ch, int(innerDims.x),\n int(innerDims.y)), innerDims);\n }\n }\n }\n }\n ";
- }
- }
- this.userCode = "\n void main() {\n ivec2 rc = getOutputCoords();\n\n vec4 result = vec4(0);\n\n int blockIndex, pos, offsetY, d0, offsetX, d1, ch;\n vec2 innerDims;\n\n " + unrolled + "\n\n " + glsl.output + " = result;\n }\n ";
- }
- return Im2ColPackedProgram;
- }());
-
- /**
- * @license
- * Copyright 2017 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var LRNProgram = /** @class */ (function () {
- function LRNProgram(xShape, radius, bias, alpha, beta) {
- this.variableNames = ['x'];
- this.outputShape = [];
- var rad = radius;
- var maxD = xShape[3] - 1;
- this.outputShape = xShape;
- // optimize pow(bias + alpha * sum, -beta)
- // src: https://github.com/tensorflow/tensorflow/..
- // blob/26033a1644a9c4a5fbe3170ab2e864b6a4ccd4ca/..
- // tensorflow/core/kernels/mkl_lrn_op.cc#L320
- var powOperator;
- var basis = "float(" + bias + ") + float(" + alpha + ") * sum";
- if (beta === 0.5) {
- powOperator = "inversesqrt(" + basis + ")";
- }
- else if (beta === 1.0) {
- powOperator = "1.0/(" + basis + ")";
- }
- else {
- powOperator = "exp(log(" + basis + ") * float(-" + beta + "));";
- }
- this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int r = coords[1];\n int c = coords[2];\n int d = coords[3];\n float x = getX(b, r, c, d);\n float sum = 0.0;\n for (int j = -" + rad + "; j <= " + rad + "; j++) {\n int idx = d + j;\n if (idx >= 0 && idx <= " + maxD + ") {\n float z = getX(b, r, c, idx);\n sum += z * z;\n }\n }\n float val = x * " + powOperator + ";\n setOutput(val);\n }\n ";
- }
- return LRNProgram;
- }());
-
- /**
- * @license
- * Copyright 2018 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var LRNGradProgram = /** @class */ (function () {
- function LRNGradProgram(inputShape, depthRadius, bias, alpha, beta) {
- this.variableNames = ['inputImage', 'outputImage', 'dy'];
- this.outputShape = [];
- this.outputShape = inputShape;
- this.depth = inputShape[3];
- this.depthRadius = depthRadius;
- this.bias = bias;
- this.alpha = alpha;
- this.beta = beta;
- this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int r = coords[1];\n int c = coords[2];\n\n float result = 0.0;\n for (int d = 0; d < " + this.depth + "; ++d) {\n int depthBegin = int(max(0.0, float(d - " + depthRadius + ")));\n int depthEnd = int(min(float(" + this.depth + "),\n float(d + " + depthRadius + " + 1)));\n\n const int MIN_DEPTH_BEGIN = 0;\n const int MAX_DEPTH_END = " + this.depth + ";\n\n float norm = 0.0;\n for (int k = MIN_DEPTH_BEGIN; k < MAX_DEPTH_END; ++k) {\n if (k < depthBegin){\n continue;\n }\n else if (k >= depthBegin && k < depthEnd) {\n norm += getInputImage(b, r, c, k) * getInputImage(b, r, c, k);\n }\n else {\n break;\n }\n }\n\n norm = float(" + alpha + ") * norm + float(" + bias + ");\n\n for(int k = MIN_DEPTH_BEGIN; k < MAX_DEPTH_END; ++k){\n if (k < depthBegin){\n continue;\n }\n else if (k >= depthBegin && k < depthEnd){\n float dyi = -2.0 * float(" + alpha + ")\n * float(" + beta + ")\n * getInputImage(b ,r ,c, k) * getOutputImage(b, r, c, d)\n / norm;\n if (k == d) {\n dyi += pow(norm, -1.0 * " + beta + ");\n }\n if (k == coords[3]) {\n dyi *= getDy(b, r, c, d);\n result += dyi;\n }\n }\n else {\n break;\n }\n }\n }\n setOutput(result);\n }\n ";
- }
- return LRNGradProgram;
- }());
-
- /**
- * @license
- * Copyright 2019 Google LLC All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var LRNPackedProgram = /** @class */ (function () {
- function LRNPackedProgram(xShape, radius, bias, alpha, beta) {
- this.variableNames = ['x'];
- this.outputShape = [];
- this.packedInputs = true;
- this.packedOutput = true;
- var rad = radius;
- var maxD = xShape[3] - 1;
- this.outputShape = xShape;
- // optimize pow(bias + alpha * sum, -beta)
- // src: https://github.com/tensorflow/tensorflow/..
- // blob/26033a1644a9c4a5fbe3170ab2e864b6a4ccd4ca/..
- // tensorflow/core/kernels/mkl_lrn_op.cc#L320
- var powOperator;
- var basis = "float(" + bias + ") + float(" + alpha + ") * sum";
- if (beta === 0.5) {
- powOperator = "inversesqrt(" + basis + ")";
- }
- else if (beta === 1.0) {
- powOperator = "1.0/(" + basis + ")";
- }
- else {
- powOperator = "exp(log(" + basis + ") * float(-" + beta + "));";
- }
- this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords.x;\n int r = coords.y;\n int c = coords.z;\n int d = coords.w;\n\n bool hasNextCol = d < " + this.outputShape[3] + ";\n bool hasNextRow = c < " + this.outputShape[2] + ";\n\n vec4 sum = vec4(0.);\n vec4 xFragAtOutputCoords = getX(b, r, c, d);\n\n vec4 xAtOutputCoords = vec4(\n getChannel(xFragAtOutputCoords, vec2(c, d)),\n hasNextCol ?\n getChannel(xFragAtOutputCoords, vec2(c, d + 1)) : 0.0,\n hasNextRow ?\n getChannel(xFragAtOutputCoords , vec2(c + 1, d)) : 0.0,\n (hasNextRow && hasNextCol) ?\n getChannel(xFragAtOutputCoords, vec2(c + 1, d + 1)) : 0.0\n );\n\n int firstChannel = d - " + rad + ";\n vec2 cache = vec2(0.);\n if(firstChannel >= 0){\n vec4 firstChannelFrag = getX(b, r, c, firstChannel);\n cache.x = getChannel(firstChannelFrag, vec2(c, firstChannel));\n if(hasNextRow){\n cache.y = getChannel(firstChannelFrag, vec2(c + 1, firstChannel));\n }\n }\n\n ivec2 depth = ivec2(d, d + 1);\n for (int j = - " + rad + "; j <= " + rad + "; j++) {\n ivec2 idx = depth + j;\n bvec2 aboveLowerBound = greaterThanEqual(idx, ivec2(0));\n bvec2 belowUpperBound = lessThanEqual(idx, ivec2(" + maxD + "));\n\n bool depthInRange = aboveLowerBound.x && belowUpperBound.x;\n bool depthPlusOneInRange = aboveLowerBound.y && belowUpperBound.y;\n\n if(depthInRange || depthPlusOneInRange){\n vec4 z = vec4(0.);\n vec4 xFragAtCurrentDepth;\n z.xz = cache.xy;\n if(depthPlusOneInRange && hasNextCol){\n xFragAtCurrentDepth = idx.y != d ?\n getX(b, r, c, idx.y) : xFragAtOutputCoords;\n z.y = getChannel(xFragAtCurrentDepth, vec2(c, idx.y));\n if(hasNextRow){\n z.w = getChannel(xFragAtCurrentDepth, vec2(c + 1, idx.y));\n }\n }\n cache.xy = z.yw;\n sum += z * z;\n }\n }\n vec4 result = xAtOutputCoords * " + powOperator + ";\n setOutput(result);\n }\n ";
- }
- return LRNPackedProgram;
- }());
-
- /**
- * @license
- * Copyright 2017 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var MaxPool2DBackpropProgram = /** @class */ (function () {
- function MaxPool2DBackpropProgram(convInfo) {
- this.variableNames = ['dy', 'maxPos'];
- this.outputShape = convInfo.inShape;
- var strideHeight = convInfo.strideHeight;
- var strideWidth = convInfo.strideWidth;
- var dilationHeight = convInfo.dilationHeight;
- var effectiveFilterHeight = convInfo.effectiveFilterHeight;
- var effectiveFilterWidth = convInfo.effectiveFilterWidth;
- var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
- var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
- var lastIndex = effectiveFilterHeight * effectiveFilterWidth - 1;
- this.userCode = "\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n\n ivec2 dyRCCorner = coords.yz - pads;\n int dyRCorner = dyRCCorner.x;\n int dyCCorner = dyRCCorner.y;\n\n // Convolve dy(?, ?, d) with pos mask(:, :, d) to get dx(xR, xC, d).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 || fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n for (int wC = 0; wC < " + effectiveFilterWidth + "; wC++) {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n float dyValue = getDy(b, idyR, idyC, d);\n int maxPosValue = " + lastIndex + " - int(getMaxPos(b, idyR, idyC, d));\n\n // Get the current value, check it against the value from the\n // position matrix.\n int curPosValue = wR * " + effectiveFilterWidth + " + wC;\n float mask = float(maxPosValue == curPosValue ? 1.0 : 0.0);\n\n dotProd += dyValue * mask;\n }\n }\n setOutput(dotProd);\n }\n ";
- }
- return MaxPool2DBackpropProgram;
- }());
- var MaxPool3DBackpropProgram = /** @class */ (function () {
- function MaxPool3DBackpropProgram(convInfo) {
- this.variableNames = ['dy', 'maxPos'];
- this.outputShape = convInfo.inShape;
- var strideDepth = convInfo.strideDepth;
- var strideHeight = convInfo.strideHeight;
- var strideWidth = convInfo.strideWidth;
- var dilationDepth = convInfo.dilationDepth;
- var dilationHeight = convInfo.dilationHeight;
- var dilationWidth = convInfo.dilationWidth;
- var effectiveFilterDepth = convInfo.effectiveFilterDepth;
- var effectiveFilterHeight = convInfo.effectiveFilterHeight;
- var effectiveFilterWidth = convInfo.effectiveFilterWidth;
- var padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
- var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
- var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
- var lastIndex = effectiveFilterDepth * effectiveFilterHeight * effectiveFilterWidth - 1;
- this.userCode = "\n const ivec3 pads = ivec3(" + padFront + ", " + padTop + ", " + padLeft + ");\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int ch = coords.u;\n\n ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;\n int dyDCorner = dyCorner.x;\n int dyRCorner = dyCorner.y;\n int dyCCorner = dyCorner.z;\n\n // Convolve dy(?, ?, ?, ch) with pos mask(:, :, :, d) to get\n // dx(xD, xR, xC, ch).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n\n for (int wD = 0; wD < " + effectiveFilterDepth + ";\n wD += " + dilationDepth + ") {\n float dyD = float(dyDCorner + wD) / " + strideDepth + ".0;\n\n if (dyD < 0.0 || dyD >= " + convInfo.outDepth + ".0 || fract(dyD) > 0.0) {\n continue;\n }\n int idyD = int(dyD);\n\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 ||\n fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n for (int wC = 0; wC < " + effectiveFilterWidth + ";\n wC += " + dilationWidth + ") {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n float dyValue = getDy(batch, idyD, idyR, idyC, ch);\n int maxPosValue = " + lastIndex + " -\n int(getMaxPos(batch, idyD, idyR, idyC, ch));\n\n // Get the current value, check it against the value from the\n // position matrix.\n int curPosValue =\n wD * " + effectiveFilterHeight + " * " + effectiveFilterWidth + " +\n wR * " + effectiveFilterWidth + " + wC;\n float mask = float(maxPosValue == curPosValue ? 1.0 : 0.0);\n\n dotProd += dyValue * mask;\n }\n }\n }\n setOutput(dotProd);\n }\n ";
- }
- return MaxPool3DBackpropProgram;
- }());
-
- /**
- * @license
- * Copyright 2018 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var MatMulPackedProgram = /** @class */ (function () {
- function MatMulPackedProgram(aShape, outputShape, transposeA, transposeB, addBias, activation, hasPreluActivation) {
- if (transposeA === void 0) { transposeA = false; }
- if (transposeB === void 0) { transposeB = false; }
- if (addBias === void 0) { addBias = false; }
- if (activation === void 0) { activation = null; }
- if (hasPreluActivation === void 0) { hasPreluActivation = false; }
- this.variableNames = ['matrixA', 'matrixB'];
- this.packedInputs = true;
- this.packedOutput = true;
- this.outputShape = outputShape;
- var sharedDim = transposeA ? aShape[1] : aShape[2];
- var sharedDimensionPacked = Math.ceil(sharedDim / 2);
- var aSample = transposeA ? 'i * 2, rc.y' : 'rc.y, i * 2';
- var bSample = transposeB ? 'rc.z, i * 2' : 'i * 2, rc.z';
- var aSwizzle = transposeA ? ['a.xxyy', 'a.zzww'] : ['a.xxzz', 'a.yyww'];
- var bSwizzle = transposeB ? ['b.xzxz', 'b.ywyw'] : ['b.xyxy', 'b.zwzw'];
- var activationSnippet = '', applyActivationSnippet = '';
- if (activation) {
- if (hasPreluActivation) {
- activationSnippet = "vec4 activation(vec4 a) {\n vec4 b = getPreluActivationWeightsAtOutCoords();\n " + activation + "\n }";
- }
- else {
- activationSnippet = "vec4 activation(vec4 x) {\n " + activation + "\n }";
- }
- applyActivationSnippet = "result = activation(result);";
- }
- var addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
- if (addBias) {
- this.variableNames.push('bias');
- }
- if (hasPreluActivation) {
- this.variableNames.push('preluActivationWeights');
- }
- this.userCode = "\n " + activationSnippet + "\n\n const float sharedDimension = " + sharedDimensionPacked + ".0;\n\n vec4 dot2x2ARowBCol(ivec3 rc) {\n vec4 result = vec4(0);\n for (int i = 0; i < " + sharedDimensionPacked + "; i++) {\n vec4 a = getMatrixA(rc.x, " + aSample + ");\n vec4 b = getMatrixB(rc.x, " + bSample + ");\n\n // These swizzled products need to be separately added.\n // See: https://github.com/tensorflow/tfjs/issues/1735\n result += (" + aSwizzle[0] + " * " + bSwizzle[0] + ");\n result += (" + aSwizzle[1] + " * " + bSwizzle[1] + ");\n }\n return result;\n }\n\n void main() {\n ivec3 rc = getOutputCoords();\n vec4 result = dot2x2ARowBCol(rc);\n\n " + addBiasSnippet + "\n\n " + applyActivationSnippet + "\n\n setOutput(result);\n }\n ";
- }
- return MatMulPackedProgram;
- }());
-
- /**
- * @license
- * Copyright 2017 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var MultinomialProgram = /** @class */ (function () {
- function MultinomialProgram(batchSize, numOutcomes, numSamples) {
- this.variableNames = ['probs'];
- this.outputShape = [batchSize, numSamples];
- this.userCode = "\n uniform float seed;\n\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n\n float r = random(seed);\n float cdf = 0.0;\n\n for (int i = 0; i < " + (numOutcomes - 1) + "; i++) {\n cdf += getProbs(batch, i);\n\n if (r < cdf) {\n setOutput(float(i));\n return;\n }\n }\n\n // If no other event happened, last event happened.\n setOutput(float(" + (numOutcomes - 1) + "));\n }\n ";
- }
- MultinomialProgram.prototype.getCustomSetupFunc = function (seed) {
- var _this = this;
- return function (gpgpu, webGLProgram) {
- if (_this.seedLoc == null) {
- _this.seedLoc = gpgpu.getUniformLocation(webGLProgram, 'seed');
- }
- gpgpu.gl.uniform1f(_this.seedLoc, seed);
- };
- };
- return MultinomialProgram;
- }());
-
- /**
- * @license
- * Copyright 2017 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var OneHotProgram = /** @class */ (function () {
- function OneHotProgram(numIndices, depth, onValue, offValue) {
- this.variableNames = ['indices'];
- this.outputShape = [numIndices, depth];
- this.userCode = "\n void main() {\n ivec2 coords = getOutputCoords();\n int index = round(getIndices(coords.x));\n setOutput(mix(float(" + offValue + "), float(" + onValue + "),\n float(index == coords.y)));\n }\n ";
- }
- return OneHotProgram;
- }());
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var PackProgram = /** @class */ (function () {
- function PackProgram(outputShape) {
- this.variableNames = ['A'];
- this.packedInputs = false;
- this.packedOutput = true;
- // Only input / output 3D tensors.
- this.outputShape = outputShape;
- var rank = outputShape.length;
- if (rank === 0) {
- this.userCode = "\n void main() {\n setOutput(vec4(getA(), 0., 0., 0.));\n }\n ";
- }
- else {
- var channels = getChannels('rc', rank);
- var dtype = getCoordsDataType(rank);
- var outOfBoundsCondition = getOutOfBoundsCondition(rank, outputShape, channels);
- var setup = getSetup(rank, outputShape[outputShape.length - 1], outputShape[outputShape.length - 2], channels);
- var output = getOutput(outputShape, channels);
- this.userCode = "\n void main() {\n " + dtype + " rc = getOutputCoords();\n\n if(" + outOfBoundsCondition + ") {\n setOutput(vec4(0));\n } else {\n " + setup + "\n\n setOutput(vec4(" + output + "));\n }\n }\n ";
- }
- }
- return PackProgram;
- }());
- function getSourceCoordsArr(rank, dims) {
- var coords = [];
- for (var row = 0; row <= 1; row++) {
- for (var col = 0; col <= 1; col++) {
- var coord = (row === 0 ? 'r' : 'rp1') + ", " + (col === 0 ? 'c' : 'cp1');
- for (var d = 2; d < rank; d++) {
- coord = dims[dims.length - 1 - d] + "," + coord;
- }
- coords.push(coord);
- }
- }
- return coords;
- }
- function getOutOfBoundsCondition(rank, shape, dims) {
- if (rank === 1) {
- return "rc > " + shape[0];
- }
- var cond = '';
- for (var i = rank - 2; i < rank; i++) {
- cond += dims[i] + " >= " + shape[i];
- if (i < rank - 1) {
- cond += '||';
- }
- }
- return cond;
- }
- function getSetup(rank, cols, rows, dims) {
- if (rank === 1) {
- return '';
- }
- var innerDims = dims.slice(-2);
- return "\n int r = " + innerDims[0] + ";\n int c = " + innerDims[1] + ";\n int rp1 = r + 1;\n int cp1 = c + 1;\n\n bool cEdge = cp1 >= " + cols + ";\n bool rEdge = rp1 >= " + rows + ";\n ";
- }
- function getOutput(shape, dims) {
- var rank = shape.length;
- var sourceCoords = getSourceCoordsArr(rank, dims);
- if (rank === 1) {
- return "getA(rc),\n rc + 1 >= " + shape[0] + " ? 0. : getA(rc + 1),\n 0, 0";
- }
- return "getA(" + sourceCoords[0] + "),\n cEdge ? 0. : getA(" + sourceCoords[1] + "),\n rEdge ? 0. : getA(" + sourceCoords[2] + "),\n rEdge || cEdge ? 0. : getA(" + sourceCoords[3] + ")";
- }
-
- /**
- * @license
- * Copyright 2017 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var PadProgram = /** @class */ (function () {
- function PadProgram(xShape, paddings, constantValue) {
- this.variableNames = ['x'];
- this.outputShape = paddings.map(function (p, i) { return p[0] /* beforePad */ + xShape[i] + p[1]; } /* afterPad */);
- var rank = xShape.length;
- var type = getCoordsDataType(rank);
- var start = paddings.map(function (p) { return p[0]; }).join(',');
- var end = paddings.map(function (p, i) { return p[0] + xShape[i]; }).join(',');
- var unpackedCoords = ['coords[0]', 'coords[1]', 'coords[2]', 'coords[3]'].slice(0, rank);
- if (rank === 1) {
- this.userCode = "\n int start = " + start + ";\n int end = " + end + ";\n\n void main() {\n int outC = getOutputCoords();\n if (outC < start || outC >= end) {\n setOutput(float(" + constantValue + "));\n } else {\n setOutput(getX(outC - start));\n }\n }\n ";
- return;
- }
- this.userCode = "\n " + type + " start = " + type + "(" + start + ");\n " + type + " end = " + type + "(" + end + ");\n\n void main() {\n " + type + " outC = getOutputCoords();\n if (any(lessThan(outC, start)) || any(greaterThanEqual(outC, end))) {\n setOutput(float(" + constantValue + "));\n } else {\n " + type + " coords = outC - start;\n setOutput(getX(" + unpackedCoords + "));\n }\n }\n ";
- }
- return PadProgram;
- }());
-
- /**
- * @license
- * Copyright 2019 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var PadPackedProgram = /** @class */ (function () {
- function PadPackedProgram(xShape, paddings, constantValue) {
- this.variableNames = ['x'];
- this.packedInputs = true;
- this.packedOutput = true;
- this.outputShape = paddings.map(function (p, i) { return p[0] /* beforePad */ + xShape[i] + p[1]; } /* afterPad */);
- var rank = xShape.length;
- var dtype = getCoordsDataType(rank);
- var start = paddings.map(function (p) { return p[0]; }).join(',');
- var end = paddings.map(function (p, i) { return p[0] + xShape[i]; }).join(',');
- var coords = getChannels('rc', rank);
- var source = getChannels('source', rank);
- var cLimit = coords[rank - 1] + " < " + this.outputShape[rank - 1];
- var innerDims = rank === 1 ? 'source' : "vec2(" + source.slice(-2).join() + ")";
- var componentSetup = [
- dtype + " rc = outputLoc;", coords[rank - 1] + " += 1;\n if(" + cLimit + ") {\n ",
- rank === 1 ? '' : "}\n rc = outputLoc;\n " + coords[rank - 2] + " += 1;\n if(" + coords[rank - 2] + " < " + this.outputShape[rank - 2] + ") {",
- rank === 1 ? '' : " " + coords[rank - 1] + " += 1;\n if(" + cLimit + ") {"
- ];
- var paddingArea = rank === 1 ?
- 'rc < start || rc >= end' :
- 'any(lessThan(rc, start)) || any(greaterThanEqual(rc, end))';
- var mainLoop = '';
- for (var i = 0, j = rank === 1 ? 2 : 4; i < j; i++) {
- mainLoop += "\n " + componentSetup[i] + "\n if (" + paddingArea + ") {\n result[" + i + "] = float(" + constantValue + ");\n } else {\n " + dtype + " source = rc - start;\n result[" + i + "] = getChannel(getX(" + source.join() + "), " + innerDims + ");\n }\n ";
- }
- mainLoop += (rank === 1 ? "} " : "}}");
- this.userCode = "\n const " + dtype + " start = " + dtype + "(" + start + ");\n const " + dtype + " end = " + dtype + "(" + end + ");\n\n void main() {\n " + dtype + " outputLoc = getOutputCoords();\n vec4 result = vec4(0.);\n " + mainLoop + "\n setOutput(result);\n }\n ";
- }
- return PadPackedProgram;
- }());
-
- /**
- * @license
- * Copyright 2017 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var Pool2DProgram = /** @class */ (function () {
- function Pool2DProgram(convInfo, poolType, computePositions, flattenPositions, includeBatchInIndex) {
- if (flattenPositions === void 0) { flattenPositions = false; }
- if (includeBatchInIndex === void 0) { includeBatchInIndex = false; }
- this.variableNames = ['x'];
- if (poolType === 'avg' && computePositions) {
- throw new Error('Cannot compute positions for average pool.');
- }
- var filterWidth = convInfo.filterWidth;
- var strideHeight = convInfo.strideHeight;
- var strideWidth = convInfo.strideWidth;
- var dilationHeight = convInfo.dilationHeight;
- var dilationWidth = convInfo.dilationWidth;
- var effectiveFilterHeight = convInfo.effectiveFilterHeight;
- var effectiveFilterWidth = convInfo.effectiveFilterWidth;
- var padTop = convInfo.padInfo.top;
- var padLeft = convInfo.padInfo.left;
- this.outputShape = convInfo.outShape;
- var isAvgPool = poolType === 'avg';
- var batchFlattenPositionStr = "((batch * " + convInfo.inHeight + " + xR) * " + convInfo.inWidth + " + xC) * " + convInfo.inChannels + " + d";
- var flattenPositionStr = "(xR * " + convInfo.inWidth + " + xC) * " + convInfo.inChannels + " + d";
- var initializationValue = '0.0';
- if (!isAvgPool) {
- // WebGL on Firefox Linux can't compile 1/0 so we do 1/eps.
- initializationValue = '-1.0 / 1e-20';
- }
- if (computePositions) {
- var compareOp_1 = '>=';
- this.userCode = "\n const ivec2 strides = ivec2(" + strideHeight + ", " + strideWidth + ");\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d = coords[3];\n\n ivec2 xRCCorner = coords.yz * strides - pads;\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n // max/min x(?, ?, d) to get y(yR, yC, d).\n // ? = to be determined\n float minMaxValue = 0.0;\n float minMaxValueFound = 0.0;\n int minMaxPosition = 0;\n float avgValue = 0.0;\n\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n int xR = xRCorner + wR;\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + effectiveFilterWidth + ";\n wC += " + dilationWidth + ") {\n int xC = xCCorner + wC;\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n float value = getX(batch, xR, xC, d);\n\n // If a min / max value has already been found, use it. If not,\n // use the current value.\n float currMinMaxValue = mix(\n value, minMaxValue, minMaxValueFound);\n if (value " + compareOp_1 + " currMinMaxValue) {\n minMaxValue = value;\n minMaxValueFound = 1.0;\n minMaxPosition = " + (flattenPositions ? (includeBatchInIndex ? batchFlattenPositionStr :
- flattenPositionStr) :
- "wR * " + effectiveFilterWidth + " + wC") + ";\n }\n }\n }\n setOutput(float(minMaxPosition));\n }\n ";
- return;
- }
- var compareOp = 'max';
- var returnValue = poolType + "(" + poolType + "(" + poolType + "(" +
- 'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])';
- if (poolType === 'avg') {
- returnValue = "avgValue / count";
- }
- var filterWidthNearestVec4 = Math.floor(filterWidth / 4) * 4;
- var filterWidthVec4Remainder = filterWidth % 4;
- var updateSnippet = "\n if (" + isAvgPool + ") {\n avgValue += dot(values, ones);\n } else {\n minMaxValue = " + compareOp + "(values, minMaxValue);\n }\n ";
- this.userCode = "\n const ivec2 strides = ivec2(" + strideHeight + ", " + strideWidth + ");\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n const float initializationValue = " + initializationValue + ";\n const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);\n\n float count = 0.0;\n\n float getValue(int batch, int xR, int xC, int d) {\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n return initializationValue;\n }\n count += 1.0;\n return getX(batch, xR, xC, d);\n }\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d = coords[3];\n\n ivec2 xRCCorner = coords.yz * strides - pads;\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n // max/min x(?, ?, d) to get y(yR, yC, d).\n // ? = to be determined\n vec4 minMaxValue = vec4(" + initializationValue + ");\n float avgValue = 0.0;\n count = 0.0;\n\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n int xR = xRCorner + wR;\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + filterWidthNearestVec4 + "; wC += 4) {\n int xC = xCCorner + wC * " + dilationWidth + ";\n\n vec4 values = vec4(\n getValue(batch, xR, xC, d),\n getValue(batch, xR, xC + " + dilationWidth + ", d),\n getValue(batch, xR, xC + 2 * " + dilationWidth + ", d),\n getValue(batch, xR, xC + 3 * " + dilationWidth + ", d)\n );\n\n " + updateSnippet + "\n }\n\n int xC = xCCorner + " + filterWidthNearestVec4 + ";\n if (" + (filterWidthVec4Remainder === 1) + ") {\n vec4 values = vec4(\n getValue(batch, xR, xC, d),\n initializationValue,\n initializationValue,\n initializationValue\n );\n\n " + updateSnippet + "\n } else if (" + (filterWidthVec4Remainder === 2) + ") {\n vec4 values = vec4(\n getValue(batch, xR, xC, d),\n getValue(batch, xR, xC + " + dilationWidth + ", d),\n initializationValue,\n initializationValue\n );\n\n " + updateSnippet + "\n } else if (" + (filterWidthVec4Remainder === 3) + ") {\n vec4 values = vec4(\n getValue(batch, xR, xC, d),\n getValue(batch, xR, xC + " + dilationWidth + ", d),\n getValue(batch, xR, xC + 2 * " + dilationWidth + ", d),\n initializationValue\n );\n\n " + updateSnippet + "\n }\n }\n setOutput(" + returnValue + ");\n }\n ";
- }
- return Pool2DProgram;
- }());
- var Pool3DProgram = /** @class */ (function () {
- function Pool3DProgram(convInfo, poolType, computePositions, flattenPositions, includeBatchInIndex) {
- if (flattenPositions === void 0) { flattenPositions = false; }
- if (includeBatchInIndex === void 0) { includeBatchInIndex = false; }
- this.variableNames = ['x'];
- if (poolType === 'avg' && computePositions) {
- throw new Error('Cannot compute positions for average pool.');
- }
- var filterWidth = convInfo.filterWidth;
- var strideDepth = convInfo.strideDepth;
- var strideHeight = convInfo.strideHeight;
- var strideWidth = convInfo.strideWidth;
- var dilationDepth = convInfo.dilationDepth;
- var dilationHeight = convInfo.dilationHeight;
- var dilationWidth = convInfo.dilationWidth;
- var effectiveFilterDepth = convInfo.effectiveFilterDepth;
- var effectiveFilterHeight = convInfo.effectiveFilterHeight;
- var effectiveFilterWidth = convInfo.effectiveFilterWidth;
- var padFront = convInfo.padInfo.front;
- var padTop = convInfo.padInfo.top;
- var padLeft = convInfo.padInfo.left;
- this.outputShape = convInfo.outShape;
- var isAvgPool = poolType === 'avg';
- var initializationValue = '0.0';
- if (!isAvgPool) {
- // WebGL on Firefox Linux can't compile 1/0 so we do 1/eps.
- initializationValue = '-1.0 / 1e-20';
- }
- if (computePositions) {
- var compareOp_2 = '>=';
- this.userCode = "\n const ivec3 strides =\n ivec3(" + strideDepth + ", " + strideHeight + ", " + strideWidth + ");\n const ivec3 pads = ivec3(" + padFront + ", " + padTop + ", " + padLeft + ");\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int ch = coords.u;\n\n ivec3 xCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;\n int xDCorner = xCorner.x;\n int xRCorner = xCorner.y;\n int xCCorner = xCorner.z;\n\n // max/min x(?, ?, ?, ch) to get y(yD, yR, yC, ch).\n // ? = to be determined\n float minMaxValue = 0.0;\n float minMaxValueFound = 0.0;\n int minMaxPosition = 0;\n\n for (int wD = 0; wD < " + effectiveFilterDepth + ";\n wD += " + dilationDepth + ") {\n int xD = xDCorner + wD;\n\n if (xD < 0 || xD >= " + convInfo.inDepth + ") {\n continue;\n }\n\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n int xR = xRCorner + wR;\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + effectiveFilterWidth + ";\n wC += " + dilationWidth + ") {\n int xC = xCCorner + wC;\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n float value = getX(batch, xD, xR, xC, ch);\n\n // If a min / max value has already been found, use it. If not,\n // use the current value.\n float currMinMaxValue = mix(\n value, minMaxValue, minMaxValueFound);\n if (value " + compareOp_2 + " currMinMaxValue) {\n minMaxValue = value;\n minMaxValueFound = 1.0;\n minMaxPosition = " + (flattenPositions ?
- (includeBatchInIndex ?
- "(((batch * " + convInfo.inDepth + " + xD) * " + convInfo.inHeight + " + xR) * " + convInfo.inWidth + " + xC) * " + convInfo.inChannels + " + ch" :
- "((xD * " + convInfo.inHeight + " + xR) * " + convInfo.inWidth + " + xC) * " + convInfo.inChannels + " + ch") :
- "wD * " + effectiveFilterHeight + " * " + effectiveFilterWidth + " +\n wR * " + effectiveFilterWidth + " + wC") + ";\n }\n }\n }\n }\n setOutput(float(minMaxPosition));\n }\n ";
- return;
- }
- var compareOp = 'max';
- var returnValue = poolType + "(" + poolType + "(" + poolType + "(" +
- 'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])';
- if (poolType === 'avg') {
- returnValue = "avgValue / count";
- }
- var filterWidthNearestVec4 = Math.floor(filterWidth / 4) * 4;
- var filterWidthVec4Remainder = filterWidth % 4;
- var updateSnippet = "\n if (" + isAvgPool + ") {\n avgValue += dot(values, ones);\n } else {\n minMaxValue = " + compareOp + "(values, minMaxValue);\n }\n ";
- this.userCode = "\n const ivec3 strides =\n ivec3(" + strideDepth + ", " + strideHeight + ", " + strideWidth + ");\n const ivec3 pads = ivec3(" + padFront + ", " + padTop + ", " + padLeft + ");\n const float initializationValue = " + initializationValue + ";\n const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);\n\n float count = 0.0;\n\n float getValue(int batch, int xD, int xR, int xC, int ch) {\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n return initializationValue;\n }\n count += 1.0;\n return getX(batch, xD, xR, xC, ch);\n }\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int ch = coords.u;\n\n ivec3 xCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;\n int xDCorner = xCorner.x;\n int xRCorner = xCorner.y;\n int xCCorner = xCorner.z;\n\n // max/min x(?, ?, ?, d) to get y(yD, yR, yC, ch).\n // ? = to be determined\n vec4 minMaxValue = vec4(" + initializationValue + ");\n float avgValue = 0.0;\n count = 0.0;\n\n for (int wD = 0; wD < " + effectiveFilterDepth + ";\n wD += " + dilationDepth + ") {\n int xD = xDCorner + wD;\n\n if (xD < 0 || xD >= " + convInfo.inDepth + ") {\n continue;\n }\n\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n int xR = xRCorner + wR;\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + filterWidthNearestVec4 + "; wC += 4) {\n int xC = xCCorner + wC * " + dilationWidth + ";\n\n vec4 values = vec4(\n getValue(batch, xD, xR, xC, ch),\n getValue(batch, xD, xR, xC + " + dilationWidth + ", ch),\n getValue(batch, xD, xR, xC + 2 * " + dilationWidth + ", ch),\n getValue(batch, xD, xR, xC + 3 * " + dilationWidth + ", ch)\n );\n\n " + updateSnippet + "\n }\n\n int xC = xCCorner + " + filterWidthNearestVec4 + ";\n if (" + (filterWidthVec4Remainder === 1) + ") {\n vec4 values = vec4(\n getValue(batch, xD, xR, xC, ch),\n initializationValue,\n initializationValue,\n initializationValue\n );\n\n " + updateSnippet + "\n } else if (" + (filterWidthVec4Remainder === 2) + ") {\n vec4 values = vec4(\n getValue(batch, xD, xR, xC, ch),\n getValue(batch, xD, xR, xC + " + dilationWidth + ", ch),\n initializationValue,\n initializationValue\n );\n\n " + updateSnippet + "\n } else if (" + (filterWidthVec4Remainder === 3) + ") {\n vec4 values = vec4(\n getValue(batch, xD, xR, xC, ch),\n getValue(batch, xD, xR, xC + " + dilationWidth + ", ch),\n getValue(batch, xD, xR, xC + 2 * " + dilationWidth + ", ch),\n initializationValue\n );\n\n " + updateSnippet + "\n }\n }\n setOutput(" + returnValue + ");\n }\n }\n ";
- }
- return Pool3DProgram;
- }());
-
- /**
- * @license
- * Copyright 2017 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var ReduceProgram = /** @class */ (function () {
- function ReduceProgram(reduceInfo, reduceType) {
- this.variableNames = ['x'];
- var windowSize = reduceInfo.windowSize;
- var batchSize = reduceInfo.batchSize;
- var inSize = reduceInfo.inSize;
- var outSize = Math.ceil(inSize / windowSize);
- this.outputShape = [batchSize, outSize];
- var initializationValue = '0.0';
- var compareOp = "";
- if (reduceType === 'prod') {
- initializationValue = '1.0';
- }
- else if (reduceType === 'min') {
- // WebGL on Firefox Linux can't compile 1/0 so we do 1/eps.
- initializationValue = '1.0 / 1e-20';
- compareOp = "min";
- }
- else if (reduceType === 'max') {
- // WebGL on Firefox Linux can't compile 1/0 so we do 1/eps.
- initializationValue = '-1.0 / 1e-20';
- compareOp = "max";
- }
- var returnValue = reduceType + "(" + reduceType + "(" + reduceType + "(" +
- 'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])';
- if (reduceType === 'sum') {
- returnValue = "sumValue";
- }
- else if (reduceType === 'prod') {
- returnValue = "prodValue";
- }
- else if (reduceType === 'all') {
- returnValue = "allValue";
- }
- else if (reduceType === 'any') {
- returnValue = "anyValue";
- }
- var windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4;
- var windowSizeVec4Remainder = windowSize % 4;
- var updateSnippet = "\n if (" + (reduceType === 'sum') + ") {\n sumValue += dot(values, ones);\n } else if (" + (reduceType === 'prod') + ") {\n vec2 tmp = vec2(values[0], values[1]) * vec2(values[2], values[3]);\n prodValue *= tmp[0] * tmp[1];\n } else {\n minMaxValue = " + compareOp + "(values, minMaxValue);\n }\n ";
- var vecType = "vec4";
- if (reduceType === 'all') {
- initializationValue = '1.0';
- updateSnippet = "\n bool reducedAllValue = all(values);\n float floatedReducedAllValue = float(reducedAllValue);\n allValue = float(allValue >= 1.0 && floatedReducedAllValue >= 1.0);\n ";
- vecType = "bvec4";
- }
- else if (reduceType === 'any') {
- initializationValue = '0.0';
- updateSnippet = "\n bool reducedAnyValue = any(values);\n float floatedReducedAnyValue = float(reducedAnyValue);\n anyValue = float(anyValue >= 1.0 || floatedReducedAnyValue >= 1.0);\n ";
- vecType = "bvec4";
- }
- var checkOutOfBounds = '';
- if (inSize % windowSize > 0) {
- checkOutOfBounds = "\n if (inIdx < 0 || inIdx >= " + inSize + ") {\n return initializationValue;\n }\n ";
- }
- this.userCode = "\n const float initializationValue = " + initializationValue + ";\n const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);\n\n float getValue(int batch, int inIdx) {\n " + checkOutOfBounds + "\n return getX(batch, inIdx);\n }\n\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n int outIdx = coords[1];\n int inOffset = outIdx * " + windowSize + ";\n\n vec4 minMaxValue = vec4(" + initializationValue + ");\n float prodValue = 1.0;\n float sumValue = 0.0;\n float allValue = 1.0;\n float anyValue = 0.0;\n\n for (int i = 0; i < " + windowSizeNearestVec4 + "; i += 4) {\n int inIdx = inOffset + i;\n " + vecType + " values = " + vecType + "(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n getValue(batch, inIdx + 3)\n );\n\n " + updateSnippet + "\n }\n\n int inIdx = inOffset + " + windowSizeNearestVec4 + ";\n if (" + (windowSizeVec4Remainder === 1) + ") {\n " + vecType + " values = " + vecType + "(\n getValue(batch, inIdx),\n initializationValue,\n initializationValue,\n initializationValue\n );\n\n " + updateSnippet + "\n } else if (" + (windowSizeVec4Remainder === 2) + ") {\n " + vecType + " values = " + vecType + "(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n initializationValue,\n initializationValue\n );\n\n " + updateSnippet + "\n } else if (" + (windowSizeVec4Remainder === 3) + ") {\n " + vecType + " values = " + vecType + "(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n initializationValue\n );\n\n " + updateSnippet + "\n }\n setOutput(" + returnValue + ");\n }\n ";
- }
- return ReduceProgram;
- }());
-
- /**
- * @license
- * Copyright 2018 Google LLC. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var ReshapePackedProgram = /** @class */ (function () {
- function ReshapePackedProgram(outputShape, inputShape) {
- this.variableNames = ['A'];
- this.packedInputs = true;
- this.packedOutput = true;
- this.outputShape = outputShape;
- var mainLoop = "";
- for (var i = 0; i < 4; i++) {
- var thisRC = "thisRC = rc;";
- if (i % 2 === 1) {
- thisRC += "thisRC.z += 1;";
- }
- if (i > 1) {
- thisRC += "thisRC.y += 1;";
- }
- mainLoop += "\n " + thisRC + "\n " + (i > 0 ? "if(thisRC.y < rows && thisRC.z < cols){" : '') + "\n int flatIndex = getFlatIndex(thisRC);\n\n ivec3 inputRC = inputCoordsFromReshapedOutCoords(flatIndex);\n vec2 inputRCInnerDims = vec2(float(inputRC.y),float(inputRC.z));\n\n result[" + i + "] =\n getChannel(getA(inputRC.x, inputRC.y, inputRC.z), inputRCInnerDims);\n " + (i > 0 ? '}' : '') + "\n ";
- }
- this.userCode = "\n " + getReshapedInputCoords(inputShape) + "\n " + getFlatIndexFrom3D(outputShape) + "\n\n void main() {\n ivec3 rc = getOutputCoords();\n\n vec4 result = vec4(0.);\n\n ivec3 thisRC;\n int rows = " + outputShape[1] + ";\n int cols = " + outputShape[2] + ";\n\n " + mainLoop + "\n\n setOutput(result);\n }\n ";
- }
- return ReshapePackedProgram;
- }());
- function getReshapedInputCoords(shape) {
- var coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], shape);
- return "\n ivec3 inputCoordsFromReshapedOutCoords(int index) {\n " + coordsFromIndexSnippet + "\n return ivec3(r, c, d);\n }\n ";
- }
-
- /**
- * @license
- * Copyright 2018 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var ResizeBilinearBackpropProgram = /** @class */ (function () {
- function ResizeBilinearBackpropProgram(dy, x, alignCorners) {
- this.variableNames = ['dy'];
- this.outputShape = [];
- this.outputShape = x.shape;
- var _a = x.shape, xHeight = _a[1], xWidth = _a[2];
- var _b = dy.shape, yHeight = _b[1], yWidth = _b[2];
- // In the backwards pass, we want to find the pixels that were generated for
- // each pixel in the input image the forward pass and add the corresponding
- // coefficient from dy to the gradient (with some interpolation).
- var effectiveXSize = [
- (alignCorners && yHeight > 1) ? xHeight - 1 : xHeight,
- (alignCorners && yWidth > 1) ? xWidth - 1 : xWidth
- ];
- var effectiveYSize = [
- (alignCorners && yHeight > 1) ? yHeight - 1 : yHeight,
- (alignCorners && yWidth > 1) ? yWidth - 1 : yWidth
- ];
- var heightScale = effectiveXSize[0] / effectiveYSize[0];
- var widthScale = effectiveXSize[1] / effectiveYSize[1];
- var invHeightScale = 1 / heightScale;
- var invWidthScale = 1 / widthScale;
- // This defines the size of the window of values around a particular
- // index in dy that we want to search for contributions to dx.
- var winHeight = (Math.ceil(invHeightScale) * 2) + 2;
- var winWidth = (Math.ceil(invWidthScale) * 2) + 2;
- this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n int r = coords[1];\n int c = coords[2];\n\n float accumulator = 0.0;\n\n const float heightScale = float(" + heightScale + ");\n const float widthScale = float(" + widthScale + ");\n\n const float invHeightScale = float(" + invHeightScale + ");\n const float invWidthScale = float(" + invWidthScale + ");\n\n const int winHeight = int(" + winHeight + ");\n const int winWidth = int(" + winWidth + ");\n\n // Compute bounds for where in dy we will look\n float startRLerp = floor(float(r) * invHeightScale);\n int startDyR = int(startRLerp - float(winHeight / 2));\n\n float startCLerp = floor(float(c) * invWidthScale);\n int startDyC = int(startCLerp - float(winWidth / 2));\n\n // Loop over dy\n for (int dyROffset = 0; dyROffset < winHeight; dyROffset++) {\n int dyR = dyROffset + startDyR;\n\n // Guard against the window exceeding the bounds of dy\n if (dyR < 0 || dyR >= " + yHeight + ") {\n continue;\n }\n\n for (int dyCOffset = 0; dyCOffset < winWidth; dyCOffset++) {\n int dyC = dyCOffset + startDyC;\n\n // Guard against the window exceeding the bounds of dy\n if (dyC < 0 || dyC >= " + yWidth + ") {\n continue;\n }\n\n float dxR = float(dyR) * heightScale;\n int topDxRIndex = int(floor(dxR));\n int bottomDxRIndex = int(min(ceil(dxR), " + (xHeight - 1) + ".0));\n float dxRLerp = dxR - float(topDxRIndex);\n float inverseDxRLerp = 1.0 - dxRLerp;\n\n float dxC = float(dyC) * widthScale;\n int leftDxCIndex = int(floor(dxC));\n int rightDxCIndex = int(min(ceil(dxC), " + (xWidth - 1) + ".0));\n float dxCLerp = dxC - float(leftDxCIndex);\n float inverseDxCLerp = 1.0 - dxCLerp;\n\n if (r == topDxRIndex && c == leftDxCIndex) {\n // topLeft\n accumulator +=\n getDy(b, dyR, dyC, d) * inverseDxRLerp * inverseDxCLerp;\n }\n\n if (r == topDxRIndex && c == rightDxCIndex) {\n // topRight\n accumulator += getDy(b, dyR, dyC, d) * inverseDxRLerp * dxCLerp;\n }\n\n if (r == bottomDxRIndex && c == leftDxCIndex) {\n // bottomLeft\n accumulator += getDy(b, dyR, dyC, d) * dxRLerp * inverseDxCLerp;\n }\n\n if (r == bottomDxRIndex && c == rightDxCIndex) {\n // bottomRight\n accumulator += getDy(b, dyR, dyC, d) * dxRLerp * dxCLerp;\n }\n }\n }\n // End loop over dy\n\n setOutput(accumulator);\n }\n ";
- }
- return ResizeBilinearBackpropProgram;
- }());
-
- /**
- * @license
- * Copyright 2017 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var ResizeBilinearProgram = /** @class */ (function () {
- function ResizeBilinearProgram(inputShape, newHeight, newWidth, alignCorners) {
- this.variableNames = ['A'];
- this.outputShape = [];
- var batch = inputShape[0], oldHeight = inputShape[1], oldWidth = inputShape[2], depth = inputShape[3];
- this.outputShape = [batch, newHeight, newWidth, depth];
- var effectiveInSize = [
- (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,
- (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth
- ];
- var effectiveOutSize = [
- (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,
- (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth
- ];
- this.userCode = "\n const vec2 effectiveInputOverOutputRatioRC = vec2(\n " + effectiveInSize[0] / effectiveOutSize[0] + ",\n " + effectiveInSize[1] / effectiveOutSize[1] + ");\n const vec2 inputShapeRC = vec2(" + oldHeight + ".0, " + oldWidth + ".0);\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n ivec2 yRC = coords.yz;\n\n // Fractional source index.\n vec2 sourceFracIndexRC = vec2(yRC) * effectiveInputOverOutputRatioRC;\n\n // Compute the four integer indices.\n ivec2 sourceFloorRC = ivec2(sourceFracIndexRC);\n ivec2 sourceCeilRC = ivec2(\n min(inputShapeRC - 1.0, ceil(sourceFracIndexRC)));\n\n float topLeft = getA(b, sourceFloorRC.x, sourceFloorRC.y, d);\n float bottomLeft = getA(b, sourceCeilRC.x, sourceFloorRC.y, d);\n float topRight = getA(b, sourceFloorRC.x, sourceCeilRC.y, d);\n float bottomRight = getA(b, sourceCeilRC.x, sourceCeilRC.y, d);\n\n vec2 fracRC = sourceFracIndexRC - vec2(sourceFloorRC);\n\n float top = topLeft + (topRight - topLeft) * fracRC.y;\n float bottom = bottomLeft + (bottomRight - bottomLeft) * fracRC.y;\n float newValue = top + (bottom - top) * fracRC.x;\n\n setOutput(newValue);\n }\n ";
- }
- return ResizeBilinearProgram;
- }());
-
- /**
- * @license
- * Copyright 2019 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var ResizeBilinearPackedProgram = /** @class */ (function () {
- function ResizeBilinearPackedProgram(inputShape, newHeight, newWidth, alignCorners) {
- this.variableNames = ['A'];
- this.packedInputs = true;
- this.packedOutput = true;
- this.outputShape = [];
- var batch = inputShape[0], oldHeight = inputShape[1], oldWidth = inputShape[2], depth = inputShape[3];
- this.outputShape = [batch, newHeight, newWidth, depth];
- var effectiveInSize = [
- (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,
- (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth
- ];
- var effectiveOutSize = [
- (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,
- (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth
- ];
- this.userCode = "\n const vec3 effectiveInputOverOutputRatioRC = vec3(\n " + effectiveInSize[0] / effectiveOutSize[0] + ",\n " + effectiveInSize[1] / effectiveOutSize[1] + ",\n " + effectiveInSize[1] / effectiveOutSize[1] + ");\n const vec3 inputShapeRC = vec3(" + oldHeight + ".0, " + oldWidth + ".0,\n " + oldWidth + ".0);\n\n float getAValue(int b, int r, int c, int d) {\n return getChannel(getA(b, r, c, d), vec2(c, d));\n }\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n // Calculate values for next column in yRC.z.\n ivec3 yRC = coords.yzz + ivec3(0, 0, 1);\n\n // Fractional source index.\n vec3 sourceFracIndexRC = vec3(yRC) * effectiveInputOverOutputRatioRC;\n\n // Compute the four integer indices.\n ivec3 sourceFloorRC = ivec3(sourceFracIndexRC);\n ivec3 sourceCeilRC = ivec3(\n min(inputShapeRC - 1.0, ceil(sourceFracIndexRC)));\n\n // Should we calculate next column and row elements in 2x2 packed cell.\n bool hasNextCol = d < " + (depth - 1) + ";\n bool hasNextRow = coords.z < " + (newWidth - 1) + ";\n\n // In parallel, construct four corners for all four components in\n // packed 2x2 cell.\n vec4 topLeft = vec4(\n getAValue(b, sourceFloorRC.x, sourceFloorRC.y, d),\n hasNextCol ? getAValue(b, sourceFloorRC.x, sourceFloorRC.y, d + 1)\n : 0.0,\n hasNextRow ? getAValue(b, sourceFloorRC.x, sourceFloorRC.z, d)\n : 0.0,\n (hasNextRow && hasNextCol) ?\n getAValue(b, sourceFloorRC.x, sourceFloorRC.z, d + 1) : 0.0);\n\n vec4 bottomLeft = vec4(\n getAValue(b, sourceCeilRC.x, sourceFloorRC.y, d),\n hasNextCol ? getAValue(b, sourceCeilRC.x, sourceFloorRC.y, d + 1)\n : 0.0,\n hasNextRow ? getAValue(b, sourceCeilRC.x, sourceFloorRC.z, d)\n : 0.0,\n (hasNextRow && hasNextCol) ?\n getAValue(b, sourceCeilRC.x, sourceFloorRC.z, d + 1) : 0.0);\n\n vec4 topRight = vec4(\n getAValue(b, sourceFloorRC.x, sourceCeilRC.y, d),\n hasNextCol ? getAValue(b, sourceFloorRC.x, sourceCeilRC.y, d + 1)\n : 0.0,\n hasNextRow ? getAValue(b, sourceFloorRC.x, sourceCeilRC.z, d)\n : 0.0,\n (hasNextRow && hasNextCol) ?\n getAValue(b, sourceFloorRC.x, sourceCeilRC.z, d + 1) : 0.0);\n\n vec4 bottomRight = vec4(\n getAValue(b, sourceCeilRC.x, sourceCeilRC.y, d),\n hasNextCol ? getAValue(b, sourceCeilRC.x, sourceCeilRC.y, d + 1)\n : 0.0,\n hasNextRow ? getAValue(b, sourceCeilRC.x, sourceCeilRC.z, d)\n : 0.0,\n (hasNextRow && hasNextCol) ?\n getAValue(b, sourceCeilRC.x, sourceCeilRC.z, d + 1) : 0.0);\n\n vec3 fracRC = sourceFracIndexRC - vec3(sourceFloorRC);\n\n vec4 top = mix(topLeft, topRight, fracRC.yyzz);\n vec4 bottom = mix(bottomLeft, bottomRight, fracRC.yyzz);\n vec4 newValue = mix(top, bottom, fracRC.x);\n\n setOutput(newValue);\n }\n ";
- }
- return ResizeBilinearPackedProgram;
- }());
-
- /**
- * @license
- * Copyright 2018 Google LLC All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var ResizeNearestNeigborBackpropProgram = /** @class */ (function () {
- function ResizeNearestNeigborBackpropProgram(dy, x, alignCorners) {
- this.variableNames = ['dy'];
- this.outputShape = [];
- this.outputShape = x.shape;
- var _a = x.shape, xHeight = _a[1], xWidth = _a[2];
- var _b = dy.shape, yHeight = _b[1], yWidth = _b[2];
- // In the backwards pass, we want to find the pixels that were generated for
- // each pixel in the input image the forward pass and add the corresponding
- // coefficient from dy to the gradient (with some interpolation).
- var effectiveXSize = [
- (alignCorners && yHeight > 1) ? xHeight - 1 : xHeight,
- (alignCorners && yWidth > 1) ? xWidth - 1 : xWidth
- ];
- var effectiveYSize = [
- (alignCorners && yHeight > 1) ? yHeight - 1 : yHeight,
- (alignCorners && yWidth > 1) ? yWidth - 1 : yWidth
- ];
- var heightScale = effectiveXSize[0] / effectiveYSize[0];
- var widthScale = effectiveXSize[1] / effectiveYSize[1];
- var invHeightScale = 1 / heightScale;
- var invWidthScale = 1 / widthScale;
- // This defines the size of the window of values around a particular
- // index in dy that we want to search for contributions to dx.
- var winHeight = (Math.ceil(invHeightScale) * 2) + 2;
- var winWidth = (Math.ceil(invWidthScale) * 2) + 2;
- this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n int r = coords[1];\n int c = coords[2];\n\n float accumulator = 0.0;\n\n const float heightScale = float(" + heightScale + ");\n const float widthScale = float(" + widthScale + ");\n\n const float invHeightScale = float(" + invHeightScale + ");\n const float invWidthScale = float(" + invWidthScale + ");\n\n const int winHeight = int(" + winHeight + ");\n const int winWidth = int(" + winWidth + ");\n\n // Compute bounds for where in dy we will look\n float startRLerp = floor(float(r) * invHeightScale);\n int startDyR = int(floor(startRLerp - float(winHeight / 2)));\n\n float startCLerp = floor(float(c) * invWidthScale);\n int startDyC = int(floor(startCLerp - float(winWidth / 2)));\n\n // Loop over dy\n for (int dyROffset = 0; dyROffset < winHeight; dyROffset++) {\n int dyR = dyROffset + startDyR;\n\n // Guard against the window exceeding the bounds of dy\n if (dyR < 0 || dyR >= " + yHeight + ") {\n continue;\n }\n\n for (int dyCOffset = 0; dyCOffset < winWidth; dyCOffset++) {\n int dyC = dyCOffset + startDyC;\n\n // Guard against the window exceeding the bounds of dy\n if (dyC < 0 || dyC >= " + yWidth + ") {\n continue;\n }\n\n float sourceFracRow =\n float(" + effectiveXSize[0] + ") *\n (float(dyR) / float(" + effectiveYSize[0] + "));\n\n float sourceFracCol =\n float(" + effectiveXSize[1] + ") *\n (float(dyC) / float(" + effectiveYSize[1] + "));\n\n int sourceNearestRow = int(min(\n float(int(" + xHeight + ") - 1),\n " + alignCorners + " ? float(round(sourceFracRow)) :\n float(floor(sourceFracRow))));\n\n int sourceNearestCol = int(min(\n float(int(" + xWidth + ") - 1),\n " + alignCorners + " ? float(round(sourceFracCol)) :\n float(floor(sourceFracCol))));\n\n if (r == sourceNearestRow && c == sourceNearestCol) {\n accumulator += getDy(b, dyR, dyC, d);\n }\n }\n }\n // End loop over dy\n\n setOutput(accumulator);\n }\n ";
- }
- return ResizeNearestNeigborBackpropProgram;
- }());
-
- /**
- * @license
- * Copyright 2018 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var ResizeNearestNeighborProgram = /** @class */ (function () {
- function ResizeNearestNeighborProgram(inputShape, newHeight, newWidth, alignCorners) {
- this.variableNames = ['A'];
- this.outputShape = [];
- var batch = inputShape[0], oldHeight = inputShape[1], oldWidth = inputShape[2], depth = inputShape[3];
- this.outputShape = [batch, newHeight, newWidth, depth];
- var effectiveInSize = [
- (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,
- (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth
- ];
- var effectiveOutSize = [
- (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,
- (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth
- ];
- // When align corners is false, we rounds the value with floor.
- var roundBase = alignCorners ? '0.5' : '0.0';
- this.userCode = "\n const vec2 effectiveInputOverOutputRatioRC = vec2(\n " + effectiveInSize[0] / effectiveOutSize[0] + ",\n " + effectiveInSize[1] / effectiveOutSize[1] + ");\n const vec2 inputShapeRC = vec2(" + oldHeight + ".0, " + oldWidth + ".0);\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n ivec2 yRC = coords.yz;\n\n // Fractional source index.\n vec2 sourceFracIndexRC = vec2(yRC) * effectiveInputOverOutputRatioRC;\n\n // Compute the coordinators of nearest neighbor point.\n ivec2 sourceNearestRC = ivec2(\n min(inputShapeRC - 1.0, floor(sourceFracIndexRC + " + roundBase + ")));\n\n float newValue = getA(b, sourceNearestRC.x, sourceNearestRC.y, d);\n\n setOutput(newValue);\n }\n ";
- }
- return ResizeNearestNeighborProgram;
- }());
-
- /**
- * @license
- * Copyright 2017 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var ReverseProgram = /** @class */ (function () {
- function ReverseProgram(xShape, axis) {
- this.variableNames = ['x'];
- var rank = xShape.length;
- if (rank > 4) {
- throw new Error("WebGL backend: Reverse of rank-" + rank + " tensor is not yet supported");
- }
- this.outputShape = xShape;
- if (rank === 1) {
- this.userCode = "\n void main() {\n int coord = getOutputCoords();\n setOutput(getX(" + xShape[0] + " - coord - 1));\n }\n ";
- return;
- }
- var getInCoord = function (i) {
- if (axis.indexOf(i) !== -1 && xShape[i] !== 1) {
- return xShape[i] + " - coords[" + i + "] - 1";
- }
- return "coords[" + i + "]";
- };
- var inCoords = xShape.map(function (_, i) { return getInCoord(i); }).join(',');
- var type = getCoordsDataType(rank);
- this.userCode = "\n void main() {\n " + type + " coords = getOutputCoords();\n setOutput(getX(" + inCoords + "));\n }\n ";
- }
- return ReverseProgram;
- }());
-
- /**
- * @license
- * Copyright 2019 Google LLC All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var ReversePackedProgram = /** @class */ (function () {
- function ReversePackedProgram(xShape, axis) {
- this.variableNames = ['x'];
- this.packedInputs = true;
- this.packedOutput = true;
- var rank = xShape.length;
- if (rank > 4) {
- throw new Error("WebGL backend: Reverse of rank-" + rank + " tensor is not yet supported");
- }
- this.outputShape = xShape;
- var channels = getChannels('rc', rank);
- var nextColumn = channels[rank - 1] + " + 1 < " + this.outputShape[rank - 1];
- var nextRow = channels[rank - 2] + " + 1 < " + this.outputShape[rank - 2];
- var type = getCoordsDataType(rank);
- if (rank === 1) {
- this.userCode = "\n void main(){\n int rc = getOutputCoords();\n vec4 result = vec4(0.);\n result.r = getChannel(getX(" + xShape[0] + " - rc - 1),\n " + xShape[0] + " - rc - 1);\n if(" + nextColumn + "){\n result.g = getChannel(getX(" + xShape[0] + " - (rc + 1) - 1),\n " + xShape[0] + " - (rc + 1) - 1);\n }\n setOutput(result);\n }\n ";
- }
- else {
- this.userCode = "\n void main() {\n " + type + " rc = getOutputCoords();\n vec4 result = vec4(0.);\n result.r = " + getR(channels.slice()) + ";\n if(" + nextColumn + "){\n result.g = " + getG(channels.slice()) + ";\n }\n if(" + nextRow + ") {\n result.b = " + getB(channels.slice()) + ";\n if(" + nextColumn + ") {\n result.a = " + getA(channels.slice()) + ";\n }\n }\n setOutput(result);\n }\n ";
- }
- function getR(channels) {
- return getChannel(channels);
- }
- function getG(channels) {
- channels[rank - 1] = '(' + channels[rank - 1] + " + 1)";
- return getChannel(channels);
- }
- function getB(channels) {
- channels[rank - 2] = '(' + channels[rank - 2] + " + 1)";
- return getChannel(channels);
- }
- function getA(channels) {
- channels[rank - 1] = '(' + channels[rank - 1] + " + 1)";
- channels[rank - 2] = '(' + channels[rank - 2] + " + 1)";
- return getChannel(channels);
- }
- function getChannel(channels) {
- var inCoordsArray = xShape.map(function (_, i) { return getInCoord(i, channels); });
- var inCoords = inCoordsArray.join(',');
- var innerDims = inCoordsArray.slice(-2).join(',');
- return "getChannel(getX(" + inCoords + "), vec2(" + innerDims + "))";
- }
- function getInCoord(i, channels1) {
- if (axis.indexOf(i) !== -1 && xShape[i] !== 1) {
- return xShape[i] + " - " + channels1[i] + " - 1";
- }
- else {
- return "" + channels1[i];
- }
- }
- }
- return ReversePackedProgram;
- }());
-
- /**
- * @license
- * Copyright 2018 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var ScatterProgram = /** @class */ (function () {
- function ScatterProgram(updateSize, sliceDim, indicesRank, updatesRank, strides, shape, summingDupeIndex) {
- if (summingDupeIndex === void 0) { summingDupeIndex = true; }
- this.variableNames = ['updates', 'indices', 'defaultValue'];
- this.outputShape = shape;
- var stridesType = getCoordsDataType(strides.length);
- var dtype = getCoordsDataType(shape.length);
- var indicesString = '';
- if (indicesRank === 1) {
- indicesString = 'i';
- }
- else if (indicesRank === 2) {
- indicesString = 'i, j';
- }
- var indicesSnippet = "getIndices(" + indicesString + ")";
- var updatesString = '';
- if (updatesRank === 1) {
- updatesString = 'i';
- }
- else if (updatesRank === 2) {
- updatesString = 'i, coords[1]';
- }
- var updatesSnippet = "getUpdates(" + updatesString + ")";
- var strideString = sliceDim > 1 ? 'strides[j]' : 'strides';
- this.userCode = "\n " + stridesType + " strides = " + stridesType + "(" + strides + ");\n\n void main() {\n " + dtype + " coords = getOutputCoords();\n float sum = 0.0;\n bool found = false;\n for (int i = 0; i < " + updateSize + "; i++) {\n int flattenedIndex = 0;\n for (int j = 0; j < " + sliceDim + "; j++) {\n int index = round(" + indicesSnippet + ");\n flattenedIndex += index * " + strideString + ";\n }\n if (flattenedIndex == coords[0]) {\n sum += " + updatesSnippet + ";\n found = true;\n }\n }\n setOutput(mix(getDefaultValue(), sum, float(found)));\n }\n ";
- }
- return ScatterProgram;
- }());
-
- /**
- * @license
- * Copyright 2018 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var SegmentOpProgram = /** @class */ (function () {
- function SegmentOpProgram(segOpInfo, segOpType) {
- this.variableNames = ['x', 'segmentIds'];
- var windowSize = segOpInfo.windowSize;
- var batchSize = segOpInfo.batchSize;
- var inSize = segOpInfo.inSize;
- var numSegments = segOpInfo.numSegments;
- var outSize = numSegments * Math.ceil(inSize / windowSize);
- this.outputShape = [batchSize, outSize];
- var initializationValue = '0.0';
- var returnValue = "sumValue";
- var windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4;
- var windowSizeVec4Remainder = windowSize % 4;
- var updateSnippet = "\n sumValue += dot(values, segFilter);\n ";
- var checkValueOutOfBounds = '';
- if (inSize % windowSize > 0) {
- checkValueOutOfBounds = "\n if (inIdx < 0 || inIdx >= " + inSize + ") {\n return initializationValue;\n }\n ";
- }
- var checkSegmentIdOutOfBounds = '';
- if (inSize % windowSize > 0) {
- checkSegmentIdOutOfBounds = "\n if (inIdx < 0 || inIdx >= " + inSize + ") {\n return -1.0;\n }\n ";
- }
- this.userCode = "\n const float initializationValue = " + initializationValue + ";\n\n float getValue(int batch, int inIdx) {\n " + checkValueOutOfBounds + "\n return getX(batch, inIdx);\n }\n\n float getSegmentIdAtIndex(int inIdx) {\n " + checkSegmentIdOutOfBounds + "\n return getSegmentIds(inIdx);\n }\n\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n int outIdx = coords[1];\n int inOffset = int(floor(float(outIdx) / float(\n " + numSegments + ")) * float(" + windowSize + "));\n int currentSeg = int(mod(float(outIdx), float(" + numSegments + ")));\n\n float sumValue = 0.0;\n\n for (int i = 0; i < " + windowSizeNearestVec4 + "; i += 4) {\n int inIdx = inOffset + i;\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n getValue(batch, inIdx + 3)\n );\n\n vec4 segFilter = vec4(\n int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 2)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 3)) == currentSeg ? 1 : 0\n );\n\n " + updateSnippet + "\n }\n\n int inIdx = inOffset + " + windowSizeNearestVec4 + ";\n if (" + (windowSizeVec4Remainder === 1) + ") {\n vec4 values = vec4(\n getValue(batch, inIdx),\n initializationValue,\n initializationValue,\n initializationValue\n );\n\n int inIdxSeg = int(getSegmentIdAtIndex(inIdx));\n\n vec4 segFilter = vec4(\n int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,\n 0,\n 0,\n 0\n );\n\n " + updateSnippet + "\n } else if (" + (windowSizeVec4Remainder === 2) + ") {\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n initializationValue,\n initializationValue\n );\n\n vec4 segFilter = vec4(\n int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,\n 0,\n 0\n );\n\n " + updateSnippet + "\n } else if (" + (windowSizeVec4Remainder === 3) + ") {\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n initializationValue\n );\n\n vec4 segFilter = vec4(\n int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 2)) == currentSeg ? 1 : 0,\n 0\n );\n\n " + updateSnippet + "\n }\n setOutput(" + returnValue + ");\n }\n ";
- }
- return SegmentOpProgram;
- }());
-
- /**
- * @license
- * Copyright 2017 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var SelectProgram = /** @class */ (function () {
- function SelectProgram(cRank, shape, rank) {
- this.variableNames = ['c', 'a', 'b'];
- this.outputShape = shape;
- var cCoords;
- var abCoords;
- if (rank > 4) {
- throw Error("Where for rank " + rank + " is not yet supported");
- }
- if (rank === 1) {
- abCoords = "resRC";
- cCoords = "resRC";
- }
- else {
- var currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w'];
- var cCoordVars = [];
- var abCoordVars = [];
- for (var i = 0; i < shape.length; i++) {
- abCoordVars.push("" + currentCoords[i]);
- if (i < cRank) {
- cCoordVars.push("" + currentCoords[i]);
- }
- }
- cCoords = cCoordVars.join();
- abCoords = abCoordVars.join();
- }
- var dtype = getCoordsDataType(rank);
- this.userCode = "\n void main() {\n " + dtype + " resRC = getOutputCoords();\n float cVal = getC(" + cCoords + ");\n if (cVal >= 1.0) {\n setOutput(getA(" + abCoords + "));\n } else {\n setOutput(getB(" + abCoords + "));\n }\n }\n ";
- }
- return SelectProgram;
- }());
-
- /**
- * @license
- * Copyright 2017 Google Inc. All Rights Reserved.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- * =============================================================================
- */
- var SliceProgram = /** @class */ (function () {
- function SliceProgram(destSize) {
-