diff --git a/calibration_trials.xlsx b/calibration_trials.xlsx index a74a9b5c56fe73dd1c8ab79af7db98166bf6ee9b..d2ae20d760f111b3499c1b7848b374d0a9437d31 100644 Binary files a/calibration_trials.xlsx and b/calibration_trials.xlsx differ diff --git a/demo_eye_tracking2-legacy-browsers.js b/demo_eye_tracking2-legacy-browsers.js index f859aa3cd23494abd46bb1a21589f19cc2d52de9..848bdd402a11e7b6c07a0c4439af85d7d1ae6070 100644 --- a/demo_eye_tracking2-legacy-browsers.js +++ b/demo_eye_tracking2-legacy-browsers.js @@ -16,7 +16,7 @@ const psychoJS = new PsychoJS({ // open window: psychoJS.openWindow({ fullscr: true, - color: new util.Color([0, 0, 0]), + color: new util.Color([(- 1), (- 1), (- 1)]), units: 'height', waitBlanking: true }); @@ -33,22 +33,22 @@ psychoJS.scheduleCondition(function() { return (psychoJS.gui.dialogComponent.but // flowScheduler gets run if the participants presses OK flowScheduler.add(updateInfo); // add timeStamp flowScheduler.add(experimentInit); -flowScheduler.add(loading_trialRoutineBegin()); -flowScheduler.add(loading_trialRoutineEachFrame()); -flowScheduler.add(loading_trialRoutineEnd()); -flowScheduler.add(webcam_trialRoutineBegin()); -flowScheduler.add(webcam_trialRoutineEachFrame()); -flowScheduler.add(webcam_trialRoutineEnd()); -flowScheduler.add(intro_calibatrion_trialRoutineBegin()); -flowScheduler.add(intro_calibatrion_trialRoutineEachFrame()); -flowScheduler.add(intro_calibatrion_trialRoutineEnd()); +flowScheduler.add(initializeEyetrackingRoutineBegin()); +flowScheduler.add(initializeEyetrackingRoutineEachFrame()); +flowScheduler.add(initializeEyetrackingRoutineEnd()); +flowScheduler.add(inst1RoutineBegin()); +flowScheduler.add(inst1RoutineEachFrame()); +flowScheduler.add(inst1RoutineEnd()); +flowScheduler.add(calibrationIntroRoutineBegin()); +flowScheduler.add(calibrationIntroRoutineEachFrame()); +flowScheduler.add(calibrationIntroRoutineEnd()); const trialsLoopScheduler = new Scheduler(psychoJS); flowScheduler.add(trialsLoopBegin(trialsLoopScheduler)); flowScheduler.add(trialsLoopScheduler); flowScheduler.add(trialsLoopEnd); -flowScheduler.add(tracking_trialRoutineBegin()); -flowScheduler.add(tracking_trialRoutineEachFrame()); -flowScheduler.add(tracking_trialRoutineEnd()); +flowScheduler.add(trackingTrialRoutineBegin()); +flowScheduler.add(trackingTrialRoutineEachFrame()); +flowScheduler.add(trackingTrialRoutineEnd()); flowScheduler.add(quitPsychoJS, '', true); // quit if user presses Cancel in dialog box: @@ -58,7 +58,8 @@ psychoJS.start({ expName: expName, expInfo: expInfo, resources: [ - {'name': 'calibration_trials.xlsx', 'path': 'calibration_trials.xlsx'} + {'name': 'calibration_trials.xlsx', 'path': 'calibration_trials.xlsx'}, + {'name': 'webgazer-2.0.1.tp.js', 'path': 'webgazer-2.0.1.tp.js'} ] }); @@ -86,108 +87,117 @@ async function updateInfo() { } -var loading_trialClock; -var loading_text; -var webcam_trialClock; -var intro_text; -var intro_calibatrion_trialClock; -var calibration_text; -var mouse_2; -var calibration_trialClock; +var initializeEyetrackingClock; +var webcamWarning; +var inst1Clock; +var instruction1Txt; +var inst1_resp; +var calibrationIntroClock; +var calibrationTxt; +var calibrationMouse; +var calibrationClock; var calibration_square; -var mouse_3; -var tracking_trialClock; +var calibrationClick; +var trackingTrialClock; var tracking_square; +var trackingTxt; +var tracking_resp; var globalClock; var routineTimer; async function experimentInit() { - // Initialize components for Routine "loading_trial" - loading_trialClock = new util.Clock(); - // Download the webgazer library and re-download seedrandom.js (since webgazer - // overrides it with a version that conflicts with PsychoJS) - psychoJS.downloadResources([ - { name: 'webgazer.js', path: 'js/webgazer-2.0.1.tp.js' }, - { name: 'seedrandom.js', path: 'https://cdnjs.cloudflare.com/ajax/libs/seedrandom/3.0.1/seedrandom.min.js' } - ]); + // Initialize components for Routine "initializeEyetracking" + initializeEyetrackingClock = new util.Clock(); + //initialize params of the webgazer package (used for eye tracking) + // Initialize x and y arrays; we use these to calculate running averages of // current gaze position; the longer the window, the slower, but more fluent // the updates let averagingWindow = 10; window.xGazes = new Array(averagingWindow ).fill(0); window.yGazes = new Array(averagingWindow ).fill(0); - // Timestamp for last time eyes exited validation box - window.eyesExitedTimestamp= (new Date).getTime(); - // No. of ms to keep webcam thumbnail visible after eyes returned into validation box - window.eyesReturnedDelay = 3000; - // DEBUG - window.psychoJS = psychoJS; - loading_text = new visual.TextStim({ + + webcamWarning = new visual.TextStim({ win: psychoJS.window, - name: 'loading_text', - text: 'Downloading additional resources. \n\nOne moment please...', + name: 'webcamWarning', + text: 'This experiment uses eye tracking. \n\nYou should see your web-browser request access to your webcam. You might need to click on this text to make that happen. Please permit access, and wait a little while. Your webcam video should appear in the top-left of the screen.', font: 'Arial', units: undefined, - pos: [0, 0], height: 0.1, wrapWidth: undefined, ori: 0, - color: new util.Color('white'), opacity: 1, + pos: [0, 0], height: 0.05, wrapWidth: undefined, ori: 0.0, + color: new util.Color('white'), opacity: undefined, depth: -1.0 }); - // Initialize components for Routine "webcam_trial" - webcam_trialClock = new util.Clock(); - intro_text = new visual.TextStim({ + // Initialize components for Routine "inst1" + inst1Clock = new util.Clock(); + instruction1Txt = new visual.TextStim({ win: psychoJS.window, - name: 'intro_text', - text: 'demo_eye_tracking: starting webcam\n\nThis experiment demonstrates eye tracking via the webgazer library. \n\nYou should see your web-browser request access to your webcam. You might need to click on this text to make that happen. Please permit access, and wait a little while. Your webcam video should appear in the top-left of the screen.', + name: 'instruction1Txt', + text: 'Webgazer initialized. \nPress space to move on', font: 'Arial', units: undefined, - pos: [0, 0], height: 0.04, wrapWidth: undefined, ori: 0, - color: new util.Color('white'), opacity: 1, - depth: 0.0 + pos: [0, 0], height: 0.05, wrapWidth: undefined, ori: 0.0, + color: new util.Color('white'), opacity: undefined, + depth: -1.0 }); - // Initialize components for Routine "intro_calibatrion_trial" - intro_calibatrion_trialClock = new util.Clock(); - calibration_text = new visual.TextStim({ + inst1_resp = new core.Keyboard({psychoJS: psychoJS, clock: new util.Clock(), waitForStart: true}); + + // Initialize components for Routine "calibrationIntro" + calibrationIntroClock = new util.Clock(); + calibrationTxt = new visual.TextStim({ win: psychoJS.window, - name: 'calibration_text', - text: "demo_eye_tracking: calibration\n\nNow we'll calibrate the eye tracker. Please try to keep your head still and within the rectangle you see in your webcam video. When you do so, the rectangle turns green.\n\nIn the next part of this experiment, the webcam video disappears. It will reappear when your head is too from the rectangle. If this happens, please move back into view. White squares appears at different locations on the screen. Please click each square with your mouse.\n\nClick anywhere to continue...", + name: 'calibrationTxt', + text: "OK great! we are almost ready to get started. \n\nFirst we need to calibrate the eye tracker. Please try to keep your head still. If you move your head too far away, you'r webcam will appear in the top left corner. If this happens, please move back into view. \n\nWhite squares will appear at different locations on the screen. Please click each square with your mouse.\n\nClick anywhere with the mouse to continue...", font: 'Arial', units: undefined, - pos: [0, 0], height: 0.04, wrapWidth: undefined, ori: 0, - color: new util.Color('white'), opacity: 1, + pos: [0, 0], height: 0.05, wrapWidth: undefined, ori: 0.0, + color: new util.Color('white'), opacity: undefined, depth: 0.0 }); - mouse_2 = new core.Mouse({ + calibrationMouse = new core.Mouse({ win: psychoJS.window, }); - mouse_2.mouseClock = new util.Clock(); - // Initialize components for Routine "calibration_trial" - calibration_trialClock = new util.Clock(); + calibrationMouse.mouseClock = new util.Clock(); + // Initialize components for Routine "calibration" + calibrationClock = new util.Clock(); calibration_square = new visual.Rect ({ win: psychoJS.window, name: 'calibration_square', - width: [0.022, 0.022][0], height: [0.022, 0.022][1], - ori: 0, pos: [0, 0], - lineWidth: 0, lineColor: new util.Color([1, 1, 1]), - fillColor: new util.Color([1, 1, 1]), - opacity: 1, depth: 0, interpolate: true, + width: [0.02, 0.02][0], height: [0.02, 0.02][1], + ori: 0.0, pos: [0, 0], + lineWidth: 1.0, lineColor: new util.Color('white'), + fillColor: new util.Color('white'), + opacity: undefined, depth: -1, interpolate: true, }); - mouse_3 = new core.Mouse({ + calibrationClick = new core.Mouse({ win: psychoJS.window, }); - mouse_3.mouseClock = new util.Clock(); - // Initialize components for Routine "tracking_trial" - tracking_trialClock = new util.Clock(); + calibrationClick.mouseClock = new util.Clock(); + // Initialize components for Routine "trackingTrial" + trackingTrialClock = new util.Clock(); tracking_square = new visual.Rect ({ win: psychoJS.window, name: 'tracking_square', width: [0.02, 0.02][0], height: [0.02, 0.02][1], - ori: 0, pos: [0, 0], - lineWidth: undefined, lineColor: new util.Color([1, 1, 1]), - fillColor: new util.Color([(- 1), (- 1), (- 1)]), - opacity: 1, depth: 0, interpolate: true, + ori: 0.0, pos: [0, 0], + lineWidth: 1.0, lineColor: new util.Color('white'), + fillColor: new util.Color('white'), + opacity: undefined, depth: 0, interpolate: true, }); + trackingTxt = new visual.TextStim({ + win: psychoJS.window, + name: 'trackingTxt', + text: 'Great! we are now tracking your eye movements! look around the screen to see how it works! \n\nPlease remember is important for you to keep your head still during the experiment. \n\nPress space to start', + font: 'Arial', + units: undefined, + pos: [0, 0], height: 0.05, wrapWidth: undefined, ori: 0.0, + color: new util.Color('white'), opacity: undefined, + depth: -1.0 + }); + + tracking_resp = new core.Keyboard({psychoJS: psychoJS, clock: new util.Clock(), waitForStart: true}); + // Create some handy timers globalClock = new util.Clock(); // to track the time since experiment started routineTimer = new util.CountdownTimer(); // to track time remaining of each (non-slip) routine @@ -199,22 +209,43 @@ async function experimentInit() { var t; var frameN; var continueRoutine; -var loading_trialComponents; -function loading_trialRoutineBegin(snapshot) { +var initializeEyetrackingComponents; +function initializeEyetrackingRoutineBegin(snapshot) { return async function () { TrialHandler.fromSnapshot(snapshot); // ensure that .thisN vals are up to date - //------Prepare to start Routine 'loading_trial'------- + //------Prepare to start Routine 'initializeEyetracking'------- t = 0; - loading_trialClock.reset(); // clock + initializeEyetrackingClock.reset(); // clock frameN = -1; continueRoutine = true; // until we're told otherwise // update component parameters for each repeat + // Show webcam thumbnail and face feedback box, but not face overlay and gaze dot + window.webgazer.params.showVideoPreview = true; + window.webgazer.params.showFaceFeedbackBox = true; + window.webgazer.params.showFaceOverlay = false; + window.webgazer.params.showGazeDot = false + // Start eye tracking + window.webgazer + // Called on each eye tracking update + .setGazeListener(function(data, clock) { + if (data !== null) { + // Remove first element from gazes array, add current gaze at the end + window.xGazes.shift(); + window.xGazes.push(data.x); + window.yGazes.shift(); + window.yGazes.push(data.y); + } + }) + .begin(); + //.showPredictionPoints(true); + + // keep track of which components have finished - loading_trialComponents = []; - loading_trialComponents.push(loading_text); + initializeEyetrackingComponents = []; + initializeEyetrackingComponents.push(webcamWarning); - loading_trialComponents.forEach( function(thisComponent) { + initializeEyetrackingComponents.forEach( function(thisComponent) { if ('status' in thisComponent) thisComponent.status = PsychoJS.Status.NOT_STARTED; }); @@ -223,23 +254,26 @@ function loading_trialRoutineBegin(snapshot) { } -function loading_trialRoutineEachFrame() { +function initializeEyetrackingRoutineEachFrame() { return async function () { - //------Loop for each frame of Routine 'loading_trial'------- + //------Loop for each frame of Routine 'initializeEyetracking'------- // get current time - t = loading_trialClock.getTime(); + t = initializeEyetrackingClock.getTime(); frameN = frameN + 1;// number of completed frames (so 0 is the first frame) // update/draw components on each frame - // Continue once the webgazer global is available - continueRoutine = !window.hasOwnProperty('webgazer'); + // Finish routine once everything is ready + continueRoutine = + !window.webgazer.isReady() || + document.getElementById('webgazerFaceFeedbackBox') === null || + document.getElementById('webgazerVideoFeed') === null; - // *loading_text* updates - if (t >= 0.0 && loading_text.status === PsychoJS.Status.NOT_STARTED) { + // *webcamWarning* updates + if (t >= 0.0 && webcamWarning.status === PsychoJS.Status.NOT_STARTED) { // keep track of start time/frame for later - loading_text.tStart = t; // (not accounting for frame time here) - loading_text.frameNStart = frameN; // exact frame index + webcamWarning.tStart = t; // (not accounting for frame time here) + webcamWarning.frameNStart = frameN; // exact frame index - loading_text.setAutoDraw(true); + webcamWarning.setAutoDraw(true); } // check for quit (typically the Esc key) @@ -253,7 +287,7 @@ function loading_trialRoutineEachFrame() { } continueRoutine = false; // reverts to True if at least one component still running - loading_trialComponents.forEach( function(thisComponent) { + initializeEyetrackingComponents.forEach( function(thisComponent) { if ('status' in thisComponent && thisComponent.status !== PsychoJS.Status.FINISHED) { continueRoutine = true; } @@ -269,15 +303,15 @@ function loading_trialRoutineEachFrame() { } -function loading_trialRoutineEnd() { +function initializeEyetrackingRoutineEnd() { return async function () { - //------Ending Routine 'loading_trial'------- - loading_trialComponents.forEach( function(thisComponent) { + //------Ending Routine 'initializeEyetracking'------- + initializeEyetrackingComponents.forEach( function(thisComponent) { if (typeof thisComponent.setAutoDraw === 'function') { thisComponent.setAutoDraw(false); } }); - // the Routine "loading_trial" was not non-slip safe, so reset the non-slip timer + // the Routine "initializeEyetracking" was not non-slip safe, so reset the non-slip timer routineTimer.reset(); return Scheduler.Event.NEXT; @@ -285,42 +319,29 @@ function loading_trialRoutineEnd() { } -var webcam_trialComponents; -function webcam_trialRoutineBegin(snapshot) { +var _inst1_resp_allKeys; +var inst1Components; +function inst1RoutineBegin(snapshot) { return async function () { TrialHandler.fromSnapshot(snapshot); // ensure that .thisN vals are up to date - //------Prepare to start Routine 'webcam_trial'------- + //------Prepare to start Routine 'inst1'------- t = 0; - webcam_trialClock.reset(); // clock + inst1Clock.reset(); // clock frameN = -1; continueRoutine = true; // until we're told otherwise // update component parameters for each repeat - // Show webcam thumbnail and face feedback box, but not face overlay and gaze dot - window.webgazer.params.showVideoPreview = true; - window.webgazer.params.showFaceFeedbackBox = true; - window.webgazer.params.showFaceOverlay = false; - window.webgazer.params.showGazeDot = false - // Start eye tracking - window.webgazer - // Called on each eye tracking update - .setGazeListener(function(data, clock) { - if (data !== null) { - // Remove first element from gazes array, add current gaze at the end - window.xGazes.shift(); - window.xGazes.push(data.x); - window.yGazes.shift(); - window.yGazes.push(data.y); - } - }) - .begin(); - //.showPredictionPoints(true); - + document.getElementById('webgazerFaceFeedbackBox').style.display = 'none'; + document.getElementById('webgazerVideoFeed').style.display = 'none'; + inst1_resp.keys = undefined; + inst1_resp.rt = undefined; + _inst1_resp_allKeys = []; // keep track of which components have finished - webcam_trialComponents = []; - webcam_trialComponents.push(intro_text); + inst1Components = []; + inst1Components.push(instruction1Txt); + inst1Components.push(inst1_resp); - webcam_trialComponents.forEach( function(thisComponent) { + inst1Components.forEach( function(thisComponent) { if ('status' in thisComponent) thisComponent.status = PsychoJS.Status.NOT_STARTED; }); @@ -329,28 +350,47 @@ function webcam_trialRoutineBegin(snapshot) { } -function webcam_trialRoutineEachFrame() { +function inst1RoutineEachFrame() { return async function () { - //------Loop for each frame of Routine 'webcam_trial'------- + //------Loop for each frame of Routine 'inst1'------- // get current time - t = webcam_trialClock.getTime(); + t = inst1Clock.getTime(); frameN = frameN + 1;// number of completed frames (so 0 is the first frame) // update/draw components on each frame - // *intro_text* updates - if (t >= 0.0 && intro_text.status === PsychoJS.Status.NOT_STARTED) { + // *instruction1Txt* updates + if (t >= 0.0 && instruction1Txt.status === PsychoJS.Status.NOT_STARTED) { // keep track of start time/frame for later - intro_text.tStart = t; // (not accounting for frame time here) - intro_text.frameNStart = frameN; // exact frame index + instruction1Txt.tStart = t; // (not accounting for frame time here) + instruction1Txt.frameNStart = frameN; // exact frame index - intro_text.setAutoDraw(true); + instruction1Txt.setAutoDraw(true); } - // Finish routine once everything is ready - continueRoutine = - !window.webgazer.isReady() || - document.getElementById('webgazerFaceFeedbackBox') === null || - document.getElementById('webgazerVideoFeed') === null; + + // *inst1_resp* updates + if (t >= 0.0 && inst1_resp.status === PsychoJS.Status.NOT_STARTED) { + // keep track of start time/frame for later + inst1_resp.tStart = t; // (not accounting for frame time here) + inst1_resp.frameNStart = frameN; // exact frame index + + // keyboard checking is just starting + psychoJS.window.callOnFlip(function() { inst1_resp.clock.reset(); }); // t=0 on next screen flip + psychoJS.window.callOnFlip(function() { inst1_resp.start(); }); // start on screen flip + psychoJS.window.callOnFlip(function() { inst1_resp.clearEvents(); }); + } + + if (inst1_resp.status === PsychoJS.Status.STARTED) { + let theseKeys = inst1_resp.getKeys({keyList: ['space'], waitRelease: false}); + _inst1_resp_allKeys = _inst1_resp_allKeys.concat(theseKeys); + if (_inst1_resp_allKeys.length > 0) { + inst1_resp.keys = _inst1_resp_allKeys[_inst1_resp_allKeys.length - 1].name; // just the last key pressed + inst1_resp.rt = _inst1_resp_allKeys[_inst1_resp_allKeys.length - 1].rt; + // a response ends the routine + continueRoutine = false; + } + } + // check for quit (typically the Esc key) if (psychoJS.experiment.experimentEnded || psychoJS.eventManager.getKeys({keyList:['escape']}).length > 0) { return quitPsychoJS('The [Escape] key was pressed. Goodbye!', false); @@ -362,7 +402,7 @@ function webcam_trialRoutineEachFrame() { } continueRoutine = false; // reverts to True if at least one component still running - webcam_trialComponents.forEach( function(thisComponent) { + inst1Components.forEach( function(thisComponent) { if ('status' in thisComponent && thisComponent.status !== PsychoJS.Status.FINISHED) { continueRoutine = true; } @@ -378,15 +418,22 @@ function webcam_trialRoutineEachFrame() { } -function webcam_trialRoutineEnd() { +function inst1RoutineEnd() { return async function () { - //------Ending Routine 'webcam_trial'------- - webcam_trialComponents.forEach( function(thisComponent) { + //------Ending Routine 'inst1'------- + inst1Components.forEach( function(thisComponent) { if (typeof thisComponent.setAutoDraw === 'function') { thisComponent.setAutoDraw(false); } }); - // the Routine "webcam_trial" was not non-slip safe, so reset the non-slip timer + psychoJS.experiment.addData('inst1_resp.keys', inst1_resp.keys); + if (typeof inst1_resp.keys !== 'undefined') { // we had a response + psychoJS.experiment.addData('inst1_resp.rt', inst1_resp.rt); + routineTimer.reset(); + } + + inst1_resp.stop(); + // the Routine "inst1" was not non-slip safe, so reset the non-slip timer routineTimer.reset(); return Scheduler.Event.NEXT; @@ -395,25 +442,25 @@ function webcam_trialRoutineEnd() { var gotValidClick; -var intro_calibatrion_trialComponents; -function intro_calibatrion_trialRoutineBegin(snapshot) { +var calibrationIntroComponents; +function calibrationIntroRoutineBegin(snapshot) { return async function () { TrialHandler.fromSnapshot(snapshot); // ensure that .thisN vals are up to date - //------Prepare to start Routine 'intro_calibatrion_trial'------- + //------Prepare to start Routine 'calibrationIntro'------- t = 0; - intro_calibatrion_trialClock.reset(); // clock + calibrationIntroClock.reset(); // clock frameN = -1; continueRoutine = true; // until we're told otherwise // update component parameters for each repeat - // setup some python lists for storing info about the mouse_2 + // setup some python lists for storing info about the calibrationMouse gotValidClick = false; // until a click is received // keep track of which components have finished - intro_calibatrion_trialComponents = []; - intro_calibatrion_trialComponents.push(calibration_text); - intro_calibatrion_trialComponents.push(mouse_2); + calibrationIntroComponents = []; + calibrationIntroComponents.push(calibrationTxt); + calibrationIntroComponents.push(calibrationMouse); - intro_calibatrion_trialComponents.forEach( function(thisComponent) { + calibrationIntroComponents.forEach( function(thisComponent) { if ('status' in thisComponent) thisComponent.status = PsychoJS.Status.NOT_STARTED; }); @@ -424,35 +471,35 @@ function intro_calibatrion_trialRoutineBegin(snapshot) { var prevButtonState; var _mouseButtons; -function intro_calibatrion_trialRoutineEachFrame() { +function calibrationIntroRoutineEachFrame() { return async function () { - //------Loop for each frame of Routine 'intro_calibatrion_trial'------- + //------Loop for each frame of Routine 'calibrationIntro'------- // get current time - t = intro_calibatrion_trialClock.getTime(); + t = calibrationIntroClock.getTime(); frameN = frameN + 1;// number of completed frames (so 0 is the first frame) // update/draw components on each frame - // *calibration_text* updates - if (t >= 0.0 && calibration_text.status === PsychoJS.Status.NOT_STARTED) { + // *calibrationTxt* updates + if (t >= 0.0 && calibrationTxt.status === PsychoJS.Status.NOT_STARTED) { // keep track of start time/frame for later - calibration_text.tStart = t; // (not accounting for frame time here) - calibration_text.frameNStart = frameN; // exact frame index + calibrationTxt.tStart = t; // (not accounting for frame time here) + calibrationTxt.frameNStart = frameN; // exact frame index - calibration_text.setAutoDraw(true); + calibrationTxt.setAutoDraw(true); } - // *mouse_2* updates - if (t >= 0.0 && mouse_2.status === PsychoJS.Status.NOT_STARTED) { + // *calibrationMouse* updates + if (t >= 0.0 && calibrationMouse.status === PsychoJS.Status.NOT_STARTED) { // keep track of start time/frame for later - mouse_2.tStart = t; // (not accounting for frame time here) - mouse_2.frameNStart = frameN; // exact frame index + calibrationMouse.tStart = t; // (not accounting for frame time here) + calibrationMouse.frameNStart = frameN; // exact frame index - mouse_2.status = PsychoJS.Status.STARTED; - mouse_2.mouseClock.reset(); - prevButtonState = mouse_2.getPressed(); // if button is down already this ISN'T a new click + calibrationMouse.status = PsychoJS.Status.STARTED; + calibrationMouse.mouseClock.reset(); + prevButtonState = calibrationMouse.getPressed(); // if button is down already this ISN'T a new click } - if (mouse_2.status === PsychoJS.Status.STARTED) { // only update if started and not finished! - _mouseButtons = mouse_2.getPressed(); + if (calibrationMouse.status === PsychoJS.Status.STARTED) { // only update if started and not finished! + _mouseButtons = calibrationMouse.getPressed(); if (!_mouseButtons.every( (e,i,) => (e == prevButtonState[i]) )) { // button state changed? prevButtonState = _mouseButtons; if (_mouseButtons.reduce( (e, acc) => (e+acc) ) > 0) { // state changed to a new click @@ -472,7 +519,7 @@ function intro_calibatrion_trialRoutineEachFrame() { } continueRoutine = false; // reverts to True if at least one component still running - intro_calibatrion_trialComponents.forEach( function(thisComponent) { + calibrationIntroComponents.forEach( function(thisComponent) { if ('status' in thisComponent && thisComponent.status !== PsychoJS.Status.FINISHED) { continueRoutine = true; } @@ -489,23 +536,23 @@ function intro_calibatrion_trialRoutineEachFrame() { var _mouseXYs; -function intro_calibatrion_trialRoutineEnd() { +function calibrationIntroRoutineEnd() { return async function () { - //------Ending Routine 'intro_calibatrion_trial'------- - intro_calibatrion_trialComponents.forEach( function(thisComponent) { + //------Ending Routine 'calibrationIntro'------- + calibrationIntroComponents.forEach( function(thisComponent) { if (typeof thisComponent.setAutoDraw === 'function') { thisComponent.setAutoDraw(false); } }); // store data for psychoJS.experiment (ExperimentHandler) - _mouseXYs = mouse_2.getPos(); - _mouseButtons = mouse_2.getPressed(); - psychoJS.experiment.addData('mouse_2.x', _mouseXYs[0]); - psychoJS.experiment.addData('mouse_2.y', _mouseXYs[1]); - psychoJS.experiment.addData('mouse_2.leftButton', _mouseButtons[0]); - psychoJS.experiment.addData('mouse_2.midButton', _mouseButtons[1]); - psychoJS.experiment.addData('mouse_2.rightButton', _mouseButtons[2]); - // the Routine "intro_calibatrion_trial" was not non-slip safe, so reset the non-slip timer + _mouseXYs = calibrationMouse.getPos(); + _mouseButtons = calibrationMouse.getPressed(); + psychoJS.experiment.addData('calibrationMouse.x', _mouseXYs[0]); + psychoJS.experiment.addData('calibrationMouse.y', _mouseXYs[1]); + psychoJS.experiment.addData('calibrationMouse.leftButton', _mouseButtons[0]); + psychoJS.experiment.addData('calibrationMouse.midButton', _mouseButtons[1]); + psychoJS.experiment.addData('calibrationMouse.rightButton', _mouseButtons[2]); + // the Routine "calibrationIntro" was not non-slip safe, so reset the non-slip timer routineTimer.reset(); return Scheduler.Event.NEXT; @@ -535,9 +582,9 @@ function trialsLoopBegin(trialsLoopScheduler, snapshot) { const snapshot = trials.getSnapshot(); trialsLoopScheduler.add(importConditions(snapshot)); - trialsLoopScheduler.add(calibration_trialRoutineBegin(snapshot)); - trialsLoopScheduler.add(calibration_trialRoutineEachFrame()); - trialsLoopScheduler.add(calibration_trialRoutineEnd()); + trialsLoopScheduler.add(calibrationRoutineBegin(snapshot)); + trialsLoopScheduler.add(calibrationRoutineEachFrame()); + trialsLoopScheduler.add(calibrationRoutineEnd()); trialsLoopScheduler.add(endLoopIteration(trialsLoopScheduler, snapshot)); }); @@ -553,20 +600,19 @@ async function trialsLoopEnd() { } -var calibration_trialComponents; -function calibration_trialRoutineBegin(snapshot) { +var callib_color; +var calibrationComponents; +function calibrationRoutineBegin(snapshot) { return async function () { TrialHandler.fromSnapshot(snapshot); // ensure that .thisN vals are up to date - //------Prepare to start Routine 'calibration_trial'------- + //------Prepare to start Routine 'calibration'------- t = 0; - calibration_trialClock.reset(); // clock + calibrationClock.reset(); // clock frameN = -1; continueRoutine = true; // until we're told otherwise + routineTimer.add(3.500000); // update component parameters for each repeat - // setup some python lists for storing info about the mouse_3 - mouse_3.clicked_name = []; - gotValidClick = false; // until a click is received // Position calibration_square using width and height of window var canvas = psychoJS.window.size; var scaling = [ @@ -578,13 +624,18 @@ function calibration_trialRoutineBegin(snapshot) { calibration_y * scaling[1] ]; console.log(newPos); - calibration_square.setPos(newPos); + //calibration_square.setPos(newPos); + callib_color = 'white'; + calibration_square.setPos([calibration_x, calibration_y]); + // setup some python lists for storing info about the calibrationClick + calibrationClick.clicked_name = []; + gotValidClick = false; // until a click is received // keep track of which components have finished - calibration_trialComponents = []; - calibration_trialComponents.push(calibration_square); - calibration_trialComponents.push(mouse_3); + calibrationComponents = []; + calibrationComponents.push(calibration_square); + calibrationComponents.push(calibrationClick); - calibration_trialComponents.forEach( function(thisComponent) { + calibrationComponents.forEach( function(thisComponent) { if ('status' in thisComponent) thisComponent.status = PsychoJS.Status.NOT_STARTED; }); @@ -593,16 +644,46 @@ function calibration_trialRoutineBegin(snapshot) { } -function calibration_trialRoutineEachFrame() { +var frameRemains; +function calibrationRoutineEachFrame() { return async function () { - //------Loop for each frame of Routine 'calibration_trial'------- + //------Loop for each frame of Routine 'calibration'------- // get current time - t = calibration_trialClock.getTime(); + t = calibrationClock.getTime(); frameN = frameN + 1;// number of completed frames (so 0 is the first frame) // update/draw components on each frame + // returns type error - checking fix + + // Hide webcam thumbnail if eyes are in validation box + if (webgazer.checkEyesInValidationBox() === true) { + // If last time that eyes were outside of validation box was longer than + // window.eyesReturnedDelay ago, hide thumbnail + if ( + document.getElementById('webgazerFaceFeedbackBox').style.display != 'none' && + (new Date).getTime() > window.eyesExitedTimestamp + window.eyesReturnedDelay + ) { + document.getElementById('webgazerFaceFeedbackBox').style.display = 'none'; + document.getElementById('webgazerVideoFeed').style.display = 'none'; + } + } else { + // Eyes outside of validation box; show thumbnail + window.eyesExitedTimestamp = (new Date).getTime(); + document.getElementById('webgazerFaceFeedbackBox').style.display = 'block'; + document.getElementById('webgazerVideoFeed').style.display = 'block'; + } + + + if ( + document.getElementById('webgazerFaceFeedbackBox').style.display != 'none' && + (new Date).getTime() > window.eyesExitedTimestamp + window.eyesReturnedDelay + ) { + document.getElementById('webgazerFaceFeedbackBox').style.display = 'none'; + document.getElementById('webgazerVideoFeed').style.display = 'none'; + } + // *calibration_square* updates - if (t >= 0.0 && calibration_square.status === PsychoJS.Status.NOT_STARTED) { + if (t >= 0.5 && calibration_square.status === PsychoJS.Status.NOT_STARTED) { // keep track of start time/frame for later calibration_square.tStart = t; // (not accounting for frame time here) calibration_square.frameNStart = frameN; // exact frame index @@ -610,53 +691,46 @@ function calibration_trialRoutineEachFrame() { calibration_square.setAutoDraw(true); } - // *mouse_3* updates - if (t >= 0.0 && mouse_3.status === PsychoJS.Status.NOT_STARTED) { + frameRemains = 0.5 + 3 - psychoJS.window.monitorFramePeriod * 0.75; // most of one frame period left + if (calibration_square.status === PsychoJS.Status.STARTED && t >= frameRemains) { + calibration_square.setAutoDraw(false); + } + + if (calibration_square.status === PsychoJS.Status.STARTED){ // only update if being drawn + calibration_square.setFillColor(new util.Color(callib_color), false); + } + // *calibrationClick* updates + if (t >= 0.5 && calibrationClick.status === PsychoJS.Status.NOT_STARTED) { // keep track of start time/frame for later - mouse_3.tStart = t; // (not accounting for frame time here) - mouse_3.frameNStart = frameN; // exact frame index + calibrationClick.tStart = t; // (not accounting for frame time here) + calibrationClick.frameNStart = frameN; // exact frame index - mouse_3.status = PsychoJS.Status.STARTED; - mouse_3.mouseClock.reset(); - prevButtonState = mouse_3.getPressed(); // if button is down already this ISN'T a new click + calibrationClick.status = PsychoJS.Status.STARTED; + calibrationClick.mouseClock.reset(); + prevButtonState = calibrationClick.getPressed(); // if button is down already this ISN'T a new click } - if (mouse_3.status === PsychoJS.Status.STARTED) { // only update if started and not finished! - _mouseButtons = mouse_3.getPressed(); + frameRemains = 0.5 + 3 - psychoJS.window.monitorFramePeriod * 0.75; // most of one frame period left + if (calibrationClick.status === PsychoJS.Status.STARTED && t >= frameRemains) { + calibrationClick.status = PsychoJS.Status.FINISHED; + } + if (calibrationClick.status === PsychoJS.Status.STARTED) { // only update if started and not finished! + _mouseButtons = calibrationClick.getPressed(); if (!_mouseButtons.every( (e,i,) => (e == prevButtonState[i]) )) { // button state changed? prevButtonState = _mouseButtons; if (_mouseButtons.reduce( (e, acc) => (e+acc) ) > 0) { // state changed to a new click // check if the mouse was inside our 'clickable' objects gotValidClick = false; for (const obj of [calibration_square]) { - if (obj.contains(mouse_3)) { + if (obj.contains(calibrationClick)) { gotValidClick = true; - mouse_3.clicked_name.push(obj.name) + calibrationClick.clicked_name.push(obj.name) } } - if (gotValidClick === true) { // abort routine on response - continueRoutine = false; - } + // abort routine on response + continueRoutine = false; } } } - // Hide webcam thumbnail if eyes are in validation box - if (webgazer.checkEyesInValidationBox() === true) { - // If last time that eyes were outside of validation box was longer than - // window.eyesReturnedDelay ago, hide thumbnail - if ( - document.getElementById('webgazerFaceFeedbackBox').style.display != 'none' && - (new Date).getTime() > window.eyesExitedTimestamp + window.eyesReturnedDelay - ) { - document.getElementById('webgazerFaceFeedbackBox').style.display = 'none'; - document.getElementById('webgazerVideoFeed').style.display = 'none'; - } - } else { - // Eyes outside of validation box; show thumbnail - window.eyesExitedTimestamp = (new Date).getTime(); - document.getElementById('webgazerFaceFeedbackBox').style.display = 'block'; - document.getElementById('webgazerVideoFeed').style.display = 'block'; - } - // check for quit (typically the Esc key) if (psychoJS.experiment.experimentEnded || psychoJS.eventManager.getKeys({keyList:['escape']}).length > 0) { return quitPsychoJS('The [Escape] key was pressed. Goodbye!', false); @@ -668,14 +742,14 @@ function calibration_trialRoutineEachFrame() { } continueRoutine = false; // reverts to True if at least one component still running - calibration_trialComponents.forEach( function(thisComponent) { + calibrationComponents.forEach( function(thisComponent) { if ('status' in thisComponent && thisComponent.status !== PsychoJS.Status.FINISHED) { continueRoutine = true; } }); // refresh the screen if continuing - if (continueRoutine) { + if (continueRoutine && routineTimer.getTime() > 0) { return Scheduler.Event.FLIP_REPEAT; } else { return Scheduler.Event.NEXT; @@ -684,50 +758,57 @@ function calibration_trialRoutineEachFrame() { } -function calibration_trialRoutineEnd() { +function calibrationRoutineEnd() { return async function () { - //------Ending Routine 'calibration_trial'------- - calibration_trialComponents.forEach( function(thisComponent) { + //------Ending Routine 'calibration'------- + calibrationComponents.forEach( function(thisComponent) { if (typeof thisComponent.setAutoDraw === 'function') { thisComponent.setAutoDraw(false); } }); // store data for psychoJS.experiment (ExperimentHandler) - _mouseXYs = mouse_3.getPos(); - _mouseButtons = mouse_3.getPressed(); - psychoJS.experiment.addData('mouse_3.x', _mouseXYs[0]); - psychoJS.experiment.addData('mouse_3.y', _mouseXYs[1]); - psychoJS.experiment.addData('mouse_3.leftButton', _mouseButtons[0]); - psychoJS.experiment.addData('mouse_3.midButton', _mouseButtons[1]); - psychoJS.experiment.addData('mouse_3.rightButton', _mouseButtons[2]); - if (mouse_3.clicked_name.length > 0) { - psychoJS.experiment.addData('mouse_3.clicked_name', mouse_3.clicked_name[0]);} - // the Routine "calibration_trial" was not non-slip safe, so reset the non-slip timer - routineTimer.reset(); - + _mouseXYs = calibrationClick.getPos(); + _mouseButtons = calibrationClick.getPressed(); + psychoJS.experiment.addData('calibrationClick.x', _mouseXYs[0]); + psychoJS.experiment.addData('calibrationClick.y', _mouseXYs[1]); + psychoJS.experiment.addData('calibrationClick.leftButton', _mouseButtons[0]); + psychoJS.experiment.addData('calibrationClick.midButton', _mouseButtons[1]); + psychoJS.experiment.addData('calibrationClick.rightButton', _mouseButtons[2]); + if (calibrationClick.clicked_name.length > 0) { + psychoJS.experiment.addData('calibrationClick.clicked_name', calibrationClick.clicked_name[0]);} return Scheduler.Event.NEXT; }; } -var tracking_trialComponents; -function tracking_trialRoutineBegin(snapshot) { +var _tracking_resp_allKeys; +var trackingTrialComponents; +function trackingTrialRoutineBegin(snapshot) { return async function () { TrialHandler.fromSnapshot(snapshot); // ensure that .thisN vals are up to date - //------Prepare to start Routine 'tracking_trial'------- + //------Prepare to start Routine 'trackingTrial'------- t = 0; - tracking_trialClock.reset(); // clock + trackingTrialClock.reset(); // clock frameN = -1; continueRoutine = true; // until we're told otherwise // update component parameters for each repeat + tracking_resp.keys = undefined; + tracking_resp.rt = undefined; + _tracking_resp_allKeys = []; // Remove the click tracker used for calibration window.webgazer.removeMouseEventListeners(); + + //hide the video thumbnail + document.getElementById('webgazerFaceFeedbackBox').style.display = 'none'; + document.getElementById('webgazerVideoFeed').style.display = 'none'; // keep track of which components have finished - tracking_trialComponents = []; - tracking_trialComponents.push(tracking_square); + trackingTrialComponents = []; + trackingTrialComponents.push(tracking_square); + trackingTrialComponents.push(trackingTxt); + trackingTrialComponents.push(tracking_resp); - tracking_trialComponents.forEach( function(thisComponent) { + trackingTrialComponents.forEach( function(thisComponent) { if ('status' in thisComponent) thisComponent.status = PsychoJS.Status.NOT_STARTED; }); @@ -736,11 +817,11 @@ function tracking_trialRoutineBegin(snapshot) { } -function tracking_trialRoutineEachFrame() { +function trackingTrialRoutineEachFrame() { return async function () { - //------Loop for each frame of Routine 'tracking_trial'------- + //------Loop for each frame of Routine 'trackingTrial'------- // get current time - t = tracking_trialClock.getTime(); + t = trackingTrialClock.getTime(); frameN = frameN + 1;// number of completed frames (so 0 is the first frame) // update/draw components on each frame @@ -753,6 +834,40 @@ function tracking_trialRoutineEachFrame() { tracking_square.setAutoDraw(true); } + + // *trackingTxt* updates + if (t >= 0.0 && trackingTxt.status === PsychoJS.Status.NOT_STARTED) { + // keep track of start time/frame for later + trackingTxt.tStart = t; // (not accounting for frame time here) + trackingTxt.frameNStart = frameN; // exact frame index + + trackingTxt.setAutoDraw(true); + } + + + // *tracking_resp* updates + if (t >= 0.0 && tracking_resp.status === PsychoJS.Status.NOT_STARTED) { + // keep track of start time/frame for later + tracking_resp.tStart = t; // (not accounting for frame time here) + tracking_resp.frameNStart = frameN; // exact frame index + + // keyboard checking is just starting + psychoJS.window.callOnFlip(function() { tracking_resp.clock.reset(); }); // t=0 on next screen flip + psychoJS.window.callOnFlip(function() { tracking_resp.start(); }); // start on screen flip + psychoJS.window.callOnFlip(function() { tracking_resp.clearEvents(); }); + } + + if (tracking_resp.status === PsychoJS.Status.STARTED) { + let theseKeys = tracking_resp.getKeys({keyList: ['space'], waitRelease: false}); + _tracking_resp_allKeys = _tracking_resp_allKeys.concat(theseKeys); + if (_tracking_resp_allKeys.length > 0) { + tracking_resp.keys = _tracking_resp_allKeys[_tracking_resp_allKeys.length - 1].name; // just the last key pressed + tracking_resp.rt = _tracking_resp_allKeys[_tracking_resp_allKeys.length - 1].rt; + // a response ends the routine + continueRoutine = false; + } + } + // Hide webcam thumbnail if eyes are in validation box if (webgazer.checkEyesInValidationBox() === true) { // If last time that eyes were outside of validation box was longer than @@ -784,7 +899,6 @@ function tracking_trialRoutineEachFrame() { psychoJS.window ) ); - // check for quit (typically the Esc key) if (psychoJS.experiment.experimentEnded || psychoJS.eventManager.getKeys({keyList:['escape']}).length > 0) { return quitPsychoJS('The [Escape] key was pressed. Goodbye!', false); @@ -796,7 +910,7 @@ function tracking_trialRoutineEachFrame() { } continueRoutine = false; // reverts to True if at least one component still running - tracking_trialComponents.forEach( function(thisComponent) { + trackingTrialComponents.forEach( function(thisComponent) { if ('status' in thisComponent && thisComponent.status !== PsychoJS.Status.FINISHED) { continueRoutine = true; } @@ -812,15 +926,22 @@ function tracking_trialRoutineEachFrame() { } -function tracking_trialRoutineEnd() { +function trackingTrialRoutineEnd() { return async function () { - //------Ending Routine 'tracking_trial'------- - tracking_trialComponents.forEach( function(thisComponent) { + //------Ending Routine 'trackingTrial'------- + trackingTrialComponents.forEach( function(thisComponent) { if (typeof thisComponent.setAutoDraw === 'function') { thisComponent.setAutoDraw(false); } }); - // the Routine "tracking_trial" was not non-slip safe, so reset the non-slip timer + psychoJS.experiment.addData('tracking_resp.keys', tracking_resp.keys); + if (typeof tracking_resp.keys !== 'undefined') { // we had a response + psychoJS.experiment.addData('tracking_resp.rt', tracking_resp.rt); + routineTimer.reset(); + } + + tracking_resp.stop(); + // the Routine "trackingTrial" was not non-slip safe, so reset the non-slip timer routineTimer.reset(); return Scheduler.Event.NEXT; @@ -872,6 +993,8 @@ async function quitPsychoJS(message, isCompleted) { + + psychoJS.window.close(); psychoJS.quit({message: message, isCompleted: isCompleted}); diff --git a/demo_eye_tracking2.js b/demo_eye_tracking2.js index 7bca6ca1eee262c11c9a598c18c109c0a7406274..0f97f96aba4b49a4e51014fc7941686cba7e5329 100644 --- a/demo_eye_tracking2.js +++ b/demo_eye_tracking2.js @@ -24,7 +24,7 @@ const psychoJS = new PsychoJS({ // open window: psychoJS.openWindow({ fullscr: true, - color: new util.Color([0, 0, 0]), + color: new util.Color([(- 1), (- 1), (- 1)]), units: 'height', waitBlanking: true }); @@ -41,22 +41,22 @@ psychoJS.scheduleCondition(function() { return (psychoJS.gui.dialogComponent.but // flowScheduler gets run if the participants presses OK flowScheduler.add(updateInfo); // add timeStamp flowScheduler.add(experimentInit); -flowScheduler.add(loading_trialRoutineBegin()); -flowScheduler.add(loading_trialRoutineEachFrame()); -flowScheduler.add(loading_trialRoutineEnd()); -flowScheduler.add(webcam_trialRoutineBegin()); -flowScheduler.add(webcam_trialRoutineEachFrame()); -flowScheduler.add(webcam_trialRoutineEnd()); -flowScheduler.add(intro_calibatrion_trialRoutineBegin()); -flowScheduler.add(intro_calibatrion_trialRoutineEachFrame()); -flowScheduler.add(intro_calibatrion_trialRoutineEnd()); +flowScheduler.add(initializeEyetrackingRoutineBegin()); +flowScheduler.add(initializeEyetrackingRoutineEachFrame()); +flowScheduler.add(initializeEyetrackingRoutineEnd()); +flowScheduler.add(inst1RoutineBegin()); +flowScheduler.add(inst1RoutineEachFrame()); +flowScheduler.add(inst1RoutineEnd()); +flowScheduler.add(calibrationIntroRoutineBegin()); +flowScheduler.add(calibrationIntroRoutineEachFrame()); +flowScheduler.add(calibrationIntroRoutineEnd()); const trialsLoopScheduler = new Scheduler(psychoJS); flowScheduler.add(trialsLoopBegin(trialsLoopScheduler)); flowScheduler.add(trialsLoopScheduler); flowScheduler.add(trialsLoopEnd); -flowScheduler.add(tracking_trialRoutineBegin()); -flowScheduler.add(tracking_trialRoutineEachFrame()); -flowScheduler.add(tracking_trialRoutineEnd()); +flowScheduler.add(trackingTrialRoutineBegin()); +flowScheduler.add(trackingTrialRoutineEachFrame()); +flowScheduler.add(trackingTrialRoutineEnd()); flowScheduler.add(quitPsychoJS, '', true); // quit if user presses Cancel in dialog box: @@ -66,7 +66,8 @@ psychoJS.start({ expName: expName, expInfo: expInfo, resources: [ - {'name': 'calibration_trials.xlsx', 'path': 'calibration_trials.xlsx'} + {'name': 'calibration_trials.xlsx', 'path': 'calibration_trials.xlsx'}, + {'name': 'webgazer-2.0.1.tp.js', 'path': 'webgazer-2.0.1.tp.js'} ] }); @@ -94,108 +95,117 @@ async function updateInfo() { } -var loading_trialClock; -var loading_text; -var webcam_trialClock; -var intro_text; -var intro_calibatrion_trialClock; -var calibration_text; -var mouse_2; -var calibration_trialClock; +var initializeEyetrackingClock; +var webcamWarning; +var inst1Clock; +var instruction1Txt; +var inst1_resp; +var calibrationIntroClock; +var calibrationTxt; +var calibrationMouse; +var calibrationClock; var calibration_square; -var mouse_3; -var tracking_trialClock; +var calibrationClick; +var trackingTrialClock; var tracking_square; +var trackingTxt; +var tracking_resp; var globalClock; var routineTimer; async function experimentInit() { - // Initialize components for Routine "loading_trial" - loading_trialClock = new util.Clock(); - // Download the webgazer library and re-download seedrandom.js (since webgazer - // overrides it with a version that conflicts with PsychoJS) - psychoJS.downloadResources([ - { name: 'webgazer.js', path: 'js/webgazer-2.0.1.tp.js' }, - { name: 'seedrandom.js', path: 'https://cdnjs.cloudflare.com/ajax/libs/seedrandom/3.0.1/seedrandom.min.js' } - ]); + // Initialize components for Routine "initializeEyetracking" + initializeEyetrackingClock = new util.Clock(); + //initialize params of the webgazer package (used for eye tracking) + // Initialize x and y arrays; we use these to calculate running averages of // current gaze position; the longer the window, the slower, but more fluent // the updates let averagingWindow = 10; window.xGazes = new Array(averagingWindow ).fill(0); window.yGazes = new Array(averagingWindow ).fill(0); - // Timestamp for last time eyes exited validation box - window.eyesExitedTimestamp= (new Date).getTime(); - // No. of ms to keep webcam thumbnail visible after eyes returned into validation box - window.eyesReturnedDelay = 3000; - // DEBUG - window.psychoJS = psychoJS; - loading_text = new visual.TextStim({ + + webcamWarning = new visual.TextStim({ win: psychoJS.window, - name: 'loading_text', - text: 'Downloading additional resources. \n\nOne moment please...', + name: 'webcamWarning', + text: 'This experiment uses eye tracking. \n\nYou should see your web-browser request access to your webcam. You might need to click on this text to make that happen. Please permit access, and wait a little while. Your webcam video should appear in the top-left of the screen.', font: 'Arial', units: undefined, - pos: [0, 0], height: 0.1, wrapWidth: undefined, ori: 0, - color: new util.Color('white'), opacity: 1, + pos: [0, 0], height: 0.05, wrapWidth: undefined, ori: 0.0, + color: new util.Color('white'), opacity: undefined, depth: -1.0 }); - // Initialize components for Routine "webcam_trial" - webcam_trialClock = new util.Clock(); - intro_text = new visual.TextStim({ + // Initialize components for Routine "inst1" + inst1Clock = new util.Clock(); + instruction1Txt = new visual.TextStim({ win: psychoJS.window, - name: 'intro_text', - text: 'demo_eye_tracking: starting webcam\n\nThis experiment demonstrates eye tracking via the webgazer library. \n\nYou should see your web-browser request access to your webcam. You might need to click on this text to make that happen. Please permit access, and wait a little while. Your webcam video should appear in the top-left of the screen.', + name: 'instruction1Txt', + text: 'Webgazer initialized. \nPress space to move on', font: 'Arial', units: undefined, - pos: [0, 0], height: 0.04, wrapWidth: undefined, ori: 0, - color: new util.Color('white'), opacity: 1, - depth: 0.0 + pos: [0, 0], height: 0.05, wrapWidth: undefined, ori: 0.0, + color: new util.Color('white'), opacity: undefined, + depth: -1.0 }); - // Initialize components for Routine "intro_calibatrion_trial" - intro_calibatrion_trialClock = new util.Clock(); - calibration_text = new visual.TextStim({ + inst1_resp = new core.Keyboard({psychoJS: psychoJS, clock: new util.Clock(), waitForStart: true}); + + // Initialize components for Routine "calibrationIntro" + calibrationIntroClock = new util.Clock(); + calibrationTxt = new visual.TextStim({ win: psychoJS.window, - name: 'calibration_text', - text: "demo_eye_tracking: calibration\n\nNow we'll calibrate the eye tracker. Please try to keep your head still and within the rectangle you see in your webcam video. When you do so, the rectangle turns green.\n\nIn the next part of this experiment, the webcam video disappears. It will reappear when your head is too from the rectangle. If this happens, please move back into view. White squares appears at different locations on the screen. Please click each square with your mouse.\n\nClick anywhere to continue...", + name: 'calibrationTxt', + text: "OK great! we are almost ready to get started. \n\nFirst we need to calibrate the eye tracker. Please try to keep your head still. If you move your head too far away, you'r webcam will appear in the top left corner. If this happens, please move back into view. \n\nWhite squares will appear at different locations on the screen. Please click each square with your mouse.\n\nClick anywhere with the mouse to continue...", font: 'Arial', units: undefined, - pos: [0, 0], height: 0.04, wrapWidth: undefined, ori: 0, - color: new util.Color('white'), opacity: 1, + pos: [0, 0], height: 0.05, wrapWidth: undefined, ori: 0.0, + color: new util.Color('white'), opacity: undefined, depth: 0.0 }); - mouse_2 = new core.Mouse({ + calibrationMouse = new core.Mouse({ win: psychoJS.window, }); - mouse_2.mouseClock = new util.Clock(); - // Initialize components for Routine "calibration_trial" - calibration_trialClock = new util.Clock(); + calibrationMouse.mouseClock = new util.Clock(); + // Initialize components for Routine "calibration" + calibrationClock = new util.Clock(); calibration_square = new visual.Rect ({ win: psychoJS.window, name: 'calibration_square', - width: [0.022, 0.022][0], height: [0.022, 0.022][1], - ori: 0, pos: [0, 0], - lineWidth: 0, lineColor: new util.Color([1, 1, 1]), - fillColor: new util.Color([1, 1, 1]), - opacity: 1, depth: 0, interpolate: true, + width: [0.02, 0.02][0], height: [0.02, 0.02][1], + ori: 0.0, pos: [0, 0], + lineWidth: 1.0, lineColor: new util.Color('white'), + fillColor: new util.Color('white'), + opacity: undefined, depth: -1, interpolate: true, }); - mouse_3 = new core.Mouse({ + calibrationClick = new core.Mouse({ win: psychoJS.window, }); - mouse_3.mouseClock = new util.Clock(); - // Initialize components for Routine "tracking_trial" - tracking_trialClock = new util.Clock(); + calibrationClick.mouseClock = new util.Clock(); + // Initialize components for Routine "trackingTrial" + trackingTrialClock = new util.Clock(); tracking_square = new visual.Rect ({ win: psychoJS.window, name: 'tracking_square', width: [0.02, 0.02][0], height: [0.02, 0.02][1], - ori: 0, pos: [0, 0], - lineWidth: undefined, lineColor: new util.Color([1, 1, 1]), - fillColor: new util.Color([(- 1), (- 1), (- 1)]), - opacity: 1, depth: 0, interpolate: true, + ori: 0.0, pos: [0, 0], + lineWidth: 1.0, lineColor: new util.Color('white'), + fillColor: new util.Color('white'), + opacity: undefined, depth: 0, interpolate: true, }); + trackingTxt = new visual.TextStim({ + win: psychoJS.window, + name: 'trackingTxt', + text: 'Great! we are now tracking your eye movements! look around the screen to see how it works! \n\nPlease remember is important for you to keep your head still during the experiment. \n\nPress space to start', + font: 'Arial', + units: undefined, + pos: [0, 0], height: 0.05, wrapWidth: undefined, ori: 0.0, + color: new util.Color('white'), opacity: undefined, + depth: -1.0 + }); + + tracking_resp = new core.Keyboard({psychoJS: psychoJS, clock: new util.Clock(), waitForStart: true}); + // Create some handy timers globalClock = new util.Clock(); // to track the time since experiment started routineTimer = new util.CountdownTimer(); // to track time remaining of each (non-slip) routine @@ -207,22 +217,43 @@ async function experimentInit() { var t; var frameN; var continueRoutine; -var loading_trialComponents; -function loading_trialRoutineBegin(snapshot) { +var initializeEyetrackingComponents; +function initializeEyetrackingRoutineBegin(snapshot) { return async function () { TrialHandler.fromSnapshot(snapshot); // ensure that .thisN vals are up to date - //------Prepare to start Routine 'loading_trial'------- + //------Prepare to start Routine 'initializeEyetracking'------- t = 0; - loading_trialClock.reset(); // clock + initializeEyetrackingClock.reset(); // clock frameN = -1; continueRoutine = true; // until we're told otherwise // update component parameters for each repeat + // Show webcam thumbnail and face feedback box, but not face overlay and gaze dot + window.webgazer.params.showVideoPreview = true; + window.webgazer.params.showFaceFeedbackBox = true; + window.webgazer.params.showFaceOverlay = false; + window.webgazer.params.showGazeDot = false + // Start eye tracking + window.webgazer + // Called on each eye tracking update + .setGazeListener(function(data, clock) { + if (data !== null) { + // Remove first element from gazes array, add current gaze at the end + window.xGazes.shift(); + window.xGazes.push(data.x); + window.yGazes.shift(); + window.yGazes.push(data.y); + } + }) + .begin(); + //.showPredictionPoints(true); + + // keep track of which components have finished - loading_trialComponents = []; - loading_trialComponents.push(loading_text); + initializeEyetrackingComponents = []; + initializeEyetrackingComponents.push(webcamWarning); - for (const thisComponent of loading_trialComponents) + for (const thisComponent of initializeEyetrackingComponents) if ('status' in thisComponent) thisComponent.status = PsychoJS.Status.NOT_STARTED; return Scheduler.Event.NEXT; @@ -230,23 +261,26 @@ function loading_trialRoutineBegin(snapshot) { } -function loading_trialRoutineEachFrame() { +function initializeEyetrackingRoutineEachFrame() { return async function () { - //------Loop for each frame of Routine 'loading_trial'------- + //------Loop for each frame of Routine 'initializeEyetracking'------- // get current time - t = loading_trialClock.getTime(); + t = initializeEyetrackingClock.getTime(); frameN = frameN + 1;// number of completed frames (so 0 is the first frame) // update/draw components on each frame - // Continue once the webgazer global is available - continueRoutine = !window.hasOwnProperty('webgazer'); + // Finish routine once everything is ready + continueRoutine = + !window.webgazer.isReady() || + document.getElementById('webgazerFaceFeedbackBox') === null || + document.getElementById('webgazerVideoFeed') === null; - // *loading_text* updates - if (t >= 0.0 && loading_text.status === PsychoJS.Status.NOT_STARTED) { + // *webcamWarning* updates + if (t >= 0.0 && webcamWarning.status === PsychoJS.Status.NOT_STARTED) { // keep track of start time/frame for later - loading_text.tStart = t; // (not accounting for frame time here) - loading_text.frameNStart = frameN; // exact frame index + webcamWarning.tStart = t; // (not accounting for frame time here) + webcamWarning.frameNStart = frameN; // exact frame index - loading_text.setAutoDraw(true); + webcamWarning.setAutoDraw(true); } // check for quit (typically the Esc key) @@ -260,7 +294,7 @@ function loading_trialRoutineEachFrame() { } continueRoutine = false; // reverts to True if at least one component still running - for (const thisComponent of loading_trialComponents) + for (const thisComponent of initializeEyetrackingComponents) if ('status' in thisComponent && thisComponent.status !== PsychoJS.Status.FINISHED) { continueRoutine = true; break; @@ -276,15 +310,15 @@ function loading_trialRoutineEachFrame() { } -function loading_trialRoutineEnd() { +function initializeEyetrackingRoutineEnd() { return async function () { - //------Ending Routine 'loading_trial'------- - for (const thisComponent of loading_trialComponents) { + //------Ending Routine 'initializeEyetracking'------- + for (const thisComponent of initializeEyetrackingComponents) { if (typeof thisComponent.setAutoDraw === 'function') { thisComponent.setAutoDraw(false); } } - // the Routine "loading_trial" was not non-slip safe, so reset the non-slip timer + // the Routine "initializeEyetracking" was not non-slip safe, so reset the non-slip timer routineTimer.reset(); return Scheduler.Event.NEXT; @@ -292,42 +326,29 @@ function loading_trialRoutineEnd() { } -var webcam_trialComponents; -function webcam_trialRoutineBegin(snapshot) { +var _inst1_resp_allKeys; +var inst1Components; +function inst1RoutineBegin(snapshot) { return async function () { TrialHandler.fromSnapshot(snapshot); // ensure that .thisN vals are up to date - //------Prepare to start Routine 'webcam_trial'------- + //------Prepare to start Routine 'inst1'------- t = 0; - webcam_trialClock.reset(); // clock + inst1Clock.reset(); // clock frameN = -1; continueRoutine = true; // until we're told otherwise // update component parameters for each repeat - // Show webcam thumbnail and face feedback box, but not face overlay and gaze dot - window.webgazer.params.showVideoPreview = true; - window.webgazer.params.showFaceFeedbackBox = true; - window.webgazer.params.showFaceOverlay = false; - window.webgazer.params.showGazeDot = false - // Start eye tracking - window.webgazer - // Called on each eye tracking update - .setGazeListener(function(data, clock) { - if (data !== null) { - // Remove first element from gazes array, add current gaze at the end - window.xGazes.shift(); - window.xGazes.push(data.x); - window.yGazes.shift(); - window.yGazes.push(data.y); - } - }) - .begin(); - //.showPredictionPoints(true); - + document.getElementById('webgazerFaceFeedbackBox').style.display = 'none'; + document.getElementById('webgazerVideoFeed').style.display = 'none'; + inst1_resp.keys = undefined; + inst1_resp.rt = undefined; + _inst1_resp_allKeys = []; // keep track of which components have finished - webcam_trialComponents = []; - webcam_trialComponents.push(intro_text); + inst1Components = []; + inst1Components.push(instruction1Txt); + inst1Components.push(inst1_resp); - for (const thisComponent of webcam_trialComponents) + for (const thisComponent of inst1Components) if ('status' in thisComponent) thisComponent.status = PsychoJS.Status.NOT_STARTED; return Scheduler.Event.NEXT; @@ -335,28 +356,47 @@ function webcam_trialRoutineBegin(snapshot) { } -function webcam_trialRoutineEachFrame() { +function inst1RoutineEachFrame() { return async function () { - //------Loop for each frame of Routine 'webcam_trial'------- + //------Loop for each frame of Routine 'inst1'------- // get current time - t = webcam_trialClock.getTime(); + t = inst1Clock.getTime(); frameN = frameN + 1;// number of completed frames (so 0 is the first frame) // update/draw components on each frame - // *intro_text* updates - if (t >= 0.0 && intro_text.status === PsychoJS.Status.NOT_STARTED) { + // *instruction1Txt* updates + if (t >= 0.0 && instruction1Txt.status === PsychoJS.Status.NOT_STARTED) { // keep track of start time/frame for later - intro_text.tStart = t; // (not accounting for frame time here) - intro_text.frameNStart = frameN; // exact frame index + instruction1Txt.tStart = t; // (not accounting for frame time here) + instruction1Txt.frameNStart = frameN; // exact frame index - intro_text.setAutoDraw(true); + instruction1Txt.setAutoDraw(true); } - // Finish routine once everything is ready - continueRoutine = - !window.webgazer.isReady() || - document.getElementById('webgazerFaceFeedbackBox') === null || - document.getElementById('webgazerVideoFeed') === null; + + // *inst1_resp* updates + if (t >= 0.0 && inst1_resp.status === PsychoJS.Status.NOT_STARTED) { + // keep track of start time/frame for later + inst1_resp.tStart = t; // (not accounting for frame time here) + inst1_resp.frameNStart = frameN; // exact frame index + + // keyboard checking is just starting + psychoJS.window.callOnFlip(function() { inst1_resp.clock.reset(); }); // t=0 on next screen flip + psychoJS.window.callOnFlip(function() { inst1_resp.start(); }); // start on screen flip + psychoJS.window.callOnFlip(function() { inst1_resp.clearEvents(); }); + } + + if (inst1_resp.status === PsychoJS.Status.STARTED) { + let theseKeys = inst1_resp.getKeys({keyList: ['space'], waitRelease: false}); + _inst1_resp_allKeys = _inst1_resp_allKeys.concat(theseKeys); + if (_inst1_resp_allKeys.length > 0) { + inst1_resp.keys = _inst1_resp_allKeys[_inst1_resp_allKeys.length - 1].name; // just the last key pressed + inst1_resp.rt = _inst1_resp_allKeys[_inst1_resp_allKeys.length - 1].rt; + // a response ends the routine + continueRoutine = false; + } + } + // check for quit (typically the Esc key) if (psychoJS.experiment.experimentEnded || psychoJS.eventManager.getKeys({keyList:['escape']}).length > 0) { return quitPsychoJS('The [Escape] key was pressed. Goodbye!', false); @@ -368,7 +408,7 @@ function webcam_trialRoutineEachFrame() { } continueRoutine = false; // reverts to True if at least one component still running - for (const thisComponent of webcam_trialComponents) + for (const thisComponent of inst1Components) if ('status' in thisComponent && thisComponent.status !== PsychoJS.Status.FINISHED) { continueRoutine = true; break; @@ -384,15 +424,22 @@ function webcam_trialRoutineEachFrame() { } -function webcam_trialRoutineEnd() { +function inst1RoutineEnd() { return async function () { - //------Ending Routine 'webcam_trial'------- - for (const thisComponent of webcam_trialComponents) { + //------Ending Routine 'inst1'------- + for (const thisComponent of inst1Components) { if (typeof thisComponent.setAutoDraw === 'function') { thisComponent.setAutoDraw(false); } } - // the Routine "webcam_trial" was not non-slip safe, so reset the non-slip timer + psychoJS.experiment.addData('inst1_resp.keys', inst1_resp.keys); + if (typeof inst1_resp.keys !== 'undefined') { // we had a response + psychoJS.experiment.addData('inst1_resp.rt', inst1_resp.rt); + routineTimer.reset(); + } + + inst1_resp.stop(); + // the Routine "inst1" was not non-slip safe, so reset the non-slip timer routineTimer.reset(); return Scheduler.Event.NEXT; @@ -401,25 +448,25 @@ function webcam_trialRoutineEnd() { var gotValidClick; -var intro_calibatrion_trialComponents; -function intro_calibatrion_trialRoutineBegin(snapshot) { +var calibrationIntroComponents; +function calibrationIntroRoutineBegin(snapshot) { return async function () { TrialHandler.fromSnapshot(snapshot); // ensure that .thisN vals are up to date - //------Prepare to start Routine 'intro_calibatrion_trial'------- + //------Prepare to start Routine 'calibrationIntro'------- t = 0; - intro_calibatrion_trialClock.reset(); // clock + calibrationIntroClock.reset(); // clock frameN = -1; continueRoutine = true; // until we're told otherwise // update component parameters for each repeat - // setup some python lists for storing info about the mouse_2 + // setup some python lists for storing info about the calibrationMouse gotValidClick = false; // until a click is received // keep track of which components have finished - intro_calibatrion_trialComponents = []; - intro_calibatrion_trialComponents.push(calibration_text); - intro_calibatrion_trialComponents.push(mouse_2); + calibrationIntroComponents = []; + calibrationIntroComponents.push(calibrationTxt); + calibrationIntroComponents.push(calibrationMouse); - for (const thisComponent of intro_calibatrion_trialComponents) + for (const thisComponent of calibrationIntroComponents) if ('status' in thisComponent) thisComponent.status = PsychoJS.Status.NOT_STARTED; return Scheduler.Event.NEXT; @@ -429,35 +476,35 @@ function intro_calibatrion_trialRoutineBegin(snapshot) { var prevButtonState; var _mouseButtons; -function intro_calibatrion_trialRoutineEachFrame() { +function calibrationIntroRoutineEachFrame() { return async function () { - //------Loop for each frame of Routine 'intro_calibatrion_trial'------- + //------Loop for each frame of Routine 'calibrationIntro'------- // get current time - t = intro_calibatrion_trialClock.getTime(); + t = calibrationIntroClock.getTime(); frameN = frameN + 1;// number of completed frames (so 0 is the first frame) // update/draw components on each frame - // *calibration_text* updates - if (t >= 0.0 && calibration_text.status === PsychoJS.Status.NOT_STARTED) { + // *calibrationTxt* updates + if (t >= 0.0 && calibrationTxt.status === PsychoJS.Status.NOT_STARTED) { // keep track of start time/frame for later - calibration_text.tStart = t; // (not accounting for frame time here) - calibration_text.frameNStart = frameN; // exact frame index + calibrationTxt.tStart = t; // (not accounting for frame time here) + calibrationTxt.frameNStart = frameN; // exact frame index - calibration_text.setAutoDraw(true); + calibrationTxt.setAutoDraw(true); } - // *mouse_2* updates - if (t >= 0.0 && mouse_2.status === PsychoJS.Status.NOT_STARTED) { + // *calibrationMouse* updates + if (t >= 0.0 && calibrationMouse.status === PsychoJS.Status.NOT_STARTED) { // keep track of start time/frame for later - mouse_2.tStart = t; // (not accounting for frame time here) - mouse_2.frameNStart = frameN; // exact frame index + calibrationMouse.tStart = t; // (not accounting for frame time here) + calibrationMouse.frameNStart = frameN; // exact frame index - mouse_2.status = PsychoJS.Status.STARTED; - mouse_2.mouseClock.reset(); - prevButtonState = mouse_2.getPressed(); // if button is down already this ISN'T a new click + calibrationMouse.status = PsychoJS.Status.STARTED; + calibrationMouse.mouseClock.reset(); + prevButtonState = calibrationMouse.getPressed(); // if button is down already this ISN'T a new click } - if (mouse_2.status === PsychoJS.Status.STARTED) { // only update if started and not finished! - _mouseButtons = mouse_2.getPressed(); + if (calibrationMouse.status === PsychoJS.Status.STARTED) { // only update if started and not finished! + _mouseButtons = calibrationMouse.getPressed(); if (!_mouseButtons.every( (e,i,) => (e == prevButtonState[i]) )) { // button state changed? prevButtonState = _mouseButtons; if (_mouseButtons.reduce( (e, acc) => (e+acc) ) > 0) { // state changed to a new click @@ -477,7 +524,7 @@ function intro_calibatrion_trialRoutineEachFrame() { } continueRoutine = false; // reverts to True if at least one component still running - for (const thisComponent of intro_calibatrion_trialComponents) + for (const thisComponent of calibrationIntroComponents) if ('status' in thisComponent && thisComponent.status !== PsychoJS.Status.FINISHED) { continueRoutine = true; break; @@ -494,23 +541,23 @@ function intro_calibatrion_trialRoutineEachFrame() { var _mouseXYs; -function intro_calibatrion_trialRoutineEnd() { +function calibrationIntroRoutineEnd() { return async function () { - //------Ending Routine 'intro_calibatrion_trial'------- - for (const thisComponent of intro_calibatrion_trialComponents) { + //------Ending Routine 'calibrationIntro'------- + for (const thisComponent of calibrationIntroComponents) { if (typeof thisComponent.setAutoDraw === 'function') { thisComponent.setAutoDraw(false); } } // store data for psychoJS.experiment (ExperimentHandler) - _mouseXYs = mouse_2.getPos(); - _mouseButtons = mouse_2.getPressed(); - psychoJS.experiment.addData('mouse_2.x', _mouseXYs[0]); - psychoJS.experiment.addData('mouse_2.y', _mouseXYs[1]); - psychoJS.experiment.addData('mouse_2.leftButton', _mouseButtons[0]); - psychoJS.experiment.addData('mouse_2.midButton', _mouseButtons[1]); - psychoJS.experiment.addData('mouse_2.rightButton', _mouseButtons[2]); - // the Routine "intro_calibatrion_trial" was not non-slip safe, so reset the non-slip timer + _mouseXYs = calibrationMouse.getPos(); + _mouseButtons = calibrationMouse.getPressed(); + psychoJS.experiment.addData('calibrationMouse.x', _mouseXYs[0]); + psychoJS.experiment.addData('calibrationMouse.y', _mouseXYs[1]); + psychoJS.experiment.addData('calibrationMouse.leftButton', _mouseButtons[0]); + psychoJS.experiment.addData('calibrationMouse.midButton', _mouseButtons[1]); + psychoJS.experiment.addData('calibrationMouse.rightButton', _mouseButtons[2]); + // the Routine "calibrationIntro" was not non-slip safe, so reset the non-slip timer routineTimer.reset(); return Scheduler.Event.NEXT; @@ -539,9 +586,9 @@ function trialsLoopBegin(trialsLoopScheduler, snapshot) { for (const thisTrial of trials) { const snapshot = trials.getSnapshot(); trialsLoopScheduler.add(importConditions(snapshot)); - trialsLoopScheduler.add(calibration_trialRoutineBegin(snapshot)); - trialsLoopScheduler.add(calibration_trialRoutineEachFrame()); - trialsLoopScheduler.add(calibration_trialRoutineEnd()); + trialsLoopScheduler.add(calibrationRoutineBegin(snapshot)); + trialsLoopScheduler.add(calibrationRoutineEachFrame()); + trialsLoopScheduler.add(calibrationRoutineEnd()); trialsLoopScheduler.add(endLoopIteration(trialsLoopScheduler, snapshot)); } @@ -557,20 +604,19 @@ async function trialsLoopEnd() { } -var calibration_trialComponents; -function calibration_trialRoutineBegin(snapshot) { +var callib_color; +var calibrationComponents; +function calibrationRoutineBegin(snapshot) { return async function () { TrialHandler.fromSnapshot(snapshot); // ensure that .thisN vals are up to date - //------Prepare to start Routine 'calibration_trial'------- + //------Prepare to start Routine 'calibration'------- t = 0; - calibration_trialClock.reset(); // clock + calibrationClock.reset(); // clock frameN = -1; continueRoutine = true; // until we're told otherwise + routineTimer.add(3.500000); // update component parameters for each repeat - // setup some python lists for storing info about the mouse_3 - mouse_3.clicked_name = []; - gotValidClick = false; // until a click is received // Position calibration_square using width and height of window var canvas = psychoJS.window.size; var scaling = [ @@ -582,13 +628,18 @@ function calibration_trialRoutineBegin(snapshot) { calibration_y * scaling[1] ]; console.log(newPos); - calibration_square.setPos(newPos); + //calibration_square.setPos(newPos); + callib_color = 'white'; + calibration_square.setPos([calibration_x, calibration_y]); + // setup some python lists for storing info about the calibrationClick + calibrationClick.clicked_name = []; + gotValidClick = false; // until a click is received // keep track of which components have finished - calibration_trialComponents = []; - calibration_trialComponents.push(calibration_square); - calibration_trialComponents.push(mouse_3); + calibrationComponents = []; + calibrationComponents.push(calibration_square); + calibrationComponents.push(calibrationClick); - for (const thisComponent of calibration_trialComponents) + for (const thisComponent of calibrationComponents) if ('status' in thisComponent) thisComponent.status = PsychoJS.Status.NOT_STARTED; return Scheduler.Event.NEXT; @@ -596,16 +647,46 @@ function calibration_trialRoutineBegin(snapshot) { } -function calibration_trialRoutineEachFrame() { +var frameRemains; +function calibrationRoutineEachFrame() { return async function () { - //------Loop for each frame of Routine 'calibration_trial'------- + //------Loop for each frame of Routine 'calibration'------- // get current time - t = calibration_trialClock.getTime(); + t = calibrationClock.getTime(); frameN = frameN + 1;// number of completed frames (so 0 is the first frame) // update/draw components on each frame + // returns type error - checking fix + + // Hide webcam thumbnail if eyes are in validation box + if (webgazer.checkEyesInValidationBox() === true) { + // If last time that eyes were outside of validation box was longer than + // window.eyesReturnedDelay ago, hide thumbnail + if ( + document.getElementById('webgazerFaceFeedbackBox').style.display != 'none' && + (new Date).getTime() > window.eyesExitedTimestamp + window.eyesReturnedDelay + ) { + document.getElementById('webgazerFaceFeedbackBox').style.display = 'none'; + document.getElementById('webgazerVideoFeed').style.display = 'none'; + } + } else { + // Eyes outside of validation box; show thumbnail + window.eyesExitedTimestamp = (new Date).getTime(); + document.getElementById('webgazerFaceFeedbackBox').style.display = 'block'; + document.getElementById('webgazerVideoFeed').style.display = 'block'; + } + + + if ( + document.getElementById('webgazerFaceFeedbackBox').style.display != 'none' && + (new Date).getTime() > window.eyesExitedTimestamp + window.eyesReturnedDelay + ) { + document.getElementById('webgazerFaceFeedbackBox').style.display = 'none'; + document.getElementById('webgazerVideoFeed').style.display = 'none'; + } + // *calibration_square* updates - if (t >= 0.0 && calibration_square.status === PsychoJS.Status.NOT_STARTED) { + if (t >= 0.5 && calibration_square.status === PsychoJS.Status.NOT_STARTED) { // keep track of start time/frame for later calibration_square.tStart = t; // (not accounting for frame time here) calibration_square.frameNStart = frameN; // exact frame index @@ -613,53 +694,46 @@ function calibration_trialRoutineEachFrame() { calibration_square.setAutoDraw(true); } - // *mouse_3* updates - if (t >= 0.0 && mouse_3.status === PsychoJS.Status.NOT_STARTED) { + frameRemains = 0.5 + 3 - psychoJS.window.monitorFramePeriod * 0.75; // most of one frame period left + if (calibration_square.status === PsychoJS.Status.STARTED && t >= frameRemains) { + calibration_square.setAutoDraw(false); + } + + if (calibration_square.status === PsychoJS.Status.STARTED){ // only update if being drawn + calibration_square.setFillColor(new util.Color(callib_color), false); + } + // *calibrationClick* updates + if (t >= 0.5 && calibrationClick.status === PsychoJS.Status.NOT_STARTED) { // keep track of start time/frame for later - mouse_3.tStart = t; // (not accounting for frame time here) - mouse_3.frameNStart = frameN; // exact frame index + calibrationClick.tStart = t; // (not accounting for frame time here) + calibrationClick.frameNStart = frameN; // exact frame index - mouse_3.status = PsychoJS.Status.STARTED; - mouse_3.mouseClock.reset(); - prevButtonState = mouse_3.getPressed(); // if button is down already this ISN'T a new click + calibrationClick.status = PsychoJS.Status.STARTED; + calibrationClick.mouseClock.reset(); + prevButtonState = calibrationClick.getPressed(); // if button is down already this ISN'T a new click } - if (mouse_3.status === PsychoJS.Status.STARTED) { // only update if started and not finished! - _mouseButtons = mouse_3.getPressed(); + frameRemains = 0.5 + 3 - psychoJS.window.monitorFramePeriod * 0.75; // most of one frame period left + if (calibrationClick.status === PsychoJS.Status.STARTED && t >= frameRemains) { + calibrationClick.status = PsychoJS.Status.FINISHED; + } + if (calibrationClick.status === PsychoJS.Status.STARTED) { // only update if started and not finished! + _mouseButtons = calibrationClick.getPressed(); if (!_mouseButtons.every( (e,i,) => (e == prevButtonState[i]) )) { // button state changed? prevButtonState = _mouseButtons; if (_mouseButtons.reduce( (e, acc) => (e+acc) ) > 0) { // state changed to a new click // check if the mouse was inside our 'clickable' objects gotValidClick = false; for (const obj of [calibration_square]) { - if (obj.contains(mouse_3)) { + if (obj.contains(calibrationClick)) { gotValidClick = true; - mouse_3.clicked_name.push(obj.name) + calibrationClick.clicked_name.push(obj.name) } } - if (gotValidClick === true) { // abort routine on response - continueRoutine = false; - } + // abort routine on response + continueRoutine = false; } } } - // Hide webcam thumbnail if eyes are in validation box - if (webgazer.checkEyesInValidationBox() === true) { - // If last time that eyes were outside of validation box was longer than - // window.eyesReturnedDelay ago, hide thumbnail - if ( - document.getElementById('webgazerFaceFeedbackBox').style.display != 'none' && - (new Date).getTime() > window.eyesExitedTimestamp + window.eyesReturnedDelay - ) { - document.getElementById('webgazerFaceFeedbackBox').style.display = 'none'; - document.getElementById('webgazerVideoFeed').style.display = 'none'; - } - } else { - // Eyes outside of validation box; show thumbnail - window.eyesExitedTimestamp = (new Date).getTime(); - document.getElementById('webgazerFaceFeedbackBox').style.display = 'block'; - document.getElementById('webgazerVideoFeed').style.display = 'block'; - } - // check for quit (typically the Esc key) if (psychoJS.experiment.experimentEnded || psychoJS.eventManager.getKeys({keyList:['escape']}).length > 0) { return quitPsychoJS('The [Escape] key was pressed. Goodbye!', false); @@ -671,14 +745,14 @@ function calibration_trialRoutineEachFrame() { } continueRoutine = false; // reverts to True if at least one component still running - for (const thisComponent of calibration_trialComponents) + for (const thisComponent of calibrationComponents) if ('status' in thisComponent && thisComponent.status !== PsychoJS.Status.FINISHED) { continueRoutine = true; break; } // refresh the screen if continuing - if (continueRoutine) { + if (continueRoutine && routineTimer.getTime() > 0) { return Scheduler.Event.FLIP_REPEAT; } else { return Scheduler.Event.NEXT; @@ -687,50 +761,57 @@ function calibration_trialRoutineEachFrame() { } -function calibration_trialRoutineEnd() { +function calibrationRoutineEnd() { return async function () { - //------Ending Routine 'calibration_trial'------- - for (const thisComponent of calibration_trialComponents) { + //------Ending Routine 'calibration'------- + for (const thisComponent of calibrationComponents) { if (typeof thisComponent.setAutoDraw === 'function') { thisComponent.setAutoDraw(false); } } // store data for psychoJS.experiment (ExperimentHandler) - _mouseXYs = mouse_3.getPos(); - _mouseButtons = mouse_3.getPressed(); - psychoJS.experiment.addData('mouse_3.x', _mouseXYs[0]); - psychoJS.experiment.addData('mouse_3.y', _mouseXYs[1]); - psychoJS.experiment.addData('mouse_3.leftButton', _mouseButtons[0]); - psychoJS.experiment.addData('mouse_3.midButton', _mouseButtons[1]); - psychoJS.experiment.addData('mouse_3.rightButton', _mouseButtons[2]); - if (mouse_3.clicked_name.length > 0) { - psychoJS.experiment.addData('mouse_3.clicked_name', mouse_3.clicked_name[0]);} - // the Routine "calibration_trial" was not non-slip safe, so reset the non-slip timer - routineTimer.reset(); - + _mouseXYs = calibrationClick.getPos(); + _mouseButtons = calibrationClick.getPressed(); + psychoJS.experiment.addData('calibrationClick.x', _mouseXYs[0]); + psychoJS.experiment.addData('calibrationClick.y', _mouseXYs[1]); + psychoJS.experiment.addData('calibrationClick.leftButton', _mouseButtons[0]); + psychoJS.experiment.addData('calibrationClick.midButton', _mouseButtons[1]); + psychoJS.experiment.addData('calibrationClick.rightButton', _mouseButtons[2]); + if (calibrationClick.clicked_name.length > 0) { + psychoJS.experiment.addData('calibrationClick.clicked_name', calibrationClick.clicked_name[0]);} return Scheduler.Event.NEXT; }; } -var tracking_trialComponents; -function tracking_trialRoutineBegin(snapshot) { +var _tracking_resp_allKeys; +var trackingTrialComponents; +function trackingTrialRoutineBegin(snapshot) { return async function () { TrialHandler.fromSnapshot(snapshot); // ensure that .thisN vals are up to date - //------Prepare to start Routine 'tracking_trial'------- + //------Prepare to start Routine 'trackingTrial'------- t = 0; - tracking_trialClock.reset(); // clock + trackingTrialClock.reset(); // clock frameN = -1; continueRoutine = true; // until we're told otherwise // update component parameters for each repeat + tracking_resp.keys = undefined; + tracking_resp.rt = undefined; + _tracking_resp_allKeys = []; // Remove the click tracker used for calibration window.webgazer.removeMouseEventListeners(); + + //hide the video thumbnail + document.getElementById('webgazerFaceFeedbackBox').style.display = 'none'; + document.getElementById('webgazerVideoFeed').style.display = 'none'; // keep track of which components have finished - tracking_trialComponents = []; - tracking_trialComponents.push(tracking_square); + trackingTrialComponents = []; + trackingTrialComponents.push(tracking_square); + trackingTrialComponents.push(trackingTxt); + trackingTrialComponents.push(tracking_resp); - for (const thisComponent of tracking_trialComponents) + for (const thisComponent of trackingTrialComponents) if ('status' in thisComponent) thisComponent.status = PsychoJS.Status.NOT_STARTED; return Scheduler.Event.NEXT; @@ -738,11 +819,11 @@ function tracking_trialRoutineBegin(snapshot) { } -function tracking_trialRoutineEachFrame() { +function trackingTrialRoutineEachFrame() { return async function () { - //------Loop for each frame of Routine 'tracking_trial'------- + //------Loop for each frame of Routine 'trackingTrial'------- // get current time - t = tracking_trialClock.getTime(); + t = trackingTrialClock.getTime(); frameN = frameN + 1;// number of completed frames (so 0 is the first frame) // update/draw components on each frame @@ -755,6 +836,40 @@ function tracking_trialRoutineEachFrame() { tracking_square.setAutoDraw(true); } + + // *trackingTxt* updates + if (t >= 0.0 && trackingTxt.status === PsychoJS.Status.NOT_STARTED) { + // keep track of start time/frame for later + trackingTxt.tStart = t; // (not accounting for frame time here) + trackingTxt.frameNStart = frameN; // exact frame index + + trackingTxt.setAutoDraw(true); + } + + + // *tracking_resp* updates + if (t >= 0.0 && tracking_resp.status === PsychoJS.Status.NOT_STARTED) { + // keep track of start time/frame for later + tracking_resp.tStart = t; // (not accounting for frame time here) + tracking_resp.frameNStart = frameN; // exact frame index + + // keyboard checking is just starting + psychoJS.window.callOnFlip(function() { tracking_resp.clock.reset(); }); // t=0 on next screen flip + psychoJS.window.callOnFlip(function() { tracking_resp.start(); }); // start on screen flip + psychoJS.window.callOnFlip(function() { tracking_resp.clearEvents(); }); + } + + if (tracking_resp.status === PsychoJS.Status.STARTED) { + let theseKeys = tracking_resp.getKeys({keyList: ['space'], waitRelease: false}); + _tracking_resp_allKeys = _tracking_resp_allKeys.concat(theseKeys); + if (_tracking_resp_allKeys.length > 0) { + tracking_resp.keys = _tracking_resp_allKeys[_tracking_resp_allKeys.length - 1].name; // just the last key pressed + tracking_resp.rt = _tracking_resp_allKeys[_tracking_resp_allKeys.length - 1].rt; + // a response ends the routine + continueRoutine = false; + } + } + // Hide webcam thumbnail if eyes are in validation box if (webgazer.checkEyesInValidationBox() === true) { // If last time that eyes were outside of validation box was longer than @@ -786,7 +901,6 @@ function tracking_trialRoutineEachFrame() { psychoJS.window ) ); - // check for quit (typically the Esc key) if (psychoJS.experiment.experimentEnded || psychoJS.eventManager.getKeys({keyList:['escape']}).length > 0) { return quitPsychoJS('The [Escape] key was pressed. Goodbye!', false); @@ -798,7 +912,7 @@ function tracking_trialRoutineEachFrame() { } continueRoutine = false; // reverts to True if at least one component still running - for (const thisComponent of tracking_trialComponents) + for (const thisComponent of trackingTrialComponents) if ('status' in thisComponent && thisComponent.status !== PsychoJS.Status.FINISHED) { continueRoutine = true; break; @@ -814,15 +928,22 @@ function tracking_trialRoutineEachFrame() { } -function tracking_trialRoutineEnd() { +function trackingTrialRoutineEnd() { return async function () { - //------Ending Routine 'tracking_trial'------- - for (const thisComponent of tracking_trialComponents) { + //------Ending Routine 'trackingTrial'------- + for (const thisComponent of trackingTrialComponents) { if (typeof thisComponent.setAutoDraw === 'function') { thisComponent.setAutoDraw(false); } } - // the Routine "tracking_trial" was not non-slip safe, so reset the non-slip timer + psychoJS.experiment.addData('tracking_resp.keys', tracking_resp.keys); + if (typeof tracking_resp.keys !== 'undefined') { // we had a response + psychoJS.experiment.addData('tracking_resp.rt', tracking_resp.rt); + routineTimer.reset(); + } + + tracking_resp.stop(); + // the Routine "trackingTrial" was not non-slip safe, so reset the non-slip timer routineTimer.reset(); return Scheduler.Event.NEXT; @@ -874,6 +995,8 @@ async function quitPsychoJS(message, isCompleted) { + + psychoJS.window.close(); psychoJS.quit({message: message, isCompleted: isCompleted}); diff --git a/demo_eye_tracking2.psyexp b/demo_eye_tracking2.psyexp index 49fa129930b062a25b9497da15137ecc83f41b02..a4167787fda546d13d42dcda8fac3b7bfbee1a22 100644 --- a/demo_eye_tracking2.psyexp +++ b/demo_eye_tracking2.psyexp @@ -7,13 +7,13 @@ - + - + @@ -25,9 +25,9 @@ - + - + @@ -45,7 +45,7 @@ - + @@ -54,50 +54,6 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - @@ -142,21 +98,38 @@ - - - + + + + + + + + + + + + + + + + + + + + - + - - - - - + + + + + @@ -164,19 +137,64 @@ - + - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - + + - + - + @@ -187,91 +205,148 @@ - + + + + + + + + + + + + + + + + + + - - + - - - + + - - - + + + - - + + - + - + - - + + - - + + - + - + - + - + - + + + + + + + + + + + + + + + + + + + + - + - + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - + - - - + + - - - + + + - - + + @@ -281,16 +356,59 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - + @@ -301,12 +419,12 @@ - - - + + + - + @@ -315,8 +433,8 @@ - + - + diff --git a/js/webgazer-1.7.3.js b/js/webgazer-1.7.3.js deleted file mode 100644 index a523cf03f2eaa6fac1d464863d2c80a9d3ca4823..0000000000000000000000000000000000000000 --- a/js/webgazer-1.7.3.js +++ /dev/null @@ -1,45936 +0,0 @@ -/** WebGazer.js: Scalable Webcam EyeTracking Using User Interactions - * - * Copyright (c) 2016-2020, Brown HCI Group - -* Licensed under GPLv3. Companies with a valuation of less than $10M can use WebGazer.js under LGPLv3. -*/ - -/*! - localForage -- Offline Storage, Improved - Version 1.7.3 - https://localforage.github.io/localForage - (c) 2013-2017 Mozilla, Apache License 2.0 -*/ -!function(a){if("object"==typeof exports&&"undefined"!=typeof module)module.exports=a();else if("function"==typeof define&&define.amd)define([],a);else{var b;b="undefined"!=typeof window?window:"undefined"!=typeof global?global:"undefined"!=typeof self?self:this,b.localforage=a()}}(function(){return function a(b,c,d){function e(g,h){if(!c[g]){if(!b[g]){var i="function"==typeof require&&require;if(!h&&i)return i(g,!0);if(f)return f(g,!0);var j=new Error("Cannot find module '"+g+"'");throw j.code="MODULE_NOT_FOUND",j}var k=c[g]={exports:{}};b[g][0].call(k.exports,function(a){var c=b[g][1][a];return e(c||a)},k,k.exports,a,b,c,d)}return c[g].exports}for(var f="function"==typeof require&&require,g=0;g=43)}}).catch(function(){return!1})}function n(a){return"boolean"==typeof xa?va.resolve(xa):m(a).then(function(a){return xa=a})}function o(a){var b=ya[a.name],c={};c.promise=new va(function(a,b){c.resolve=a,c.reject=b}),b.deferredOperations.push(c),b.dbReady?b.dbReady=b.dbReady.then(function(){return c.promise}):b.dbReady=c.promise}function p(a){var b=ya[a.name],c=b.deferredOperations.pop();if(c)return c.resolve(),c.promise}function q(a,b){var c=ya[a.name],d=c.deferredOperations.pop();if(d)return d.reject(b),d.promise}function r(a,b){return new va(function(c,d){if(ya[a.name]=ya[a.name]||B(),a.db){if(!b)return c(a.db);o(a),a.db.close()}var e=[a.name];b&&e.push(a.version);var f=ua.open.apply(ua,e);b&&(f.onupgradeneeded=function(b){var c=f.result;try{c.createObjectStore(a.storeName),b.oldVersion<=1&&c.createObjectStore(wa)}catch(c){if("ConstraintError"!==c.name)throw c;console.warn('The database "'+a.name+'" has been upgraded from version '+b.oldVersion+" to version "+b.newVersion+', but the storage "'+a.storeName+'" already exists.')}}),f.onerror=function(a){a.preventDefault(),d(f.error)},f.onsuccess=function(){c(f.result),p(a)}})}function s(a){return r(a,!1)}function t(a){return r(a,!0)}function u(a,b){if(!a.db)return!0;var c=!a.db.objectStoreNames.contains(a.storeName),d=a.versiona.db.version;if(d&&(a.version!==b&&console.warn('The database "'+a.name+"\" can't be downgraded from version "+a.db.version+" to version "+a.version+"."),a.version=a.db.version),e||c){if(c){var f=a.db.version+1;f>a.version&&(a.version=f)}return!0}return!1}function v(a){return new va(function(b,c){var d=new FileReader;d.onerror=c,d.onloadend=function(c){var d=btoa(c.target.result||"");b({__local_forage_encoded_blob:!0,data:d,type:a.type})},d.readAsBinaryString(a)})}function w(a){return g([l(atob(a.data))],{type:a.type})}function x(a){return a&&a.__local_forage_encoded_blob}function y(a){var b=this,c=b._initReady().then(function(){var a=ya[b._dbInfo.name];if(a&&a.dbReady)return a.dbReady});return i(c,a,a),c}function z(a){o(a);for(var b=ya[a.name],c=b.forages,d=0;d0&&(!a.db||"InvalidStateError"===e.name||"NotFoundError"===e.name))return va.resolve().then(function(){if(!a.db||"NotFoundError"===e.name&&!a.db.objectStoreNames.contains(a.storeName)&&a.version<=a.db.version)return a.db&&(a.version=a.db.version+1),t(a)}).then(function(){return z(a).then(function(){A(a,b,c,d-1)})}).catch(c);c(e)}}function B(){return{forages:[],db:null,dbReady:null,deferredOperations:[]}}function C(a){function b(){return va.resolve()}var c=this,d={db:null};if(a)for(var e in a)d[e]=a[e];var f=ya[d.name];f||(f=B(),ya[d.name]=f),f.forages.push(c),c._initReady||(c._initReady=c.ready,c.ready=y);for(var g=[],h=0;h>4,k[i++]=(15&d)<<4|e>>2,k[i++]=(3&e)<<6|63&f;return j}function O(a){var b,c=new Uint8Array(a),d="";for(b=0;b>2],d+=Da[(3&c[b])<<4|c[b+1]>>4],d+=Da[(15&c[b+1])<<2|c[b+2]>>6],d+=Da[63&c[b+2]];return c.length%3==2?d=d.substring(0,d.length-1)+"=":c.length%3==1&&(d=d.substring(0,d.length-2)+"=="),d}function P(a,b){var c="";if(a&&(c=Ua.call(a)),a&&("[object ArrayBuffer]"===c||a.buffer&&"[object ArrayBuffer]"===Ua.call(a.buffer))){var d,e=Ga;a instanceof ArrayBuffer?(d=a,e+=Ia):(d=a.buffer,"[object Int8Array]"===c?e+=Ka:"[object Uint8Array]"===c?e+=La:"[object Uint8ClampedArray]"===c?e+=Ma:"[object Int16Array]"===c?e+=Na:"[object Uint16Array]"===c?e+=Pa:"[object Int32Array]"===c?e+=Oa:"[object Uint32Array]"===c?e+=Qa:"[object Float32Array]"===c?e+=Ra:"[object Float64Array]"===c?e+=Sa:b(new Error("Failed to get type for BinaryArray"))),b(e+O(d))}else if("[object Blob]"===c){var f=new FileReader;f.onload=function(){var c=Ea+a.type+"~"+O(this.result);b(Ga+Ja+c)},f.readAsArrayBuffer(a)}else try{b(JSON.stringify(a))}catch(c){console.error("Couldn't convert value into a JSON string: ",a),b(null,c)}}function Q(a){if(a.substring(0,Ha)!==Ga)return JSON.parse(a);var b,c=a.substring(Ta),d=a.substring(Ha,Ta);if(d===Ja&&Fa.test(c)){var e=c.match(Fa);b=e[1],c=c.substring(e[0].length)}var f=N(c);switch(d){case Ia:return f;case Ja:return g([f],{type:b});case Ka:return new Int8Array(f);case La:return new Uint8Array(f);case Ma:return new Uint8ClampedArray(f);case Na:return new Int16Array(f);case Pa:return new Uint16Array(f);case Oa:return new Int32Array(f);case Qa:return new Uint32Array(f);case Ra:return new Float32Array(f);case Sa:return new Float64Array(f);default:throw new Error("Unkown type: "+d)}}function R(a,b,c,d){a.executeSql("CREATE TABLE IF NOT EXISTS "+b.storeName+" (id INTEGER PRIMARY KEY, key unique, value)",[],c,d)}function S(a){var b=this,c={db:null};if(a)for(var d in a)c[d]="string"!=typeof a[d]?a[d].toString():a[d];var e=new va(function(a,d){try{c.db=openDatabase(c.name,String(c.version),c.description,c.size)}catch(a){return d(a)}c.db.transaction(function(e){R(e,c,function(){b._dbInfo=c,a()},function(a,b){d(b)})},d)});return c.serializer=Va,e}function T(a,b,c,d,e,f){a.executeSql(c,d,e,function(a,g){g.code===g.SYNTAX_ERR?a.executeSql("SELECT name FROM sqlite_master WHERE type='table' AND name = ?",[b.storeName],function(a,h){h.rows.length?f(a,g):R(a,b,function(){a.executeSql(c,d,e,f)},f)},f):f(a,g)},f)}function U(a,b){var c=this;a=j(a);var d=new va(function(b,d){c.ready().then(function(){var e=c._dbInfo;e.db.transaction(function(c){T(c,e,"SELECT * FROM "+e.storeName+" WHERE key = ? LIMIT 1",[a],function(a,c){var d=c.rows.length?c.rows.item(0).value:null;d&&(d=e.serializer.deserialize(d)),b(d)},function(a,b){d(b)})})}).catch(d)});return h(d,b),d}function V(a,b){var c=this,d=new va(function(b,d){c.ready().then(function(){var e=c._dbInfo;e.db.transaction(function(c){T(c,e,"SELECT * FROM "+e.storeName,[],function(c,d){for(var f=d.rows,g=f.length,h=0;h0)return void f(W.apply(e,[a,h,c,d-1]));g(b)}})})}).catch(g)});return h(f,c),f}function X(a,b,c){return W.apply(this,[a,b,c,1])}function Y(a,b){var c=this;a=j(a);var d=new va(function(b,d){c.ready().then(function(){var e=c._dbInfo;e.db.transaction(function(c){T(c,e,"DELETE FROM "+e.storeName+" WHERE key = ?",[a],function(){b()},function(a,b){d(b)})})}).catch(d)});return h(d,b),d}function Z(a){var b=this,c=new va(function(a,c){b.ready().then(function(){var d=b._dbInfo;d.db.transaction(function(b){T(b,d,"DELETE FROM "+d.storeName,[],function(){a()},function(a,b){c(b)})})}).catch(c)});return h(c,a),c}function $(a){var b=this,c=new va(function(a,c){b.ready().then(function(){var d=b._dbInfo;d.db.transaction(function(b){T(b,d,"SELECT COUNT(key) as c FROM "+d.storeName,[],function(b,c){var d=c.rows.item(0).c;a(d)},function(a,b){c(b)})})}).catch(c)});return h(c,a),c}function _(a,b){var c=this,d=new va(function(b,d){c.ready().then(function(){var e=c._dbInfo;e.db.transaction(function(c){T(c,e,"SELECT key FROM "+e.storeName+" WHERE id = ? LIMIT 1",[a+1],function(a,c){var d=c.rows.length?c.rows.item(0).key:null;b(d)},function(a,b){d(b)})})}).catch(d)});return h(d,b),d}function aa(a){var b=this,c=new va(function(a,c){b.ready().then(function(){var d=b._dbInfo;d.db.transaction(function(b){T(b,d,"SELECT key FROM "+d.storeName,[],function(b,c){for(var d=[],e=0;e '__WebKitDatabaseInfoTable__'",[],function(c,d){for(var e=[],f=0;f0}function ha(a){var b=this,c={};if(a)for(var d in a)c[d]=a[d];return c.keyPrefix=ea(a,b._defaultConfig),ga()?(b._dbInfo=c,c.serializer=Va,va.resolve()):va.reject()}function ia(a){var b=this,c=b.ready().then(function(){for(var a=b._dbInfo.keyPrefix,c=localStorage.length-1;c>=0;c--){var d=localStorage.key(c);0===d.indexOf(a)&&localStorage.removeItem(d)}});return h(c,a),c}function ja(a,b){var c=this;a=j(a);var d=c.ready().then(function(){var b=c._dbInfo,d=localStorage.getItem(b.keyPrefix+a);return d&&(d=b.serializer.deserialize(d)),d});return h(d,b),d}function ka(a,b){var c=this,d=c.ready().then(function(){for(var b=c._dbInfo,d=b.keyPrefix,e=d.length,f=localStorage.length,g=1,h=0;h=0;b--){var c=localStorage.key(b);0===c.indexOf(a)&&localStorage.removeItem(c)}}):va.reject("Invalid arguments"),h(d,b),d}function ra(a,b){a[b]=function(){var c=arguments;return a.ready().then(function(){return a[b].apply(a,c)})}}function sa(){for(var a=1;a3;o-=4)t(),t(),t(),t();while(o>0)t(),o--;i=new Date;if(i-r>n)break}for(o=s;o>3;o-=4)t(),t(),t(),t();while(o>0)t(),o--;return i=new Date,1e3*(3*s-1)/(i-r)},numeric._myIndexOf=function(t){var n=this.length,r;for(r=0;rnumeric.largeArray)return r.push("...Large Array..."),!0;var f=!1;r.push("[");for(t=0;t0&&(r.push(","),f&&r.push("\n ")),f=i(e[t]);return r.push("]"),!0}r.push("{");var f=!1;for(t in e)e.hasOwnProperty(t)&&(f&&r.push(",\n"),f=!0,r.push(t),r.push(": \n"),i(e[t]));return r.push("}"),!0}var r=[];return i(t),r.join("")},numeric.parseDate=function(t){function n(e){if(typeof e=="string")return Date.parse(e.replace(/-/g,"/"));if(e instanceof Array){var t=[],r;for(r=0;r0){s[f]=[];for(r=0;r>2,u=((r&3)<<4)+(i>>4),a=((i&15)<<2)+(s>>6),f=s&63,n+1>=t?a=f=64:n+2>=t&&(f=64),c+=l.charAt(o)+l.charAt(u)+l.charAt(a)+l.charAt(f);return c}function r(e,t,n){typeof t=="undefined"&&(t=0),typeof n=="undefined"&&(n=e.length);var r=[0,1996959894,3993919788,2567524794,124634137,1886057615,3915621685,2657392035,249268274,2044508324,3772115230,2547177864,162941995,2125561021,3887607047,2428444049,498536548,1789927666,4089016648,2227061214,450548861,1843258603,4107580753,2211677639,325883990,1684777152,4251122042,2321926636,335633487,1661365465,4195302755,2366115317,997073096,1281953886,3579855332,2724688242,1006888145,1258607687,3524101629,2768942443,901097722,1119000684,3686517206,2898065728,853044451,1172266101,3705015759,2882616665,651767980,1373503546,3369554304,3218104598,565507253,1454621731,3485111705,3099436303,671266974,1594198024,3322730930,2970347812,795835527,1483230225,3244367275,3060149565,1994146192,31158534,2563907772,4023717930,1907459465,112637215,2680153253,3904427059,2013776290,251722036,2517215374,3775830040,2137656763,141376813,2439277719,3865271297,1802195444,476864866,2238001368,4066508878,1812370925,453092731,2181625025,4111451223,1706088902,314042704,2344532202,4240017532,1658658271,366619977,2362670323,4224994405,1303535960,984961486,2747007092,3569037538,1256170817,1037604311,2765210733,3554079995,1131014506,879679996,2909243462,3663771856,1141124467,855842277,2852801631,3708648649,1342533948,654459306,3188396048,3373015174,1466479909,544179635,3110523913,3462522015,1591671054,702138776,2966460450,3352799412,1504918807,783551873,3082640443,3233442989,3988292384,2596254646,62317068,1957810842,3939845945,2647816111,81470997,1943803523,3814918930,2489596804,225274430,2053790376,3826175755,2466906013,167816743,2097651377,4027552580,2265490386,503444072,1762050814,4150417245,2154129355,426522225,1852507879,4275313526,2312317920,282753626,1742555852,4189708143,2394877945,397917763,1622183637,3604390888,2714866558,953729732,1340076626,3518719985,2797360999,1068828381,1219638859,3624741850,2936675148,906185462,1090812512,3747672003,2825379669,829329135,1181335161,3412177804,3160834842,628085408,1382605366,3423369109,3138078467,570562233,1426400815,3317316542,2998733608,733239954,1555261956,3268935591,3050360625,752459403,1541320221,2607071920,3965973030,1969922972,40735498,2617837225,3943577151,1913087877,83908371,2512341634,3803740692,2075208622,213261112,2463272603,3855990285,2094854071,198958881,2262029012,4057260610,1759359992,534414190,2176718541,4139329115,1873836001,414664567,2282248934,4279200368,1711684554,285281116,2405801727,4167216745,1634467795,376229701,2685067896,3608007406,1308918612,956543938,2808555105,3495958263,1231636301,1047427035,2932959818,3654703836,1088359270,936918e3,2847714899,3736837829,1202900863,817233897,3183342108,3401237130,1404277552,615818150,3134207493,3453421203,1423857449,601450431,3009837614,3294710456,1567103746,711928724,3020668471,3272380065,1510334235,755167117],i=-1,s=0,o=e.length,u;for(u=t;u>>8^r[s];return i^-1}var i=t[0].length,s=t[0][0].length,o,u,a,f,l,c,h,p,d,v,m,g=[137,80,78,71,13,10,26,10,0,0,0,13,73,72,68,82,s>>24&255,s>>16&255,s>>8&255,s&255,i>>24&255,i>>16&255,i>>8&255,i&255,8,2,0,0,0,-1,-2,-3,-4,-5,-6,-7,-8,73,68,65,84,8,29];m=r(g,12,29),g[29]=m>>24&255,g[30]=m>>16&255,g[31]=m>>8&255,g[32]=m&255,o=1,u=0;for(p=0;p>8&255,g.push(c),g.push(h),g.push(~c&255),g.push(~h&255),p===0&&g.push(0);for(d=0;d255?c=255:c<0?c=0:c=Math.round(c),o=(o+c)%65521,u=(u+o)%65521,g.push(c);g.push(0)}return v=(u<<16)+o,g.push(v>>24&255),g.push(v>>16&255),g.push(v>>8&255),g.push(v&255),l=g.length-41,g[33]=l>>24&255,g[34]=l>>16&255,g[35]=l>>8&255,g[36]=l&255,m=r(g,37),g.push(m>>24&255),g.push(m>>16&255),g.push(m>>8&255),g.push(m&255),g.push(0),g.push(0),g.push(0),g.push(0),g.push(73),g.push(69),g.push(78),g.push(68),g.push(174),g.push(66),g.push(96),g.push(130),"data:image/png;base64,"+n(g)},numeric._dim=function(t){var n=[];while(typeof t=="object")n.push(t.length),t=t[0];return n},numeric.dim=function(t){var n,r;if(typeof t=="object")return n=t[0],typeof n=="object"?(r=n[0],typeof r=="object"?numeric._dim(t):[t.length,n.length]):[t.length];return[]},numeric.mapreduce=function(t,n){return Function("x","accum","_s","_k",'if(typeof accum === "undefined") accum = '+n+";\n"+'if(typeof x === "number") { var xi = x; '+t+"; return accum; }\n"+'if(typeof _s === "undefined") _s = numeric.dim(x);\n'+'if(typeof _k === "undefined") _k = 0;\n'+"var _n = _s[_k];\n"+"var i,xi;\n"+"if(_k < _s.length-1) {\n"+" for(i=_n-1;i>=0;i--) {\n"+" accum = arguments.callee(x[i],accum,_s,_k+1);\n"+" }"+" return accum;\n"+"}\n"+"for(i=_n-1;i>=1;i-=2) { \n"+" xi = x[i];\n"+" "+t+";\n"+" xi = x[i-1];\n"+" "+t+";\n"+"}\n"+"if(i === 0) {\n"+" xi = x[i];\n"+" "+t+"\n"+"}\n"+"return accum;")},numeric.mapreduce2=function(t,n){return Function("x","var n = x.length;\nvar i,xi;\n"+n+";\n"+"for(i=n-1;i!==-1;--i) { \n"+" xi = x[i];\n"+" "+t+";\n"+"}\n"+"return accum;")},numeric.same=function same(e,t){var n,r;if(e instanceof Array&&t instanceof Array){r=e.length;if(r!==t.length)return!1;for(n=0;n=0;o-=2)s[o+1]=n,s[o]=n;return o===-1&&(s[0]=n),s}for(o=i-1;o>=0;o--)s[o]=numeric.rep(t,n,r+1);return s},numeric.dotMMsmall=function(t,n){var r,i,s,o,u,a,f,l,c,h,p,d,v,m;o=t.length,u=n.length,a=n[0].length,f=Array(o);for(r=o-1;r>=0;r--){l=Array(a),c=t[r];for(s=a-1;s>=0;s--){h=c[u-1]*n[u-1][s];for(i=u-2;i>=1;i-=2)p=i-1,h+=c[i]*n[i][s]+c[p]*n[p][s];i===0&&(h+=c[0]*n[0][s]),l[s]=h}f[r]=l}return f},numeric._getCol=function(t,n,r){var i=t.length,s;for(s=i-1;s>0;--s)r[s]=t[s][n],--s,r[s]=t[s][n];s===0&&(r[0]=t[0][n])},numeric.dotMMbig=function(t,n){var r=numeric._getCol,i=n.length,s=Array(i),o=t.length,u=n[0].length,a=new Array(o),f,l=numeric.dotVV,c,h,p,d;--i,--o;for(c=o;c!==-1;--c)a[c]=Array(u);--u;for(c=u;c!==-1;--c){r(n,c,s);for(h=o;h!==-1;--h)d=0,f=t[h],a[h][c]=l(f,s)}return a},numeric.dotMV=function(t,n){var r=t.length,i=n.length,s,o=Array(r),u=numeric.dotVV;for(s=r-1;s>=0;s--)o[s]=u(t[s],n);return o},numeric.dotVM=function(t,n){var r,i,s,o,u,a,f,l,c,h,p,d,v,m,g,y,b,w,E;o=t.length,u=n[0].length,f=Array(u);for(s=u-1;s>=0;s--){h=t[o-1]*n[o-1][s];for(i=o-2;i>=1;i-=2)p=i-1,h+=t[i]*n[i][s]+t[p]*n[p][s];i===0&&(h+=t[0]*n[0][s]),f[s]=h}return f},numeric.dotVV=function(t,n){var r,i=t.length,s,o=t[i-1]*n[i-1];for(r=i-2;r>=1;r-=2)s=r-1,o+=t[r]*n[r]+t[s]*n[s];return r===0&&(o+=t[0]*n[0]),o},numeric.dot=function(t,n){var r=numeric.dim;switch(r(t).length*1e3+r(n).length){case 2002:return n.length<10?numeric.dotMMsmall(t,n):numeric.dotMMbig(t,n);case 2001:return numeric.dotMV(t,n);case 1002:return numeric.dotVM(t,n);case 1001:return numeric.dotVV(t,n);case 1e3:return numeric.mulVS(t,n);case 1:return numeric.mulSV(t,n);case 0:return t*n;default:throw new Error("numeric.dot only works on vectors and matrices")}},numeric.diag=function(t){var n,r,i,s=t.length,o=Array(s),u;for(n=s-1;n>=0;n--){u=Array(s),r=n+2;for(i=s-1;i>=r;i-=2)u[i]=0,u[i-1]=0;i>n&&(u[i]=0),u[n]=t[n];for(i=n-1;i>=1;i-=2)u[i]=0,u[i-1]=0;i===0&&(u[0]=0),o[n]=u}return o},numeric.getDiag=function(e){var t=Math.min(e.length,e[0].length),n,r=Array(t);for(n=t-1;n>=1;--n)r[n]=e[n][n],--n,r[n]=e[n][n];return n===0&&(r[0]=e[0][0]),r},numeric.identity=function(t){return numeric.diag(numeric.rep([t],1))},numeric.pointwise=function(t,n,r){typeof r=="undefined"&&(r="");var i=[],s,o=/\[i\]$/,u,a="",f=!1;for(s=0;s=0;i--) ret[i] = arguments.callee("+t.join(",")+",_s,_k+1);\n"+" return ret;\n"+"}\n"+r+"\n"+"for(i=_n-1;i!==-1;--i) {\n"+" "+n+"\n"+"}\n"+"return ret;",Function.apply(null,i)},numeric.pointwise2=function(t,n,r){typeof r=="undefined"&&(r="");var i=[],s,o=/\[i\]$/,u,a="",f=!1;for(s=0;s=0;s--)_biforeach(typeof e=="object"?e[s]:e,typeof t=="object"?t[s]:t,n,r+1,i)},numeric._biforeach2=function _biforeach2(e,t,n,r,i){if(r===n.length-1)return i(e,t);var s,o=n[r],u=Array(o);for(s=o-1;s>=0;--s)u[s]=_biforeach2(typeof e=="object"?e[s]:e,typeof t=="object"?t[s]:t,n,r+1,i);return u},numeric._foreach=function _foreach(e,t,n,r){if(n===t.length-1){r(e);return}var i,s=t[n];for(i=s-1;i>=0;i--)_foreach(e[i],t,n+1,r)},numeric._foreach2=function _foreach2(e,t,n,r){if(n===t.length-1)return r(e);var i,s=t[n],o=Array(s);for(i=s-1;i>=0;i--)o[i]=_foreach2(e[i],t,n+1,r);return o},numeric.ops2={add:"+",sub:"-",mul:"*",div:"/",mod:"%",and:"&&",or:"||",eq:"===",neq:"!==",lt:"<",gt:">",leq:"<=",geq:">=",band:"&",bor:"|",bxor:"^",lshift:"<<",rshift:">>",rrshift:">>>"},numeric.opseq={addeq:"+=",subeq:"-=",muleq:"*=",diveq:"/=",modeq:"%=",lshifteq:"<<=",rshifteq:">>=",rrshifteq:">>>=",bandeq:"&=",boreq:"|=",bxoreq:"^="},numeric.mathfuns=["abs","acos","asin","atan","ceil","cos","exp","floor","log","round","sin","sqrt","tan","isNaN","isFinite"],numeric.mathfuns2=["atan2","pow","max","min"],numeric.ops1={neg:"-",not:"!",bnot:"~",clone:""},numeric.mapreducers={any:["if(xi) return true;","var accum = false;"],all:["if(!xi) return false;","var accum = true;"],sum:["accum += xi;","var accum = 0;"],prod:["accum *= xi;","var accum = 1;"],norm2Squared:["accum += xi*xi;","var accum = 0;"],norminf:["accum = max(accum,abs(xi));","var accum = 0, max = Math.max, abs = Math.abs;"],norm1:["accum += abs(xi)","var accum = 0, abs = Math.abs;"],sup:["accum = max(accum,xi);","var accum = -Infinity, max = Math.max;"],inf:["accum = min(accum,xi);","var accum = Infinity, min = Math.min;"]},function(){var e,t;for(e=0;em&&(v=h,m=d);a=o[v],o[v]=o[p],o[p]=a,c=f[v],f[v]=f[p],f[p]=c,t=a[p];for(d=p;d!==s;++d)a[d]/=t;for(d=s-1;d!==-1;--d)c[d]/=t;for(h=i-1;h!==-1;--h)if(h!==p){u=o[h],l=f[h],t=u[p];for(d=p+1;d!==s;++d)u[d]-=a[d]*t;for(d=s-1;d>0;--d)l[d]-=c[d]*t,--d,l[d]-=c[d]*t;d===0&&(l[0]-=c[0]*t)}}return f},numeric.det=function(t){var n=numeric.dim(t);if(n.length!==2||n[0]!==n[1])throw new Error("numeric: det() only works on square matrices");var r=n[0],i=1,s,o,u,a=numeric.clone(t),f,l,c,h,p,d,v;for(o=0;oMath.abs(a[u][o])&&(u=s);u!==o&&(h=a[u],a[u]=a[o],a[o]=h,i*=-1),f=a[o];for(s=o+1;s=1;n-=2){a=t[n],u=t[n-1];for(r=s-1;r>=1;--r)f=o[r],f[n]=a[r],f[n-1]=u[r],--r,f=o[r],f[n]=a[r],f[n-1]=u[r];r===0&&(f=o[0],f[n]=a[0],f[n-1]=u[0])}if(n===0){u=t[0];for(r=s-1;r>=1;--r)o[r][0]=u[r],--r,o[r][0]=u[r];r===0&&(o[0][0]=u[0])}return o},numeric.negtranspose=function(t){var n,r,i=t.length,s=t[0].length,o=Array(s),u,a,f;for(r=0;r=1;n-=2){a=t[n],u=t[n-1];for(r=s-1;r>=1;--r)f=o[r],f[n]=-a[r],f[n-1]=-u[r],--r,f=o[r],f[n]=-a[r],f[n-1]=-u[r];r===0&&(f=o[0],f[n]=-a[0],f[n-1]=-u[0])}if(n===0){u=t[0];for(r=s-1;r>=1;--r)o[r][0]=-u[r],--r,o[r][0]=-u[r];r===0&&(o[0][0]=-u[0])}return o},numeric._random=function _random(e,t){var n,r=e[t],i=Array(r),s;if(t===e.length-1){s=Math.random;for(n=r-1;n>=1;n-=2)i[n]=s(),i[n-1]=s();return n===0&&(i[0]=s()),i}for(n=r-1;n>=0;n--)i[n]=_random(e,t+1);return i},numeric.random=function(t){return numeric._random(t,0)},numeric.norm2=function(t){return Math.sqrt(numeric.norm2Squared(t))},numeric.linspace=function(t,n,r){typeof r=="undefined"&&(r=Math.max(Math.round(n-t)+1,1));if(r<2)return r===1?[t]:[];var i,s=Array(r);r--;for(i=r;i>=0;i--)s[i]=(i*n+(r-i)*t)/r;return s},numeric.getBlock=function(t,n,r){function s(e,t){var o,u=n[t],a=r[t]-u,f=Array(a);if(t===i.length-1){for(o=a;o>=0;o--)f[o]=e[o+u];return f}for(o=a;o>=0;o--)f[o]=s(e[o+u],t+1);return f}var i=numeric.dim(t);return s(t,0)},numeric.setBlock=function(t,n,r,i){function o(e,t,i){var u,a=n[i],f=r[i]-a;if(i===s.length-1)for(u=f;u>=0;u--)e[u+a]=t[u];for(u=f;u>=0;u--)o(e[u+a],t[u],i+1)}var s=numeric.dim(t);return o(t,i,0),t},numeric.getRange=function(t,n,r){var i=n.length,s=r.length,o,u,a=Array(i),f,l;for(o=i-1;o!==-1;--o){a[o]=Array(s),f=a[o],l=t[n[o]];for(u=s-1;u!==-1;--u)f[u]=l[r[u]]}return a},numeric.blockMatrix=function(t){var n=numeric.dim(t);if(n.length<4)return numeric.blockMatrix([t]);var r=n[0],i=n[1],s,o,u,a,f;s=0,o=0;for(u=0;u=0;f--){a=Array(o),c=t[f];for(l=o-1;l>=3;--l)a[l]=c*n[l],--l,a[l]=c*n[l],--l,a[l]=c*n[l],--l,a[l]=c*n[l];while(l>=0)a[l]=c*n[l],--l;u[f]=a}return u},numeric.T=function(t,n){this.x=t,this.y=n},numeric.t=function(t,n){return new numeric.T(t,n)},numeric.Tbinop=function(t,n,r,i,s){var o=numeric.indexOf;if(typeof s!="string"){var u;s="";for(u in numeric)numeric.hasOwnProperty(u)&&(t.indexOf(u)>=0||n.indexOf(u)>=0||r.indexOf(u)>=0||i.indexOf(u)>=0)&&u.length>1&&(s+="var "+u+" = numeric."+u+";\n")}return Function(["y"],"var x = this;\nif(!(y instanceof numeric.T)) { y = new numeric.T(y); }\n"+s+"\n"+"if(x.y) {"+" if(y.y) {"+" return new numeric.T("+i+");\n"+" }\n"+" return new numeric.T("+r+");\n"+"}\n"+"if(y.y) {\n"+" return new numeric.T("+n+");\n"+"}\n"+"return new numeric.T("+t+");\n")},numeric.T.prototype.add=numeric.Tbinop("add(x.x,y.x)","add(x.x,y.x),y.y","add(x.x,y.x),x.y","add(x.x,y.x),add(x.y,y.y)"),numeric.T.prototype.sub=numeric.Tbinop("sub(x.x,y.x)","sub(x.x,y.x),neg(y.y)","sub(x.x,y.x),x.y","sub(x.x,y.x),sub(x.y,y.y)"),numeric.T.prototype.mul=numeric.Tbinop("mul(x.x,y.x)","mul(x.x,y.x),mul(x.x,y.y)","mul(x.x,y.x),mul(x.y,y.x)","sub(mul(x.x,y.x),mul(x.y,y.y)),add(mul(x.x,y.y),mul(x.y,y.x))"),numeric.T.prototype.reciprocal=function(){var t=numeric.mul,n=numeric.div;if(this.y){var r=numeric.add(t(this.x,this.x),t(this.y,this.y));return new numeric.T(n(this.x,r),n(numeric.neg(this.y),r))}return new T(n(1,this.x))},numeric.T.prototype.div=function div(e){e instanceof numeric.T||(e=new numeric.T(e));if(e.y)return this.mul(e.reciprocal());var div=numeric.div;return this.y?new numeric.T(div(this.x,e.x),div(this.y,e.x)):new numeric.T(div(this.x,e.x))},numeric.T.prototype.dot=numeric.Tbinop("dot(x.x,y.x)","dot(x.x,y.x),dot(x.x,y.y)","dot(x.x,y.x),dot(x.y,y.x)","sub(dot(x.x,y.x),dot(x.y,y.y)),add(dot(x.x,y.y),dot(x.y,y.x))"),numeric.T.prototype.transpose=function(){var t=numeric.transpose,n=this.x,r=this.y;return r?new numeric.T(t(n),t(r)):new numeric.T(t(n))},numeric.T.prototype.transjugate=function(){var t=numeric.transpose,n=this.x,r=this.y;return r?new numeric.T(t(n),numeric.negtranspose(r)):new numeric.T(t(n))},numeric.Tunop=function(t,n,r){return typeof r!="string"&&(r=""),Function("var x = this;\n"+r+"\n"+"if(x.y) {"+" "+n+";\n"+"}\n"+t+";\n")},numeric.T.prototype.exp=numeric.Tunop("return new numeric.T(ex)","return new numeric.T(mul(cos(x.y),ex),mul(sin(x.y),ex))","var ex = numeric.exp(x.x), cos = numeric.cos, sin = numeric.sin, mul = numeric.mul;"),numeric.T.prototype.conj=numeric.Tunop("return new numeric.T(x.x);","return new numeric.T(x.x,numeric.neg(x.y));"),numeric.T.prototype.neg=numeric.Tunop("return new numeric.T(neg(x.x));","return new numeric.T(neg(x.x),neg(x.y));","var neg = numeric.neg;"),numeric.T.prototype.sin=numeric.Tunop("return new numeric.T(numeric.sin(x.x))","return x.exp().sub(x.neg().exp()).div(new numeric.T(0,2));"),numeric.T.prototype.cos=numeric.Tunop("return new numeric.T(numeric.cos(x.x))","return x.exp().add(x.neg().exp()).div(2);"),numeric.T.prototype.abs=numeric.Tunop("return new numeric.T(numeric.abs(x.x));","return new numeric.T(numeric.sqrt(numeric.add(mul(x.x,x.x),mul(x.y,x.y))));","var mul = numeric.mul;"),numeric.T.prototype.log=numeric.Tunop("return new numeric.T(numeric.log(x.x));","var theta = new numeric.T(numeric.atan2(x.y,x.x)), r = x.abs();\nreturn new numeric.T(numeric.log(r.x),theta.x);"),numeric.T.prototype.norm2=numeric.Tunop("return numeric.norm2(x.x);","var f = numeric.norm2Squared;\nreturn Math.sqrt(f(x.x)+f(x.y));"),numeric.T.prototype.inv=function(){var t=this;if(typeof t.y=="undefined")return new numeric.T(numeric.inv(t.x));var n=t.x.length,r,i,s,o=numeric.identity(n),u=numeric.rep([n,n],0),a=numeric.clone(t.x),f=numeric.clone(t.y),l,c,h,p,d,v,m,g,r,i,s,y,b,w,E,S,x,T;for(r=0;ry&&(s=i,y=b);s!==r&&(T=a[r],a[r]=a[s],a[s]=T,T=f[r],f[r]=f[s],f[s]=T,T=o[r],o[r]=o[s],o[s]=T,T=u[r],u[r]=u[s],u[s]=T),l=a[r],c=f[r],d=o[r],v=u[r],w=l[r],E=c[r];for(i=r+1;i0;r--){d=o[r],v=u[r];for(i=r-1;i>=0;i--){m=o[i],g=u[i],w=a[i][r],E=f[i][r];for(s=n-1;s>=0;s--)S=d[s],x=v[s],m[s]-=w*S-E*x,g[s]-=w*x+E*S}}return new numeric.T(o,u)},numeric.T.prototype.get=function(t){var n=this.x,r=this.y,i=0,s,o=t.length;if(r){while(i=0?1:-1,i=r*numeric.norm2(t);n[0]+=i;var s=numeric.norm2(n);if(s===0)throw new Error("eig: internal error");return numeric.div(n,s)},numeric.toUpperHessenberg=function(t){var n=numeric.dim(t);if(n.length!==2||n[0]!==n[1])throw new Error("numeric: toUpperHessenberg() only works on square matrices");var r=n[0],i,s,o,u,a,f=numeric.clone(t),l,c,h,p,d=numeric.identity(r),v;for(s=0;s0){a=numeric.house(u),l=numeric.getBlock(f,[s+1,s],[r-1,r-1]),c=numeric.tensor(a,numeric.dot(a,l));for(i=s+1;i=4*c){var k,L;k=.5*(h+Math.sqrt(h*h-4*c)),L=.5*(h-Math.sqrt(h*h-4*c)),p=numeric.add(numeric.sub(numeric.dot(p,p),numeric.mul(p,k+L)),numeric.diag(numeric.rep([3],k*L)))}else p=numeric.add(numeric.sub(numeric.dot(p,p),numeric.mul(p,h)),numeric.diag(numeric.rep([3],c)));s=[p[0][0],p[1][0],p[2][0]],o=numeric.house(s),g=[e[0],e[1],e[2]],y=numeric.tensor(o,numeric.dot(o,g));for(w=0;w<3;w++){m=e[w],b=y[w];for(S=0;S=0?(w<0?x=-0.5*(w-A(S)):x=-0.5*(w+A(S)),k=(m-x)*(m-x)+g*g,L=y*y+(b-x)*(b-x),k>L?(k=A(k),N=(m-x)/k,C=g/k):(L=A(L),N=y/L,C=(b-x)/L),p=new s([[C,-N],[N,C]]),h.setRows(u,v,p.dot(h.getRows(u,v)))):(x=-0.5*w,T=.5*A(-S),k=(m-x)*(m-x)+g*g,L=y*y+(b-x)*(b-x),k>L?(k=A(k+T*T),N=(m-x)/k,C=g/k,x=0,T/=k):(L=A(L+T*T),N=y/L,C=(b-x)/L,x=T/L,T=0),p=new s([[C,-N],[N,C]],[[x, -T],[T,-x]]),h.setRows(u,v,p.dot(h.getRows(u,v))))}}var O=h.dot(t).dot(h.transjugate()),o=t.length,M=numeric.T.identity(o);for(v=0;v0)for(a=v-1;a>=0;a--){var _=O.get([a,a]),D=O.get([v,v]);if(!numeric.neq(_.x,D.x)&&!numeric.neq(_.y,D.y)){M.setRow(v,M.getRow(a));continue}x=O.getRow(a).getBlock([a],[v-1]),T=M.getRow(v).getBlock([a],[v-1]),M.set([v,a],O.get([a,v]).neg().sub(x.dot(T)).div(_.sub(D)))}for(v=0;v=u.length)u[u.length]=0;i[o]!==0&&u[o]++}}var r=u.length,a=Array(r+1);a[0]=0;for(s=0;s=d){s[f]=h[u];if(u===0)return;++f,--u,p=l[u],d=c[u]}else a=o[r[p]],i[a]===0?(i[a]=1,l[u]=p,++u,h[u]=a,p=n[a],c[u]=d=n[a+1]):++p},numeric.ccsLPSolve=function(t,n,r,i,s,o,u){var a=t[0],f=t[1],l=t[2],c=a.length-1,h=0,p=n[0],d=n[1],v=n[2],m,g,y,b,w,E,S,x,T,N,C,k;g=p[s],y=p[s+1],i.length=0;for(m=g;mb&&(w=m,b=E)}C(h[d])=v){s[l]=o[p[a]];if(a===0)return;++l,--a,d=c[a],v=h[a]}else f=r[d],i[f]===0?(i[f]=1,c[a]=d,++a,p[a]=f,f=o[f],d=n[f],h[a]=v=n[f+1]):++d}},numeric.ccsLPSolve0=function(t,n,r,i,s,o,u,a){var f=t[0],l=t[1],c=t[2],h=f.length-1,p=0,d=n[0],v=n[1],m=n[2],g,y,b,w,E,S,x,T,N,C,k,L;y=d[s],b=d[s+1],i.length=0;for(g=y;gb&&(w=m,b=E)}C(h[k[d]])t[n]&&(t[n]=e.length);var r;for(r in e)e.hasOwnProperty(r)&&dim(e[r],t,n+1);return t},numeric.sclone=function clone(e,t,n){typeof t=="undefined"&&(t=0),typeof n=="undefined"&&(n=numeric.sdim(e).length);var r,i=Array(e.length);if(t===n-1){for(r in e)e.hasOwnProperty(r)&&(i[r]=e[r]);return i}for(r in e)e.hasOwnProperty(r)&&(i[r]=clone(e[r],t+1,n));return i},numeric.sdiag=function(t){var n=t.length,r,i=Array(n),s,o,u;for(r=n-1;r>=1;r-=2)s=r-1,i[r]=[],i[r][r]=t[r],i[s]=[],i[s][s]=t[s];return r===0&&(i[0]=[],i[0][0]=t[r]),i},numeric.sidentity=function(t){return numeric.sdiag(numeric.rep([t],1))},numeric.stranspose=function(t){var n=[],r=t.length,i,s,o;for(i in t){if(!t.hasOwnProperty(i))continue;o=t[i];for(s in o){if(!o.hasOwnProperty(s))continue;typeof n[s]!="object"&&(n[s]=[]),n[s][i]=o[s]}}return n},numeric.sLUP=function(t,n){throw new Error("The function numeric.sLUP had a bug in it and has been removed. Please use the new numeric.ccsLUP function instead.")},numeric.sdotMM=function(t,n){var r=t.length,i=n.length,s=numeric.stranspose(n),o=s.length,u,a,f,l,c,h,p=Array(r),d;for(f=r-1;f>=0;f--){d=[],u=t[f];for(c=o-1;c>=0;c--){h=0,a=s[c];for(l in u){if(!u.hasOwnProperty(l))continue;l in a&&(h+=u[l]*a[l])}h&&(d[c]=h)}p[f]=d}return p},numeric.sdotMV=function(t,n){var r=t.length,i,s,o,u=Array(r),a;for(s=r-1;s>=0;s--){i=t[s],a=0;for(o in i){if(!i.hasOwnProperty(o))continue;n[o]&&(a+=i[o]*n[o])}a&&(u[s]=a)}return u},numeric.sdotVM=function(t,n){var r,i,s,o,u=[],a;for(r in t){if(!t.hasOwnProperty(r))continue;s=n[r],o=t[r];for(i in s){if(!s.hasOwnProperty(i))continue;u[i]||(u[i]=0),u[i]+=o*s[i]}}return u},numeric.sdotVV=function(t,n){var r,i=0;for(r in t)t[r]&&n[r]&&(i+=t[r]*n[r]);return i},numeric.sdot=function(t,n){var r=numeric.sdim(t).length,i=numeric.sdim(n).length,s=r*1e3+i;switch(s){case 0:return t*n;case 1001:return numeric.sdotVV(t,n);case 2001:return numeric.sdotMV(t,n);case 1002:return numeric.sdotVM(t,n);case 2002:return numeric.sdotMM(t,n);default:throw new Error("numeric.sdot not implemented for tensors of order "+r+" and "+i)}},numeric.sscatter=function(t){var n=t[0].length,r,i,s,o=t.length,u=[],a;for(i=n-1;i>=0;--i){if(!t[o-1][i])continue;a=u;for(s=0;s=0;--i)t[i]=[];for(i=r;i>=0;--i)t[i].push(n[i]);t[r+1].push(s)}}else gather(s,t,n)}return n.length>r&&n.pop(),t},numeric.cLU=function(t){var n=t[0],r=t[1],i=t[2],s=n.length,o=0,u,a,f,l,c,h;for(u=0;uo&&(o=n[u]);o++;var p=Array(o),d=Array(o),v=numeric.rep([o],Infinity),m=numeric.rep([o],-Infinity),g,y,b;for(f=0;fm[u]&&(m[u]=a);for(u=0;um[u+1]&&(m[u+1]=m[u]);for(u=o-1;u>=1;u--)v[u]=0;v--){while(l[g]>v)s[v]-=c[g]*s[l[g]],g--;s[v]/=c[g],g--}return s},numeric.cgrid=function(t,n){typeof t=="number"&&(t=[t,t]);var r=numeric.rep(t,-1),i,s,o;if(typeof n!="function")switch(n){case"L":n=function(e,n){return e>=t[0]/2||nf&&(f=i[u]);f++,r=numeric.rep([f],0);for(u=0;u1)o=u((i+s)/2),n[o]<=t?i=o:s=o;return this._at(t,i)}var r=t.length,c,h=Array(r);for(c=r-1;c!==-1;--c)h[c]=this.at(t[c]);return h},numeric.Spline.prototype.diff=function(){var t=this.x,n=this.yl,r=this.yr,i=this.kl,s=this.kr,o=n.length,u,a,f,l=i,c=s,h=Array(o),p=Array(o),d=numeric.add,v=numeric.mul,m=numeric.div,g=numeric.sub;for(u=o-1;u!==-1;--u)a=t[u+1]-t[u],f=g(r[u+1],n[u]),h[u]=m(d(v(f,6),v(i[u],-4*a),v(s[u+1],-2*a)),a*a),p[u+1]=m(d(v(f,-6),v(i[u],2*a),v(s[u+1],4*a)),a*a);return new numeric.Spline(t,l,c,h,p)},numeric.Spline.prototype.roots=function(){function t(e){return e*e}function n(e,t,n,r,i){var s=n*2-(t-e),o=-r*2+(t-e),u=(i+1)*.5,a=u*(1-u);return(1-u)*e+u*t+s*a*(1-u)+o*a*u}var r=[],i=this.x,s=this.yl,o=this.yr,u=this.kl,a=this.kr;typeof s[0]=="number"&&(s=[s],o=[o],u=[u],a=[a]);var f=s.length,l=i.length-1,c,h,p,d,v,m,g,y,b,w,r=Array(f),E,S,x,T,N,C,k,L,A,O,M,_,D,P,H,B,j,F=Math.sqrt;for(c=0;c!==f;++c){g=s[c],y=o[c],b=u[c],w=a[c],E=[];for(h=0;h!==l;h++){h>0&&y[h]*g[h]<0&&E.push(i[h]),A=i[h+1]-i[h],O=i[h],T=g[h],N=y[h+1],S=b[h]/A,x=w[h+1]/A,L=t(S-x+3*(T-N))+12*x*T,C=x+3*T+2*S-3*N,k=3*(x+S+2*(T-N)),L<=0?(_=C/k,_>i[h]&&_i[h]&&_i[h]&&D0){H=B,_=D;continue}var I=0;for(;;){j=(_*B-D*H)/(_-D);if(j<=H||j>=B)break;P=this._at(j,h);if(P*D>0)B=j,D=P,I===-1&&(_*=.5),I=-1;else{if(!(P*_>0))break;H=j,_=P,I===1&&(D*=.5),I=1}}E.push(j),H=M[p+1],_=this._at(H,h)}D===0&&E.push(B)}r[c]=E}return typeof this.yl[0]=="number"?r[0]:r},numeric.spline=function(t,n,r,i){var s=t.length,o=[],u=[],a=[],f,l=numeric.sub,c=numeric.mul,h=numeric.add;for(f=s-2;f>=0;f--)u[f]=t[f+1]-t[f],a[f]=l(n[f+1],n[f]);if(typeof r=="string"||typeof i=="string")r=i="periodic";var p=[[],[],[]];switch(typeof r){case"undefined":o[0]=c(3/(u[0]*u[0]),a[0]),p[0].push(0,0),p[1].push(0,1),p[2].push(2/u[0],1/u[0]);break;case"string":o[0]=h(c(3/(u[s-2]*u[s-2]),a[s-2]),c(3/(u[0]*u[0]),a[0])),p[0].push(0,0,0),p[1].push(s-2,0,1),p[2].push(1/u[s-2],2/u[s-2]+2/u[0],1/u[0]);break;default:o[0]=r,p[0].push(0),p[1].push(0),p[2].push(1)}for(f=1;f20)throw new Error("Numerical gradient fails");u[o]=n[o]+N,a=t(u),u[o]=n[o]-N,f=t(u),u[o]=n[o];if(isNaN(a)||isNaN(f)){N/=16;continue}l[o]=(a-f)/(2*N),y=n[o]-N,b=n[o],w=n[o]+N,S=(a-i)/N,x=(i-f)/N,T=s(m(l[o]),m(i),m(a),m(f),m(y),m(b),m(w),1e-8),p=g(s(m(S-l[o]),m(x-l[o]),m(S-x))/T,N/T);if(!(p>v))break;N/=16}}return l},numeric.uncmin=function(t,n,r,i,s,o,u){var a=numeric.gradient;typeof u=="undefined"&&(u={}),typeof r=="undefined"&&(r=1e-8),typeof i=="undefined"&&(i=function(e){return a(t,e)}),typeof s=="undefined"&&(s=1e3),n=numeric.clone(n);var f=n.length,l=t(n),c,h;if(isNaN(l))throw new Error("uncmin: f(x0) is a NaN!");var p=Math.max,d=numeric.norm2;r=p(r,numeric.epsilon);var v,m,g,y=u.Hinv||numeric.identity(f),b=numeric.dot,w=numeric.inv,E=numeric.sub,S=numeric.add,x=numeric.tensor,T=numeric.div,N=numeric.mul,C=numeric.all,k=numeric.isFinite,L=numeric.neg,A=0,O,M,_,D,P,H,B,j,F,I,q,R,U="";m=i(n);while(A=.1*F*h||isNaN(c)){F*=.5,++A;continue}break}if(F*I1)i=s(.5*(n+r)),a[i]<=t?n=i:r=i;return this._at(t,n)},numeric.dopri=function(t,n,r,i,s,o,u){typeof s=="undefined"&&(s=1e-6),typeof o=="undefined"&&(o=1e3);var a=[t],f=[r],l=[i(t,r)],c,h,p,d,v,m,g=[],y=.2,b=[.075,.225],w=[44/45,-56/15,32/9],E=[19372/6561,-25360/2187,64448/6561,-212/729],S=[9017/3168,-355/33,46732/5247,49/176,-5103/18656],x=[35/384,0,500/1113,125/192,-2187/6784,11/84],T=[.10013431883002395,0,.3918321794184259,-0.02982460176594817,.05893268337240795,-0.04497888809104361,.023904308236133973],N=[.2,.3,.8,8/9,1,1],C=[-71/57600,0,71/16695,-71/1920,17253/339200,-22/525,.025],k=0,L,A,O=(n-t)/10,M=0,_=numeric.add,D=numeric.mul,P,H,B=Math.max,j=Math.min,F=Math.abs,I=numeric.norminf,q=Math.pow,R=numeric.any,U=numeric.lt,z=numeric.and,W=numeric.sub,X,V,$,J=new numeric.Dopri(a,f,l,g,-1,"");typeof u=="function"&&(X=u(t,r));while(tn&&(O=n-t),c=i(t+N[0]*O,_(r,D(y*O,l[k]))),h=i(t+N[1]*O,_(_(r,D(b[0]*O,l[k])),D(b[1]*O,c))),p=i(t+N[2]*O,_(_(_(r,D(w[0]*O,l[k])),D(w[1]*O,c)),D(w[2]*O,h))),d=i(t+N[3]*O,_(_(_(_(r,D(E[0]*O,l[k])),D(E[1]*O,c)),D(E[2]*O,h)),D(E[3]*O,p))),v=i(t+N[4]*O,_(_(_(_(_(r,D(S[0]*O,l[k])),D(S[1]*O,c)),D(S[2]*O,h)),D(S[3]*O,p)),D(S[4]*O,d))),P=_(_(_(_(_(r,D(l[k],O*x[0])),D(h,O*x[2])),D(p,O*x[3])),D(d,O*x[4])),D(v,O*x[5])),m=i(t+O,P),L=_(_(_(_(_(D(l[k],O*C[0]),D(h,O*C[2])),D(p,O*C[3])),D(d,O*C[4])),D(v,O*C[5])),D(m,O*C[6])),typeof L=="number"?H=F(L):H=I(L);if(H>s){O=.2*O*q(s/H,.25);if(t+O===t){J.msg="Step size became too small";break}continue}g[k]=_(_(_(_(_(_(r,D(l[k],O*T[0])),D(h,O*T[2])),D(p,O*T[3])),D(d,O*T[4])),D(v,O*T[5])),D(m,O*T[6])),++k,a[k]=t+O,f[k]=P,l[k]=m;if(typeof u=="function"){var K,Q=t,G=t+.5*O,Y;V=u(G,g[k-1]),$=z(U(X,0),U(0,V)),R($)||(Q=G,G=t+O,X=V,V=u(G,P),$=z(U(X,0),U(0,V)));if(R($)){var Z,et,tt,nt,rt=0,it=1,st=1;for(;;){if(typeof X=="number")Y=(st*V*Q-it*X*G)/(st*V-it*X);else{Y=G;for(A=X.length-1;A!==-1;--A)X[A]<0&&V[A]>0&&(Y=j(Y,(st*V[A]*Q-it*X[A]*G)/(st*V[A]-it*X[A])))}if(Y<=Q||Y>=G)break;K=J._at(Y,k-1),nt=u(Y,K),tt=z(U(X,0),U(0,nt)),R(tt)?(G=Y,V=nt,$=tt,st=1,rt===-1?it*=.5:it=1,rt=-1):(Q=Y,X=nt,it=1,rt===1?st*=.5:st=1,rt=1)}return P=J._at(.5*(t+Y),k-1),J.f[k]=i(Y,K),J.x[k]=Y,J.y[k]=K,J.ymid[k-1]=P,J.events=$,J.iterations=M,J}}t+=O,r=P,X=V,O=j(.8*O*q(s/H,.25),4*O)}return J.iterations=M,J},numeric.LU=function(e,t){t=t||!1;var n=Math.abs,r,i,s,o,u,a,f,l,c,h=e.length,p=h-1,d=new Array(h);t||(e=numeric.clone(e));for(s=0;s=0;--r){l=s[r];for(i=r+1;iK)E=K;W=d(t,l(E,B)),I=h(J,j);for(X=v-1;X!==-1;--X)I[X][X]+=1;$=q(I,p(W,E),!0);var Q=p(R,h(n,$)),G=1;for(X=m-1;X!==-1;--X)Q[X]<0&&(G=D(G,-0.999*Q[X]));g=c(o,l($,G)),R=c(r,h(n,g));if(!P(H(R,0)))return{solution:o,message:"",iterations:U};o=g;if(E=0?y=!1:y=!0;if(y)return{solution:g,message:"Unbounded",iterations:U}}return{solution:o,message:"maximum iteration count exceeded",iterations:U}},numeric._solveLP=function(t,n,r,i,s){var o=t.length,u=r.length,a,f=numeric.sum,l=numeric.log,c=numeric.mul,h=numeric.sub,p=numeric.dot,d=numeric.div,v=numeric.add,m=numeric.rep([o],0).concat([1]),g=numeric.rep([u,1],-1),y=numeric.blockMatrix([[n,g]]),b=r,a=numeric.rep([o],0).concat(Math.max(0,numeric.sup(numeric.neg(r)))+1),w=numeric.__solveLP(m,y,b,i,s,a,!1),E=numeric.clone(w.solution);E.length=o;var S=numeric.inf(h(r,p(n,E)));if(S<0)return{solution:NaN,message:"Infeasible",iterations:w.iterations};var x=numeric.__solveLP(t,n,r,i,s-w.iterations,E,!0);return x.iterations+=w.iterations,x},numeric.solveLP=function(t,n,r,i,s,o,u){typeof u=="undefined"&&(u=1e3),typeof o=="undefined"&&(o=numeric.epsilon);if(typeof i=="undefined")return numeric._solveLP(t,n,r,o,u);var a=i.length,f=i[0].length,l=n.length,c=numeric.echelonize(i),h=numeric.rep([f],0),p=c.P,d=[],v;for(v=p.length-1;v!==-1;--v)h[p[v]]=1;for(v=f-1;v!==-1;--v)h[v]===0&&d.push(v);var m=numeric.getRange,g=numeric.linspace(0,a-1),y=numeric.linspace(0,l-1),b=m(i,g,d),w=m(n,y,p),E=m(n,y,d),S=numeric.dot,x=numeric.sub,T=S(w,c.I),N=x(E,S(T,b)),C=x(r,S(T,s)),k=Array(p.length),L=Array(d.length);for(v=p.length-1;v!==-1;--v)k[v]=t[p[v]];for(v=d.length-1;v!==-1;--v)L[v]=t[d[v]];var A=x(L,S(k,S(c.I,b))),O=numeric._solveLP(A,N,C,o,u),M=O.solution;if(M!==M)return O;var _=S(c.I,x(s,S(b,M))),D=Array(t.length);for(v=p.length-1;v!==-1;--v)D[p[v]]=_[v];for(v=d.length-1;v!==-1;--v)D[d[v]]=M[v];return{solution:D,message:O.message,iterations:O.iterations}},numeric.MPStoLP=function(t){function y(e){throw new Error("MPStoLP: "+e+"\nLine "+s+": "+t[s]+"\nCurrent state: "+r[n]+"\n")}t instanceof String&&t.split("\n");var n=0,r=["Initial state","NAME","ROWS","COLUMNS","RHS","BOUNDS","ENDATA"],i=t.length,s,o,u,a=0,f={},l=[],c=0,h={},p=0,d,v=[],m=[],g=[];for(s=0;s=s)t/=2,u/=2,a>>>=1;return(t+a)/u},c},o=t.pow(n,r),i=t.pow(2,i),s=i*2,f(t.random(),e)}([],numeric.seedrandom,256,6,52),function(e){function t(e){if(typeof e!="object")return e;var n=[],r,i=e.length;for(r=0;rp)g[E]=P;else{g[E]=-Math.abs(P);if(P>0){for(w=1;w<=o;w+=1)f[w][b]=-f[w][b];l[b]=-l[b]}}}for(b=1;b<=v;b+=1)g[L+d[b]]=0;O=0,D=0;for(b=1;b<=h;b+=1)g[L+b]=1;b-=1){P=g[b],E=k+b*(b+3)/2,S=E-b;for(w=b+1;w<=v;w+=1)P-=g[E]*g[C+w],E+=w;P/=g[S],g[C+b]=P;if(d[b]p)g[L+O]=P;else{g[L+O]=-Math.abs(P);if(P>0){for(w=1;w<=o;w+=1)f[w][O]=-f[w][O];l[O]=-l[O]}}return 700}v+=1,d[v]=O,E=k+(v-1)*v/2+1;for(b=1;b<=v-1;b+=1)g[E]=g[b],E+=1;if(v===o)g[E]=g[o];else{for(b=o;b>=v+1;b-=1){if(g[b]===0)break;j=Math.max(Math.abs(g[b-1]),Math.abs(g[b])),F=Math.min(Math.abs(g[b-1]),Math.abs(g[b])),g[b-1]>=0?D=Math.abs(j*Math.sqrt(1+F*F/(j*j))):D=-Math.abs(j*Math.sqrt(1+F*F/(j*j))),j=g[b-1]/D,F=g[b]/D;if(j===1)break;if(j===0){g[b-1]=F*D;for(w=1;w<=o;w+=1)D=e[w][b-1],e[w][b-1]=e[w][b],e[w][b]=D}else{g[b-1]=D,I=F/(1+j);for(w=1;w<=o;w+=1)D=j*e[w][b-1]+F*e[w][b],e[w][b]=I*(e[w][b-1]+D)-e[w][b],e[w][b-1]=D}}g[E]=g[v]}return 0}function J(){E=k+T*(T+1)/2+1,S=E+T;if(g[S]===0)return 798;j=Math.max(Math.abs(g[S-1]),Math.abs(g[S])),F=Math.min(Math.abs(g[S-1]),Math.abs(g[S])),g[S-1]>=0?D=Math.abs(j*Math.sqrt(1+F*F/(j*j))):D=-Math.abs(j*Math.sqrt(1+F*F/(j*j))),j=g[S-1]/D,F=g[S]/D;if(j===1)return 798;if(j===0){for(b=T+1;b<=v;b+=1)D=g[S-1],g[S-1]=g[S],g[S]=D,S+=b;for(b=1;b<=o;b+=1)D=e[b][T],e[b][T]=e[b][T+1],e[b][T+1]=D}else{I=F/(1+j);for(b=T+1;b<=v;b+=1)D=j*g[S-1]+F*g[S],g[S]=I*(g[S-1]+D)-g[S],g[S-1]=D,S+=b;for(b=1;b<=o;b+=1)D=j*e[b][T]+F*e[b][T+1],e[b][T+1]=I*(e[b][T]+D)-e[b][T+1],e[b][T]=D}return 0}function K(){S=E-T;for(b=1;b<=T;b+=1)g[S]=g[E],E+=1,S+=1;return g[A+T]=g[A+T+1],d[T]=d[T+1],T+=1,Tt?e*Math.sqrt(1+t*t/e/e):t==0?e:t*Math.sqrt(1+e*e/t/t)}var n,r=numeric.epsilon,i=1e-64/r,s=50,o=0,u=0,a=0,f=0,l=0,c=numeric.clone(t),h=c.length,p=c[0].length;if(h=0&&(b=-b),w=y*b-T,c[u][u]=y-b;for(a=l;a=0&&(b=-b),w=y*b-T,c[u][u+1]=y-b;for(a=l;aE&&(E=S)}for(u=p-1;u!=-1;u+=-1){if(b!=0){w=b*c[u][u+1];for(a=l;a=s-1)throw"Error: no convergence.";E=v[l],S=v[f-1],b=d[f-1],w=d[f],y=((S-x)*(S+x)+(b-w)*(b+w))/(2*w*S),b=g(y,1),y<0?y=((E-x)*(E+x)+w*(S/(y-b)-w))/E:y=((E-x)*(E+x)+w*(S/(y+b)-w))/E,o=1,T=1;for(u=l+1;u=0;a--)if(v[a] 0 && t[t.length - 1]) && (op[0] === 6 || op[0] === 2)) { _ = 0; continue; } - if (op[0] === 3 && (!t || (op[1] > t[0] && op[1] < t[3]))) { _.label = op[1]; break; } - if (op[0] === 6 && _.label < t[1]) { _.label = t[1]; t = op; break; } - if (t && _.label < t[2]) { _.label = t[2]; _.ops.push(op); break; } - if (t[2]) _.ops.pop(); - _.trys.pop(); continue; - } - op = body.call(thisArg, _); - } catch (e) { op = [6, e]; y = 0; } finally { f = t = 0; } - if (op[0] & 5) throw op[1]; return { value: op[0] ? op[1] : void 0, done: true }; - } - } - - /** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - // Expects flags from URL in the format ?tfjsflags=FLAG1:1,FLAG2:true. - var TENSORFLOWJS_FLAGS_PREFIX = 'tfjsflags'; - /** - * The environment contains evaluated flags as well as the registered platform. - * This is always used as a global singleton and can be retrieved with - * `tf.env()`. - */ - /** @doc {heading: 'Environment'} */ - var Environment = /** @class */ (function () { - // tslint:disable-next-line: no-any - function Environment(global) { - this.global = global; - this.flags = {}; - this.flagRegistry = {}; - this.urlFlags = {}; - this.populateURLFlags(); - } - Environment.prototype.setPlatform = function (platformName, platform) { - if (this.platform != null) { - console.warn("Platform " + this.platformName + " has already been set. " + - ("Overwriting the platform with " + platform + ".")); - } - this.platformName = platformName; - this.platform = platform; - }; - Environment.prototype.registerFlag = function (flagName, evaluationFn, setHook) { - this.flagRegistry[flagName] = { evaluationFn: evaluationFn, setHook: setHook }; - // Override the flag value from the URL. This has to happen here because the - // environment is initialized before flags get registered. - if (this.urlFlags[flagName] != null) { - var flagValue = this.urlFlags[flagName]; - console.warn("Setting feature override from URL " + flagName + ": " + flagValue + "."); - this.set(flagName, flagValue); - } - }; - Environment.prototype.get = function (flagName) { - if (flagName in this.flags) { - return this.flags[flagName]; - } - this.flags[flagName] = this.evaluateFlag(flagName); - return this.flags[flagName]; - }; - Environment.prototype.getNumber = function (flagName) { - return this.get(flagName); - }; - Environment.prototype.getBool = function (flagName) { - return this.get(flagName); - }; - Environment.prototype.getFlags = function () { - return this.flags; - }; - Object.defineProperty(Environment.prototype, "features", { - // For backwards compatibility. - get: function () { - return this.flags; - }, - enumerable: true, - configurable: true - }); - Environment.prototype.set = function (flagName, value) { - if (this.flagRegistry[flagName] == null) { - throw new Error("Cannot set flag " + flagName + " as it has not been registered."); - } - this.flags[flagName] = value; - if (this.flagRegistry[flagName].setHook != null) { - this.flagRegistry[flagName].setHook(value); - } - }; - Environment.prototype.evaluateFlag = function (flagName) { - if (this.flagRegistry[flagName] == null) { - throw new Error("Cannot evaluate flag '" + flagName + "': no evaluation function found."); - } - return this.flagRegistry[flagName].evaluationFn(); - }; - Environment.prototype.setFlags = function (flags) { - this.flags = Object.assign({}, flags); - }; - Environment.prototype.reset = function () { - this.flags = {}; - this.urlFlags = {}; - this.populateURLFlags(); - }; - Environment.prototype.populateURLFlags = function () { - var _this = this; - if (typeof this.global === 'undefined' || - typeof this.global.location === 'undefined' || - typeof this.global.location.search === 'undefined') { - return; - } - var urlParams = getQueryParams(this.global.location.search); - if (TENSORFLOWJS_FLAGS_PREFIX in urlParams) { - var keyValues = urlParams[TENSORFLOWJS_FLAGS_PREFIX].split(','); - keyValues.forEach(function (keyValue) { - var _a = keyValue.split(':'), key = _a[0], value = _a[1]; - _this.urlFlags[key] = parseValue(key, value); - }); - } - }; - return Environment; - }()); - function getQueryParams(queryString) { - var params = {}; - queryString.replace(/[?&]([^=?&]+)(?:=([^&]*))?/g, function (s) { - var t = []; - for (var _i = 1; _i < arguments.length; _i++) { - t[_i - 1] = arguments[_i]; - } - decodeParam(params, t[0], t[1]); - return t.join('='); - }); - return params; - } - function decodeParam(params, name, value) { - params[decodeURIComponent(name)] = decodeURIComponent(value || ''); - } - function parseValue(flagName, value) { - value = value.toLowerCase(); - if (value === 'true' || value === 'false') { - return value === 'true'; - } - else if ("" + +value === value) { - return +value; - } - throw new Error("Could not parse value flag value " + value + " for flag " + flagName + "."); - } - /** - * Returns the current environment (a global singleton). - * - * The environment object contains the evaluated feature values as well as the - * active platform. - */ - /** @doc {heading: 'Environment'} */ - function env() { - return exports.ENV; - } - exports.ENV = null; - function setEnvironmentGlobal(environment) { - exports.ENV = environment; - } - - /** - * @license - * Copyright 2019 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var kernelRegistry = new Map(); - var gradRegistry = new Map(); - /** - * Returns the kernel function (code) associated with the provided names. - * - * @param kernelName The official name of the kernel. - * @param backendName The official name of the backend. - */ - function getKernel(kernelName, backendName) { - var key = makeKey(kernelName, backendName); - return kernelRegistry.get(key); - } - /** - * Returns the registered gradient info associated with the provided kernel. - * @param kernelName The official TF kernel name. - */ - function getGradient(kernelName) { - return gradRegistry.get(kernelName); - } - function getKernelsForBackend(backendName) { - var it = kernelRegistry.entries(); - var result = []; - while (true) { - var _a = it.next(), done = _a.done, value = _a.value; - if (done) { - break; - } - var key = value[0], config = value[1]; - var backend = key.split('_')[0]; - if (backend === backendName) { - result.push(config); - } - } - return result; - } - /** - * Registers the function (forward pass) for the kernel in a global registry. - * - * @param config A config object with the following properties: - * - `kernelName` The official name of the kernel. - * - `backendName` The official name of the backend. - * - `kernelFunc` The function to run during the forward pass of the kernel. - * - `setupFunc` Optional. Gets called once, after the backend initializes. - * - `disposeFunc` Optional. Gets called once, right before the backend is - * disposed. - */ - function registerKernel(config) { - var kernelName = config.kernelName, backendName = config.backendName; - var key = makeKey(kernelName, backendName); - if (kernelRegistry.has(key)) { - throw new Error("The kernel '" + kernelName + "' for backend " + - ("'" + backendName + "' is already registered")); - } - kernelRegistry.set(key, config); - } - /** - * Registers a gradient function for a given kernel in the global registry, - * to be used during the back-propagation of that kernel. - * - * @param config An object with the following properties: - * - `kernelName` The name of the kernel that the gradient function is for. - * - `gradFunc` The function to run during back-propagation. - */ - function registerGradient(config) { - var kernelName = config.kernelName; - if (gradRegistry.has(kernelName)) { - console.warn("Overriding the gradient for '" + kernelName + "'"); - } - gradRegistry.set(kernelName, config); - } - /** - * Removes the kernel function from the registry. - * - * @param kernelName The official name of the kernel. - * @param backendName The official name of the backend. - * - */ - function unregisterKernel(kernelName, backendName) { - var key = makeKey(kernelName, backendName); - if (!kernelRegistry.has(key)) { - throw new Error("The kernel '" + kernelName + "' for backend " + - ("'" + backendName + "' is not registered")); - } - kernelRegistry.delete(key); - } - /** Removes the registered gradient from the global registry. */ - function unregisterGradient(kernelName) { - if (!gradRegistry.has(kernelName)) { - throw new Error("The gradient '" + kernelName + "' for backend is not registered"); - } - gradRegistry.delete(kernelName); - } - function makeKey(kernelName, backendName) { - return backendName + "_" + kernelName; - } - - /** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /** - * Shuffles the array in-place using Fisher-Yates algorithm. - * - * ```js - * const a = [1, 2, 3, 4, 5]; - * tf.util.shuffle(a); - * console.log(a); - * ``` - * - * @param array The array to shuffle in-place. - */ - /** @doc {heading: 'Util', namespace: 'util'} */ - // tslint:disable-next-line:no-any - function shuffle(array) { - var counter = array.length; - var temp = 0; - var index = 0; - // While there are elements in the array - while (counter > 0) { - // Pick a random index - index = (Math.random() * counter) | 0; - // Decrease counter by 1 - counter--; - // And swap the last element with it - temp = array[counter]; - array[counter] = array[index]; - array[index] = temp; - } - } - /** Clamps a value to a specified range. */ - function clamp(min, x, max) { - return Math.max(min, Math.min(x, max)); - } - function nearestLargerEven(val) { - return val % 2 === 0 ? val : val + 1; - } - function sum(arr) { - var sum = 0; - for (var i = 0; i < arr.length; i++) { - sum += arr[i]; - } - return sum; - } - /** - * Returns a sample from a uniform [a, b) distribution. - * - * @param a The minimum support (inclusive). - * @param b The maximum support (exclusive). - * @return A pseudorandom number on the half-open interval [a,b). - */ - function randUniform(a, b) { - var r = Math.random(); - return (b * r) + (1 - r) * a; - } - /** Returns the squared Euclidean distance between two vectors. */ - function distSquared(a, b) { - var result = 0; - for (var i = 0; i < a.length; i++) { - var diff = Number(a[i]) - Number(b[i]); - result += diff * diff; - } - return result; - } - /** - * Asserts that the expression is true. Otherwise throws an error with the - * provided message. - * - * ```js - * const x = 2; - * tf.util.assert(x === 2, 'x is not 2'); - * ``` - * - * @param expr The expression to assert (as a boolean). - * @param msg A function that returns the message to report when throwing an - * error. We use a function for performance reasons. - */ - /** @doc {heading: 'Util', namespace: 'util'} */ - function assert(expr, msg) { - if (!expr) { - throw new Error(typeof msg === 'string' ? msg : msg()); - } - } - function assertShapesMatch(shapeA, shapeB, errorMessagePrefix) { - if (errorMessagePrefix === void 0) { errorMessagePrefix = ''; } - assert(arraysEqual(shapeA, shapeB), function () { return errorMessagePrefix + (" Shapes " + shapeA + " and " + shapeB + " must match"); }); - } - function assertNonNull(a) { - assert(a != null, function () { return "The input to the tensor constructor must be a non-null value."; }); - } - // NOTE: We explicitly type out what T extends instead of any so that - // util.flatten on a nested array of number doesn't try to infer T as a - // number[][], causing us to explicitly type util.flatten(). - /** - * Flattens an arbitrarily nested array. - * - * ```js - * const a = [[1, 2], [3, 4], [5, [6, [7]]]]; - * const flat = tf.util.flatten(a); - * console.log(flat); - * ``` - * - * @param arr The nested array to flatten. - * @param result The destination array which holds the elements. - * @param skipTypedArray If true, avoids flattening the typed arrays. Defaults - * to false. - */ - /** @doc {heading: 'Util', namespace: 'util'} */ - function flatten(arr, result, skipTypedArray) { - if (result === void 0) { result = []; } - if (skipTypedArray === void 0) { skipTypedArray = false; } - if (result == null) { - result = []; - } - if (Array.isArray(arr) || isTypedArray(arr) && !skipTypedArray) { - for (var i = 0; i < arr.length; ++i) { - flatten(arr[i], result, skipTypedArray); - } - } - else { - result.push(arr); - } - return result; - } - /** - * Returns the size (number of elements) of the tensor given its shape. - * - * ```js - * const shape = [3, 4, 2]; - * const size = tf.util.sizeFromShape(shape); - * console.log(size); - * ``` - */ - /** @doc {heading: 'Util', namespace: 'util'} */ - function sizeFromShape(shape) { - if (shape.length === 0) { - // Scalar. - return 1; - } - var size = shape[0]; - for (var i = 1; i < shape.length; i++) { - size *= shape[i]; - } - return size; - } - function isScalarShape(shape) { - return shape.length === 0; - } - function arraysEqual(n1, n2) { - if (n1 === n2) { - return true; - } - if (n1 == null || n2 == null) { - return false; - } - if (n1.length !== n2.length) { - return false; - } - for (var i = 0; i < n1.length; i++) { - if (n1[i] !== n2[i]) { - return false; - } - } - return true; - } - function isInt(a) { - return a % 1 === 0; - } - function tanh(x) { - // tslint:disable-next-line:no-any - if (Math.tanh != null) { - // tslint:disable-next-line:no-any - return Math.tanh(x); - } - if (x === Infinity) { - return 1; - } - else if (x === -Infinity) { - return -1; - } - else { - var e2x = Math.exp(2 * x); - return (e2x - 1) / (e2x + 1); - } - } - function sizeToSquarishShape(size) { - var width = Math.ceil(Math.sqrt(size)); - return [width, Math.ceil(size / width)]; - } - /** - * Creates a new array with randomized indicies to a given quantity. - * - * ```js - * const randomTen = tf.util.createShuffledIndices(10); - * console.log(randomTen); - * ``` - * - * @param number Quantity of how many shuffled indicies to create. - */ - /** @doc {heading: 'Util', namespace: 'util'} */ - function createShuffledIndices(n) { - var shuffledIndices = new Uint32Array(n); - for (var i = 0; i < n; ++i) { - shuffledIndices[i] = i; - } - shuffle(shuffledIndices); - return shuffledIndices; - } - function rightPad(a, size) { - if (size <= a.length) { - return a; - } - return a + ' '.repeat(size - a.length); - } - function repeatedTry(checkFn, delayFn, maxCounter) { - if (delayFn === void 0) { delayFn = function (counter) { return 0; }; } - return new Promise(function (resolve, reject) { - var tryCount = 0; - var tryFn = function () { - if (checkFn()) { - resolve(); - return; - } - tryCount++; - var nextBackoff = delayFn(tryCount); - if (maxCounter != null && tryCount >= maxCounter) { - reject(); - return; - } - setTimeout(tryFn, nextBackoff); - }; - tryFn(); - }); - } - /** - * Given the full size of the array and a shape that may contain -1 as the - * implicit dimension, returns the inferred shape where -1 is replaced. - * E.g. For shape=[2, -1, 3] and size=24, it will return [2, 4, 3]. - * - * @param shape The shape, which may contain -1 in some dimension. - * @param size The full size (number of elements) of the array. - * @return The inferred shape where -1 is replaced with the inferred size. - */ - function inferFromImplicitShape(shape, size) { - var shapeProd = 1; - var implicitIdx = -1; - for (var i = 0; i < shape.length; ++i) { - if (shape[i] >= 0) { - shapeProd *= shape[i]; - } - else if (shape[i] === -1) { - if (implicitIdx !== -1) { - throw Error("Shapes can only have 1 implicit size. " + - ("Found -1 at dim " + implicitIdx + " and dim " + i)); - } - implicitIdx = i; - } - else if (shape[i] < 0) { - throw Error("Shapes can not be < 0. Found " + shape[i] + " at dim " + i); - } - } - if (implicitIdx === -1) { - if (size > 0 && size !== shapeProd) { - throw Error("Size(" + size + ") must match the product of shape " + shape); - } - return shape; - } - if (shapeProd === 0) { - throw Error("Cannot infer the missing size in [" + shape + "] when " + - "there are 0 elements"); - } - if (size % shapeProd !== 0) { - throw Error("The implicit shape can't be a fractional number. " + - ("Got " + size + " / " + shapeProd)); - } - var newShape = shape.slice(); - newShape[implicitIdx] = size / shapeProd; - return newShape; - } - function parseAxisParam(axis, shape) { - var rank = shape.length; - // Normalize input - axis = axis == null ? shape.map(function (s, i) { return i; }) : [].concat(axis); - // Check for valid range - assert(axis.every(function (ax) { return ax >= -rank && ax < rank; }), function () { - return "All values in axis param must be in range [-" + rank + ", " + rank + ") but " + - ("got axis " + axis); - }); - // Check for only integers - assert(axis.every(function (ax) { return isInt(ax); }), function () { return "All values in axis param must be integers but " + - ("got axis " + axis); }); - // Handle negative axis. - return axis.map(function (a) { return a < 0 ? rank + a : a; }); - } - /** Reduces the shape by removing all dimensions of shape 1. */ - function squeezeShape(shape, axis) { - var newShape = []; - var keptDims = []; - var isEmptyArray = axis != null && Array.isArray(axis) && axis.length === 0; - var axes = (axis == null || isEmptyArray) ? - null : - parseAxisParam(axis, shape).sort(); - var j = 0; - for (var i = 0; i < shape.length; ++i) { - if (axes != null) { - if (axes[j] === i && shape[i] !== 1) { - throw new Error("Can't squeeze axis " + i + " since its dim '" + shape[i] + "' is not 1"); - } - if ((axes[j] == null || axes[j] > i) && shape[i] === 1) { - newShape.push(shape[i]); - keptDims.push(i); - } - if (axes[j] <= i) { - j++; - } - } - if (shape[i] !== 1) { - newShape.push(shape[i]); - keptDims.push(i); - } - } - return { newShape: newShape, keptDims: keptDims }; - } - function getTypedArrayFromDType(dtype, size) { - var values = null; - if (dtype == null || dtype === 'float32') { - values = new Float32Array(size); - } - else if (dtype === 'int32') { - values = new Int32Array(size); - } - else if (dtype === 'bool') { - values = new Uint8Array(size); - } - else { - throw new Error("Unknown data type " + dtype); - } - return values; - } - function getArrayFromDType(dtype, size) { - var values = null; - if (dtype == null || dtype === 'float32') { - values = new Float32Array(size); - } - else if (dtype === 'int32') { - values = new Int32Array(size); - } - else if (dtype === 'bool') { - values = new Uint8Array(size); - } - else if (dtype === 'string') { - values = new Array(size); - } - else { - throw new Error("Unknown data type " + dtype); - } - return values; - } - function checkConversionForErrors(vals, dtype) { - for (var i = 0; i < vals.length; i++) { - var num = vals[i]; - if (isNaN(num) || !isFinite(num)) { - throw Error("A tensor of type " + dtype + " being uploaded contains " + num + "."); - } - } - } - /** Returns true if the dtype is valid. */ - function isValidDtype(dtype) { - return dtype === 'bool' || dtype === 'complex64' || dtype === 'float32' || - dtype === 'int32' || dtype === 'string'; - } - /** - * Returns true if the new type can't encode the old type without loss of - * precision. - */ - function hasEncodingLoss(oldType, newType) { - if (newType === 'complex64') { - return false; - } - if (newType === 'float32' && oldType !== 'complex64') { - return false; - } - if (newType === 'int32' && oldType !== 'float32' && oldType !== 'complex64') { - return false; - } - if (newType === 'bool' && oldType === 'bool') { - return false; - } - return true; - } - function isTypedArray(a) { - return a instanceof Float32Array || a instanceof Int32Array || - a instanceof Uint8Array; - } - function bytesPerElement(dtype) { - if (dtype === 'float32' || dtype === 'int32') { - return 4; - } - else if (dtype === 'complex64') { - return 8; - } - else if (dtype === 'bool') { - return 1; - } - else { - throw new Error("Unknown dtype " + dtype); - } - } - /** - * Returns the approximate number of bytes allocated in the string array - 2 - * bytes per character. Computing the exact bytes for a native string in JS is - * not possible since it depends on the encoding of the html page that serves - * the website. - */ - function bytesFromStringArray(arr) { - if (arr == null) { - return 0; - } - var bytes = 0; - arr.forEach(function (x) { return bytes += x.length; }); - return bytes; - } - /** Returns true if the value is a string. */ - function isString(value) { - return typeof value === 'string' || value instanceof String; - } - function isBoolean(value) { - return typeof value === 'boolean'; - } - function isNumber(value) { - return typeof value === 'number'; - } - function inferDtype(values) { - if (Array.isArray(values)) { - return inferDtype(values[0]); - } - if (values instanceof Float32Array) { - return 'float32'; - } - else if (values instanceof Int32Array || values instanceof Uint8Array) { - return 'int32'; - } - else if (isNumber(values)) { - return 'float32'; - } - else if (isString(values)) { - return 'string'; - } - else if (isBoolean(values)) { - return 'bool'; - } - return 'float32'; - } - function isFunction(f) { - return !!(f && f.constructor && f.call && f.apply); - } - function nearestDivisor(size, start) { - for (var i = start; i < size; ++i) { - if (size % i === 0) { - return i; - } - } - return size; - } - function computeStrides(shape) { - var rank = shape.length; - if (rank < 2) { - return []; - } - // Last dimension has implicit stride of 1, thus having D-1 (instead of D) - // strides. - var strides = new Array(rank - 1); - strides[rank - 2] = shape[rank - 1]; - for (var i = rank - 3; i >= 0; --i) { - strides[i] = strides[i + 1] * shape[i + 1]; - } - return strides; - } - function toTypedArray(a, dtype, debugMode) { - if (dtype === 'string') { - throw new Error('Cannot convert a string[] to a TypedArray'); - } - if (Array.isArray(a)) { - a = flatten(a); - } - if (debugMode) { - checkConversionForErrors(a, dtype); - } - if (noConversionNeeded(a, dtype)) { - return a; - } - if (dtype == null || dtype === 'float32' || dtype === 'complex64') { - return new Float32Array(a); - } - else if (dtype === 'int32') { - return new Int32Array(a); - } - else if (dtype === 'bool') { - var bool = new Uint8Array(a.length); - for (var i = 0; i < bool.length; ++i) { - if (Math.round(a[i]) !== 0) { - bool[i] = 1; - } - } - return bool; - } - else { - throw new Error("Unknown data type " + dtype); - } - } - function createNestedArray(offset, shape, a) { - var ret = new Array(); - if (shape.length === 1) { - var d = shape[0]; - for (var i = 0; i < d; i++) { - ret[i] = a[offset + i]; - } - } - else { - var d = shape[0]; - var rest = shape.slice(1); - var len = rest.reduce(function (acc, c) { return acc * c; }); - for (var i = 0; i < d; i++) { - ret[i] = createNestedArray(offset + i * len, rest, a); - } - } - return ret; - } - // Provide a nested array of TypedArray in given shape. - function toNestedArray(shape, a) { - if (shape.length === 0) { - // Scalar type should return a single number. - return a[0]; - } - var size = shape.reduce(function (acc, c) { return acc * c; }); - if (size === 0) { - // A tensor with shape zero should be turned into empty list. - return []; - } - if (size !== a.length) { - throw new Error("[" + shape + "] does not match the input size."); - } - return createNestedArray(0, shape, a); - } - function noConversionNeeded(a, dtype) { - return (a instanceof Float32Array && dtype === 'float32') || - (a instanceof Int32Array && dtype === 'int32') || - (a instanceof Uint8Array && dtype === 'bool'); - } - function makeOnesTypedArray(size, dtype) { - var array = makeZerosTypedArray(size, dtype); - for (var i = 0; i < array.length; i++) { - array[i] = 1; - } - return array; - } - function makeZerosTypedArray(size, dtype) { - if (dtype == null || dtype === 'float32' || dtype === 'complex64') { - return new Float32Array(size); - } - else if (dtype === 'int32') { - return new Int32Array(size); - } - else if (dtype === 'bool') { - return new Uint8Array(size); - } - else { - throw new Error("Unknown data type " + dtype); - } - } - /** - * Returns the current high-resolution time in milliseconds relative to an - * arbitrary time in the past. It works across different platforms (node.js, - * browsers). - * - * ```js - * console.log(tf.util.now()); - * ``` - */ - /** @doc {heading: 'Util', namespace: 'util'} */ - function now() { - return env().platform.now(); - } - function assertNonNegativeIntegerDimensions(shape) { - shape.forEach(function (dimSize) { - assert(Number.isInteger(dimSize) && dimSize >= 0, function () { - return "Tensor must have a shape comprised of positive integers but got " + - ("shape [" + shape + "]."); - }); - }); - } - /** - * Returns a platform-specific implementation of - * [`fetch`](https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API). - * - * If `fetch` is defined on the global object (`window`, `process`, etc.), - * `tf.util.fetch` returns that function. - * - * If not, `tf.util.fetch` returns a platform-specific solution. - * - * ```js - * const resource = await tf.util.fetch('https://unpkg.com/@tensorflow/tfjs'); - * // handle response - * ``` - */ - /** @doc {heading: 'Util'} */ - function fetch$1(path, requestInits) { - return env().platform.fetch(path, requestInits); - } - /** - * Encodes the provided string into bytes using the provided encoding scheme. - * - * @param s The string to encode. - * @param encoding The encoding scheme. Defaults to utf-8. - * - */ - /** @doc {heading: 'Util'} */ - function encodeString(s, encoding) { - if (encoding === void 0) { encoding = 'utf-8'; } - encoding = encoding || 'utf-8'; - return env().platform.encode(s, encoding); - } - /** - * Decodes the provided bytes into a string using the provided encoding scheme. - * @param bytes The bytes to decode. - * - * @param encoding The encoding scheme. Defaults to utf-8. - */ - /** @doc {heading: 'Util'} */ - function decodeString(bytes, encoding) { - if (encoding === void 0) { encoding = 'utf-8'; } - encoding = encoding || 'utf-8'; - return env().platform.decode(bytes, encoding); - } - /** - * Computes flat index for a given location (multidimentionsal index) in a - * Tensor/multidimensional array. - * - * @param locs Location in the tensor. - * @param rank Rank of the tensor. - * @param strides Tensor strides. - */ - function locToIndex(locs, rank, strides) { - if (rank === 0) { - return 0; - } - else if (rank === 1) { - return locs[0]; - } - var index = locs[locs.length - 1]; - for (var i = 0; i < locs.length - 1; ++i) { - index += strides[i] * locs[i]; - } - return index; - } - /** - * Computes the location (multidimensional index) in a tensor/multidimentional - * array for a given flat index. - * - * @param index Index in flat array. - * @param rank Rank of tensor. - * @param strides Strides of tensor. - */ - function indexToLoc(index, rank, strides) { - if (rank === 0) { - return []; - } - else if (rank === 1) { - return [index]; - } - var locs = new Array(rank); - for (var i = 0; i < locs.length - 1; ++i) { - locs[i] = Math.floor(index / strides[i]); - index -= locs[i] * strides[i]; - } - locs[locs.length - 1] = index; - return locs; - } - - var util = /*#__PURE__*/Object.freeze({ - shuffle: shuffle, - clamp: clamp, - nearestLargerEven: nearestLargerEven, - sum: sum, - randUniform: randUniform, - distSquared: distSquared, - assert: assert, - assertShapesMatch: assertShapesMatch, - assertNonNull: assertNonNull, - flatten: flatten, - sizeFromShape: sizeFromShape, - isScalarShape: isScalarShape, - arraysEqual: arraysEqual, - isInt: isInt, - tanh: tanh, - sizeToSquarishShape: sizeToSquarishShape, - createShuffledIndices: createShuffledIndices, - rightPad: rightPad, - repeatedTry: repeatedTry, - inferFromImplicitShape: inferFromImplicitShape, - parseAxisParam: parseAxisParam, - squeezeShape: squeezeShape, - getTypedArrayFromDType: getTypedArrayFromDType, - getArrayFromDType: getArrayFromDType, - checkConversionForErrors: checkConversionForErrors, - isValidDtype: isValidDtype, - hasEncodingLoss: hasEncodingLoss, - isTypedArray: isTypedArray, - bytesPerElement: bytesPerElement, - bytesFromStringArray: bytesFromStringArray, - isString: isString, - isBoolean: isBoolean, - isNumber: isNumber, - inferDtype: inferDtype, - isFunction: isFunction, - nearestDivisor: nearestDivisor, - computeStrides: computeStrides, - toTypedArray: toTypedArray, - toNestedArray: toNestedArray, - makeOnesTypedArray: makeOnesTypedArray, - makeZerosTypedArray: makeZerosTypedArray, - now: now, - assertNonNegativeIntegerDimensions: assertNonNegativeIntegerDimensions, - fetch: fetch$1, - encodeString: encodeString, - decodeString: decodeString, - locToIndex: locToIndex, - indexToLoc: indexToLoc - }); - - /** - * @license - * Copyright 2018 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var Profiler = /** @class */ (function () { - function Profiler(backendTimer, logger) { - this.backendTimer = backendTimer; - this.logger = logger; - if (logger == null) { - this.logger = new Logger(); - } - } - Profiler.prototype.profileKernel = function (kernelName, inputs, f) { - var _this = this; - var outputs; - var holdResultWrapperFn = function () { - outputs = f(); - }; - var timer = this.backendTimer.time(holdResultWrapperFn); - outputs.forEach(function (r) { - // Dangling promise here because we don't want to propagate up - // asynchronicity. - r.data().then(function (vals) { - checkComputationForErrors(vals, r.dtype, kernelName); - timer.then(function (timing) { - var extraInfo = ''; - if (timing.getExtraProfileInfo != null) { - extraInfo = timing.getExtraProfileInfo(); - } - _this.logger.logKernelProfile(kernelName, r, vals, timing.kernelMs, inputs, extraInfo); - }); - }); - }); - return outputs; - }; - return Profiler; - }()); - function checkComputationForErrors(vals, dtype, kernelName) { - if (dtype !== 'float32') { - // Only floating point computations will generate NaN values - return false; - } - for (var i = 0; i < vals.length; i++) { - var num = vals[i]; - if (isNaN(num) || !isFinite(num)) { - // Throwing custom exception so behavior is testable. - console.warn("Found " + num + " in the result of '" + kernelName + "'"); - return true; - } - } - return false; - } - var Logger = /** @class */ (function () { - function Logger() { - } - Logger.prototype.logKernelProfile = function (name, result, vals, timeMs, inputs, extraInfo) { - var time = typeof timeMs === 'number' ? rightPad(timeMs + "ms", 9) : - timeMs['error']; - var paddedName = rightPad(name, 25); - var rank = result.rank; - var size = result.size; - var shape = rightPad(result.shape.toString(), 14); - var inputShapesDescription = ''; - for (var name_1 in inputs) { - var input = inputs[name_1]; - // The input might be a non-tensor (e.g HTMLImageElement), in which case - // we claim the output shape as input shape. - var inputShape = input.shape || result.shape; - var inputRank = inputShape.length; - inputShapesDescription += - name_1 + ": " + inputRank + "D " + (inputRank > 0 ? inputShape : '') + " "; - } - console.log("%c" + paddedName + "\t%c" + time + "\t%c" + rank + "D " + shape + "\t%c" + size + "\t%c" + inputShapesDescription + "\t%c" + extraInfo, 'font-weight:bold', 'color:red', 'color:blue', 'color: orange', 'color: green', 'color: steelblue'); - }; - return Logger; - }()); - - /** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /** - * Computes a list of TapeNodes that connect x to y, filtering everything else - * out and preserving the order of the original tape elements. - * - * @param tape The tape elements to filter. - * @param xs The input Tensors. - * @param y The output Tensor. - */ - function getFilteredNodesXToY(tape, xs, y) { - // Forward pass to compute all the nodes and Tensors that are transitively a - // function of x. - var tensorsFromX = {}; - var nodesFromX = {}; - for (var i = 0; i < xs.length; i++) { - tensorsFromX[xs[i].id] = true; - } - for (var i = 0; i < tape.length; i++) { - var node = tape[i]; - var nodeInputs = node.inputs; - for (var inputName in nodeInputs) { - var input = nodeInputs[inputName]; - var anyInputFromX = false; - for (var j = 0; j < xs.length; j++) { - if (tensorsFromX[input.id]) { - node.outputs.forEach(function (output) { return tensorsFromX[output.id] = true; }); - anyInputFromX = true; - nodesFromX[node.id] = true; - break; - } - } - if (anyInputFromX) { - break; - } - } - } - // Backward pass to find all of the nodes and Tensors that lead to y. - var tensorsLeadToY = {}; - tensorsLeadToY[y.id] = true; - var nodesToY = {}; - for (var i = tape.length - 1; i >= 0; i--) { - var node = tape[i]; - var nodeInputs = node.inputs; - // If any of the outputs lead to y, mark all of the inputs as leading to y. - for (var j = 0; j < node.outputs.length; j++) { - if (tensorsLeadToY[node.outputs[j].id]) { - for (var inputName in nodeInputs) { - tensorsLeadToY[nodeInputs[inputName].id] = true; - nodesToY[node.id] = true; - } - break; - } - } - } - // Return the paths that come from x and lead to y. - var filteredTape = []; - for (var i = 0; i < tape.length; i++) { - var node = tape[i]; - if (nodesFromX[node.id] && nodesToY[node.id]) { - // Prune the inputs from the node that aren't a function of x. - var prunedInputs = {}; - for (var inputName in node.inputs) { - var nodeInput = node.inputs[inputName]; - if (tensorsFromX[nodeInput.id]) { - prunedInputs[inputName] = nodeInput; - } - } - // Copy the node and overwrite inputsAndArgs to the pruned version. - var prunedNode = Object.assign({}, node); - prunedNode.inputs = prunedInputs; - prunedNode.outputs = node.outputs; - filteredTape.push(prunedNode); - } - } - return filteredTape; - } - /** - * Backpropagate gradients through the filtered TapeNodes. - * - * @param tensorAccumulatedGradientMap A map of Tensor to its gradient. This map - * is mutated by this method. - * @param filteredTape The filtered TapeNodes to backprop through. - */ - function backpropagateGradients(tensorAccumulatedGradientMap, filteredTape, tidy) { - var _loop_1 = function (i) { - var node = filteredTape[i]; - var dys = []; - node.outputs.forEach(function (o) { - var gradTensor = tensorAccumulatedGradientMap[o.id]; - if (gradTensor != null) { - dys.push(gradTensor); - } - else { - // This particular output is not in the back-propagation subgraph, so it - // does not affect the final output, thus we put null for its dy. - dys.push(null); - } - }); - if (node.gradient == null) { - throw new Error("Cannot compute gradient: gradient function not found " + - ("for " + node.kernelName + ".")); - } - // Backprop dy through this node and accumulate gradients over the inputs. - var inputGradients = node.gradient(dys); - var _loop_2 = function (inputName) { - if (!(inputName in inputGradients)) { - throw new Error("Cannot backprop through input " + inputName + ". " + - ("Available gradients found: " + Object.keys(inputGradients) + ".")); - } - // Call the gradient function. - var dx = tidy(function () { return inputGradients[inputName](); }); - if (dx.dtype !== 'float32') { - throw new Error("Error in gradient for op " + node.kernelName + ". The gradient of input " + - (inputName + " must have 'float32' dtype, but has '" + dx.dtype + "'")); - } - var x = node.inputs[inputName]; - if (!arraysEqual(dx.shape, x.shape)) { - throw new Error("Error in gradient for op " + node.kernelName + ". The gradient of input " + - ("'" + inputName + "' has shape '" + dx.shape + "', which does not match ") + - ("the shape of the input '" + x.shape + "'")); - } - if (tensorAccumulatedGradientMap[x.id] == null) { - tensorAccumulatedGradientMap[x.id] = dx; - } - else { - var curGradient = tensorAccumulatedGradientMap[x.id]; - tensorAccumulatedGradientMap[x.id] = curGradient.add(dx); - curGradient.dispose(); - } - }; - for (var inputName in node.inputs) { - _loop_2(inputName); - } - }; - // Walk the tape backward and keep a map of Tensor to its gradient. - for (var i = filteredTape.length - 1; i >= 0; i--) { - _loop_1(i); - } - } - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - // Maximum number of values before we decide to show ellipsis. - var FORMAT_LIMIT_NUM_VALS = 20; - // Number of first and last values to show when displaying a, b,...,y, z. - var FORMAT_NUM_FIRST_LAST_VALS = 3; - // Number of significant digits to show. - var FORMAT_NUM_SIG_DIGITS = 7; - function tensorToString(vals, shape, dtype, verbose) { - var strides = computeStrides(shape); - var padPerCol = computeMaxSizePerColumn(vals, shape, dtype, strides); - var rank = shape.length; - var valsLines = subTensorToString(vals, shape, dtype, strides, padPerCol); - var lines = ['Tensor']; - if (verbose) { - lines.push(" dtype: " + dtype); - lines.push(" rank: " + rank); - lines.push(" shape: [" + shape + "]"); - lines.push(" values:"); - } - lines.push(valsLines.map(function (l) { return ' ' + l; }).join('\n')); - return lines.join('\n'); - } - function computeMaxSizePerColumn(vals, shape, dtype, strides) { - var n = sizeFromShape(shape); - var numCols = strides[strides.length - 1]; - var padPerCol = new Array(numCols).fill(0); - var rank = shape.length; - var valuesOrTuples = dtype === 'complex64' ? createComplexTuples(vals) : vals; - if (rank > 1) { - for (var row = 0; row < n / numCols; row++) { - var offset = row * numCols; - for (var j = 0; j < numCols; j++) { - padPerCol[j] = Math.max(padPerCol[j], valToString(valuesOrTuples[offset + j], 0, dtype).length); - } - } - } - return padPerCol; - } - function valToString(val, pad, dtype) { - var valStr; - if (Array.isArray(val)) { - valStr = parseFloat(val[0].toFixed(FORMAT_NUM_SIG_DIGITS)) + " + " + - (parseFloat(val[1].toFixed(FORMAT_NUM_SIG_DIGITS)) + "j"); - } - else if (isString(val)) { - valStr = "'" + val + "'"; - } - else if (dtype === 'bool') { - valStr = boolNumToString(val); - } - else { - valStr = parseFloat(val.toFixed(FORMAT_NUM_SIG_DIGITS)).toString(); - } - return rightPad(valStr, pad); - } - function boolNumToString(v) { - return v === 0 ? 'false' : 'true'; - } - function subTensorToString(vals, shape, dtype, strides, padPerCol, isLast) { - if (isLast === void 0) { isLast = true; } - var storagePerElement = dtype === 'complex64' ? 2 : 1; - var size = shape[0]; - var rank = shape.length; - if (rank === 0) { - if (dtype === 'complex64') { - var complexTuple = createComplexTuples(vals); - return [valToString(complexTuple[0], 0, dtype)]; - } - if (dtype === 'bool') { - return [boolNumToString(vals[0])]; - } - return [vals[0].toString()]; - } - if (rank === 1) { - if (size > FORMAT_LIMIT_NUM_VALS) { - var firstValsSize = FORMAT_NUM_FIRST_LAST_VALS * storagePerElement; - var firstVals = Array.from(vals.slice(0, firstValsSize)); - var lastVals = Array.from(vals.slice((size - FORMAT_NUM_FIRST_LAST_VALS) * storagePerElement, size * storagePerElement)); - if (dtype === 'complex64') { - firstVals = createComplexTuples(firstVals); - lastVals = createComplexTuples(lastVals); - } - return [ - '[' + - firstVals.map(function (x, i) { return valToString(x, padPerCol[i], dtype); }) - .join(', ') + - ', ..., ' + - lastVals - .map(function (x, i) { return valToString(x, padPerCol[size - FORMAT_NUM_FIRST_LAST_VALS + i], dtype); }) - .join(', ') + - ']' - ]; - } - var displayVals = dtype === 'complex64' ? createComplexTuples(vals) : - Array.from(vals); - return [ - '[' + - displayVals.map(function (x, i) { return valToString(x, padPerCol[i], dtype); }) - .join(', ') + - ']' - ]; - } - // The array is rank 2 or more. - var subshape = shape.slice(1); - var substrides = strides.slice(1); - var stride = strides[0] * storagePerElement; - var lines = []; - if (size > FORMAT_LIMIT_NUM_VALS) { - for (var i = 0; i < FORMAT_NUM_FIRST_LAST_VALS; i++) { - var start = i * stride; - var end = start + stride; - lines.push.apply(lines, subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, false /* isLast */)); - } - lines.push('...'); - for (var i = size - FORMAT_NUM_FIRST_LAST_VALS; i < size; i++) { - var start = i * stride; - var end = start + stride; - lines.push.apply(lines, subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, i === size - 1 /* isLast */)); - } - } - else { - for (var i = 0; i < size; i++) { - var start = i * stride; - var end = start + stride; - lines.push.apply(lines, subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, i === size - 1 /* isLast */)); - } - } - var sep = rank === 2 ? ',' : ''; - lines[0] = '[' + lines[0] + sep; - for (var i = 1; i < lines.length - 1; i++) { - lines[i] = ' ' + lines[i] + sep; - } - var newLineSep = ',\n'; - for (var i = 2; i < rank; i++) { - newLineSep += '\n'; - } - lines[lines.length - 1] = - ' ' + lines[lines.length - 1] + ']' + (isLast ? '' : newLineSep); - return lines; - } - function createComplexTuples(vals) { - var complexTuples = []; - for (var i = 0; i < vals.length; i += 2) { - complexTuples.push([vals[i], vals[i + 1]]); - } - return complexTuples; - } - - /** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /** - * A mutable object, similar to `tf.Tensor`, that allows users to set values - * at locations before converting to an immutable `tf.Tensor`. - * - * See `tf.buffer` for creating a tensor buffer. - */ - /** @doc {heading: 'Tensors', subheading: 'Classes'} */ - var TensorBuffer = /** @class */ (function () { - function TensorBuffer(shape, dtype, values) { - var _this = this; - this.dtype = dtype; - this.shape = shape.slice(); - this.size = sizeFromShape(shape); - if (values != null) { - var n_1 = values.length; - assert(n_1 === this.size, function () { return "Length of values '" + n_1 + "' does not match the size " + - ("inferred by the shape '" + _this.size + "'."); }); - } - if (dtype === 'complex64') { - throw new Error("complex64 dtype TensorBuffers are not supported. Please create " + - "a TensorBuffer for the real and imaginary parts separately and " + - "call tf.complex(real, imag)."); - } - this.values = values || getArrayFromDType(dtype, this.size); - this.strides = computeStrides(shape); - } - /** - * Sets a value in the buffer at a given location. - * - * @param value The value to set. - * @param locs The location indices. - */ - /** @doc {heading: 'Tensors', subheading: 'Creation'} */ - TensorBuffer.prototype.set = function (value) { - var _this = this; - var locs = []; - for (var _i = 1; _i < arguments.length; _i++) { - locs[_i - 1] = arguments[_i]; - } - if (locs.length === 0) { - locs = [0]; - } - assert(locs.length === this.rank, function () { return "The number of provided coordinates (" + locs.length + ") must " + - ("match the rank (" + _this.rank + ")"); }); - var index = this.locToIndex(locs); - this.values[index] = value; - }; - /** - * Returns the value in the buffer at the provided location. - * - * @param locs The location indices. - */ - /** @doc {heading: 'Tensors', subheading: 'Creation'} */ - TensorBuffer.prototype.get = function () { - var locs = []; - for (var _i = 0; _i < arguments.length; _i++) { - locs[_i] = arguments[_i]; - } - if (locs.length === 0) { - locs = [0]; - } - var i = 0; - for (var _a = 0, locs_1 = locs; _a < locs_1.length; _a++) { - var loc = locs_1[_a]; - if (loc < 0 || loc >= this.shape[i]) { - var msg = "Requested out of range element at " + locs + ". " + - (" Buffer shape=" + this.shape); - throw new Error(msg); - } - i++; - } - var index = locs[locs.length - 1]; - for (var i_1 = 0; i_1 < locs.length - 1; ++i_1) { - index += this.strides[i_1] * locs[i_1]; - } - return this.values[index]; - }; - TensorBuffer.prototype.locToIndex = function (locs) { - if (this.rank === 0) { - return 0; - } - else if (this.rank === 1) { - return locs[0]; - } - var index = locs[locs.length - 1]; - for (var i = 0; i < locs.length - 1; ++i) { - index += this.strides[i] * locs[i]; - } - return index; - }; - TensorBuffer.prototype.indexToLoc = function (index) { - if (this.rank === 0) { - return []; - } - else if (this.rank === 1) { - return [index]; - } - var locs = new Array(this.shape.length); - for (var i = 0; i < locs.length - 1; ++i) { - locs[i] = Math.floor(index / this.strides[i]); - index -= locs[i] * this.strides[i]; - } - locs[locs.length - 1] = index; - return locs; - }; - Object.defineProperty(TensorBuffer.prototype, "rank", { - get: function () { - return this.shape.length; - }, - enumerable: true, - configurable: true - }); - /** - * Creates an immutable `tf.Tensor` object from the buffer. - */ - /** @doc {heading: 'Tensors', subheading: 'Creation'} */ - TensorBuffer.prototype.toTensor = function () { - return trackerFn().makeTensor(this.values, this.shape, this.dtype); - }; - return TensorBuffer; - }()); - // For tracking tensor creation and disposal. - var trackerFn = null; - // Used by chaining methods to call into ops. - var opHandler = null; - // Used to warn about deprecated methods. - var deprecationWarningFn = null; - /** - * An external consumer can register itself as the tensor tracker. This way - * the Tensor class can notify the tracker for every tensor created and - * disposed. - */ - function setTensorTracker(fn) { - trackerFn = fn; - } - /** - * An external consumer can register itself as the op handler. This way the - * Tensor class can have chaining methods that call into ops via the op - * handler. - */ - function setOpHandler(handler) { - opHandler = handler; - } - /** - * Sets the deprecation warning function to be used by this file. This way the - * Tensor class can be a leaf but still use the environment. - */ - function setDeprecationWarningFn(fn) { - deprecationWarningFn = fn; - } - /** - * A `tf.Tensor` object represents an immutable, multidimensional array of - * numbers that has a shape and a data type. - * - * See `tf.tensor` for details on how to create a `tf.Tensor`. - */ - /** @doc {heading: 'Tensors', subheading: 'Classes'} */ - var Tensor = /** @class */ (function () { - function Tensor(shape, dtype, dataId, id) { - /** Whether this tensor has been globally kept. */ - this.kept = false; - this.isDisposedInternal = false; - this.shape = shape.slice(); - this.dtype = dtype || 'float32'; - this.size = sizeFromShape(shape); - this.strides = computeStrides(shape); - this.dataId = dataId; - this.id = id; - this.rankType = (this.rank < 5 ? this.rank.toString() : 'higher'); - } - /** Flatten a Tensor to a 1D array. */ - /** @doc {heading: 'Tensors', subheading: 'Classes'} */ - Tensor.prototype.flatten = function () { - this.throwIfDisposed(); - return this.as1D(); - }; - /** Converts a size-1 `tf.Tensor` to a `tf.Scalar`. */ - /** @doc {heading: 'Tensors', subheading: 'Classes'} */ - Tensor.prototype.asScalar = function () { - this.throwIfDisposed(); - assert(this.size === 1, function () { return 'The array must have only 1 element.'; }); - return this.reshape([]); - }; - /** Converts a `tf.Tensor` to a `tf.Tensor1D`. */ - /** @doc {heading: 'Tensors', subheading: 'Classes'} */ - Tensor.prototype.as1D = function () { - this.throwIfDisposed(); - return this.reshape([this.size]); - }; - /** - * Converts a `tf.Tensor` to a `tf.Tensor2D`. - * - * @param rows Number of rows in `tf.Tensor2D`. - * @param columns Number of columns in `tf.Tensor2D`. - */ - /** @doc {heading: 'Tensors', subheading: 'Classes'} */ - Tensor.prototype.as2D = function (rows, columns) { - this.throwIfDisposed(); - return this.reshape([rows, columns]); - }; - /** - * Converts a `tf.Tensor` to a `tf.Tensor3D`. - * - * @param rows Number of rows in `tf.Tensor3D`. - * @param columns Number of columns in `tf.Tensor3D`. - * @param depth Depth of `tf.Tensor3D`. - */ - /** @doc {heading: 'Tensors', subheading: 'Classes'} */ - Tensor.prototype.as3D = function (rows, columns, depth) { - this.throwIfDisposed(); - return this.reshape([rows, columns, depth]); - }; - /** - * Converts a `tf.Tensor` to a `tf.Tensor4D`. - * - * @param rows Number of rows in `tf.Tensor4D`. - * @param columns Number of columns in `tf.Tensor4D`. - * @param depth Depth of `tf.Tensor4D`. - * @param depth2 4th dimension of `tf.Tensor4D`. - */ - /** @doc {heading: 'Tensors', subheading: 'Classes'} */ - Tensor.prototype.as4D = function (rows, columns, depth, depth2) { - this.throwIfDisposed(); - return this.reshape([rows, columns, depth, depth2]); - }; - /** - * Converts a `tf.Tensor` to a `tf.Tensor5D`. - * - * @param rows Number of rows in `tf.Tensor5D`. - * @param columns Number of columns in `tf.Tensor5D`. - * @param depth Depth of `tf.Tensor5D`. - * @param depth2 4th dimension of `tf.Tensor5D`. - * @param depth3 5th dimension of 'tf.Tensor5D' - */ - /** @doc {heading: 'Tensors', subheading: 'Classes'} */ - Tensor.prototype.as5D = function (rows, columns, depth, depth2, depth3) { - this.throwIfDisposed(); - return this.reshape([rows, columns, depth, depth2, depth3]); - }; - /** - * Casts a `tf.Tensor` to a specified dtype. - * - * @param dtype Data-type to cast the tensor to. - */ - /** @doc {heading: 'Tensors', subheading: 'Classes'} */ - Tensor.prototype.asType = function (dtype) { - this.throwIfDisposed(); - return opHandler.cast(this, dtype); - }; - Object.defineProperty(Tensor.prototype, "rank", { - get: function () { - return this.shape.length; - }, - enumerable: true, - configurable: true - }); - /** - * Returns a promise of `tf.TensorBuffer` that holds the underlying data. - */ - /** @doc {heading: 'Tensors', subheading: 'Classes'} */ - Tensor.prototype.buffer = function () { - return __awaiter(this, void 0, void 0, function () { - var vals; - return __generator(this, function (_a) { - switch (_a.label) { - case 0: return [4 /*yield*/, this.data()]; - case 1: - vals = _a.sent(); - return [2 /*return*/, opHandler.buffer(this.shape, this.dtype, vals)]; - } - }); - }); - }; - /** Returns a `tf.TensorBuffer` that holds the underlying data. */ - /** @doc {heading: 'Tensors', subheading: 'Classes'} */ - Tensor.prototype.bufferSync = function () { - return opHandler.buffer(this.shape, this.dtype, this.dataSync()); - }; - /** - * Returns the tensor data as a nested array. The transfer of data is done - * asynchronously. - */ - /** @doc {heading: 'Tensors', subheading: 'Classes'} */ - Tensor.prototype.array = function () { - return __awaiter(this, void 0, void 0, function () { - var vals; - return __generator(this, function (_a) { - switch (_a.label) { - case 0: return [4 /*yield*/, this.data()]; - case 1: - vals = _a.sent(); - return [2 /*return*/, toNestedArray(this.shape, vals)]; - } - }); - }); - }; - /** - * Returns the tensor data as a nested array. The transfer of data is done - * synchronously. - */ - /** @doc {heading: 'Tensors', subheading: 'Classes'} */ - Tensor.prototype.arraySync = function () { - return toNestedArray(this.shape, this.dataSync()); - }; - /** - * Asynchronously downloads the values from the `tf.Tensor`. Returns a - * promise of `TypedArray` that resolves when the computation has finished. - */ - /** @doc {heading: 'Tensors', subheading: 'Classes'} */ - Tensor.prototype.data = function () { - return __awaiter(this, void 0, void 0, function () { - var data, bytes; - return __generator(this, function (_a) { - switch (_a.label) { - case 0: - this.throwIfDisposed(); - data = trackerFn().read(this.dataId); - if (!(this.dtype === 'string')) return [3 /*break*/, 2]; - return [4 /*yield*/, data]; - case 1: - bytes = _a.sent(); - try { - return [2 /*return*/, bytes.map(function (b) { return decodeString(b); })]; - } - catch (_b) { - throw new Error('Failed to decode the string bytes into utf-8. ' + - 'To get the original bytes, call tensor.bytes().'); - } - _a.label = 2; - case 2: return [2 /*return*/, data]; - } - }); - }); - }; - /** - * Synchronously downloads the values from the `tf.Tensor`. This blocks the - * UI thread until the values are ready, which can cause performance issues. - */ - /** @doc {heading: 'Tensors', subheading: 'Classes'} */ - Tensor.prototype.dataSync = function () { - this.throwIfDisposed(); - var data = trackerFn().readSync(this.dataId); - if (this.dtype === 'string') { - try { - return data.map(function (b) { return decodeString(b); }); - } - catch (_a) { - throw new Error('Failed to decode the string bytes into utf-8. ' + - 'To get the original bytes, call tensor.bytes().'); - } - } - return data; - }; - /** Returns the underlying bytes of the tensor's data. */ - Tensor.prototype.bytes = function () { - return __awaiter(this, void 0, void 0, function () { - var data; - return __generator(this, function (_a) { - switch (_a.label) { - case 0: - this.throwIfDisposed(); - return [4 /*yield*/, trackerFn().read(this.dataId)]; - case 1: - data = _a.sent(); - if (this.dtype === 'string') { - return [2 /*return*/, data]; - } - else { - return [2 /*return*/, new Uint8Array(data.buffer)]; - } - return [2 /*return*/]; - } - }); - }); - }; - /** - * Disposes `tf.Tensor` from memory. - */ - /** @doc {heading: 'Tensors', subheading: 'Classes'} */ - Tensor.prototype.dispose = function () { - if (this.isDisposed) { - return; - } - trackerFn().disposeTensor(this); - this.isDisposedInternal = true; - }; - Object.defineProperty(Tensor.prototype, "isDisposed", { - get: function () { - return this.isDisposedInternal; - }, - enumerable: true, - configurable: true - }); - Tensor.prototype.throwIfDisposed = function () { - if (this.isDisposed) { - throw new Error("Tensor is disposed."); - } - }; - /** Casts the array to type `float32` */ - /** @doc {heading: 'Tensors', subheading: 'Classes'} */ - Tensor.prototype.toFloat = function () { - return this.asType('float32'); - }; - /** Casts the array to type `int32` */ - /** @doc {heading: 'Tensors', subheading: 'Classes'} */ - Tensor.prototype.toInt = function () { - return this.asType('int32'); - }; - /** Casts the array to type `bool` */ - /** @doc {heading: 'Tensors', subheading: 'Classes'} */ - Tensor.prototype.toBool = function () { - return this.asType('bool'); - }; - /** - * Prints the `tf.Tensor`. See `tf.print` for details. - * - * @param verbose Whether to print verbose information about the tensor, - * including dtype and size. - */ - /** @doc {heading: 'Tensors', subheading: 'Classes'} */ - Tensor.prototype.print = function (verbose) { - if (verbose === void 0) { verbose = false; } - return opHandler.print(this, verbose); - }; - /** - * Reshapes the tensor into the provided shape. - * See `tf.reshape` for more details. - * - * @param newShape An array of integers defining the output tensor shape. - */ - /** @doc {heading: 'Tensors', subheading: 'Classes'} */ - Tensor.prototype.reshape = function (newShape) { - this.throwIfDisposed(); - return opHandler.reshape(this, newShape); - }; - /** - * Reshapes the tensor into the shape of the provided tensor. - * - * @param x The tensor of required shape. - */ - /** @doc {heading: 'Tensors', subheading: 'Classes'} */ - Tensor.prototype.reshapeAs = function (x) { - this.throwIfDisposed(); - return this.reshape(x.shape); - }; - /** - * Returns a `tf.Tensor` that has expanded rank, by inserting a dimension - * into the tensor's shape. See `tf.expandDims` for details. - * - * @param axis The dimension index at which to insert shape of 1. Defaults to - * 0 (the first dimension). - */ - /** @doc {heading: 'Tensors', subheading: 'Classes'} */ - Tensor.prototype.expandDims = function (axis) { - if (axis === void 0) { axis = 0; } - return opHandler.expandDims(this, axis); - }; - /** - * Returns the cumulative sum of the `tf.Tensor` along `axis`. - * - * @param axis The axis along which to sum. Optional. Defaults to 0. - * @param exclusive Whether to perform exclusive cumulative sum. Defaults to - * false. If set to true then the sum of each tensor entry does not - * include its own value, but only the values previous to it along the - * specified axis. - * @param reverse Whether to sum in the opposite direction. Defaults to - * false. - */ - /** @doc {heading: 'Tensors', subheading: 'Classes'} */ - Tensor.prototype.cumsum = function (axis, exclusive, reverse) { - if (axis === void 0) { axis = 0; } - if (exclusive === void 0) { exclusive = false; } - if (reverse === void 0) { reverse = false; } - return opHandler.cumsum(this, axis, exclusive, reverse); - }; - /** - * Returns a `tf.Tensor` with dimensions of size 1 removed from the shape. - * See `tf.squeeze` for more details. - * - * @param axis A list of numbers. If specified, only squeezes the - * dimensions listed. The dimension index starts at 0. It is an error to - * squeeze a dimension that is not 1. - */ - /** @doc {heading: 'Tensors', subheading: 'Classes'} */ - Tensor.prototype.squeeze = function (axis) { - this.throwIfDisposed(); - return opHandler.squeeze(this, axis); - }; - /** Returns a copy of the tensor. See `tf.clone` for details. */ - /** @doc {heading: 'Tensors', subheading: 'Classes'} */ - Tensor.prototype.clone = function () { - this.throwIfDisposed(); - return opHandler.clone(this); - }; - /** - * Returns a human-readable description of the tensor. Useful for logging. - */ - /** @doc {heading: 'Tensors', subheading: 'Classes'} */ - Tensor.prototype.toString = function (verbose) { - if (verbose === void 0) { verbose = false; } - var vals = this.dataSync(); - return tensorToString(vals, this.shape, this.dtype, verbose); - }; - // Below is chain API that is not exposed to docs to avoid repetition. To - // expose a method, move it above this comment and add @doc and jsdoc. - Tensor.prototype.gather = function (indices, axis) { - if (axis === void 0) { axis = 0; } - this.throwIfDisposed(); - return opHandler.gather(this, indices, axis); - }; - Tensor.prototype.matMul = function (b, transposeA, transposeB) { - if (transposeA === void 0) { transposeA = false; } - if (transposeB === void 0) { transposeB = false; } - this.throwIfDisposed(); - return opHandler.matMul(this, b, transposeA, transposeB); - }; - Tensor.prototype.dot = function (b) { - this.throwIfDisposed(); - return opHandler.dot(this, b); - }; - Tensor.prototype.norm = function (ord, axis, keepDims) { - if (ord === void 0) { ord = 'euclidean'; } - if (axis === void 0) { axis = null; } - if (keepDims === void 0) { keepDims = false; } - this.throwIfDisposed(); - return opHandler.norm(this, ord, axis, keepDims); - }; - Tensor.prototype.slice = function (begin, size) { - this.throwIfDisposed(); - return opHandler.slice(this, begin, size); - }; - Tensor.prototype.reverse = function (axis) { - this.throwIfDisposed(); - return opHandler.reverse(this, axis); - }; - Tensor.prototype.concat = function (x, axis) { - if (axis === void 0) { axis = 0; } - this.throwIfDisposed(); - if (x instanceof Tensor) { - x = [x]; - } - return opHandler.concat([this].concat(x), axis); - }; - Tensor.prototype.split = function (numOrSizeSplits, axis) { - if (axis === void 0) { axis = 0; } - this.throwIfDisposed(); - return opHandler.split(this, numOrSizeSplits, axis); - }; - Tensor.prototype.stack = function (x, axis) { - if (axis === void 0) { axis = 0; } - return opHandler.stack([this, x], axis); - }; - Tensor.prototype.unstack = function (axis) { - if (axis === void 0) { axis = 0; } - return opHandler.unstack(this, axis); - }; - /** - * @deprecated Use `tf.batchNorm` instead, and note the positional argument - * change of scale, offset, and varianceEpsilon. - */ - Tensor.prototype.batchNormalization = function (mean, variance, varianceEpsilon, scale, offset) { - if (varianceEpsilon === void 0) { varianceEpsilon = .001; } - deprecationWarningFn('tf.batchNormalization() is going away. ' + - 'Use tf.batchNorm() instead, and note the positional argument change ' + - 'of scale, offset, and varianceEpsilon'); - return this.batchNorm(mean, variance, offset, scale, varianceEpsilon); - }; - // Reduction ops. - Tensor.prototype.all = function (axis, keepDims) { - if (axis === void 0) { axis = null; } - if (keepDims === void 0) { keepDims = false; } - this.throwIfDisposed(); - return opHandler.all(this, axis, keepDims); - }; - Tensor.prototype.any = function (axis, keepDims) { - if (axis === void 0) { axis = null; } - if (keepDims === void 0) { keepDims = false; } - this.throwIfDisposed(); - return opHandler.any(this, axis, keepDims); - }; - Tensor.prototype.logSumExp = function (axis, keepDims) { - if (axis === void 0) { axis = null; } - if (keepDims === void 0) { keepDims = false; } - this.throwIfDisposed(); - return opHandler.logSumExp(this, axis, keepDims); - }; - Tensor.prototype.sum = function (axis, keepDims) { - if (axis === void 0) { axis = null; } - if (keepDims === void 0) { keepDims = false; } - this.throwIfDisposed(); - return opHandler.sum(this, axis, keepDims); - }; - Tensor.prototype.prod = function (axis, keepDims) { - if (axis === void 0) { axis = null; } - if (keepDims === void 0) { keepDims = false; } - this.throwIfDisposed(); - return opHandler.prod(this, axis, keepDims); - }; - Tensor.prototype.mean = function (axis, keepDims) { - if (axis === void 0) { axis = null; } - if (keepDims === void 0) { keepDims = false; } - this.throwIfDisposed(); - return opHandler.mean(this, axis, keepDims); - }; - Tensor.prototype.min = function (axis, keepDims) { - if (axis === void 0) { axis = null; } - if (keepDims === void 0) { keepDims = false; } - this.throwIfDisposed(); - return opHandler.min(this, axis, keepDims); - }; - Tensor.prototype.max = function (axis, keepDims) { - if (axis === void 0) { axis = null; } - if (keepDims === void 0) { keepDims = false; } - this.throwIfDisposed(); - return opHandler.max(this, axis, keepDims); - }; - Tensor.prototype.argMin = function (axis) { - if (axis === void 0) { axis = null; } - this.throwIfDisposed(); - return opHandler.argMin(this, axis); - }; - Tensor.prototype.argMax = function (axis) { - if (axis === void 0) { axis = null; } - this.throwIfDisposed(); - return opHandler.argMax(this, axis); - }; - // Transformations - Tensor.prototype.cast = function (dtype) { - this.throwIfDisposed(); - return opHandler.cast(this, dtype); - }; - // Binary ops. - Tensor.prototype.addStrict = function (x) { - this.throwIfDisposed(); - return opHandler.addStrict(this, x); - }; - Tensor.prototype.atan2 = function (x) { - this.throwIfDisposed(); - return opHandler.atan2(this, x); - }; - Tensor.prototype.sub = function (x) { - this.throwIfDisposed(); - return opHandler.sub(this, x); - }; - Tensor.prototype.subStrict = function (x) { - this.throwIfDisposed(); - return opHandler.subStrict(this, x); - }; - Tensor.prototype.pow = function (exp) { - this.throwIfDisposed(); - return opHandler.pow(this, exp); - }; - Tensor.prototype.powStrict = function (exp) { - this.throwIfDisposed(); - return opHandler.powStrict(this, exp); - }; - Tensor.prototype.mul = function (x) { - this.throwIfDisposed(); - return opHandler.mul(this, x); - }; - Tensor.prototype.mulStrict = function (x) { - this.throwIfDisposed(); - return opHandler.mulStrict(this, x); - }; - Tensor.prototype.floorDiv = function (x) { - this.throwIfDisposed(); - return opHandler.floorDiv(this, x); - }; - Tensor.prototype.divStrict = function (x) { - this.throwIfDisposed(); - return opHandler.divStrict(this, x); - }; - Tensor.prototype.minimum = function (x) { - this.throwIfDisposed(); - return opHandler.minimum(this, x); - }; - Tensor.prototype.minimumStrict = function (x) { - this.throwIfDisposed(); - return opHandler.minimumStrict(this, x); - }; - Tensor.prototype.maximum = function (x) { - this.throwIfDisposed(); - return opHandler.maximum(this, x); - }; - Tensor.prototype.maximumStrict = function (x) { - this.throwIfDisposed(); - return opHandler.maximumStrict(this, x); - }; - Tensor.prototype.mod = function (x) { - this.throwIfDisposed(); - return opHandler.mod(this, x); - }; - Tensor.prototype.modStrict = function (x) { - this.throwIfDisposed(); - return opHandler.modStrict(this, x); - }; - Tensor.prototype.squaredDifferenceStrict = function (x) { - this.throwIfDisposed(); - return opHandler.squaredDifferenceStrict(this, x); - }; - // Compare ops. - Tensor.prototype.notEqual = function (x) { - this.throwIfDisposed(); - return opHandler.notEqual(this, x); - }; - Tensor.prototype.notEqualStrict = function (x) { - this.throwIfDisposed(); - return opHandler.notEqualStrict(this, x); - }; - Tensor.prototype.less = function (x) { - this.throwIfDisposed(); - return opHandler.less(this, x); - }; - Tensor.prototype.lessStrict = function (x) { - this.throwIfDisposed(); - return opHandler.lessStrict(this, x); - }; - Tensor.prototype.equal = function (x) { - this.throwIfDisposed(); - return opHandler.equal(this, x); - }; - Tensor.prototype.equalStrict = function (x) { - this.throwIfDisposed(); - return opHandler.equalStrict(this, x); - }; - Tensor.prototype.lessEqual = function (x) { - this.throwIfDisposed(); - return opHandler.lessEqual(this, x); - }; - Tensor.prototype.lessEqualStrict = function (x) { - this.throwIfDisposed(); - return opHandler.lessEqualStrict(this, x); - }; - Tensor.prototype.greater = function (x) { - this.throwIfDisposed(); - return opHandler.greater(this, x); - }; - Tensor.prototype.greaterStrict = function (x) { - this.throwIfDisposed(); - return opHandler.greaterStrict(this, x); - }; - Tensor.prototype.greaterEqual = function (x) { - this.throwIfDisposed(); - return opHandler.greaterEqual(this, x); - }; - Tensor.prototype.greaterEqualStrict = function (x) { - this.throwIfDisposed(); - return opHandler.greaterEqualStrict(this, x); - }; - // Compare ops. - Tensor.prototype.logicalAnd = function (x) { - this.throwIfDisposed(); - return opHandler.logicalAnd(this, x); - }; - Tensor.prototype.logicalOr = function (x) { - this.throwIfDisposed(); - return opHandler.logicalOr(this, x); - }; - Tensor.prototype.logicalNot = function () { - this.throwIfDisposed(); - return opHandler.logicalNot(this); - }; - Tensor.prototype.logicalXor = function (x) { - this.throwIfDisposed(); - return opHandler.logicalXor(this, x); - }; - Tensor.prototype.where = function (condition, x) { - this.throwIfDisposed(); - return opHandler.where(condition, this, x); - }; - // Unary ops. - Tensor.prototype.neg = function () { - this.throwIfDisposed(); - return opHandler.neg(this); - }; - Tensor.prototype.ceil = function () { - this.throwIfDisposed(); - return opHandler.ceil(this); - }; - Tensor.prototype.floor = function () { - this.throwIfDisposed(); - return opHandler.floor(this); - }; - Tensor.prototype.sign = function () { - this.throwIfDisposed(); - return opHandler.sign(this); - }; - Tensor.prototype.isNaN = function () { - this.throwIfDisposed(); - return opHandler.isNaN(this); - }; - Tensor.prototype.isInf = function () { - this.throwIfDisposed(); - return opHandler.isInf(this); - }; - Tensor.prototype.isFinite = function () { - this.throwIfDisposed(); - return opHandler.isFinite(this); - }; - Tensor.prototype.exp = function () { - this.throwIfDisposed(); - return opHandler.exp(this); - }; - Tensor.prototype.expm1 = function () { - this.throwIfDisposed(); - return opHandler.expm1(this); - }; - Tensor.prototype.log = function () { - this.throwIfDisposed(); - return opHandler.log(this); - }; - Tensor.prototype.log1p = function () { - this.throwIfDisposed(); - return opHandler.log1p(this); - }; - Tensor.prototype.sqrt = function () { - this.throwIfDisposed(); - return opHandler.sqrt(this); - }; - Tensor.prototype.rsqrt = function () { - this.throwIfDisposed(); - return opHandler.rsqrt(this); - }; - Tensor.prototype.square = function () { - this.throwIfDisposed(); - return opHandler.square(this); - }; - Tensor.prototype.reciprocal = function () { - this.throwIfDisposed(); - return opHandler.reciprocal(this); - }; - Tensor.prototype.abs = function () { - this.throwIfDisposed(); - return opHandler.abs(this); - }; - Tensor.prototype.clipByValue = function (min, max) { - this.throwIfDisposed(); - return opHandler.clipByValue(this, min, max); - }; - Tensor.prototype.relu = function () { - this.throwIfDisposed(); - return opHandler.relu(this); - }; - Tensor.prototype.relu6 = function () { - this.throwIfDisposed(); - return opHandler.relu6(this); - }; - Tensor.prototype.elu = function () { - this.throwIfDisposed(); - return opHandler.elu(this); - }; - Tensor.prototype.selu = function () { - this.throwIfDisposed(); - return opHandler.selu(this); - }; - Tensor.prototype.leakyRelu = function (alpha) { - if (alpha === void 0) { alpha = 0.2; } - this.throwIfDisposed(); - return opHandler.leakyRelu(this, alpha); - }; - Tensor.prototype.prelu = function (alpha) { - this.throwIfDisposed(); - return opHandler.prelu(this, alpha); - }; - Tensor.prototype.sigmoid = function () { - this.throwIfDisposed(); - return opHandler.sigmoid(this); - }; - Tensor.prototype.logSigmoid = function () { - this.throwIfDisposed(); - return opHandler.logSigmoid(this); - }; - Tensor.prototype.softplus = function () { - this.throwIfDisposed(); - return opHandler.softplus(this); - }; - Tensor.prototype.zerosLike = function () { - this.throwIfDisposed(); - return opHandler.zerosLike(this); - }; - Tensor.prototype.onesLike = function () { - this.throwIfDisposed(); - return opHandler.onesLike(this); - }; - Tensor.prototype.sin = function () { - this.throwIfDisposed(); - return opHandler.sin(this); - }; - Tensor.prototype.cos = function () { - this.throwIfDisposed(); - return opHandler.cos(this); - }; - Tensor.prototype.tan = function () { - this.throwIfDisposed(); - return opHandler.tan(this); - }; - Tensor.prototype.asin = function () { - this.throwIfDisposed(); - return opHandler.asin(this); - }; - Tensor.prototype.acos = function () { - this.throwIfDisposed(); - return opHandler.acos(this); - }; - Tensor.prototype.atan = function () { - this.throwIfDisposed(); - return opHandler.atan(this); - }; - Tensor.prototype.sinh = function () { - this.throwIfDisposed(); - return opHandler.sinh(this); - }; - Tensor.prototype.cosh = function () { - this.throwIfDisposed(); - return opHandler.cosh(this); - }; - Tensor.prototype.tanh = function () { - this.throwIfDisposed(); - return opHandler.tanh(this); - }; - Tensor.prototype.asinh = function () { - this.throwIfDisposed(); - return opHandler.asinh(this); - }; - Tensor.prototype.acosh = function () { - this.throwIfDisposed(); - return opHandler.acosh(this); - }; - Tensor.prototype.atanh = function () { - this.throwIfDisposed(); - return opHandler.atanh(this); - }; - Tensor.prototype.erf = function () { - this.throwIfDisposed(); - return opHandler.erf(this); - }; - Tensor.prototype.round = function () { - this.throwIfDisposed(); - return opHandler.round(this); - }; - Tensor.prototype.step = function (alpha) { - if (alpha === void 0) { alpha = 0.0; } - this.throwIfDisposed(); - return opHandler.step(this, alpha); - }; - Tensor.prototype.softmax = function (dim) { - if (dim === void 0) { dim = -1; } - this.throwIfDisposed(); - return opHandler.softmax(this, dim); - }; - Tensor.prototype.logSoftmax = function (axis) { - if (axis === void 0) { axis = -1; } - this.throwIfDisposed(); - return opHandler.logSoftmax(this, axis); - }; - // Image ops. - Tensor.prototype.resizeBilinear = function (newShape2D, alignCorners) { - if (alignCorners === void 0) { alignCorners = false; } - this.throwIfDisposed(); - return opHandler.image.resizeBilinear(this, newShape2D, alignCorners); - }; - Tensor.prototype.resizeNearestNeighbor = function (newShape2D, alignCorners) { - if (alignCorners === void 0) { alignCorners = false; } - this.throwIfDisposed(); - return opHandler.image.resizeNearestNeighbor(this, newShape2D, alignCorners); - }; - // Convolutions. - Tensor.prototype.conv1d = function (filter, stride, pad, dataFormat, dilation, dimRoundingMode) { - if (dataFormat === void 0) { dataFormat = 'NWC'; } - if (dilation === void 0) { dilation = 1; } - this.throwIfDisposed(); - return opHandler.conv1d(this, filter, stride, pad, dataFormat, dilation, dimRoundingMode); - }; - Tensor.prototype.conv2d = function (filter, strides, pad, dataFormat, dilations, dimRoundingMode) { - if (dataFormat === void 0) { dataFormat = 'NHWC'; } - if (dilations === void 0) { dilations = [1, 1]; } - this.throwIfDisposed(); - return opHandler.conv2d(this, filter, strides, pad, dataFormat, dilations, dimRoundingMode); - }; - Tensor.prototype.conv2dTranspose = function (filter, outputShape, strides, pad, dimRoundingMode) { - this.throwIfDisposed(); - return opHandler.conv2dTranspose(this, filter, outputShape, strides, pad, dimRoundingMode); - }; - Tensor.prototype.depthwiseConv2D = function (filter, strides, pad, dataFormat, dilations, dimRoundingMode) { - if (dataFormat === void 0) { dataFormat = 'NHWC'; } - if (dilations === void 0) { dilations = [1, 1]; } - this.throwIfDisposed(); - return opHandler.depthwiseConv2d(this, filter, strides, pad, dataFormat, dilations, dimRoundingMode); - }; - Tensor.prototype.separableConv2d = function (depthwiseFilter, pointwiseFilter, strides, pad, dilation, dataFormat) { - if (dilation === void 0) { dilation = [1, 1]; } - if (dataFormat === void 0) { dataFormat = 'NHWC'; } - this.throwIfDisposed(); - return opHandler.separableConv2d(this, depthwiseFilter, pointwiseFilter, strides, pad, dilation, dataFormat); - }; - // Pooling. - Tensor.prototype.avgPool = function (filterSize, strides, pad, dimRoundingMode) { - this.throwIfDisposed(); - return opHandler.avgPool(this, filterSize, strides, pad, dimRoundingMode); - }; - Tensor.prototype.maxPool = function (filterSize, strides, pad, dimRoundingMode) { - this.throwIfDisposed(); - return opHandler.maxPool(this, filterSize, strides, pad, dimRoundingMode); - }; - Tensor.prototype.localResponseNormalization = function (radius, bias, alpha, beta) { - if (radius === void 0) { radius = 5; } - if (bias === void 0) { bias = 1; } - if (alpha === void 0) { alpha = 1; } - if (beta === void 0) { beta = 0.5; } - return opHandler.localResponseNormalization(this, radius, bias, alpha, beta); - }; - Tensor.prototype.pool = function (windowShape, poolingType, padding, dilationRate, strides) { - this.throwIfDisposed(); - return opHandler.pool(this, windowShape, poolingType, padding, dilationRate, strides); - }; - Tensor.prototype.variable = function (trainable, name, dtype) { - if (trainable === void 0) { trainable = true; } - this.throwIfDisposed(); - return trackerFn().makeVariable(this, trainable, name, dtype); - }; - Tensor.prototype.unsortedSegmentSum = function (segmentIds, numSegments) { - this.throwIfDisposed(); - return opHandler.unsortedSegmentSum(this, segmentIds, numSegments); - }; - Tensor.prototype.batchToSpaceND = function (blockShape, crops) { - this.throwIfDisposed(); - return opHandler.batchToSpaceND(this, blockShape, crops); - }; - Tensor.prototype.spaceToBatchND = function (blockShape, paddings) { - this.throwIfDisposed(); - return opHandler.spaceToBatchND(this, blockShape, paddings); - }; - Tensor.prototype.topk = function (k, sorted) { - if (k === void 0) { k = 1; } - if (sorted === void 0) { sorted = true; } - this.throwIfDisposed(); - return opHandler.topk(this, k, sorted); - }; - Tensor.prototype.stridedSlice = function (begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask) { - if (beginMask === void 0) { beginMask = 0; } - if (endMask === void 0) { endMask = 0; } - if (ellipsisMask === void 0) { ellipsisMask = 0; } - if (newAxisMask === void 0) { newAxisMask = 0; } - if (shrinkAxisMask === void 0) { shrinkAxisMask = 0; } - this.throwIfDisposed(); - return opHandler.stridedSlice(this, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask); - }; - Tensor.prototype.depthToSpace = function (blockSize, dataFormat) { - this.throwIfDisposed(); - return opHandler.depthToSpace(this, blockSize, dataFormat); - }; - Tensor.prototype.fft = function () { - this.throwIfDisposed(); - return opHandler.spectral.fft(this); - }; - Tensor.prototype.ifft = function () { - this.throwIfDisposed(); - return opHandler.spectral.ifft(this); - }; - Tensor.prototype.rfft = function () { - this.throwIfDisposed(); - return opHandler.spectral.rfft(this); - }; - Tensor.prototype.irfft = function () { - this.throwIfDisposed(); - return opHandler.spectral.irfft(this); - }; - return Tensor; - }()); - Object.defineProperty(Tensor, Symbol.hasInstance, { - value: function (instance) { - return !!instance && instance.dataId != null && instance.shape != null && - instance.dtype != null; - } - }); - /** - * A mutable `tf.Tensor`, useful for persisting state, e.g. for training. - */ - /** @doc {heading: 'Tensors', subheading: 'Classes'} */ - var Variable = /** @class */ (function (_super) { - __extends(Variable, _super); - function Variable(initialValue, trainable, name, tensorId) { - var _this = _super.call(this, initialValue.shape, initialValue.dtype, initialValue.dataId, tensorId) || this; - _this.trainable = trainable; - _this.name = name; - return _this; - } - /** - * Assign a new `tf.Tensor` to this variable. The new `tf.Tensor` must have - * the same shape and dtype as the old `tf.Tensor`. - * - * @param newValue New tensor to be assigned to this variable. - */ - /** @doc {heading: 'Tensors', subheading: 'Classes'} */ - Variable.prototype.assign = function (newValue) { - if (newValue.dtype !== this.dtype) { - throw new Error("dtype of the new value (" + newValue.dtype + ") and " + - ("previous value (" + this.dtype + ") must match")); - } - if (!arraysEqual(newValue.shape, this.shape)) { - throw new Error("shape of the new value (" + newValue.shape + ") and " + - ("previous value (" + this.shape + ") must match")); - } - trackerFn().disposeTensor(this); - this.dataId = newValue.dataId; - trackerFn().incRef(this, null /* backend */); - }; - Variable.prototype.dispose = function () { - trackerFn().disposeVariable(this); - this.isDisposedInternal = true; - }; - return Variable; - }(Tensor)); - Object.defineProperty(Variable, Symbol.hasInstance, { - value: function (instance) { - return instance instanceof Tensor && instance.assign != null && - instance.assign instanceof Function; - } - }); - - /** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - (function (Rank) { - Rank["R0"] = "R0"; - Rank["R1"] = "R1"; - Rank["R2"] = "R2"; - Rank["R3"] = "R3"; - Rank["R4"] = "R4"; - Rank["R5"] = "R5"; - Rank["R6"] = "R6"; - })(exports.Rank || (exports.Rank = {})); - // Looks for upcasting types. Used, for example, in operations with mixed dtype - // inputs. - var UpcastInt32AndMap; - (function (UpcastInt32AndMap) { - UpcastInt32AndMap["float32"] = "float32"; - UpcastInt32AndMap["int32"] = "int32"; - UpcastInt32AndMap["bool"] = "int32"; - UpcastInt32AndMap["complex64"] = "complex64"; - })(UpcastInt32AndMap || (UpcastInt32AndMap = {})); - var UpcastBoolAndMap; - (function (UpcastBoolAndMap) { - UpcastBoolAndMap["float32"] = "float32"; - UpcastBoolAndMap["int32"] = "int32"; - UpcastBoolAndMap["bool"] = "bool"; - UpcastBoolAndMap["complex64"] = "complex64"; - })(UpcastBoolAndMap || (UpcastBoolAndMap = {})); - var UpcastFloat32AndMap; - (function (UpcastFloat32AndMap) { - UpcastFloat32AndMap["float32"] = "float32"; - UpcastFloat32AndMap["int32"] = "float32"; - UpcastFloat32AndMap["bool"] = "float32"; - UpcastFloat32AndMap["complex64"] = "complex64"; - })(UpcastFloat32AndMap || (UpcastFloat32AndMap = {})); - var UpcastComplex64AndMap; - (function (UpcastComplex64AndMap) { - UpcastComplex64AndMap["float32"] = "complex64"; - UpcastComplex64AndMap["int32"] = "complex64"; - UpcastComplex64AndMap["bool"] = "complex64"; - UpcastComplex64AndMap["complex64"] = "complex64"; - })(UpcastComplex64AndMap || (UpcastComplex64AndMap = {})); - var upcastTypeMap = { - 'float32': UpcastFloat32AndMap, - 'int32': UpcastInt32AndMap, - 'bool': UpcastBoolAndMap, - 'complex64': UpcastComplex64AndMap - }; - function upcastType(typeA, typeB) { - if (typeA === 'string' || typeB === 'string') { - if (typeA === 'string' && typeB === 'string') { - return 'string'; - } - throw new Error("Can not upcast " + typeA + " with " + typeB); - } - return upcastTypeMap[typeA][typeB]; - } - /** Returns the output type after summation. */ - function sumOutType(type) { - return upcastType(type, 'int32'); - } - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - function makeTypesMatch(a, b) { - if (a.dtype === b.dtype) { - return [a, b]; - } - var dtype = upcastType(a.dtype, b.dtype); - return [a.cast(dtype), b.cast(dtype)]; - } - function assertTypesMatch(a, b) { - assert(a.dtype === b.dtype, function () { return "The dtypes of the first(" + a.dtype + ") and" + - (" second(" + b.dtype + ") input must match"); }); - } - function isTensorInList(tensor, tensorList) { - return tensorList.some(function (x) { return x.id === tensor.id; }); - } - /** - * Extracts any `Tensor`s found within the provided object. - * - * @param container an object that may be a `Tensor` or may directly contain - * `Tensor`s, such as a `Tensor[]` or `{key: Tensor, ...}`. In general it - * is safe to pass any object here, except that `Promise`s are not - * supported. - * @returns An array of `Tensors` found within the passed object. If the - * argument is simply a `Tensor', a list containing that `Tensor` is - * returned. If the object is not a `Tensor` or does not - * contain `Tensors`, an empty list is returned. - */ - function getTensorsInContainer(result) { - var list = []; - var seen = new Set(); - walkTensorContainer(result, list, seen); - return list; - } - function walkTensorContainer(container, list, seen) { - if (container == null) { - return; - } - if (container instanceof Tensor) { - list.push(container); - return; - } - if (!isIterable(container)) { - return; - } - // Iteration over keys works also for arrays. - var iterable = container; - for (var k in iterable) { - var val = iterable[k]; - if (!seen.has(val)) { - seen.add(val); - walkTensorContainer(val, list, seen); - } - } - } - // tslint:disable-next-line:no-any - function isIterable(obj) { - return Array.isArray(obj) || typeof obj === 'object'; - } - - var tensor_util = /*#__PURE__*/Object.freeze({ - makeTypesMatch: makeTypesMatch, - assertTypesMatch: assertTypesMatch, - isTensorInList: isTensorInList, - getTensorsInContainer: getTensorsInContainer - }); - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var EngineState = /** @class */ (function () { - function EngineState() { - // Public since optimizers will use it. - this.registeredVariables = {}; - this.nextTapeNodeId = 0; - this.numBytes = 0; - this.numTensors = 0; - this.numStringTensors = 0; - this.numDataBuffers = 0; - // Number of nested tf.grad() statements when computing higher-order - // gradients. E.g. `1` for first-order gradients and `2` for second-order - // gradients. Used to track if the tape should be removed after a backprop. - this.gradientDepth = 0; - // Number of nested kernel calls. When kernel depth is greater than 1, we turn - // off the tape. - this.kernelDepth = 0; - this.scopeStack = []; - /** - * Keeps track of the number of data moves during a kernel execution. We - * maintain a stack since kernels can call other kernels, recursively. - */ - this.numDataMovesStack = []; - this.nextScopeId = 0; - this.tensorInfo = new WeakMap(); - this.profiling = false; - this.activeProfile = { newBytes: 0, newTensors: 0, peakBytes: 0, kernels: [], result: null }; - } - EngineState.prototype.dispose = function () { - for (var variableName in this.registeredVariables) { - this.registeredVariables[variableName].dispose(); - } - }; - return EngineState; - }()); - var Engine = /** @class */ (function () { - function Engine(ENV) { - this.ENV = ENV; - this.registry = {}; - this.registryFactory = {}; - this.pendingBackendInitId = 0; - this.state = new EngineState(); - } - Engine.prototype.ready = function () { - return __awaiter(this, void 0, void 0, function () { - var sortedBackends, i, backendName, success; - return __generator(this, function (_a) { - switch (_a.label) { - case 0: - if (this.pendingBackendInit != null) { - return [2 /*return*/, this.pendingBackendInit.then(function () { })]; - } - if (this.backendInstance != null) { - return [2 /*return*/]; - } - sortedBackends = this.getSortedBackends(); - i = 0; - _a.label = 1; - case 1: - if (!(i < sortedBackends.length)) return [3 /*break*/, 5]; - backendName = sortedBackends[i]; - return [4 /*yield*/, this.initializeBackend(backendName).success]; - case 2: - success = _a.sent(); - if (!success) return [3 /*break*/, 4]; - return [4 /*yield*/, this.setBackend(backendName)]; - case 3: - _a.sent(); - return [2 /*return*/]; - case 4: - i++; - return [3 /*break*/, 1]; - case 5: throw new Error("Could not initialize any backends, all backend initializations " + - "failed."); - } - }); - }); - }; - Object.defineProperty(Engine.prototype, "backend", { - get: function () { - if (this.pendingBackendInit != null) { - throw new Error("Backend '" + this.backendName + "' has not yet been initialized. Make " + - "sure to await tf.ready() or await tf.setBackend() before calling " + - "other methods"); - } - if (this.backendInstance == null) { - var _a = this.initializeBackendsAndReturnBest(), name_1 = _a.name, asyncInit = _a.asyncInit; - if (asyncInit) { - throw new Error("The highest priority backend '" + name_1 + "' has not yet been " + - "initialized. Make sure to await tf.ready() or " + - "await tf.setBackend() before calling other methods"); - } - this.setBackend(name_1); - } - return this.backendInstance; - }, - enumerable: true, - configurable: true - }); - Engine.prototype.backendNames = function () { - return Object.keys(this.registryFactory); - }; - Engine.prototype.findBackend = function (backendName) { - if (!(backendName in this.registry)) { - // If the backend hasn't been initialized but we have a registry entry for - // it, initialize it and return it. - if (backendName in this.registryFactory) { - var asyncInit = this.initializeBackend(backendName).asyncInit; - if (asyncInit) { - // Backend is not ready yet. - return null; - } - } - else { - return null; - } - } - return this.registry[backendName]; - }; - Engine.prototype.findBackendFactory = function (backendName) { - if (!(backendName in this.registryFactory)) { - return null; - } - return this.registryFactory[backendName].factory; - }; - Engine.prototype.registerBackend = function (backendName, factory, priority) { - if (priority === void 0) { priority = 1; } - if (backendName in this.registryFactory) { - console.warn(backendName + " backend was already registered. " + - "Reusing existing backend factory."); - return false; - } - this.registryFactory[backendName] = { factory: factory, priority: priority }; - return true; - }; - Engine.prototype.setBackend = function (backendName) { - return __awaiter(this, void 0, void 0, function () { - var _a, success, asyncInit, result, _b; - return __generator(this, function (_c) { - switch (_c.label) { - case 0: - if (this.registryFactory[backendName] == null) { - throw new Error("Backend name '" + backendName + "' not found in registry"); - } - this.backendName = backendName; - if (!(this.registry[backendName] == null)) return [3 /*break*/, 4]; - this.backendInstance = null; - _a = this.initializeBackend(backendName), success = _a.success, asyncInit = _a.asyncInit; - if (!asyncInit) return [3 /*break*/, 2]; - return [4 /*yield*/, success]; - case 1: - _b = _c.sent(); - return [3 /*break*/, 3]; - case 2: - _b = success; - _c.label = 3; - case 3: - result = _b; - if (!result) { - return [2 /*return*/, false]; - } - _c.label = 4; - case 4: - this.backendInstance = this.registry[backendName]; - this.setupRegisteredKernels(); - // Reset the profiler. - this.profiler = new Profiler(this.backendInstance); - return [2 /*return*/, true]; - } - }); - }); - }; - Engine.prototype.setupRegisteredKernels = function () { - var _this = this; - var kernels = getKernelsForBackend(this.backendName); - kernels.forEach(function (kernel) { - if (kernel.setupFunc != null) { - kernel.setupFunc(_this.backendInstance); - } - }); - }; - Engine.prototype.disposeRegisteredKernels = function (backendName) { - var _this = this; - var kernels = getKernelsForBackend(backendName); - kernels.forEach(function (kernel) { - if (kernel.disposeFunc != null) { - kernel.disposeFunc(_this.registry[backendName]); - } - }); - }; - /** - * Initializes a backend by looking up the backend name in the factory - * registry and calling the factory method. Returns a boolean representing - * whether the initialization of the backend suceeded. Throws an error if - * there is no backend in the factory registry. - */ - Engine.prototype.initializeBackend = function (backendName) { - var _this = this; - var registryFactoryEntry = this.registryFactory[backendName]; - if (registryFactoryEntry == null) { - throw new Error("Cannot initialize backend " + backendName + ", no registration found."); - } - try { - var backend = registryFactoryEntry.factory(); - // Test if the factory returns a promise. - if (Promise.resolve(backend) === backend) { - var promiseId_1 = ++this.pendingBackendInitId; - var success = backend - .then(function (backendInstance) { - // Outdated promise. Another backend was set in the meantime. - if (promiseId_1 < _this.pendingBackendInitId) { - return false; - } - _this.registry[backendName] = backendInstance; - _this.pendingBackendInit = null; - return true; - }) - .catch(function (err) { - // Outdated promise. Another backend was set in the meantime. - if (promiseId_1 < _this.pendingBackendInitId) { - return false; - } - _this.pendingBackendInit = null; - console.warn("Initialization of backend " + backendName + " failed"); - console.warn(err.stack || err.message); - return false; - }); - this.pendingBackendInit = success; - return { success: success, asyncInit: true }; - } - else { - this.registry[backendName] = backend; - return { success: true, asyncInit: false }; - } - } - catch (err) { - console.warn("Initialization of backend " + backendName + " failed"); - console.warn(err.stack || err.message); - return { success: false, asyncInit: false }; - } - }; - Engine.prototype.removeBackend = function (backendName) { - if (!(backendName in this.registryFactory)) { - throw new Error(backendName + " backend not found in registry"); - } - if (this.backendName === backendName && this.pendingBackendInit != null) { - // There is a pending promise of the backend we want to remove. Make it - // obsolete. - this.pendingBackendInitId++; - } - if (backendName in this.registry) { - this.disposeRegisteredKernels(backendName); - this.registry[backendName].dispose(); - delete this.registry[backendName]; - } - delete this.registryFactory[backendName]; - // Unset the backend if it is active. - if (this.backendName === backendName) { - this.pendingBackendInit = null; - this.backendName = null; - this.backendInstance = null; - } - }; - Engine.prototype.getSortedBackends = function () { - var _this = this; - if (Object.keys(this.registryFactory).length === 0) { - throw new Error('No backend found in registry.'); - } - return Object.keys(this.registryFactory).sort(function (a, b) { - // Highest priority comes first. - return _this.registryFactory[b].priority - - _this.registryFactory[a].priority; - }); - }; - Engine.prototype.initializeBackendsAndReturnBest = function () { - var sortedBackends = this.getSortedBackends(); - for (var i = 0; i < sortedBackends.length; i++) { - var backendName = sortedBackends[i]; - var _a = this.initializeBackend(backendName), success = _a.success, asyncInit = _a.asyncInit; - if (asyncInit || success) { - return { name: backendName, asyncInit: asyncInit }; - } - } - throw new Error("Could not initialize any backends, all backend initializations " + - "failed."); - }; - Engine.prototype.moveData = function (destBackend, dataId) { - var info = this.state.tensorInfo.get(dataId); - var srcBackend = info.backend; - var values = this.readSync(dataId); - // Delete the tensor from the old backend and move it to the new - // backend. - srcBackend.disposeData(dataId); - info.backend = destBackend; - destBackend.move(dataId, values, info.shape, info.dtype); - if (this.shouldCheckForMemLeaks()) { - // Track the number of moves during a kernel execution to correctly - // detect memory leaks. - this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]++; - } - }; - Engine.prototype.tidy = function (nameOrFn, fn) { - var _this = this; - var name = null; - if (fn == null) { - // Called with only 1 argument. - if (typeof nameOrFn !== 'function') { - throw new Error('Please provide a function to tidy()'); - } - fn = nameOrFn; - } - else { - // Called with 2 arguments. - if (typeof nameOrFn !== 'string' && !(nameOrFn instanceof String)) { - throw new Error('When calling with two arguments, the first argument ' + - 'to tidy() must be a string'); - } - if (typeof fn !== 'function') { - throw new Error('When calling with two arguments, the 2nd argument ' + - 'to tidy() must be a function'); - } - name = nameOrFn; - // TODO(nsthorat,smilkov): Do operation logging and performance - // profiling. - } - var result; - return this.scopedRun(function () { return _this.startScope(name); }, function () { return _this.endScope(result); }, function () { - result = fn(); - if (result instanceof Promise) { - console.error('Cannot return a Promise inside of tidy.'); - } - return result; - }); - }; - Engine.prototype.scopedRun = function (start, end, f) { - start(); - try { - var res = f(); - end(); - return res; - } - catch (ex) { - end(); - throw ex; - } - }; - Engine.prototype.nextTensorId = function () { - return Engine.nextTensorId++; - }; - Engine.prototype.nextVariableId = function () { - return Engine.nextVariableId++; - }; - /** - * This method is called instead of the public-facing tensor.clone() when - * saving a tensor for backwards pass. It makes sure to add the clone - * operation to the tape regardless of being called inside a kernel - * execution. - * - * This method will go away once all kernels are modularized since we won't - * need to turn off the tape inside runKernel(). - */ - Engine.prototype.clone = function (x) { - var y = this.makeTensorFromDataId(x.dataId, x.shape, x.dtype); - var inputs = { x: x }; - var grad = function (dy) { return ({ x: function () { return dy.toFloat(); } }); }; - var saved = []; - this.addTapeNode(this.state.activeScope.name, inputs, [y], grad, saved, {}); - return y; - }; - /** - * Execute a kernel with the given name and return the output tensor. - * - * @param kernelName The name of the kernel to execute. - * @param inputs A map of input names to tensors. - * @param attrs A map of attribute names to their values. An attribute is a - * primitive (non-tensor) input to the kernel. - * @param inputsToSave A list of tensors, inputs to save for the backprop - * computation. - * @param outputsToSave A list of booleans, specifying which output to save - * for the backprop computation. These are booleans since the output - * tensors are not visible to the user. - */ - Engine.prototype.runKernel = function (kernelName, inputs, attrs, inputsToSave, outputsToSave) { - var forwardFunc = null; - var backwardsFunc = null; - // Call runKernel as a stop-gap until we modularize all kernels. - // Once we modularize all kernels, we will remove the existing - // `runKernelFunc`. - return this.runKernelFunc(forwardFunc, inputs, backwardsFunc, kernelName, attrs, inputsToSave, outputsToSave); - }; - Engine.prototype.shouldCheckForMemLeaks = function () { - return this.ENV.getBool('IS_TEST'); - }; - Engine.prototype.checkKernelForMemLeak = function (kernelName, numDataIdsBefore, outInfos) { - var numDataIdsAfter = this.backend.numDataIds(); - // Count the number of data ids associated with the result of the kernel. - var numOutputDataIds = 0; - outInfos.forEach(function (info) { - // Complex numbers allocate 3 data ids, one for 'real', one for - // 'imaginary', and one for the container that holds the former two. - numOutputDataIds += (info.dtype === 'complex64' ? 3 : 1); - }); - // Account for the number of moves during kernel execution. A "data move" - // can happen in the middle of a kernel execution, placing a new (key,value) - // pair in the data storage. Since data moves have net zero effect (we - // always remove the data from the old backend), we have to cancel them out - // when detecting memory leaks. - var numMoves = this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]; - var dataIdsLeaked = numDataIdsAfter - numDataIdsBefore - numOutputDataIds - numMoves; - if (dataIdsLeaked > 0) { - throw new Error("Backend '" + this.backendName + "' has an internal memory leak " + - ("(" + dataIdsLeaked + " data ids) after running '" + kernelName + "'")); - } - }; - /** - * @deprecated Use `runKernel` for newly added kernels. Keep using this method - * only for kernels that are not yet fully modularized. - */ - Engine.prototype.runKernelFunc = function (forwardFunc, inputs, backwardsFunc, kernelName, attrs, inputsToSave, outputsToSave) { - var _this = this; - var outputs; - var saved = []; - var isTapeOn = this.isTapeOn(); - if (kernelName == null) { - kernelName = - this.state.activeScope != null ? this.state.activeScope.name : ''; - } - var startingBytecount = this.state.numBytes; - var startingNumTensors = this.state.numTensors; - if (this.shouldCheckForMemLeaks()) { - this.state.numDataMovesStack.push(0); - } - var kernelFunc; - var kernel = getKernel(kernelName, this.backendName); - var out; - if (kernel != null) { - kernelFunc = function () { - var numDataIdsBefore = _this.backend.numDataIds(); - out = kernel.kernelFunc({ inputs: inputs, attrs: attrs, backend: _this.backend }); - var outInfos = Array.isArray(out) ? out : [out]; - if (_this.shouldCheckForMemLeaks()) { - _this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos); - } - var outTensors = outInfos.map(function (_a) { - var dataId = _a.dataId, shape = _a.shape, dtype = _a.dtype; - return _this.makeTensorFromDataId(dataId, shape, dtype); - }); - // Save the inputs and outputs. - // Do not save unless we are recording to the tape. Otherwise it would - // cause a mem leak since we would never run backprop, which disposes - // the kept tensors. - if (isTapeOn) { - var tensorsToSave = _this.getTensorsForGradient(kernelName, inputs, outTensors); - if (tensorsToSave == null) { - // Fallback for ops that call runKernelFunc and pass in - // inputsToSave and outputsToSave. Currently this is the set of ops - // with kernel support in the WASM backend. Once those ops and - // respective gradients are modularised we can remove this path. - if (outputsToSave == null) { - outputsToSave = []; - } - var outsToSave = outTensors.filter(function (_, i) { return outputsToSave[i]; }); - tensorsToSave = (inputsToSave || []).slice().concat(outsToSave); - } - saved = _this.saveTensorsForBackwardMode(tensorsToSave); - } - return outTensors; - }; - } - else { - var saveFunc_1 = function (tensors) { - // Do not save unless we are recording to the tape. Otherwise it would - // cause a mem leak since we would never run backprop, which disposes - // the kept tensors. - if (!isTapeOn) { - return; - } - saved = tensors.map(function (tensor) { return _this.keep(_this.clone(tensor)); }); - }; - kernelFunc = function () { - var numDataIdsBefore = _this.backend.numDataIds(); - out = _this.tidy(function () { return forwardFunc(_this.backend, saveFunc_1); }); - var outs = (Array.isArray(out) ? out : [out]); - if (_this.shouldCheckForMemLeaks()) { - _this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outs); - } - return outs; - }; - } - // Stop recording to a tape when running a kernel. - this.scopedRun(function () { return _this.state.kernelDepth++; }, function () { return _this.state.kernelDepth--; }, function () { - if (!_this.ENV.getBool('DEBUG')) { - outputs = kernelFunc(); - } - else { - outputs = _this.profiler.profileKernel(kernelName, inputs, function () { return kernelFunc(); }); - } - }); - if (isTapeOn) { - this.addTapeNode(kernelName, inputs, outputs, backwardsFunc, saved, attrs); - } - if (this.state.profiling) { - this.state.activeProfile.kernels.push({ - name: kernelName, - bytesAdded: this.state.numBytes - startingBytecount, - totalBytesSnapshot: this.state.numBytes, - tensorsAdded: this.state.numTensors - startingNumTensors, - totalTensorsSnapshot: this.state.numTensors, - inputShapes: Object.keys(inputs).map(function (key) { return inputs[key].shape; }), - outputShapes: outputs.map(function (item) { return item.shape; }) - }); - } - return (Array.isArray(out) ? outputs : outputs[0]); - }; - /** - * Saves tensors used in forward mode for use in backward mode. - * - * @param tensors the list of tensors to save. - */ - Engine.prototype.saveTensorsForBackwardMode = function (tensors) { - var _this = this; - var saved = tensors.map(function (tensor) { return _this.keep(_this.clone(tensor)); }); - return saved; - }; - /** - * Returns a list of tensors to save for a given gradient calculation. - * - * Returns undefined if their is no registered gradient for this kernel in the - * gradient registry. - * - * @param kernelName name of kernel to look up gradient for. - * @param inputs a map of input tensors. - * @param outputs an array of output tensors from forward mode of kernel. - */ - Engine.prototype.getTensorsForGradient = function (kernelName, inputs, outputs) { - var gradConfig = getGradient(kernelName); - if (gradConfig != null) { - var inputsToSave = gradConfig.inputsToSave || []; - var outputsToSave_1 = gradConfig.outputsToSave || []; - // If saveAllInputs is true, all inputs will be saved. Otherwise, inputs - // specified in inputsToSave will be saved. - var inputTensorsToSave = void 0; - if (gradConfig.saveAllInputs) { - assert(Array.isArray(inputs), function () { return 'saveAllInputs is true, expected inputs to be an array.'; }); - inputTensorsToSave = Object.keys(inputs).map(function (key) { return inputs[key]; }); - } - else { - inputTensorsToSave = inputsToSave.map(function (inputName) { return inputs[inputName]; }); - } - var outputTensorsToSave = outputs.filter(function (_, i) { return outputsToSave_1[i]; }); - return inputTensorsToSave.concat(outputTensorsToSave); - } - // TODO(yassogba) throw exception here once all runkernelFunc calls with - // inputsToSave/outputsToSave are removed - return null; - }; - /** - * Internal method used by public APIs for tensor creation. Makes a new - * tensor with the provided shape, dtype and values. It always - * creates a new data id and writes the values to the underlying backend. - */ - Engine.prototype.makeTensor = function (values, shape, dtype, backend) { - if (values == null) { - throw new Error('Values passed to engine.makeTensor() are null'); - } - dtype = dtype || 'float32'; - backend = backend || this.backend; - var backendVals = values; - if (dtype === 'string' && isString(values[0])) { - backendVals = values.map(function (d) { return encodeString(d); }); - } - var dataId = backend.write(backendVals, shape, dtype); - var t = new Tensor(shape, dtype, dataId, this.nextTensorId()); - this.incRef(t, backend); - // Count bytes for string tensors. - if (dtype === 'string') { - var info = this.state.tensorInfo.get(dataId); - var newBytes = bytesFromStringArray(backendVals); - this.state.numBytes += newBytes - info.bytes; - info.bytes = newBytes; - } - return t; - }; - /** - * Internal method used by backends. Makes a new tensor - * that is a wrapper around an existing data id. It doesn't create - * a new data id, only increments the ref count used in memory tracking. - */ - Engine.prototype.makeTensorFromDataId = function (dataId, shape, dtype, backend) { - dtype = dtype || 'float32'; - var t = new Tensor(shape, dtype, dataId, this.nextTensorId()); - this.incRef(t, backend); - return t; - }; - Engine.prototype.makeVariable = function (initialValue, trainable, name, dtype) { - if (trainable === void 0) { trainable = true; } - name = name || this.nextVariableId().toString(); - if (dtype != null && dtype !== initialValue.dtype) { - initialValue = initialValue.asType(dtype); - } - var v = new Variable(initialValue, trainable, name, this.nextTensorId()); - if (this.state.registeredVariables[v.name] != null) { - throw new Error("Variable with name " + v.name + " was already registered"); - } - this.state.registeredVariables[v.name] = v; - this.incRef(v, this.backend); - return v; - }; - Engine.prototype.incRef = function (a, backend) { - var refCount = this.state.tensorInfo.has(a.dataId) ? - this.state.tensorInfo.get(a.dataId).refCount : - 0; - this.state.numTensors++; - if (a.dtype === 'string') { - this.state.numStringTensors++; - } - if (refCount === 0) { - this.state.numDataBuffers++; - // Bytes for complex numbers are counted by their components. Bytes for - // string tensors are counted when writing values. - var bytes = 0; - if (a.dtype !== 'complex64' && a.dtype !== 'string') { - bytes = a.size * bytesPerElement(a.dtype); - } - this.state.tensorInfo.set(a.dataId, { - backend: backend || this.backend, - dtype: a.dtype, - shape: a.shape, - bytes: bytes, - refCount: 0 - }); - this.state.numBytes += bytes; - } - this.state.tensorInfo.get(a.dataId).refCount++; - if (!(a instanceof Variable)) { - this.track(a); - } - }; - Engine.prototype.disposeTensor = function (a) { - if (!this.state.tensorInfo.has(a.dataId)) { - return; - } - this.state.numTensors--; - if (a.dtype === 'string') { - this.state.numStringTensors--; - } - var info = this.state.tensorInfo.get(a.dataId); - var refCount = info.refCount; - if (refCount <= 1) { - // Don't count bytes for complex numbers as they are counted by their - // components. - if (a.dtype !== 'complex64') { - this.state.numBytes -= info.bytes; - } - this.state.numDataBuffers--; - info.backend.disposeData(a.dataId); - this.state.tensorInfo.delete(a.dataId); - } - else { - this.state.tensorInfo.get(a.dataId).refCount--; - } - // TODO(nsthorat): Construct an error and save the stack trace for - // debugging when in debug mode. Creating a stack trace is too expensive - // to do unconditionally. - }; - Engine.prototype.disposeVariables = function () { - for (var varName in this.state.registeredVariables) { - var v = this.state.registeredVariables[varName]; - this.disposeVariable(v); - } - }; - Engine.prototype.disposeVariable = function (v) { - this.disposeTensor(v); - if (this.state.registeredVariables[v.name] != null) { - delete this.state.registeredVariables[v.name]; - } - }; - Engine.prototype.memory = function () { - var info = this.backend.memory(); - info.numTensors = this.state.numTensors; - info.numDataBuffers = this.state.numDataBuffers; - info.numBytes = this.state.numBytes; - if (this.state.numStringTensors > 0) { - info.unreliable = true; - if (info.reasons == null) { - info.reasons = []; - } - info.reasons.push('Memory usage by string tensors is approximate ' + - '(2 bytes per character)'); - } - return info; - }; - Engine.prototype.profile = function (query) { - return __awaiter(this, void 0, void 0, function () { - var startBytes, startNumTensors; - return __generator(this, function (_a) { - this.state.profiling = true; - startBytes = this.state.numBytes; - startNumTensors = this.state.numTensors; - this.state.activeProfile.kernels = []; - this.state.activeProfile.result = query(); - this.state.profiling = false; - this.state.activeProfile.peakBytes = Math.max.apply(Math, this.state.activeProfile.kernels.map(function (d) { return d.totalBytesSnapshot; })); - this.state.activeProfile.newBytes = this.state.numBytes - startBytes; - this.state.activeProfile.newTensors = - this.state.numTensors - startNumTensors; - return [2 /*return*/, this.state.activeProfile]; - }); - }); - }; - Engine.prototype.isTapeOn = function () { - return this.state.gradientDepth > 0 && this.state.kernelDepth === 0; - }; - Engine.prototype.addTapeNode = function (kernelName, inputs, outputs, gradientsFunc, saved, attrs) { - var _this = this; - var tapeNode = { id: this.state.nextTapeNodeId++, kernelName: kernelName, inputs: inputs, outputs: outputs, saved: saved }; - var gradConfig = getGradient(kernelName); - if (gradConfig != null) { - gradientsFunc = gradConfig.gradFunc; - } - if (gradientsFunc != null) { - tapeNode.gradient = function (dys) { - // TODO(smilkov): To optimize back-prop, pass dys that are not used in - // the backprop graph to the user as null instead of zeros - dys = dys.map(function (dy, i) { - if (dy == null) { - var output = outputs[i]; - var vals = makeZerosTypedArray(output.size, output.dtype); - return _this.makeTensor(vals, output.shape, output.dtype); - } - return dy; - }); - // Grad functions of ops with single outputs expect a dy, while ops - // with multiple outputs expect dys (array of dy). - return gradientsFunc(dys.length > 1 ? dys : dys[0], saved, attrs); - }; - } - this.state.activeTape.push(tapeNode); - }; - Engine.prototype.keep = function (result) { - result.kept = true; - return result; - }; - Engine.prototype.startTape = function () { - if (this.state.gradientDepth === 0) { - this.state.activeTape = []; - } - this.state.gradientDepth++; - }; - Engine.prototype.endTape = function () { - this.state.gradientDepth--; - }; - /** - * Start a scope. Use this with endScope() to achieve the same functionality - * as scope() without the need for a function closure. - */ - Engine.prototype.startScope = function (name) { - var scopeInfo = { - track: [], - name: 'unnamed scope', - id: this.state.nextScopeId++ - }; - if (name) { - scopeInfo.name = name; - } - this.state.scopeStack.push(scopeInfo); - this.state.activeScope = scopeInfo; - }; - /** - * End a scope. Use this with startScope() to achieve the same functionality - * as scope() without the need for a function closure. - */ - Engine.prototype.endScope = function (result) { - var _this = this; - var tensorsToTrackInParent = getTensorsInContainer(result); - var tensorsToTrackInParentSet = new Set(tensorsToTrackInParent.map(function (t) { return t.id; })); - // Dispose the arrays tracked in this scope. - for (var i = 0; i < this.state.activeScope.track.length; i++) { - var tensor = this.state.activeScope.track[i]; - if (!tensor.kept && !tensorsToTrackInParentSet.has(tensor.id)) { - tensor.dispose(); - } - } - var oldScope = this.state.scopeStack.pop(); - this.state.activeScope = this.state.scopeStack.length === 0 ? - null : - this.state.scopeStack[this.state.scopeStack.length - 1]; - // Track the current result in the parent scope. - tensorsToTrackInParent.forEach(function (tensor) { - // Only track the tensor if was allocated in the inner scope and is not - // globally kept. - if (!tensor.kept && tensor.scopeId === oldScope.id) { - _this.track(tensor); - } - }); - }; - /** - * Returns gradients of `f` with respect to each of the `xs`. The gradients - * returned are of the same length as `xs`, but some might be null if `f` - * was not a function of that `x`. It also takes optional dy to multiply the - * gradient, which defaults to `1`. - */ - Engine.prototype.gradients = function (f, xs, dy, allowNoGradients) { - var _this = this; - if (allowNoGradients === void 0) { allowNoGradients = false; } - assert(xs.length > 0, function () { return 'gradients() received an empty list of xs.'; }); - if (dy != null && dy.dtype !== 'float32') { - throw new Error("dy must have 'float32' dtype, but has '" + dy.dtype + "'"); - } - var y = this.scopedRun(function () { return _this.startTape(); }, function () { return _this.endTape(); }, function () { return _this.tidy('forward', f); }); - assert(y instanceof Tensor, function () { return 'The result y returned by f() must be a tensor.'; }); - // Filter out the nodes that don't connect x => y. - var filteredTape = getFilteredNodesXToY(this.state.activeTape, xs, y); - if (!allowNoGradients && filteredTape.length === 0 && xs.length > 0) { - throw new Error('Cannot compute gradient of y=f(x) with respect to x. Make sure ' + - 'that the f you passed encloses all operations that lead from x ' + - 'to y.'); - } - return this.tidy('backward', function () { - var accumulatedGradientMap = {}; - accumulatedGradientMap[y.id] = (dy == null) ? ones(y.shape) : dy; - // Backprop gradients through the filtered nodes. - backpropagateGradients(accumulatedGradientMap, filteredTape, - // Pass the tidy function to avoid circular dep with `tape.ts`. - function (f) { return _this.tidy(f); }); - var grads = xs.map(function (x) { return accumulatedGradientMap[x.id]; }); - if (_this.state.gradientDepth === 0) { - // This means that we are not computing higher-order gradients - // and can clean up the tape. - _this.state.activeTape.forEach(function (node) { - for (var _i = 0, _a = node.saved; _i < _a.length; _i++) { - var tensor = _a[_i]; - tensor.dispose(); - } - }); - _this.state.activeTape = null; - } - return { value: y, grads: grads }; - }); - }; - Engine.prototype.customGrad = function (f) { - var _this = this; - assert(isFunction(f), function () { return 'The f passed in customGrad(f) must be a function.'; }); - return function () { - var inputs = []; - for (var _i = 0; _i < arguments.length; _i++) { - inputs[_i] = arguments[_i]; - } - assert(inputs.every(function (t) { return t instanceof Tensor; }), function () { return 'The args passed in customGrad(f)(x1, x2,...) must all be ' + - 'tensors'; }); - var res; - var inputMap = {}; - inputs.forEach(function (input, i) { - inputMap[i] = input; - }); - return _this.runKernelFunc(function (_, save) { - res = f.apply(void 0, inputs.concat([save])); - assert(res.value instanceof Tensor, function () { return 'The function f passed in customGrad(f) must return an ' + - 'object where `obj.value` is a tensor'; }); - assert(isFunction(res.gradFunc), function () { return 'The function f passed in customGrad(f) must return an ' + - 'object where `obj.gradFunc` is a function.'; }); - return res.value; - }, inputMap, function (dy, saved) { - var gradRes = res.gradFunc(dy, saved); - var grads = Array.isArray(gradRes) ? gradRes : [gradRes]; - assert(grads.length === inputs.length, function () { return 'The function f passed in customGrad(f) must return an ' + - 'object where `obj.gradFunc` is a function that returns ' + - 'the same number of tensors as inputs passed to f(...).'; }); - assert(grads.every(function (t) { return t instanceof Tensor; }), function () { return 'The function f passed in customGrad(f) must return an ' + - 'object where `obj.gradFunc` is a function that returns ' + - 'a list of only tensors.'; }); - var gradMap = {}; - grads.forEach(function (grad, i) { - gradMap[i] = function () { return grad; }; - }); - return gradMap; - }); - }; - }; - Engine.prototype.readSync = function (dataId) { - // Route the read to the correct backend. - var info = this.state.tensorInfo.get(dataId); - return info.backend.readSync(dataId); - }; - Engine.prototype.read = function (dataId) { - // Route the read to the correct backend. - var info = this.state.tensorInfo.get(dataId); - return info.backend.read(dataId); - }; - Engine.prototype.time = function (query) { - return __awaiter(this, void 0, void 0, function () { - var start, timingInfo; - return __generator(this, function (_a) { - switch (_a.label) { - case 0: - start = now(); - return [4 /*yield*/, this.backend.time(query)]; - case 1: - timingInfo = _a.sent(); - timingInfo.wallMs = now() - start; - return [2 /*return*/, timingInfo]; - } - }); - }); - }; - /** - * Tracks a Tensor in the current scope to be automatically cleaned up - * when the current scope ends, and returns the value. - * - * @param result The Tensor to track in the current scope. - */ - Engine.prototype.track = function (result) { - if (this.state.activeScope != null) { - result.scopeId = this.state.activeScope.id; - this.state.activeScope.track.push(result); - } - return result; - }; - Object.defineProperty(Engine.prototype, "registeredVariables", { - get: function () { - return this.state.registeredVariables; - }, - enumerable: true, - configurable: true - }); - /** - * Resets the engine state. Removes all backends but does not remove - * registered backend factories. - */ - Engine.prototype.reset = function () { - // Make any pending promise obsolete. - this.pendingBackendInitId++; - this.state.dispose(); - this.ENV.reset(); - this.state = new EngineState(); - for (var backendName in this.registry) { - this.disposeRegisteredKernels(backendName); - this.registry[backendName].dispose(); - delete this.registry[backendName]; - } - this.backendName = null; - this.backendInstance = null; - this.pendingBackendInit = null; - }; - Engine.nextTensorId = 0; - Engine.nextVariableId = 0; - return Engine; - }()); - function ones(shape) { - var values = makeOnesTypedArray(sizeFromShape(shape), 'float32'); - return ENGINE.makeTensor(values, shape, 'float32'); - } - var GLOBAL; - function getGlobalNamespace() { - if (GLOBAL == null) { - // tslint:disable-next-line:no-any - var ns = void 0; - if (typeof (window) !== 'undefined') { - ns = window; - } - else if (typeof (global) !== 'undefined') { - ns = global; - } - else if (typeof (process) !== 'undefined') { - ns = process; - } - else if (typeof (self) !== 'undefined') { - ns = self; - } - else { - throw new Error('Could not find a global object'); - } - GLOBAL = ns; - } - return GLOBAL; - } - function getOrMakeEngine() { - var ns = getGlobalNamespace(); - if (ns._tfengine == null) { - var environment = new Environment(ns); - ns._tfengine = new Engine(environment); - } - setEnvironmentGlobal(ns._tfengine.ENV); - // Tell the current tensor interface that the global engine is responsible - // for tracking. - setTensorTracker(function () { return ns._tfengine; }); - return ns._tfengine; - } - var ENGINE = getOrMakeEngine(); - - /** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - function isMobile() { - // tslint:disable-next-line:no-any - var a = navigator.userAgent || navigator.vendor || window.opera; - // tslint:disable-next-line:max-line-length - return /(android|bb\d+|meego).+mobile|avantgo|bada\/|blackberry|blazer|compal|elaine|fennec|hiptop|iemobile|ip(hone|od)|iris|kindle|lge |maemo|midp|mmp|mobile.+firefox|netfront|opera m(ob|in)i|palm( os)?|phone|p(ixi|re)\/|plucker|pocket|psp|series(4|6)0|symbian|treo|up\.(browser|link)|vodafone|wap|windows ce|xda|xiino/i - .test(a) || - // tslint:disable-next-line:max-line-length - /1207|6310|6590|3gso|4thp|50[1-6]i|770s|802s|a wa|abac|ac(er|oo|s\-)|ai(ko|rn)|al(av|ca|co)|amoi|an(ex|ny|yw)|aptu|ar(ch|go)|as(te|us)|attw|au(di|\-m|r |s )|avan|be(ck|ll|nq)|bi(lb|rd)|bl(ac|az)|br(e|v)w|bumb|bw\-(n|u)|c55\/|capi|ccwa|cdm\-|cell|chtm|cldc|cmd\-|co(mp|nd)|craw|da(it|ll|ng)|dbte|dc\-s|devi|dica|dmob|do(c|p)o|ds(12|\-d)|el(49|ai)|em(l2|ul)|er(ic|k0)|esl8|ez([4-7]0|os|wa|ze)|fetc|fly(\-|_)|g1 u|g560|gene|gf\-5|g\-mo|go(\.w|od)|gr(ad|un)|haie|hcit|hd\-(m|p|t)|hei\-|hi(pt|ta)|hp( i|ip)|hs\-c|ht(c(\-| |_|a|g|p|s|t)|tp)|hu(aw|tc)|i\-(20|go|ma)|i230|iac( |\-|\/)|ibro|idea|ig01|ikom|im1k|inno|ipaq|iris|ja(t|v)a|jbro|jemu|jigs|kddi|keji|kgt( |\/)|klon|kpt |kwc\-|kyo(c|k)|le(no|xi)|lg( g|\/(k|l|u)|50|54|\-[a-w])|libw|lynx|m1\-w|m3ga|m50\/|ma(te|ui|xo)|mc(01|21|ca)|m\-cr|me(rc|ri)|mi(o8|oa|ts)|mmef|mo(01|02|bi|de|do|t(\-| |o|v)|zz)|mt(50|p1|v )|mwbp|mywa|n10[0-2]|n20[2-3]|n30(0|2)|n50(0|2|5)|n7(0(0|1)|10)|ne((c|m)\-|on|tf|wf|wg|wt)|nok(6|i)|nzph|o2im|op(ti|wv)|oran|owg1|p800|pan(a|d|t)|pdxg|pg(13|\-([1-8]|c))|phil|pire|pl(ay|uc)|pn\-2|po(ck|rt|se)|prox|psio|pt\-g|qa\-a|qc(07|12|21|32|60|\-[2-7]|i\-)|qtek|r380|r600|raks|rim9|ro(ve|zo)|s55\/|sa(ge|ma|mm|ms|ny|va)|sc(01|h\-|oo|p\-)|sdk\/|se(c(\-|0|1)|47|mc|nd|ri)|sgh\-|shar|sie(\-|m)|sk\-0|sl(45|id)|sm(al|ar|b3|it|t5)|so(ft|ny)|sp(01|h\-|v\-|v )|sy(01|mb)|t2(18|50)|t6(00|10|18)|ta(gt|lk)|tcl\-|tdg\-|tel(i|m)|tim\-|t\-mo|to(pl|sh)|ts(70|m\-|m3|m5)|tx\-9|up(\.b|g1|si)|utst|v400|v750|veri|vi(rg|te)|vk(40|5[0-3]|\-v)|vm40|voda|vulc|vx(52|53|60|61|70|80|81|83|85|98)|w3c(\-| )|webc|whit|wi(g |nc|nw)|wmlb|wonu|x700|yas\-|your|zeto|zte\-/i - .test(a.substr(0, 4)); - } - function isBrowser() { - return (typeof window !== 'undefined' && window.document != null) || - //@ts-ignore - (typeof WorkerGlobalScope !== 'undefined'); - } - - /** - * @license - * Copyright 2019 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var ENV = env(); - /** - * This file contains environment-related flag registrations. - */ - /** Whether to enable debug mode. */ - ENV.registerFlag('DEBUG', function () { return false; }, function (debugValue) { - if (debugValue) { - console.warn('Debugging mode is ON. The output of every math call will ' + - 'be downloaded to CPU and checked for NaNs. ' + - 'This significantly impacts performance.'); - } - }); - /** Whether we are in a browser (as versus, say, node.js) environment. */ - ENV.registerFlag('IS_BROWSER', function () { return isBrowser(); }); - /** Whether we are in a browser (as versus, say, node.js) environment. */ - ENV.registerFlag('IS_NODE', function () { return (typeof process !== 'undefined') && - (typeof process.versions !== 'undefined') && - (typeof process.versions.node !== 'undefined'); }); - /** Whether this browser is Chrome. */ - ENV.registerFlag('IS_CHROME', function () { return typeof navigator !== 'undefined' && navigator != null && - navigator.userAgent != null && /Chrome/.test(navigator.userAgent) && - /Google Inc/.test(navigator.vendor); }); - /** - * True when the environment is "production" where we disable safety checks - * to gain performance. - */ - ENV.registerFlag('PROD', function () { return false; }); - /** - * Whether to do sanity checks when inferring a shape from user-provided - * values, used when creating a new tensor. - */ - ENV.registerFlag('TENSORLIKE_CHECK_SHAPE_CONSISTENCY', function () { return ENV.getBool('DEBUG'); }); - /** Whether deprecation warnings are enabled. */ - ENV.registerFlag('DEPRECATION_WARNINGS_ENABLED', function () { return true; }); - /** True if running unit tests. */ - ENV.registerFlag('IS_TEST', function () { return false; }); - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var contexts = {}; - var WEBGL_ATTRIBUTES = { - alpha: false, - antialias: false, - premultipliedAlpha: false, - preserveDrawingBuffer: false, - depth: false, - stencil: false, - failIfMajorPerformanceCaveat: true - }; - function setWebGLContext(webGLVersion, gl) { - contexts[webGLVersion] = gl; - } - function getWebGLContext(webGLVersion) { - if (!(webGLVersion in contexts)) { - contexts[webGLVersion] = getWebGLRenderingContext(webGLVersion); - } - var gl = contexts[webGLVersion]; - if (gl.isContextLost()) { - delete contexts[webGLVersion]; - return getWebGLContext(webGLVersion); - } - gl.disable(gl.DEPTH_TEST); - gl.disable(gl.STENCIL_TEST); - gl.disable(gl.BLEND); - gl.disable(gl.DITHER); - gl.disable(gl.POLYGON_OFFSET_FILL); - gl.disable(gl.SAMPLE_COVERAGE); - gl.enable(gl.SCISSOR_TEST); - gl.enable(gl.CULL_FACE); - gl.cullFace(gl.BACK); - return contexts[webGLVersion]; - } - function createCanvas(webGLVersion) { - if (typeof OffscreenCanvas !== 'undefined' && webGLVersion === 2) { - return new OffscreenCanvas(300, 150); - } - else if (typeof document !== 'undefined') { - return document.createElement('canvas'); - } - else { - throw new Error('Cannot create a canvas in this context'); - } - } - function getWebGLRenderingContext(webGLVersion) { - if (webGLVersion !== 1 && webGLVersion !== 2) { - throw new Error('Cannot get WebGL rendering context, WebGL is disabled.'); - } - var canvas = createCanvas(webGLVersion); - canvas.addEventListener('webglcontextlost', function (ev) { - ev.preventDefault(); - delete contexts[webGLVersion]; - }, false); - if (webGLVersion === 1) { - return (canvas.getContext('webgl', WEBGL_ATTRIBUTES) || - canvas.getContext('experimental-webgl', WEBGL_ATTRIBUTES)); - } - return canvas.getContext('webgl2', WEBGL_ATTRIBUTES); - } - - /** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var PackingScheme; - (function (PackingScheme) { - /** - * All values in a single texel are densely packed without any constraints. - * - * This is how the shader encodes a tensor with shape = [2, 3, 4] - * (indices are [batch, row, col]). - * - * 000|001 010|011 020|021 - * ------- ------- ------- - * 002|003 012|013 022|023 - * - * 100|101 110|111 120|121 - * ------- ------- ------- - * 102|103 112|113 122|123 - * - */ - PackingScheme[PackingScheme["DENSE"] = 0] = "DENSE"; - /** - * Single texels contain only values from the same batch, and from adjacent - * rows and columns. - * - * This is how the shader encodes a tensor with shape = [2, 3, 5] - * (indices are [batch, row, col]). - * - * 000|001 002|003 004|xxx 020|021 022|023 024|xxx - * ------- ------- ------- ------- ------- ------- - * 010|011 012|013 014|xxx xxx|xxx xxx|xxx xxx|xxx - * - * 100|101 102|103 104|xxx 120|121 122|123 124|xxx - * ------- ------- ------- ------- ------- ------- - * 110|111 112|113 114|xxx xxx|xxx xxx|xxx xxx|xxx - * - */ - PackingScheme[PackingScheme["SHARED_BATCH"] = 1] = "SHARED_BATCH"; - })(PackingScheme || (PackingScheme = {})); - var TextureUsage; - (function (TextureUsage) { - TextureUsage[TextureUsage["RENDER"] = 0] = "RENDER"; - TextureUsage[TextureUsage["UPLOAD"] = 1] = "UPLOAD"; - TextureUsage[TextureUsage["PIXELS"] = 2] = "PIXELS"; - TextureUsage[TextureUsage["DOWNLOAD"] = 3] = "DOWNLOAD"; - })(TextureUsage || (TextureUsage = {})); - var PhysicalTextureType; - (function (PhysicalTextureType) { - PhysicalTextureType[PhysicalTextureType["UNPACKED_FLOAT16"] = 0] = "UNPACKED_FLOAT16"; - PhysicalTextureType[PhysicalTextureType["UNPACKED_FLOAT32"] = 1] = "UNPACKED_FLOAT32"; - PhysicalTextureType[PhysicalTextureType["PACKED_4X1_UNSIGNED_BYTE"] = 2] = "PACKED_4X1_UNSIGNED_BYTE"; - PhysicalTextureType[PhysicalTextureType["PACKED_2X2_FLOAT32"] = 3] = "PACKED_2X2_FLOAT32"; - PhysicalTextureType[PhysicalTextureType["PACKED_2X2_FLOAT16"] = 4] = "PACKED_2X2_FLOAT16"; - })(PhysicalTextureType || (PhysicalTextureType = {})); - function getUnpackedMatrixTextureShapeWidthHeight(rows, columns) { - return [columns, rows]; - } - function getUnpackedArraySizeFromMatrixSize(matrixSize, channelsPerTexture) { - return matrixSize * channelsPerTexture; - } - /** - * Get shape for densely packed RGBA texture. - */ - function getDenseTexShape(shape) { - var size = sizeFromShape(shape); - var texelsNeeded = Math.ceil(size / 4); - return sizeToSquarishShape(texelsNeeded); - } - function getPackedMatrixTextureShapeWidthHeight(rows, columns) { - return [ - Math.max(1, Math.ceil(columns / 2)), Math.max(1, Math.ceil(rows / 2)) - ]; - } - function getPackedRGBAArraySizeFromMatrixShape(rows, columns) { - var _a = getPackedMatrixTextureShapeWidthHeight(rows, columns), w = _a[0], h = _a[1]; - return w * h * 4; - } - function getTextureConfig( - // tslint:disable-next-line:no-any - gl, textureHalfFloatExtension) { - // tslint:disable-next-line:no-any - var glany = gl; - var internalFormatFloat; - var internalFormatHalfFloat; - var internalFormatPackedHalfFloat; - var internalFormatPackedFloat; - var textureFormatFloat; - var downloadTextureFormat; - var downloadUnpackNumChannels; - var defaultNumChannels; - var textureTypeHalfFloat; - var textureTypeFloat; - if (env().getNumber('WEBGL_VERSION') === 2) { - internalFormatFloat = glany.R32F; - internalFormatHalfFloat = glany.R16F; - internalFormatPackedHalfFloat = glany.RGBA16F; - internalFormatPackedFloat = glany.RGBA32F; - textureFormatFloat = glany.RED; - downloadUnpackNumChannels = 4; - defaultNumChannels = 1; - textureTypeHalfFloat = glany.HALF_FLOAT; - textureTypeFloat = glany.FLOAT; - } - else { - internalFormatFloat = gl.RGBA; - internalFormatHalfFloat = gl.RGBA; - internalFormatPackedHalfFloat = gl.RGBA; - internalFormatPackedFloat = glany.RGBA; - textureFormatFloat = gl.RGBA; - downloadUnpackNumChannels = 4; - defaultNumChannels = 4; - textureTypeHalfFloat = textureHalfFloatExtension != null ? - textureHalfFloatExtension.HALF_FLOAT_OES : - null; - textureTypeFloat = gl.FLOAT; - } - downloadTextureFormat = gl.RGBA; - return { - internalFormatFloat: internalFormatFloat, - internalFormatHalfFloat: internalFormatHalfFloat, - internalFormatPackedHalfFloat: internalFormatPackedHalfFloat, - internalFormatPackedFloat: internalFormatPackedFloat, - textureFormatFloat: textureFormatFloat, - downloadTextureFormat: downloadTextureFormat, - downloadUnpackNumChannels: downloadUnpackNumChannels, - defaultNumChannels: defaultNumChannels, - textureTypeHalfFloat: textureTypeHalfFloat, - textureTypeFloat: textureTypeFloat - }; - } - - /** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - function callAndCheck(gl, debugMode, func) { - var returnValue = func(); - if (debugMode) { - checkWebGLError(gl); - } - return returnValue; - } - function checkWebGLError(gl) { - var error = gl.getError(); - if (error !== gl.NO_ERROR) { - throw new Error('WebGL Error: ' + getWebGLErrorMessage(gl, error)); - } - } - // https://en.wikipedia.org/wiki/Half-precision_floating-point_format - var MIN_FLOAT16 = 5.96e-8; - var MAX_FLOAT16 = 65504; - function canBeRepresented(num) { - if (env().getBool('WEBGL_RENDER_FLOAT32_ENABLED') || num === 0 || - (MIN_FLOAT16 < Math.abs(num) && Math.abs(num) < MAX_FLOAT16)) { - return true; - } - return false; - } - function getWebGLErrorMessage(gl, status) { - switch (status) { - case gl.NO_ERROR: - return 'NO_ERROR'; - case gl.INVALID_ENUM: - return 'INVALID_ENUM'; - case gl.INVALID_VALUE: - return 'INVALID_VALUE'; - case gl.INVALID_OPERATION: - return 'INVALID_OPERATION'; - case gl.INVALID_FRAMEBUFFER_OPERATION: - return 'INVALID_FRAMEBUFFER_OPERATION'; - case gl.OUT_OF_MEMORY: - return 'OUT_OF_MEMORY'; - case gl.CONTEXT_LOST_WEBGL: - return 'CONTEXT_LOST_WEBGL'; - default: - return "Unknown error code " + status; - } - } - function getExtensionOrThrow(gl, debug, extensionName) { - return throwIfNull(gl, debug, function () { return gl.getExtension(extensionName); }, 'Extension "' + extensionName + '" not supported on this browser.'); - } - function createVertexShader(gl, debug, vertexShaderSource) { - var vertexShader = throwIfNull(gl, debug, function () { return gl.createShader(gl.VERTEX_SHADER); }, 'Unable to create vertex WebGLShader.'); - callAndCheck(gl, debug, function () { return gl.shaderSource(vertexShader, vertexShaderSource); }); - callAndCheck(gl, debug, function () { return gl.compileShader(vertexShader); }); - if (gl.getShaderParameter(vertexShader, gl.COMPILE_STATUS) === false) { - console.log(gl.getShaderInfoLog(vertexShader)); - throw new Error('Failed to compile vertex shader.'); - } - return vertexShader; - } - function createFragmentShader(gl, debug, fragmentShaderSource) { - var fragmentShader = throwIfNull(gl, debug, function () { return gl.createShader(gl.FRAGMENT_SHADER); }, 'Unable to create fragment WebGLShader.'); - callAndCheck(gl, debug, function () { return gl.shaderSource(fragmentShader, fragmentShaderSource); }); - callAndCheck(gl, debug, function () { return gl.compileShader(fragmentShader); }); - if (gl.getShaderParameter(fragmentShader, gl.COMPILE_STATUS) === false) { - logShaderSourceAndInfoLog(fragmentShaderSource, gl.getShaderInfoLog(fragmentShader)); - throw new Error('Failed to compile fragment shader.'); - } - return fragmentShader; - } - var lineNumberRegex = /ERROR: [0-9]+:([0-9]+):/g; - function logShaderSourceAndInfoLog(shaderSource, shaderInfoLog) { - var lineNumberRegexResult = lineNumberRegex.exec(shaderInfoLog); - if (lineNumberRegexResult == null) { - console.log("Couldn't parse line number in error: " + shaderInfoLog); - console.log(shaderSource); - return; - } - var lineNumber = +lineNumberRegexResult[1]; - var shaderLines = shaderSource.split('\n'); - var pad = shaderLines.length.toString().length + 2; - var linesWithLineNumbers = shaderLines.map(function (line, lineNumber) { - return rightPad((lineNumber + 1).toString(), pad) + line; - }); - var maxLineLength = 0; - for (var i = 0; i < linesWithLineNumbers.length; i++) { - maxLineLength = Math.max(linesWithLineNumbers[i].length, maxLineLength); - } - var beforeErrorLines = linesWithLineNumbers.slice(0, lineNumber - 1); - var errorLine = linesWithLineNumbers.slice(lineNumber - 1, lineNumber); - var afterErrorLines = linesWithLineNumbers.slice(lineNumber); - console.log(beforeErrorLines.join('\n')); - console.log(shaderInfoLog.split('\n')[0]); - console.log("%c " + rightPad(errorLine[0], maxLineLength), 'border:1px solid red; background-color:#e3d2d2; color:#a61717'); - console.log(afterErrorLines.join('\n')); - } - function createProgram(gl, debug) { - return throwIfNull(gl, debug, function () { return gl.createProgram(); }, 'Unable to create WebGLProgram.'); - } - function linkProgram(gl, debug, program) { - callAndCheck(gl, debug, function () { return gl.linkProgram(program); }); - if (gl.getProgramParameter(program, gl.LINK_STATUS) === false) { - console.log(gl.getProgramInfoLog(program)); - throw new Error('Failed to link vertex and fragment shaders.'); - } - } - function validateProgram(gl, debug, program) { - callAndCheck(gl, debug, function () { return gl.validateProgram(program); }); - if (gl.getProgramParameter(program, gl.VALIDATE_STATUS) === false) { - console.log(gl.getProgramInfoLog(program)); - throw new Error('Shader program validation failed.'); - } - } - function createStaticVertexBuffer(gl, debug, data) { - var buffer = throwIfNull(gl, debug, function () { return gl.createBuffer(); }, 'Unable to create WebGLBuffer'); - callAndCheck(gl, debug, function () { return gl.bindBuffer(gl.ARRAY_BUFFER, buffer); }); - callAndCheck(gl, debug, function () { return gl.bufferData(gl.ARRAY_BUFFER, data, gl.STATIC_DRAW); }); - return buffer; - } - function createStaticIndexBuffer(gl, debug, data) { - var buffer = throwIfNull(gl, debug, function () { return gl.createBuffer(); }, 'Unable to create WebGLBuffer'); - callAndCheck(gl, debug, function () { return gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, buffer); }); - callAndCheck(gl, debug, function () { return gl.bufferData(gl.ELEMENT_ARRAY_BUFFER, data, gl.STATIC_DRAW); }); - return buffer; - } - function getNumChannels() { - if (env().getNumber('WEBGL_VERSION') === 2) { - return 1; - } - return 4; - } - function createTexture(gl, debug) { - return throwIfNull(gl, debug, function () { return gl.createTexture(); }, 'Unable to create WebGLTexture.'); - } - function validateTextureSize(width, height) { - var maxTextureSize = env().getNumber('WEBGL_MAX_TEXTURE_SIZE'); - if ((width <= 0) || (height <= 0)) { - var requested = "[" + width + "x" + height + "]"; - throw new Error('Requested texture size ' + requested + ' is invalid.'); - } - if ((width > maxTextureSize) || (height > maxTextureSize)) { - var requested = "[" + width + "x" + height + "]"; - var max = "[" + maxTextureSize + "x" + maxTextureSize + "]"; - throw new Error('Requested texture size ' + requested + - ' greater than WebGL maximum on this browser / GPU ' + max + '.'); - } - } - function createFramebuffer(gl, debug) { - return throwIfNull(gl, debug, function () { return gl.createFramebuffer(); }, 'Unable to create WebGLFramebuffer.'); - } - function bindVertexBufferToProgramAttribute(gl, debug, program, attribute, buffer, arrayEntriesPerItem, itemStrideInBytes, itemOffsetInBytes) { - var loc = gl.getAttribLocation(program, attribute); - if (loc === -1) { - // The GPU compiler decided to strip out this attribute because it's unused, - // thus no need to bind. - return false; - } - callAndCheck(gl, debug, function () { return gl.bindBuffer(gl.ARRAY_BUFFER, buffer); }); - callAndCheck(gl, debug, function () { return gl.vertexAttribPointer(loc, arrayEntriesPerItem, gl.FLOAT, false, itemStrideInBytes, itemOffsetInBytes); }); - callAndCheck(gl, debug, function () { return gl.enableVertexAttribArray(loc); }); - return true; - } - function bindTextureUnit(gl, debug, texture, textureUnit) { - validateTextureUnit(gl, textureUnit); - callAndCheck(gl, debug, function () { return gl.activeTexture(gl.TEXTURE0 + textureUnit); }); - callAndCheck(gl, debug, function () { return gl.bindTexture(gl.TEXTURE_2D, texture); }); - } - function unbindTextureUnit(gl, debug, textureUnit) { - validateTextureUnit(gl, textureUnit); - callAndCheck(gl, debug, function () { return gl.activeTexture(gl.TEXTURE0 + textureUnit); }); - callAndCheck(gl, debug, function () { return gl.bindTexture(gl.TEXTURE_2D, null); }); - } - function getProgramUniformLocationOrThrow(gl, debug, program, uniformName) { - return throwIfNull(gl, debug, function () { return gl.getUniformLocation(program, uniformName); }, 'uniform "' + uniformName + '" not present in program.'); - } - function getProgramUniformLocation(gl, program, uniformName) { - return gl.getUniformLocation(program, uniformName); - } - function bindTextureToProgramUniformSampler(gl, debug, program, texture, uniformSamplerLocation, textureUnit) { - callAndCheck(gl, debug, function () { return bindTextureUnit(gl, debug, texture, textureUnit); }); - callAndCheck(gl, debug, function () { return gl.uniform1i(uniformSamplerLocation, textureUnit); }); - } - function bindCanvasToFramebuffer(gl, debug) { - callAndCheck(gl, debug, function () { return gl.bindFramebuffer(gl.FRAMEBUFFER, null); }); - callAndCheck(gl, debug, function () { return gl.viewport(0, 0, gl.canvas.width, gl.canvas.height); }); - callAndCheck(gl, debug, function () { return gl.scissor(0, 0, gl.canvas.width, gl.canvas.height); }); - } - function bindColorTextureToFramebuffer(gl, debug, texture, framebuffer) { - callAndCheck(gl, debug, function () { return gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer); }); - callAndCheck(gl, debug, function () { return gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0); }); - } - function unbindColorTextureFromFramebuffer(gl, debug, framebuffer) { - callAndCheck(gl, debug, function () { return gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer); }); - callAndCheck(gl, debug, function () { return gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, null, 0); }); - } - function validateFramebuffer(gl) { - var status = gl.checkFramebufferStatus(gl.FRAMEBUFFER); - if (status !== gl.FRAMEBUFFER_COMPLETE) { - throw new Error('Error binding framebuffer: ' + getFramebufferErrorMessage(gl, status)); - } - } - function getFramebufferErrorMessage(gl, status) { - switch (status) { - case gl.FRAMEBUFFER_INCOMPLETE_ATTACHMENT: - return 'FRAMEBUFFER_INCOMPLETE_ATTACHMENT'; - case gl.FRAMEBUFFER_INCOMPLETE_MISSING_ATTACHMENT: - return 'FRAMEBUFFER_INCOMPLETE_MISSING_ATTACHMENT'; - case gl.FRAMEBUFFER_INCOMPLETE_DIMENSIONS: - return 'FRAMEBUFFER_INCOMPLETE_DIMENSIONS'; - case gl.FRAMEBUFFER_UNSUPPORTED: - return 'FRAMEBUFFER_UNSUPPORTED'; - default: - return "unknown error " + status; - } - } - function throwIfNull(gl, debug, returnTOrNull, failureMessage) { - var tOrNull = callAndCheck(gl, debug, function () { return returnTOrNull(); }); - if (tOrNull == null) { - throw new Error(failureMessage); - } - return tOrNull; - } - function validateTextureUnit(gl, textureUnit) { - var maxTextureUnit = gl.MAX_COMBINED_TEXTURE_IMAGE_UNITS - 1; - var glTextureUnit = textureUnit + gl.TEXTURE0; - if (glTextureUnit < gl.TEXTURE0 || glTextureUnit > maxTextureUnit) { - var textureUnitRange = "[gl.TEXTURE0, gl.TEXTURE" + maxTextureUnit + "]"; - throw new Error("textureUnit must be in " + textureUnitRange + "."); - } - } - function getBatchDim(shape, dimsToSkip) { - if (dimsToSkip === void 0) { dimsToSkip = 2; } - return sizeFromShape(shape.slice(0, shape.length - dimsToSkip)); - } - function getRowsCols(shape) { - if (shape.length === 0) { - throw Error('Cannot get rows and columns of an empty shape array.'); - } - return [ - shape.length > 1 ? shape[shape.length - 2] : 1, shape[shape.length - 1] - ]; - } - function getShapeAs3D(shape) { - var shapeAs3D = [1, 1, 1]; - var isScalar = shape.length === 0 || (shape.length === 1 && shape[0] === 1); - if (!isScalar) { - shapeAs3D = - [getBatchDim(shape)].concat(getRowsCols(shape)); - } - return shapeAs3D; - } - function getTextureShapeFromLogicalShape(logShape, isPacked) { - var _a; - if (isPacked === void 0) { isPacked = false; } - var maxTexSize = env().getNumber('WEBGL_MAX_TEXTURE_SIZE'); - if (isPacked) { - maxTexSize = maxTexSize * 2; - // This logic ensures we accurately count the number of packed texels needed - // to accommodate the tensor. We can only pack values in the same texel if - // they are from adjacent pairs of rows/cols within the same batch. So if a - // tensor has 3 rows, we pretend it has 4 rows in order to account for the - // fact that the texels containing the third row are half empty. - logShape = logShape.map(function (d, i) { return i >= logShape.length - 2 ? - nearestLargerEven(logShape[i]) : - logShape[i]; }); - // Packed texture height is at least 2 (the channel height of a single - // texel). - if (logShape.length === 1) { - logShape = [2, logShape[0]]; - } - } - // If logical shape is 2, we don't squeeze, since we want to match physical. - if (logShape.length !== 2) { - var squeezeResult = squeezeShape(logShape); - logShape = squeezeResult.newShape; - } - var size = sizeFromShape(logShape); - if (logShape.length <= 1 && size <= maxTexSize) { - return [1, size]; - } - else if (logShape.length === 2 && logShape[0] <= maxTexSize && - logShape[1] <= maxTexSize) { - return logShape; - } - else if (logShape.length === 3 && logShape[0] * logShape[1] <= maxTexSize && - logShape[2] <= maxTexSize) { - return [logShape[0] * logShape[1], logShape[2]]; - } - else if (logShape.length === 3 && logShape[0] <= maxTexSize && - logShape[1] * logShape[2] <= maxTexSize) { - return [logShape[0], logShape[1] * logShape[2]]; - } - else if (logShape.length === 4 && - logShape[0] * logShape[1] * logShape[2] <= maxTexSize && - logShape[3] <= maxTexSize) { - return [logShape[0] * logShape[1] * logShape[2], logShape[3]]; - } - else if (logShape.length === 4 && logShape[0] <= maxTexSize && - logShape[1] * logShape[2] * logShape[3] <= maxTexSize) { - return [logShape[0], logShape[1] * logShape[2] * logShape[3]]; - } - else { - if (isPacked) { - // For packed textures size equals the number of channels required to - // accommodate the texture data. However in order to squarify such that - // inner dimensions stay even, we rewrite size to equal the number of - // texels. Then in the return statement we rehydrate the squarified - // dimensions to channel units. - var batchDim = getBatchDim(logShape); - var rows = 2, cols = 2; - if (logShape.length) { - _a = getRowsCols(logShape), rows = _a[0], cols = _a[1]; - } - size = batchDim * (rows / 2) * (cols / 2); - return sizeToSquarishShape(size).map(function (d) { return d * 2; }); - } - return sizeToSquarishShape(size); - } - } - function isEven(n) { - return n % 2 === 0; - } - /** - * This determines whether reshaping a packed texture requires rearranging - * the data within the texture, assuming 2x2 packing. - */ - function isReshapeFree(shape1, shape2) { - shape1 = shape1.slice(-2); - shape2 = shape2.slice(-2); - if (arraysEqual(shape1, shape2)) { - return true; - } - if (!shape1.length || !shape2.length) { // One of the shapes is a scalar. - return true; - } - if (shape1[0] === 0 || shape1[1] === 0 || shape2[0] === 0 || - shape2[1] === 0) { - return true; - } - if (shape1.length !== shape2.length) { // One of the shapes is a vector. - var shape1Cols = shape1.slice(-1)[0]; - var shape2Cols = shape2.slice(-1)[0]; - if (shape1Cols === shape2Cols) { - return true; - } - if (isEven(shape1Cols) && isEven(shape2Cols) && - (shape1[0] === 1 || shape2[0] === 1)) { - return true; - } - } - return shape1[1] === shape2[1] && isEven(shape1[0]) && isEven(shape2[0]); - } - // We cache webgl params because the environment gets reset between - // unit tests and we don't want to constantly query the WebGLContext for - // MAX_TEXTURE_SIZE. - var MAX_TEXTURE_SIZE; - var MAX_TEXTURES_IN_SHADER; - function getWebGLMaxTextureSize(webGLVersion) { - if (MAX_TEXTURE_SIZE == null) { - var gl = getWebGLContext(webGLVersion); - MAX_TEXTURE_SIZE = gl.getParameter(gl.MAX_TEXTURE_SIZE); - } - return MAX_TEXTURE_SIZE; - } - function resetMaxTextureSize() { - MAX_TEXTURE_SIZE = null; - } - function resetMaxTexturesInShader() { - MAX_TEXTURES_IN_SHADER = null; - } - function getMaxTexturesInShader(webGLVersion) { - if (MAX_TEXTURES_IN_SHADER == null) { - var gl = getWebGLContext(webGLVersion); - MAX_TEXTURES_IN_SHADER = gl.getParameter(gl.MAX_TEXTURE_IMAGE_UNITS); - } - // We cap at 16 to avoid spurious runtime "memory exhausted" error. - return Math.min(16, MAX_TEXTURES_IN_SHADER); - } - function getWebGLDisjointQueryTimerVersion(webGLVersion) { - if (webGLVersion === 0) { - return 0; - } - var queryTimerVersion; - var gl = getWebGLContext(webGLVersion); - if (hasExtension(gl, 'EXT_disjoint_timer_query_webgl2') && - webGLVersion === 2) { - queryTimerVersion = 2; - } - else if (hasExtension(gl, 'EXT_disjoint_timer_query')) { - queryTimerVersion = 1; - } - else { - queryTimerVersion = 0; - } - return queryTimerVersion; - } - function hasExtension(gl, extensionName) { - var ext = gl.getExtension(extensionName); - return ext != null; - } - function isWebGLVersionEnabled(webGLVersion) { - try { - var gl = getWebGLContext(webGLVersion); - if (gl != null) { - return true; - } - } - catch (e) { - return false; - } - return false; - } - function isCapableOfRenderingToFloatTexture(webGLVersion) { - if (webGLVersion === 0) { - return false; - } - var gl = getWebGLContext(webGLVersion); - if (webGLVersion === 1) { - if (!hasExtension(gl, 'OES_texture_float')) { - return false; - } - } - else { - if (!hasExtension(gl, 'EXT_color_buffer_float')) { - return false; - } - } - var isFrameBufferComplete = createFloatTextureAndBindToFramebuffer(gl); - return isFrameBufferComplete; - } - /** - * Check if we can download values from a float/half-float texture. - * - * Note that for performance reasons we use binding a texture to a framebuffer - * as a proxy for ability to download float values later using readPixels. The - * texture params of this texture will not match those in readPixels exactly - * but if we are unable to bind some kind of float texture to the frameBuffer - * then we definitely will not be able to read float values from it. - */ - function isDownloadFloatTextureEnabled(webGLVersion) { - if (webGLVersion === 0) { - return false; - } - var gl = getWebGLContext(webGLVersion); - if (webGLVersion === 1) { - if (!hasExtension(gl, 'OES_texture_float')) { - return false; - } - if (!hasExtension(gl, 'WEBGL_color_buffer_float')) { - return false; - } - } - else { - if (hasExtension(gl, 'EXT_color_buffer_float')) { - return createFloatTextureAndBindToFramebuffer(gl); - } - var COLOR_BUFFER_HALF_FLOAT = 'EXT_color_buffer_half_float'; - if (hasExtension(gl, COLOR_BUFFER_HALF_FLOAT)) { - var textureHalfFloatExtension = gl.getExtension(COLOR_BUFFER_HALF_FLOAT); - return createHalfFloatTextureAndBindToFramebuffer(gl, textureHalfFloatExtension); - } - return false; - } - var isFrameBufferComplete = createFloatTextureAndBindToFramebuffer(gl); - return isFrameBufferComplete; - } - function createFloatTextureAndBindToFramebuffer(gl) { - var texConfig = getTextureConfig(gl); - var texture = gl.createTexture(); - gl.bindTexture(gl.TEXTURE_2D, texture); - var width = 1; - var height = 1; - gl.texImage2D(gl.TEXTURE_2D, 0, texConfig.internalFormatFloat, width, height, 0, texConfig.textureFormatFloat, texConfig.textureTypeFloat, null); - var frameBuffer = gl.createFramebuffer(); - gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer); - gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0); - var isFrameBufferComplete = gl.checkFramebufferStatus(gl.FRAMEBUFFER) === gl.FRAMEBUFFER_COMPLETE; - gl.bindTexture(gl.TEXTURE_2D, null); - gl.bindFramebuffer(gl.FRAMEBUFFER, null); - gl.deleteTexture(texture); - gl.deleteFramebuffer(frameBuffer); - return isFrameBufferComplete; - } - function createHalfFloatTextureAndBindToFramebuffer( - // tslint:disable-next-line:no-any - gl, textureHalfFloatExtension) { - var texConfig = getTextureConfig(gl, textureHalfFloatExtension); - var texture = gl.createTexture(); - gl.bindTexture(gl.TEXTURE_2D, texture); - var width = 1; - var height = 1; - gl.texImage2D(gl.TEXTURE_2D, 0, texConfig.internalFormatHalfFloat, width, height, 0, texConfig.textureFormatFloat, texConfig.textureTypeHalfFloat, null); - var frameBuffer = gl.createFramebuffer(); - gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer); - gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0); - var isFrameBufferComplete = gl.checkFramebufferStatus(gl.FRAMEBUFFER) === gl.FRAMEBUFFER_COMPLETE; - gl.bindTexture(gl.TEXTURE_2D, null); - gl.bindFramebuffer(gl.FRAMEBUFFER, null); - gl.deleteTexture(texture); - gl.deleteFramebuffer(frameBuffer); - return isFrameBufferComplete; - } - function isWebGLFenceEnabled(webGLVersion) { - if (webGLVersion !== 2) { - return false; - } - var gl = getWebGLContext(webGLVersion); - // tslint:disable-next-line:no-any - var isEnabled = gl.fenceSync != null; - return isEnabled; - } - - var webgl_util = /*#__PURE__*/Object.freeze({ - callAndCheck: callAndCheck, - canBeRepresented: canBeRepresented, - getWebGLErrorMessage: getWebGLErrorMessage, - getExtensionOrThrow: getExtensionOrThrow, - createVertexShader: createVertexShader, - createFragmentShader: createFragmentShader, - createProgram: createProgram, - linkProgram: linkProgram, - validateProgram: validateProgram, - createStaticVertexBuffer: createStaticVertexBuffer, - createStaticIndexBuffer: createStaticIndexBuffer, - getNumChannels: getNumChannels, - createTexture: createTexture, - validateTextureSize: validateTextureSize, - createFramebuffer: createFramebuffer, - bindVertexBufferToProgramAttribute: bindVertexBufferToProgramAttribute, - bindTextureUnit: bindTextureUnit, - unbindTextureUnit: unbindTextureUnit, - getProgramUniformLocationOrThrow: getProgramUniformLocationOrThrow, - getProgramUniformLocation: getProgramUniformLocation, - bindTextureToProgramUniformSampler: bindTextureToProgramUniformSampler, - bindCanvasToFramebuffer: bindCanvasToFramebuffer, - bindColorTextureToFramebuffer: bindColorTextureToFramebuffer, - unbindColorTextureFromFramebuffer: unbindColorTextureFromFramebuffer, - validateFramebuffer: validateFramebuffer, - getFramebufferErrorMessage: getFramebufferErrorMessage, - getBatchDim: getBatchDim, - getRowsCols: getRowsCols, - getShapeAs3D: getShapeAs3D, - getTextureShapeFromLogicalShape: getTextureShapeFromLogicalShape, - isReshapeFree: isReshapeFree, - getWebGLMaxTextureSize: getWebGLMaxTextureSize, - resetMaxTextureSize: resetMaxTextureSize, - resetMaxTexturesInShader: resetMaxTexturesInShader, - getMaxTexturesInShader: getMaxTexturesInShader, - getWebGLDisjointQueryTimerVersion: getWebGLDisjointQueryTimerVersion, - hasExtension: hasExtension, - isWebGLVersionEnabled: isWebGLVersionEnabled, - isCapableOfRenderingToFloatTexture: isCapableOfRenderingToFloatTexture, - isDownloadFloatTextureEnabled: isDownloadFloatTextureEnabled, - isWebGLFenceEnabled: isWebGLFenceEnabled - }); - - /** - * @license - * Copyright 2019 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var ENV$1 = env(); - /** - * This file contains WebGL-specific flag registrations. - */ - /** - * True if WebGL is supported. - */ - ENV$1.registerFlag('HAS_WEBGL', function () { return ENV$1.getNumber('WEBGL_VERSION') > 0; }); - /** 0: No WebGL, 1: WebGL 1.0, 2: WebGL 2.0. */ - ENV$1.registerFlag('WEBGL_VERSION', function () { - if (isWebGLVersionEnabled(2)) { - return 2; - } - else if (isWebGLVersionEnabled(1)) { - return 1; - } - return 0; - }); - ENV$1.registerFlag('WEBGL_BUFFER_SUPPORTED', function () { return ENV$1.get('WEBGL_VERSION') === 2; }); - /** Whether the WebGL backend will sometimes forward ops to the CPU. */ - ENV$1.registerFlag('WEBGL_CPU_FORWARD', function () { return true; }); - /** Whether the WebGL backend will always use f16 textures for rendering. */ - ENV$1.registerFlag('WEBGL_FORCE_F16_TEXTURES', function () { return false; }); - /** Whether to turn all packing related flags on. */ - ENV$1.registerFlag('WEBGL_PACK', function () { return ENV$1.getBool('HAS_WEBGL'); }); - /** Whether we will pack the batchnormalization op. */ - ENV$1.registerFlag('WEBGL_PACK_NORMALIZATION', function () { return ENV$1.getBool('WEBGL_PACK'); }); - /** Whether we will pack the clip op. */ - ENV$1.registerFlag('WEBGL_PACK_CLIP', function () { return ENV$1.getBool('WEBGL_PACK'); }); - /** Whether we will pack the depthwise conv op. */ - // TODO: https://github.com/tensorflow/tfjs/issues/1679 - ENV$1.registerFlag('WEBGL_PACK_DEPTHWISECONV', function () { return false; }); - /** Whether we will pack binary ops. */ - ENV$1.registerFlag('WEBGL_PACK_BINARY_OPERATIONS', function () { return ENV$1.getBool('WEBGL_PACK'); }); - /** Whether we will pack unary ops. */ - ENV$1.registerFlag('WEBGL_PACK_UNARY_OPERATIONS', function () { return ENV$1.getBool('WEBGL_PACK'); }); - /** Whether we will pack array ops. */ - ENV$1.registerFlag('WEBGL_PACK_ARRAY_OPERATIONS', function () { return ENV$1.getBool('WEBGL_PACK'); }); - /** Whether we will pack image ops. */ - ENV$1.registerFlag('WEBGL_PACK_IMAGE_OPERATIONS', function () { return ENV$1.getBool('WEBGL_PACK'); }); - /** Whether we will pack reduce ops. */ - ENV$1.registerFlag('WEBGL_PACK_REDUCE', function () { return ENV$1.getBool('WEBGL_PACK'); }); - /** Whether packed WebGL kernels lazily unpack their outputs. */ - ENV$1.registerFlag('WEBGL_LAZILY_UNPACK', function () { return ENV$1.getBool('WEBGL_PACK'); }); - /** Whether we will use the im2col algorithm to speed up convolutions. */ - ENV$1.registerFlag('WEBGL_CONV_IM2COL', function () { return ENV$1.getBool('WEBGL_PACK'); }); - /** The maximum texture dimension. */ - ENV$1.registerFlag('WEBGL_MAX_TEXTURE_SIZE', function () { return getWebGLMaxTextureSize(ENV$1.getNumber('WEBGL_VERSION')); }); - /** The maximum texture dimension. */ - ENV$1.registerFlag('WEBGL_MAX_TEXTURES_IN_SHADER', function () { return getMaxTexturesInShader(ENV$1.getNumber('WEBGL_VERSION')); }); - /** - * The disjoint_query_timer extension version. - * 0: disabled, 1: EXT_disjoint_timer_query, 2: - * EXT_disjoint_timer_query_webgl2. - * In Firefox with WebGL 2.0, - * EXT_disjoint_timer_query_webgl2 is not available, so we must use the - * WebGL 1.0 extension. - */ - ENV$1.registerFlag('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION', function () { - var webGLVersion = ENV$1.getNumber('WEBGL_VERSION'); - if (webGLVersion === 0) { - return 0; - } - return getWebGLDisjointQueryTimerVersion(webGLVersion); - }); - /** - * Whether the timer object from the disjoint_query_timer extension gives - * timing information that is reliable. - */ - ENV$1.registerFlag('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE', function () { return ENV$1.getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0 && - !isMobile(); }); - /** - * Whether the device is physically capable of rendering to float32 textures. - */ - ENV$1.registerFlag('WEBGL_RENDER_FLOAT32_CAPABLE', function () { return isCapableOfRenderingToFloatTexture(ENV$1.getNumber('WEBGL_VERSION')); }); - /** - * Whether rendering to float32 textures is enabled. If disabled, renders to - * float16 textures. - */ - ENV$1.registerFlag('WEBGL_RENDER_FLOAT32_ENABLED', function () { - return ENV$1.getBool('WEBGL_FORCE_F16_TEXTURES') ? - false : - ENV$1.getBool('WEBGL_RENDER_FLOAT32_CAPABLE'); - }); - /** - * Whether downloading float textures is enabled (16 or 32 bit). If disabled, - * uses IEEE 754 encoding of the float32 values to 4 uint8 when downloading. - */ - ENV$1.registerFlag('WEBGL_DOWNLOAD_FLOAT_ENABLED', function () { return isDownloadFloatTextureEnabled(ENV$1.getNumber('WEBGL_VERSION')); }); - /** Whether the fence API is available. */ - ENV$1.registerFlag('WEBGL_FENCE_API_ENABLED', function () { return isWebGLFenceEnabled(ENV$1.getNumber('WEBGL_VERSION')); }); - /** - * Tensors with size <= than this will be uploaded as uniforms, not textures. - */ - ENV$1.registerFlag('WEBGL_SIZE_UPLOAD_UNIFORM', function () { - // Use uniform uploads only when 32bit floats are supported. In - // 16bit - // environments there are problems with comparing a 16bit texture value - // with a 32bit uniform value. - var useUniforms = ENV$1.getBool('WEBGL_RENDER_FLOAT32_ENABLED'); - return useUniforms ? 4 : 0; - }); - - /** - * @license - * Copyright 2018 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /** - * Enables production mode which disables correctness checks in favor of - * performance. - */ - /** @doc {heading: 'Environment'} */ - function enableProdMode() { - env().set('PROD', true); - } - /** - * Enables debug mode which will log information about all executed kernels: - * the elapsed time of the kernel execution, as well as the rank, shape, and - * size of the output tensor. - * - * Debug mode will significantly slow down your application as it will - * download the result of every operation to the CPU. This should not be used in - * production. Debug mode does not affect the timing information of the kernel - * execution as we do not measure download time in the kernel execution time. - * - * See also: `tf.profile`, `tf.memory`. - */ - /** @doc {heading: 'Environment'} */ - function enableDebugMode() { - env().set('DEBUG', true); - } - /** Globally disables deprecation warnings */ - function disableDeprecationWarnings() { - env().set('DEPRECATION_WARNINGS_ENABLED', false); - console.warn("TensorFlow.js deprecation warnings have been disabled."); - } - /** Warn users about deprecated functionality. */ - function deprecationWarn(msg) { - if (env().getBool('DEPRECATION_WARNINGS_ENABLED')) { - console.warn(msg + ' You can disable deprecation warnings with ' + - 'tf.disableDeprecationWarnings().'); - } - } - setDeprecationWarningFn(deprecationWarn); - /** - * Dispose all variables kept in backend engine. - */ - /** @doc {heading: 'Environment'} */ - function disposeVariables() { - ENGINE.disposeVariables(); - } - /** - * It returns the global engine that keeps track of all tensors and backends. - */ - /** @doc {heading: 'Environment'} */ - function engine() { - return ENGINE; - } - /** - * Returns memory info at the current time in the program. The result is an - * object with the following properties: - * - * - `numBytes`: Number of bytes allocated (undisposed) at this time. - * - `numTensors`: Number of unique tensors allocated. - * - `numDataBuffers`: Number of unique data buffers allocated - * (undisposed) at this time, which is ≤ the number of tensors - * (e.g. `a.reshape(newShape)` makes a new Tensor that shares the same - * data buffer with `a`). - * - `unreliable`: True if the memory usage is unreliable. See `reasons` when - * `unreliable` is true. - * - `reasons`: `string[]`, reasons why the memory is unreliable, present if - * `unreliable` is true. - * - * WebGL Properties: - * - `numBytesInGPU`: Number of bytes allocated (undisposed) in the GPU only at - * this time. - */ - /** @doc {heading: 'Performance', subheading: 'Memory'} */ - function memory() { - return ENGINE.memory(); - } - /** - * Executes the provided function `f()` and returns a promise that resolves - * with information about the function's memory use: - * - `newBytes`: the number of new bytes allocated - * - `newTensors`: the number of new tensors created - * - `peakBytes`: the peak number of bytes allocated - * - `kernels`: an array of objects for each kernel involved that reports - * their input and output shapes, number of bytes used, and number of new - * tensors created. - * - * ```js - * const profile = await tf.profile(() => { - * const x = tf.tensor1d([1, 2, 3]); - * let x2 = x.square(); - * x2.dispose(); - * x2 = x.square(); - * x2.dispose(); - * return x; - * }); - * - * console.log(`newBytes: ${profile.newBytes}`); - * console.log(`newTensors: ${profile.newTensors}`); - * console.log(`byte usage over all kernels: ${profile.kernels.map(k => - * k.totalBytesSnapshot)}`); - * ``` - * - */ - /** @doc {heading: 'Performance', subheading: 'Profile'} */ - function profile(f) { - return ENGINE.profile(f); - } - /** - * Executes the provided function `fn` and after it is executed, cleans up all - * intermediate tensors allocated by `fn` except those returned by `fn`. - * `fn` must not return a Promise (async functions not allowed). The returned - * result can be a complex object. - * - * Using this method helps avoid memory leaks. In general, wrap calls to - * operations in `tf.tidy` for automatic memory cleanup. - * - * NOTE: Variables do *not* get cleaned up when inside a tidy(). If you want to - * dispose variables, please use `tf.disposeVariables` or call dispose() - * directly on variables. - * - * ```js - * // y = 2 ^ 2 + 1 - * const y = tf.tidy(() => { - * // a, b, and one will be cleaned up when the tidy ends. - * const one = tf.scalar(1); - * const a = tf.scalar(2); - * const b = a.square(); - * - * console.log('numTensors (in tidy): ' + tf.memory().numTensors); - * - * // The value returned inside the tidy function will return - * // through the tidy, in this case to the variable y. - * return b.add(one); - * }); - * - * console.log('numTensors (outside tidy): ' + tf.memory().numTensors); - * y.print(); - * ``` - * - * @param nameOrFn The name of the closure, or the function to execute. - * If a name is provided, the 2nd argument should be the function. - * If debug mode is on, the timing and the memory usage of the function - * will be tracked and displayed on the console using the provided name. - * @param fn The function to execute. - */ - /** @doc {heading: 'Performance', subheading: 'Memory'} */ - function tidy(nameOrFn, fn) { - return ENGINE.tidy(nameOrFn, fn); - } - /** - * Disposes any `tf.Tensor`s found within the provided object. - * - * @param container an object that may be a `tf.Tensor` or may directly - * contain `tf.Tensor`s, such as a `Tensor[]` or `{key: Tensor, ...}`. If - * the object is not a `tf.Tensor` or does not contain `Tensors`, nothing - * happens. In general it is safe to pass any object here, except that - * `Promise`s are not supported. - */ - /** @doc {heading: 'Performance', subheading: 'Memory'} */ - function dispose(container) { - var tensors = getTensorsInContainer(container); - tensors.forEach(function (tensor) { return tensor.dispose(); }); - } - /** - * Keeps a `tf.Tensor` generated inside a `tf.tidy` from being disposed - * automatically. - * - * ```js - * let b; - * const y = tf.tidy(() => { - * const one = tf.scalar(1); - * const a = tf.scalar(2); - * - * // b will not be cleaned up by the tidy. a and one will be cleaned up - * // when the tidy ends. - * b = tf.keep(a.square()); - * - * console.log('numTensors (in tidy): ' + tf.memory().numTensors); - * - * // The value returned inside the tidy function will return - * // through the tidy, in this case to the variable y. - * return b.add(one); - * }); - * - * console.log('numTensors (outside tidy): ' + tf.memory().numTensors); - * console.log('y:'); - * y.print(); - * console.log('b:'); - * b.print(); - * ``` - * - * @param result The tensor to keep from being disposed. - */ - /** @doc {heading: 'Performance', subheading: 'Memory'} */ - function keep(result) { - return ENGINE.keep(result); - } - /** - * Executes `f()` and returns a promise that resolves with timing - * information. - * - * The result is an object with the following properties: - * - * - `wallMs`: Wall execution time. - * - `kernelMs`: Kernel execution time, ignoring data transfer. If using the - * WebGL backend and the query timer extension is not available, this will - * return an error object. - * - On `WebGL` The following additional properties exist: - * - `uploadWaitMs`: CPU blocking time on texture uploads. - * - `downloadWaitMs`: CPU blocking time on texture downloads (readPixels). - * - * ```js - * const x = tf.randomNormal([20, 20]); - * const time = await tf.time(() => x.matMul(x)); - * - * console.log(`kernelMs: ${time.kernelMs}, wallTimeMs: ${time.wallMs}`); - * ``` - * - * @param f The function to execute and time. - */ - /** @doc {heading: 'Performance', subheading: 'Timing'} */ - function time(f) { - return ENGINE.time(f); - } - /** - * Sets the backend (cpu, webgl, wasm, etc) responsible for creating tensors and - * executing operations on those tensors. Returns a promise that resolves - * to a boolean if the backend initialization was successful. - * - * Note this disposes the current backend, if any, as well as any tensors - * associated with it. A new backend is initialized, even if it is of the - * same type as the previous one. - * - * @param backendName The name of the backend. Currently supports - * `'webgl'|'cpu'` in the browser, `'tensorflow'` under node.js - * (requires tfjs-node), and `'wasm'` (requires tfjs-backend-wasm). - */ - /** @doc {heading: 'Backends'} */ - function setBackend(backendName) { - return ENGINE.setBackend(backendName); - } - /** - * Returns a promise that resolves when the currently selected backend (or the - * highest priority one) has initialized. Await this promise when you are using - * a backend that has async initialization. - */ - /** @doc {heading: 'Backends'} */ - function ready() { - return ENGINE.ready(); - } - /** - * Returns the current backend name (cpu, webgl, etc). The backend is - * responsible for creating tensors and executing operations on those tensors. - */ - /** @doc {heading: 'Backends'} */ - function getBackend() { - return ENGINE.backendName; - } - /** - * Removes a backend and the registered factory. - */ - /** @doc {heading: 'Backends'} */ - function removeBackend(name) { - ENGINE.removeBackend(name); - } - /** - * Finds the backend registered under the provided name. Returns null if the - * name is not in the registry, or the registration hasn't finished yet. - */ - function findBackend(name) { - return ENGINE.findBackend(name); - } - /** - * Finds the backend factory registered under the provided name. Returns a - * function that produces a new backend when called. Returns null if the name - * is not in the registry. - */ - function findBackendFactory(name) { - return ENGINE.findBackendFactory(name); - } - /** - * Registers a global backend. The registration should happen when importing - * a module file (e.g. when importing `backend_webgl.ts`), and is used for - * modular builds (e.g. custom tfjs bundle with only webgl support). - * - * @param factory The backend factory function. When called, it should - * return a backend instance, or a promise of an instance. - * @param priority The priority of the backend (higher = more important). - * In case multiple backends are registered, the priority is used to find - * the best backend. Defaults to 1. - * @return False if there is already a registered backend under this name, true - * if not. - */ - /** @doc {heading: 'Backends'} */ - function registerBackend(name, factory, priority) { - if (priority === void 0) { priority = 1; } - return ENGINE.registerBackend(name, factory, priority); - } - /** - * Gets the current backend. If no backends have been initialized, this will - * attempt to initialize the best backend. Will throw an error if the highest - * priority backend has async initialization, in which case, you should call - * 'await tf.ready()' before running other code. - */ - /** @doc {heading: 'Backends'} */ - function backend() { - return ENGINE.backend; - } - /** - * Sets the global platform. - * - * @param platformName The name of this platform. - * @param platform A platform implementation. - */ - function setPlatform(platformName, platform) { - env().setPlatform(platformName, platform); - } - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - function warn() { - var msg = []; - for (var _i = 0; _i < arguments.length; _i++) { - msg[_i] = arguments[_i]; - } - if (!env().getBool('IS_TEST')) { - console.warn.apply(console, msg); - } - } - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - function inferShape(val, dtype) { - var firstElem = val; - if (isTypedArray(val)) { - return dtype === 'string' ? [] : [val.length]; - } - if (!Array.isArray(val)) { - return []; // Scalar. - } - var shape = []; - while (Array.isArray(firstElem) || - isTypedArray(firstElem) && dtype !== 'string') { - shape.push(firstElem.length); - firstElem = firstElem[0]; - } - if (Array.isArray(val) && - env().getBool('TENSORLIKE_CHECK_SHAPE_CONSISTENCY')) { - deepAssertShapeConsistency(val, shape, []); - } - return shape; - } - function deepAssertShapeConsistency(val, shape, indices) { - indices = indices || []; - if (!(Array.isArray(val)) && !isTypedArray(val)) { - assert(shape.length === 0, function () { return "Element arr[" + indices.join('][') + "] is a primitive, " + - ("but should be an array/TypedArray of " + shape[0] + " elements"); }); - return; - } - assert(shape.length > 0, function () { return "Element arr[" + indices.join('][') + "] should be a primitive, " + - ("but is an array of " + val.length + " elements"); }); - assert(val.length === shape[0], function () { return "Element arr[" + indices.join('][') + "] should have " + shape[0] + " " + - ("elements, but has " + val.length + " elements"); }); - var subShape = shape.slice(1); - for (var i = 0; i < val.length; ++i) { - deepAssertShapeConsistency(val[i], subShape, indices.concat(i)); - } - } - function assertDtype(expectedDtype, actualDType, argName, functionName) { - if (expectedDtype == null) { - return; - } - if (expectedDtype !== 'numeric' && expectedDtype !== actualDType || - expectedDtype === 'numeric' && actualDType === 'string') { - throw new Error("Argument '" + argName + "' passed to '" + functionName + "' must " + - ("be " + expectedDtype + " tensor, but got " + actualDType + " tensor")); - } - } - function convertToTensor(x, argName, functionName, parseAsDtype) { - if (parseAsDtype === void 0) { parseAsDtype = 'numeric'; } - if (x instanceof Tensor) { - assertDtype(parseAsDtype, x.dtype, argName, functionName); - return x; - } - var inferredDtype = inferDtype(x); - // If the user expects a bool/int/float, use that info to update the - // inferredDtype when it is not a string. - if (inferredDtype !== 'string' && - ['bool', 'int32', 'float32'].indexOf(parseAsDtype) >= 0) { - inferredDtype = parseAsDtype; - } - assertDtype(parseAsDtype, inferredDtype, argName, functionName); - if ((x == null) || - (!isTypedArray(x) && !Array.isArray(x) && typeof x !== 'number' && - typeof x !== 'boolean' && typeof x !== 'string')) { - var type = x == null ? 'null' : x.constructor.name; - throw new Error("Argument '" + argName + "' passed to '" + functionName + "' must be a " + - ("Tensor or TensorLike, but got '" + type + "'")); - } - var inferredShape = inferShape(x, inferredDtype); - if (!isTypedArray(x) && !Array.isArray(x)) { - x = [x]; - } - var skipTypedArray = true; - var values = inferredDtype !== 'string' ? - toTypedArray(x, inferredDtype, env().getBool('DEBUG')) : - flatten(x, [], skipTypedArray); - return ENGINE.makeTensor(values, inferredShape, inferredDtype); - } - function convertToTensorArray(arg, argName, functionName, parseAsDtype) { - if (parseAsDtype === void 0) { parseAsDtype = 'numeric'; } - if (!Array.isArray(arg)) { - throw new Error("Argument " + argName + " passed to " + functionName + " must be a " + - '`Tensor[]` or `TensorLike[]`'); - } - var tensors = arg; - return tensors.map(function (t, i) { return convertToTensor(t, argName + "[" + i + "]", functionName); }, parseAsDtype); - } - - /** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /** - * Returns true if the axis specifies the inner most dimensions of the - * array. - */ - function axesAreInnerMostDims(axes, rank) { - for (var i = 0; i < axes.length; ++i) { - if (axes[axes.length - i - 1] !== rank - 1 - i) { - return false; - } - } - return true; - } - function combineLocations(outputLoc, reduceLoc, axes) { - var rank = outputLoc.length + reduceLoc.length; - var loc = []; - var outIdx = 0; - var reduceIdx = 0; - for (var dim = 0; dim < rank; dim++) { - if (axes.indexOf(dim) === -1) { - loc.push(outputLoc[outIdx++]); - } - else { - loc.push(reduceLoc[reduceIdx++]); - } - } - return loc; - } - function computeOutAndReduceShapes(aShape, axes) { - var outShape = []; - var rank = aShape.length; - for (var dim = 0; dim < rank; dim++) { - if (axes.indexOf(dim) === -1) { - outShape.push(aShape[dim]); - } - } - var reduceShape = axes.map(function (dim) { return aShape[dim]; }); - return [outShape, reduceShape]; - } - function expandShapeToKeepDim(shape, axes) { - var reduceSubShape = axes.map(function (x) { return 1; }); - return combineLocations(shape, reduceSubShape, axes); - } - function assertAxesAreInnerMostDims(msg, axes, rank) { - assert(axesAreInnerMostDims(axes, rank), function () { return msg + " supports only inner-most axes for now. " + - ("Got axes " + axes + " and rank-" + rank + " input."); }); - } - /** - * Returns the axes permutation to be used with `tf.transpose`, if such - * permutation is necessary. Otherwise it returns null. This method is used by - * operations that operate only on inner-most axes. - */ - function getAxesPermutation(axes, rank) { - if (axesAreInnerMostDims(axes, rank)) { - return null; - } - var result = []; - for (var i = 0; i < rank; ++i) { - if (axes.indexOf(i) === -1) { - result.push(i); - } - } - axes.forEach(function (axis) { return result.push(axis); }); - return result; - } - /** Returns the axes permutation that undoes the original permutation. */ - function getUndoAxesPermutation(axes) { - return axes.map(function (axis, i) { return [i, axis]; }) - .sort(function (a, b) { return a[1] - b[1]; }) - .map(function (x) { return x[0]; }); - } - function getInnerMostAxes(numAxes, rank) { - var res = []; - for (var i = rank - numAxes; i < rank; ++i) { - res.push(i); - } - return res; - } - - /** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - function assertParamsConsistent(shapes, axis) { - var rank = shapes[0].length; - shapes.forEach(function (shape, i) { - assert(shape.length === rank, function () { - return "Error in concat" + rank + "D: rank of tensors[" + i + "] must be the same " + - ("as the rank of the rest (" + rank + ")"); - }); - }); - assert(axis >= 0 && axis < rank, function () { return "Error in concat" + rank + "D: axis must be between 0 and " + (rank - 1) + "."; }); - var firstShape = shapes[0]; - shapes.forEach(function (shape, i) { - for (var r = 0; r < rank; r++) { - assert((r === axis) || (shape[r] === firstShape[r]), function () { return "Error in concat" + rank + "D: Shape of tensors[" + i + "] (" + shape + ") " + - ("does not match the shape of the rest (" + firstShape + ") ") + - ("along the non-concatenated axis " + i + "."); }); - } - }); - } - function computeOutShape(shapes, axis) { - var outputShape = shapes[0].slice(); - for (var i = 1; i < shapes.length; i++) { - outputShape[axis] += shapes[i][axis]; - } - return outputShape; - } - - /** - * @license - * Copyright 2018 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /** - * Used for wrapping functions that perform math operations on - * Tensors. The function will be wrapped in a named scope that cleans all - * memory usage after the function is done. - */ - function op(f) { - var keys = Object.keys(f); - if (keys.length !== 1) { - throw new Error("Please provide an object with a single key " + - "(operation name) mapping to a function. Got an object with " + - (keys.length + " keys.")); - } - var opName = keys[0]; - var fn = f[opName]; - // Strip the underscore from the end of the function name. - if (opName.endsWith('_')) { - opName = opName.substring(0, opName.length - 1); - } - // tslint:disable-next-line:no-any - var f2 = function () { - var args = []; - for (var _i = 0; _i < arguments.length; _i++) { - args[_i] = arguments[_i]; - } - ENGINE.startScope(opName); - try { - var result = fn.apply(void 0, args); - if (result instanceof Promise) { - console.error('Cannot return a Promise inside of tidy.'); - } - ENGINE.endScope(result); - return result; - } - catch (ex) { - ENGINE.endScope(null); - throw ex; - } - }; - Object.defineProperty(f2, 'name', { value: opName, configurable: true }); - // tslint:disable-next-line:no-any - return f2; - } - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /** - * Converts two real numbers to a complex number. - * - * Given a tensor `real` representing the real part of a complex number, and a - * tensor `imag` representing the imaginary part of a complex number, this - * operation returns complex numbers elementwise of the form [r0, i0, r1, i1], - * where r represents the real part and i represents the imag part. - * - * The input tensors real and imag must have the same shape. - * - * ```js - * const real = tf.tensor1d([2.25, 3.25]); - * const imag = tf.tensor1d([4.75, 5.75]); - * const complex = tf.complex(real, imag); - * - * complex.print(); - * ``` - */ - /** @doc {heading: 'Tensors', subheading: 'Creation'} */ - function complex_(real, imag) { - var $real = convertToTensor(real, 'real', 'complex'); - var $imag = convertToTensor(imag, 'imag', 'complex'); - assertShapesMatch($real.shape, $imag.shape, "real and imag shapes, " + $real.shape + " and " + $imag.shape + ", " + - "must match in call to tf.complex()."); - return ENGINE.runKernelFunc(function (backend) { return backend.complex($real, $imag); }, { $real: $real, $imag: $imag }); - } - /** - * Returns the real part of a complex (or real) tensor. - * - * Given a tensor input, this operation returns a tensor of type float that is - * the real part of each element in input considered as a complex number. - * - * If the input is real, it simply makes a clone. - * - * ```js - * const x = tf.complex([-2.25, 3.25], [4.75, 5.75]); - * tf.real(x).print(); - * ``` - */ - /** @doc {heading: 'Tensors', subheading: 'Creation'} */ - function real_(input) { - var $input = convertToTensor(input, 'input', 'real'); - return ENGINE.runKernelFunc(function (backend) { return backend.real($input); }, { $input: $input }); - } - /** - * Returns the imaginary part of a complex (or real) tensor. - * - * Given a tensor input, this operation returns a tensor of type float that is - * the imaginary part of each element in input considered as a complex number. - * If input is real, a tensor of all zeros is returned. - * - * ```js - * const x = tf.complex([-2.25, 3.25], [4.75, 5.75]); - * tf.imag(x).print(); - * ``` - */ - /** @doc {heading: 'Tensors', subheading: 'Creation'} */ - function imag_(input) { - var $input = convertToTensor(input, 'input', 'imag'); - return ENGINE.runKernelFunc(function (backend) { return backend.imag($input); }, { $input: $input }); - } - var complex = op({ complex_: complex_ }); - var real = op({ real_: real_ }); - var imag = op({ imag_: imag_ }); - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /** - * Creates a `tf.Tensor` with the provided values, shape and dtype. - * - * ```js - * // Pass an array of values to create a vector. - * tf.tensor([1, 2, 3, 4]).print(); - * ``` - * - * ```js - * // Pass a nested array of values to make a matrix or a higher - * // dimensional tensor. - * tf.tensor([[1, 2], [3, 4]]).print(); - * ``` - * - * ```js - * // Pass a flat array and specify a shape yourself. - * tf.tensor([1, 2, 3, 4], [2, 2]).print(); - * ``` - * - * @param values The values of the tensor. Can be nested array of numbers, - * or a flat array, or a `TypedArray`. If the values are strings, - * they will be encoded as utf-8 and kept as `Uint8Array[]`. - * @param shape The shape of the tensor. Optional. If not provided, - * it is inferred from `values`. - * @param dtype The data type. - */ - /** @doc {heading: 'Tensors', subheading: 'Creation'} */ - function tensor(values, shape, dtype) { - var inferredShape = inferShape(values, dtype); - return makeTensor(values, shape, inferredShape, dtype); - } - /** This is shared code across all tensor creation methods. */ - function makeTensor(values, shape, inferredShape, dtype) { - if (dtype == null) { - dtype = inferDtype(values); - } - if (dtype === 'complex64') { - throw new Error("Cannot construct a complex64 tensor directly. " + - "Please use tf.complex(real, imag)."); - } - if (!isTypedArray(values) && !Array.isArray(values) && - typeof values !== 'number' && typeof values !== 'boolean' && - typeof values !== 'string') { - throw new Error('values passed to tensor(values) must be a number/boolean/string or ' + - 'an array of numbers/booleans/strings, or a TypedArray'); - } - if (shape != null) { - assertNonNegativeIntegerDimensions(shape); - var providedSize_1 = sizeFromShape(shape); - var inferredSize_1 = sizeFromShape(inferredShape); - assert(providedSize_1 === inferredSize_1, function () { - return "Based on the provided shape, [" + shape + "], the tensor should have " + - (providedSize_1 + " values but has " + inferredSize_1); - }); - for (var i = 0; i < inferredShape.length; ++i) { - var inferred = inferredShape[i]; - var flatDimsDontMatch = i === inferredShape.length - 1 ? - inferred !== sizeFromShape(shape.slice(i)) : - true; - assert(inferredShape[i] === shape[i] || !flatDimsDontMatch, function () { return "Error creating a new Tensor. Inferred shape " + - ("(" + inferredShape + ") does not match the provided ") + - ("shape (" + shape + "). "); }); - } - } - if (!isTypedArray(values) && !Array.isArray(values)) { - values = [values]; - } - shape = shape || inferredShape; - values = dtype !== 'string' ? - toTypedArray(values, dtype, env().getBool('DEBUG')) : - flatten(values, [], true); - return ENGINE.makeTensor(values, shape, dtype); - } - /** - * Creates rank-0 `tf.Tensor` (scalar) with the provided value and dtype. - * - * The same functionality can be achieved with `tf.tensor`, but in general - * we recommend using `tf.scalar` as it makes the code more readable. - * - * ```js - * tf.scalar(3.14).print(); - * ``` - * - * @param value The value of the scalar. - * @param dtype The data type. - */ - /** @doc {heading: 'Tensors', subheading: 'Creation'} */ - function scalar(value, dtype) { - if (((isTypedArray(value) && dtype !== 'string') || Array.isArray(value)) && - dtype !== 'complex64') { - throw new Error('Error creating a new Scalar: value must be a primitive ' + - '(number|boolean|string)'); - } - if (dtype === 'string' && isTypedArray(value) && - !(value instanceof Uint8Array)) { - throw new Error('When making a scalar from encoded string, ' + - 'the value must be `Uint8Array`.'); - } - var shape = []; - var inferredShape = []; - return makeTensor(value, shape, inferredShape, dtype); - } - /** - * Creates rank-1 `tf.Tensor` with the provided values, shape and dtype. - * - * The same functionality can be achieved with `tf.tensor`, but in general - * we recommend using `tf.tensor1d` as it makes the code more readable. - * - * ```js - * tf.tensor1d([1, 2, 3]).print(); - * ``` - * - * @param values The values of the tensor. Can be array of numbers, - * or a `TypedArray`. - * @param dtype The data type. - */ - /** @doc {heading: 'Tensors', subheading: 'Creation'} */ - function tensor1d(values, dtype) { - assertNonNull(values); - var inferredShape = inferShape(values, dtype); - if (inferredShape.length !== 1) { - throw new Error('tensor1d() requires values to be a flat/TypedArray'); - } - var shape = null; - return makeTensor(values, shape, inferredShape, dtype); - } - /** - * Creates rank-2 `tf.Tensor` with the provided values, shape and dtype. - * - * The same functionality can be achieved with `tf.tensor`, but in general - * we recommend using `tf.tensor2d` as it makes the code more readable. - * - * ```js - * // Pass a nested array. - * tf.tensor2d([[1, 2], [3, 4]]).print(); - * ``` - * ```js - * // Pass a flat array and specify a shape. - * tf.tensor2d([1, 2, 3, 4], [2, 2]).print(); - * ``` - * - * @param values The values of the tensor. Can be nested array of numbers, - * or a flat array, or a `TypedArray`. - * @param shape The shape of the tensor. If not provided, it is inferred from - * `values`. - * @param dtype The data type. - */ - /** @doc {heading: 'Tensors', subheading: 'Creation'} */ - function tensor2d(values, shape, dtype) { - assertNonNull(values); - if (shape != null && shape.length !== 2) { - throw new Error('tensor2d() requires shape to have two numbers'); - } - var inferredShape = inferShape(values, dtype); - if (inferredShape.length !== 2 && inferredShape.length !== 1) { - throw new Error('tensor2d() requires values to be number[][] or flat/TypedArray'); - } - if (inferredShape.length === 1 && shape == null) { - throw new Error('tensor2d() requires shape to be provided when `values` ' + - 'are a flat/TypedArray'); - } - return makeTensor(values, shape, inferredShape, dtype); - } - /** - * Creates rank-3 `tf.Tensor` with the provided values, shape and dtype. - * - * The same functionality can be achieved with `tf.tensor`, but in general - * we recommend using `tf.tensor3d` as it makes the code more readable. - * - * ```js - * // Pass a nested array. - * tf.tensor3d([[[1], [2]], [[3], [4]]]).print(); - * ``` - * ```js - * // Pass a flat array and specify a shape. - * tf.tensor3d([1, 2, 3, 4], [2, 2, 1]).print(); - * ``` - * - * @param values The values of the tensor. Can be nested array of numbers, - * or a flat array, or a `TypedArray`. - * @param shape The shape of the tensor. If not provided, it is inferred from - * `values`. - * @param dtype The data type. - */ - /** @doc {heading: 'Tensors', subheading: 'Creation'} */ - function tensor3d(values, shape, dtype) { - assertNonNull(values); - if (shape != null && shape.length !== 3) { - throw new Error('tensor3d() requires shape to have three numbers'); - } - var inferredShape = inferShape(values, dtype); - if (inferredShape.length !== 3 && inferredShape.length !== 1) { - throw new Error('tensor3d() requires values to be number[][][] or flat/TypedArray'); - } - if (inferredShape.length === 1 && shape == null) { - throw new Error('tensor3d() requires shape to be provided when `values` ' + - 'are a flat array'); - } - return makeTensor(values, shape, inferredShape, dtype); - } - /** - * Creates rank-4 `tf.Tensor` with the provided values, shape and dtype. - * - * The same functionality can be achieved with `tf.tensor`, but in general - * we recommend using `tf.tensor4d` as it makes the code more readable. - * - * ```js - * // Pass a nested array. - * tf.tensor4d([[[[1], [2]], [[3], [4]]]]).print(); - * ``` - * ```js - * // Pass a flat array and specify a shape. - * tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]).print(); - * ``` - * - * @param values The values of the tensor. Can be nested array of numbers, - * or a flat array, or a `TypedArray`. - * @param shape The shape of the tensor. Optional. If not provided, - * it is inferred from `values`. - * @param dtype The data type. - */ - /** @doc {heading: 'Tensors', subheading: 'Creation'} */ - function tensor4d(values, shape, dtype) { - assertNonNull(values); - if (shape != null && shape.length !== 4) { - throw new Error('tensor4d() requires shape to have four numbers'); - } - var inferredShape = inferShape(values, dtype); - if (inferredShape.length !== 4 && inferredShape.length !== 1) { - throw new Error('tensor4d() requires values to be number[][][][] or flat/TypedArray'); - } - if (inferredShape.length === 1 && shape == null) { - throw new Error('tensor4d() requires shape to be provided when `values` ' + - 'are a flat array'); - } - return makeTensor(values, shape, inferredShape, dtype); - } - /** - * Creates rank-5 `tf.Tensor` with the provided values, shape and dtype. - * - * The same functionality can be achieved with `tf.tensor`, but in general - * we recommend using `tf.tensor5d` as it makes the code more readable. - * - * ```js - * // Pass a nested array. - * tf.tensor5d([[[[[1], [2]], [[3], [4]]]]]).print(); - * ``` - * ```js - * // Pass a flat array and specify a shape. - * tf.tensor5d([1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 2, 2, 1]).print(); - * ``` - * - * @param values The values of the tensor. Can be nested array of numbers, - * or a flat array, or a `TypedArray`. - * @param shape The shape of the tensor. Optional. If not provided, - * it is inferred from `values`. - * @param dtype The data type. - */ - /** @doc {heading: 'Tensors', subheading: 'Creation'} */ - function tensor5d(values, shape, dtype) { - assertNonNull(values); - if (shape != null && shape.length !== 5) { - throw new Error('tensor5d() requires shape to have five numbers'); - } - var inferredShape = inferShape(values, dtype); - if (inferredShape.length !== 5 && inferredShape.length !== 1) { - throw new Error('tensor5d() requires values to be ' + - 'number[][][][][] or flat/TypedArray'); - } - if (inferredShape.length === 1 && shape == null) { - throw new Error('tensor5d() requires shape to be provided when `values` ' + - 'are a flat array'); - } - return makeTensor(values, shape, inferredShape, dtype); - } - /** - * Creates rank-6 `tf.Tensor` with the provided values, shape and dtype. - * - * The same functionality can be achieved with `tf.tensor`, but in general - * we recommend using `tf.tensor6d` as it makes the code more readable. - * - * ```js - * // Pass a nested array. - * tf.tensor6d([[[[[[1],[2]],[[3],[4]]],[[[5],[6]],[[7],[8]]]]]]).print(); - * ``` - * ```js - * // Pass a flat array and specify a shape. - * tf.tensor6d([1, 2, 3, 4, 5, 6, 7, 8], [1, 1, 2, 2, 2, 1]).print(); - * ``` - * - * @param values The values of the tensor. Can be nested array of numbers, - * or a flat array, or a `TypedArray`. - * @param shape The shape of the tensor. Optional. If not provided, - * it is inferred from `values`. - * @param dtype The data type. - */ - /** @doc {heading: 'Tensors', subheading: 'Creation'} */ - function tensor6d(values, shape, dtype) { - assertNonNull(values); - if (shape != null && shape.length !== 6) { - throw new Error('tensor6d() requires shape to have six numbers'); - } - var inferredShape = inferShape(values, dtype); - if (inferredShape.length !== 6 && inferredShape.length !== 1) { - throw new Error('tensor6d() requires values to be number[][][][][][] or ' + - 'flat/TypedArray'); - } - if (inferredShape.length === 1 && shape == null) { - throw new Error('tensor6d() requires shape to be provided when `values` ' + - 'are a flat array'); - } - shape = shape || - inferredShape; - return makeTensor(values, shape, inferredShape, dtype); - } - /** - * Creates a new variable with the provided initial value. - * ```js - * const x = tf.variable(tf.tensor([1, 2, 3])); - * x.assign(tf.tensor([4, 5, 6])); - * - * x.print(); - * ``` - * - * @param initialValue Initial value for the tensor. - * @param trainable If true, optimizers are allowed to update it. - * @param name Name of the variable. Defaults to a unique id. - * @param dtype If set, initialValue will be converted to the given type. - */ - /** @doc {heading: 'Tensors', subheading: 'Creation'} */ - function variable(initialValue, trainable, name, dtype) { - if (trainable === void 0) { trainable = true; } - return ENGINE.makeVariable(initialValue, trainable, name, dtype); - } - /** - * Creates a `tf.Tensor` with all elements set to 1. - * - * ```js - * tf.ones([2, 2]).print(); - * ``` - * - * @param shape An array of integers defining the output tensor shape. - * @param dtype The type of an element in the resulting tensor. Defaults to - * 'float'. - */ - /** @doc {heading: 'Tensors', subheading: 'Creation'} */ - function ones$1(shape, dtype) { - if (dtype === void 0) { dtype = 'float32'; } - if (dtype === 'complex64') { - var real_1 = ones$1(shape, 'float32'); - var imag_1 = zeros(shape, 'float32'); - return complex(real_1, imag_1); - } - var values = makeOnesTypedArray(sizeFromShape(shape), dtype); - return ENGINE.makeTensor(values, shape, dtype); - } - /** - * Creates a `tf.Tensor` with all elements set to 0. - * - * ```js - * tf.zeros([2, 2]).print(); - * ``` - * - * @param shape An array of integers defining the output tensor shape. - * @param dtype The type of an element in the resulting tensor. Can - * be 'float32', 'int32' or 'bool'. Defaults to 'float'. - */ - /** @doc {heading: 'Tensors', subheading: 'Creation'} */ - function zeros(shape, dtype) { - if (dtype === void 0) { dtype = 'float32'; } - if (dtype === 'complex64') { - var real_2 = zeros(shape, 'float32'); - var imag_2 = zeros(shape, 'float32'); - return complex(real_2, imag_2); - } - var values = makeZerosTypedArray(sizeFromShape(shape), dtype); - return ENGINE.makeTensor(values, shape, dtype); - } - /** - * Creates a `tf.Tensor` filled with a scalar value. - * - * ```js - * tf.fill([2, 2], 4).print(); - * ``` - * - * @param shape An array of integers defining the output tensor shape. - * @param value The scalar value to fill the tensor with. - * @param dtype The type of an element in the resulting tensor. Defaults to - * 'float'. - */ - /** @doc {heading: 'Tensors', subheading: 'Creation'} */ - function fill(shape, value, dtype) { - return ENGINE.runKernelFunc(function (backend) { return backend.fill(shape, value, dtype); }, {}); - } - /** - * Creates a `tf.Tensor` with all elements set to 1 with the same shape as the - * given tensor. - * - * ```js - * const x = tf.tensor([1, 2]); - * tf.onesLike(x).print(); - * ``` - * @param x A tensor. - */ - /** @doc {heading: 'Tensors', subheading: 'Creation'} */ - function onesLike_(x) { - var $x = convertToTensor(x, 'x', 'onesLike'); - if ($x.dtype === 'complex64') { - var r = onesLike(real($x)); - var i = zerosLike(imag($x)); - return complex(r, i); - } - var der = function (dy, saved) { return ({ x: function () { return zerosLike(dy); } }); }; - return ENGINE.runKernelFunc(function (backend) { return backend.onesLike($x); }, { x: $x }, der, 'OnesLike'); - } - /** - * Creates a `tf.Tensor` with all elements set to 0 with the same shape as the - * given tensor. - * - * ```js - * const x = tf.tensor([1, 2]); - * tf.zerosLike(x).print(); - * ``` - * - * @param x The tensor of required shape. - */ - /** @doc {heading: 'Tensors', subheading: 'Creation'} */ - function zerosLike_(x) { - var $x = convertToTensor(x, 'x', 'zerosLike'); - var der = function (dy, saved) { return ({ x: function () { return zerosLike(dy); } }); }; - return ENGINE.runKernelFunc(function (backend) { return backend.zerosLike($x); }, { x: $x }, der, 'ZerosLike'); - } - /** - * Return an evenly spaced sequence of numbers over the given interval. - * - * ```js - * tf.linspace(0, 9, 10).print(); - * ``` - * @param start The start value of the sequence. - * @param stop The end value of the sequence. - * @param num The number of values to generate. - */ - /** @doc {heading: 'Tensors', subheading: 'Creation'} */ - function linspace(start, stop, num) { - if (num <= 0) { - throw new Error('The number of values should be positive.'); - } - return ENGINE.runKernelFunc(function (backend) { return backend.linspace(start, stop, num); }, {}); - } - /** - * Creates a new `tf.Tensor1D` filled with the numbers in the range provided. - * - * The tensor is a is half-open interval meaning it includes start, but - * excludes stop. Decrementing ranges and negative step values are also - * supported. - * - * ```js - * tf.range(0, 9, 2).print(); - * ``` - * - * @param start An integer start value - * @param stop An integer stop value - * @param step An integer increment (will default to 1 or -1) - * @param dtype The data type of the output tensor. Defaults to 'float32'. - */ - /** @doc {heading: 'Tensors', subheading: 'Creation'} */ - function range(start, stop, step, dtype) { - if (step === void 0) { step = 1; } - if (dtype === void 0) { dtype = 'float32'; } - if (step === 0) { - throw new Error('Cannot have a step of zero'); - } - var sameStartStop = start === stop; - var increasingRangeNegativeStep = start < stop && step < 0; - var decreasingRangePositiveStep = stop < start && step > 1; - if (sameStartStop || increasingRangeNegativeStep || - decreasingRangePositiveStep) { - return zeros([0], dtype); - } - var numElements = Math.abs(Math.ceil((stop - start) / step)); - var values = makeZerosTypedArray(numElements, dtype); - if (stop < start && step === 1) { - // Auto adjust the step's sign if it hasn't been set - // (or was set to 1) - step = -1; - } - values[0] = start; - for (var i = 1; i < values.length; i++) { - values[i] = values[i - 1] + step; - } - return tensor1d(values, dtype); - } - var onesLike = op({ onesLike_: onesLike_ }); - var zerosLike = op({ zerosLike_: zerosLike_ }); - - /** - * @license - * Copyright 2018 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /** - * Concatenates a list of`tf.Tensor1D`s along an axis. See `concat` for details. - * - * For example, if: - * A: shape(3) = |r1, g1, b1| - * B: shape(2) = |r2, g2| - * C = tf.concat1d([A, B]) == |r1, g1, b1, r2, g2| - * - * @param tensors A list of`tf.Tensor`s to concatenate. - * @return The concatenated array. - */ - function concat1d_(tensors) { - return concat(tensors, 0 /* axis */); - } - /** - * Concatenates a list of`tf.Tensor2D`s along an axis. See `concat` for details. - * - * For example, if: - * A: shape(2, 3) = | r1, g1, b1 | - * | r2, g2, b2 | - * - * B: shape(2, 3) = | r3, g3, b3 | - * | r4, g4, b4 | - * - * C = tf.concat2d([A, B], axis) - * - * if axis = 0: - * C: shape(4, 3) = | r1, g1, b1 | - * | r2, g2, b2 | - * | r3, g3, b3 | - * | r4, g4, b4 | - * - * if axis = 1: - * C = shape(2, 6) = | r1, g1, b1, r3, g3, b3 | - * | r2, g2, b2, r4, g4, b4 | - * - * - * @param tensors A list of `tf.Tensor`s to concatenate. - * @param axis The axis to concatenate along. - * @return The concatenated array. - */ - function concat2d_(tensors, axis) { - return concat(tensors, axis); - } - /** - * Concatenates a list of `tf.Tensor3D`s along an axis. - * See `concat` for details. - * - * For example, if: - * A: shape(2, 1, 3) = | r1, g1, b1 | - * | r2, g2, b2 | - * - * B: shape(2, 1, 3) = | r3, g3, b3 | - * | r4, g4, b4 | - * - * C = tf.concat3d([A, B], axis) - * - * if axis = 0: - * C: shape(4, 1, 3) = | r1, g1, b1 | - * | r2, g2, b2 | - * | r3, g3, b3 | - * | r4, g4, b4 | - * - * if axis = 1: - * C: shape(2, 2, 3) = | r1, g1, b1, r3, g3, b3 | - * | r2, g2, b2, r4, g4, b4 | - * - * if axis = 2: - * C = shape(2, 1, 6) = | r1, g1, b1, r3, g3, b3 | - * | r2, g2, b2, r4, g4, b4 | - * - * @param tensors A list of`tf.Tensor`s to concatenate. - * @param axis The axis to concate along. - * @return The concatenated array. - */ - function concat3d_(tensors, axis) { - return concat(tensors, axis); - } - /** - * Concatenates a list of `tf.Tensor4D`s along an axis. - * See `concat` for details. - * - * @param tensors A list of `tf.Tensor`s to concatenate. - * @param axis The axis to concate along. - * @return The concatenated array. - */ - function concat4d_(tensors, axis) { - return concat(tensors, axis); - } - /** - * Concatenates a list of `tf.Tensor`s along a given axis. - * - * The tensors ranks and types must match, and their sizes must match in all - * dimensions except `axis`. - * - * Also available are stricter rank-specific methods that assert that - * `tensors` are of the given rank: - * - `tf.concat1d` - * - `tf.concat2d` - * - `tf.concat3d` - * - `tf.concat4d` - * - * Except `tf.concat1d` (which does not have axis param), all methods have - * same signature as this method. - * - * ```js - * const a = tf.tensor1d([1, 2]); - * const b = tf.tensor1d([3, 4]); - * a.concat(b).print(); // or a.concat(b) - * ``` - * - * ```js - * const a = tf.tensor1d([1, 2]); - * const b = tf.tensor1d([3, 4]); - * const c = tf.tensor1d([5, 6]); - * tf.concat([a, b, c]).print(); - * ``` - * - * ```js - * const a = tf.tensor2d([[1, 2], [10, 20]]); - * const b = tf.tensor2d([[3, 4], [30, 40]]); - * const axis = 1; - * tf.concat([a, b], axis).print(); - * ``` - * @param tensors A list of tensors to concatenate. - * @param axis The axis to concate along. Defaults to 0 (the first dim). - */ - /** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ - function concat_(tensors, axis) { - if (axis === void 0) { axis = 0; } - assert(tensors.length >= 1, function () { return 'Pass at least one tensor to concat'; }); - var $tensors = convertToTensorArray(tensors, 'tensors', 'concat'); - if ($tensors[0].dtype === 'complex64') { - $tensors.forEach(function (tensor) { - if (tensor.dtype !== 'complex64') { - throw new Error("Cannot concatenate complex64 tensors with a tensor\n with dtype " + tensor.dtype + ". "); - } - }); - } - axis = parseAxisParam(axis, $tensors[0].shape)[0]; - var outShape = computeOutShape($tensors.map(function (t) { return t.shape; }), axis); - if (sizeFromShape(outShape) === 0) { - return tensor([], outShape); - } - // Keep only non-empty tensors (ignore tensors with 0 in their shape). - $tensors = $tensors.filter(function (t) { return t.size > 0; }); - if ($tensors.length === 1) { - return $tensors[0]; - } - var shapes = $tensors.map(function (t) { return t.shape; }); - assertParamsConsistent(shapes, axis); - var der = function (dy) { - var sizeSplits = shapes.map(function (s) { return s[axis]; }); - var derTensors = split(dy, sizeSplits, axis); - return derTensors.map(function (t) { return function () { return t; }; }); - }; - var inputs = $tensors; - var attr = { axis: axis }; - return ENGINE.runKernelFunc(function (backend) { return backend.concat($tensors, axis); }, inputs, der, 'Concat', attr); - } - /** - * Splits a `tf.Tensor` into sub tensors. - * - * If `numOrSizeSplits` is a number, splits `x` along dimension `axis` - * into `numOrSizeSplits` smaller tensors. - * Requires that `numOrSizeSplits` evenly divides `x.shape[axis]`. - * - * If `numOrSizeSplits` is a number array, splits `x` into - * `numOrSizeSplits.length` pieces. The shape of the `i`-th piece has the - * same size as `x` except along dimension `axis` where the size is - * `numOrSizeSplits[i]`. - * - * ```js - * const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]); - * const [a, b] = tf.split(x, 2, 1); - * a.print(); - * b.print(); - * - * const [c, d, e] = tf.split(x, [1, 2, 1], 1); - * c.print(); - * d.print(); - * e.print(); - * ``` - * - * @param x The input tensor to split. - * @param numOrSizeSplits Either an integer indicating the number of - * splits along the axis or an array of integers containing the sizes of - * each output tensor along the axis. If a number then it must evenly divide - * `x.shape[axis]`; otherwise the sum of sizes must match `x.shape[axis]`. - * @param axis The dimension along which to split. Defaults to 0 (the first - * dim). - */ - /** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ - function split_(x, numOrSizeSplits, axis) { - if (axis === void 0) { axis = 0; } - var $x = convertToTensor(x, 'x', 'split'); - axis = parseAxisParam(axis, $x.shape)[0]; - var splitSizes; - if (typeof (numOrSizeSplits) === 'number') { - assert($x.shape[axis] % numOrSizeSplits === 0, function () { return 'Number of splits must evenly divide the axis.'; }); - splitSizes = - new Array(numOrSizeSplits).fill($x.shape[axis] / numOrSizeSplits); - } - else { - assert($x.shape[axis] === numOrSizeSplits.reduce(function (a, b) { return a + b; }), function () { return 'The sum of sizes must match the size of the axis dimension.'; }); - splitSizes = numOrSizeSplits; - } - var der = function (dy) { return ({ $x: function () { return concat(dy, axis); } }); }; - return ENGINE.runKernelFunc(function (backend) { return backend.split($x, splitSizes, axis); }, { $x: $x }, der); - } - var concat = op({ concat_: concat_ }); - var concat1d = op({ concat1d_: concat1d_ }); - var concat2d = op({ concat2d_: concat2d_ }); - var concat3d = op({ concat3d_: concat3d_ }); - var concat4d = op({ concat4d_: concat4d_ }); - var split = op({ split_: split_ }); - - /** - * @license - * Copyright 2018 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /** - * Reshapes a `tf.Tensor` to a given shape. - * - * Given an input tensor, returns a new tensor with the same values as the - * input tensor with shape `shape`. - * - * If one component of shape is the special value -1, the size of that - * dimension is computed so that the total size remains constant. In - * particular, a shape of [-1] flattens into 1-D. At most one component of - * shape can be -1. - * - * If shape is 1-D or higher, then the operation returns a tensor with shape - * shape filled with the values of tensor. In this case, the number of - * elements implied by shape must be the same as the number of elements in - * tensor. - * - * ```js - * const x = tf.tensor1d([1, 2, 3, 4]); - * x.reshape([2, 2]).print(); - * ``` - * - * @param x The input tensor to be reshaped. - * @param shape An array of integers defining the output tensor shape. - */ - /** @doc {heading: 'Tensors', subheading: 'Transformations'} */ - function reshape_(x, shape) { - var $x = convertToTensor(x, 'x', 'reshape', null); - shape = inferFromImplicitShape(shape, $x.size); - assert($x.size === sizeFromShape(shape), function () { return 'new shape and old shape must have the same number of elements.'; }); - var grad = function (dy) { - return { x: function () { return dy.reshape($x.shape); } }; - }; - var attrs = { shape: shape }; - return ENGINE.runKernelFunc(function (backend) { return backend.reshape($x, shape); }, { x: $x }, grad, 'Reshape', attrs); - } - /** - * Removes dimensions of size 1 from the shape of a `tf.Tensor`. - * - * ```js - * const x = tf.tensor([1, 2, 3, 4], [1, 1, 4]); - * x.squeeze().print(); - * ``` - * - * @param x The input tensor to be squeezed. - * @param axis An optional list of numbers. If specified, only - * squeezes the dimensions listed. The dimension index starts at 0. It - * is an error to squeeze a dimension that is not 1. - */ - /** @doc {heading: 'Tensors', subheading: 'Transformations'} */ - function squeeze_(x, axis) { - var $x = convertToTensor(x, 'x', 'squeeze'); - return reshape($x, squeezeShape($x.shape, axis).newShape); - } - /** - * Casts a `tf.Tensor` to a new dtype. - * - * ```js - * const x = tf.tensor1d([1.5, 2.5, 3]); - * tf.cast(x, 'int32').print(); - * ``` - * @param x The input tensor to be casted. - * @param dtype The dtype to cast the input tensor to. - */ - /** @doc {heading: 'Tensors', subheading: 'Transformations'} */ - function cast_(x, dtype) { - var $x = convertToTensor(x, 'x', 'cast'); - // Sanity checks. - if (!isValidDtype(dtype)) { - throw new Error("Failed to cast to unknown dtype " + dtype); - } - if (dtype === 'string' && $x.dtype !== 'string' || - dtype !== 'string' && $x.dtype === 'string') { - throw new Error('Only strings can be casted to strings'); - } - var grad = function (dy) { - return { x: function () { return dy.clone(); } }; - }; - var attrs = { dtype: dtype }; - return ENGINE.runKernelFunc(function (backend) { return backend.cast($x, dtype); }, { x: $x }, grad, 'Cast', attrs); - } - /** - * Stacks a list of rank-`R` `tf.Tensor`s into one rank-`(R+1)` `tf.Tensor`. - * - * ```js - * const a = tf.tensor1d([1, 2]); - * const b = tf.tensor1d([3, 4]); - * const c = tf.tensor1d([5, 6]); - * tf.stack([a, b, c]).print(); - * ``` - * - * @param tensors A list of tensor objects with the same shape and dtype. - * @param axis The axis to stack along. Defaults to 0 (the first dim). - */ - /** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ - function stack_(tensors, axis) { - if (axis === void 0) { axis = 0; } - var $tensors = convertToTensorArray(tensors, 'tensors', 'stack'); - assert($tensors.length >= 1, function () { return 'Pass at least one tensor to tf.stack'; }); - if ($tensors.length === 1) { - return $tensors[0].expandDims(axis); - } - var rank = $tensors[0].rank; - var shape = $tensors[0].shape; - var dtype = $tensors[0].dtype; - assert(axis <= rank, function () { return 'Axis must be <= rank of the tensor'; }); - $tensors.forEach(function (t) { - assertShapesMatch(shape, t.shape, 'All tensors passed to stack must have matching shapes'); - }); - $tensors.forEach(function (t) { - assert(dtype === t.dtype, function () { return 'All tensors passed to stack must have matching dtypes'; }); - }); - var expandedTensors = $tensors.map(function (t) { return t.expandDims(axis); }); - return concat(expandedTensors, axis); - } - /** - * This operation reshapes the "batch" dimension 0 into `M + 1` dimensions of - * shape `blockShape + [batch]`, interleaves these blocks back into the grid - * defined by the spatial dimensions `[1, ..., M]`, to obtain a result with - * the same rank as the input. The spatial dimensions of this intermediate - * result are then optionally cropped according to `crops` to produce the - * output. This is the reverse of `tf.spaceToBatchND`. See below for a precise - * description. - * - * ```js - * const x = tf.tensor4d([1, 2, 3, 4], [4, 1, 1, 1]); - * const blockShape = [2, 2]; - * const crops = [[0, 0], [0, 0]]; - * - * x.batchToSpaceND(blockShape, crops).print(); - * ``` - * - * @param x A `tf.Tensor`. N-D with `x.shape` = `[batch] + spatialShape + - * remainingShape`, where spatialShape has `M` dimensions. - * @param blockShape A 1-D array. Must have shape `[M]`, all values must - * be >= 1. - * @param crops A 2-D array. Must have shape `[M, 2]`, all values must be >= 0. - * `crops[i] = [cropStart, cropEnd]` specifies the amount to crop from input - * dimension `i + 1`, which corresponds to spatial dimension `i`. It is required - * that `cropStart[i] + cropEnd[i] <= blockShape[i] * inputShape[i + 1]` - * - * This operation is equivalent to the following steps: - * - * 1. Reshape `x` to `reshaped` of shape: `[blockShape[0], ..., - * blockShape[M-1], batch / prod(blockShape), x.shape[1], ..., - * x.shape[N-1]]` - * - * 2. Permute dimensions of `reshaped`to produce `permuted` of shape `[batch / - * prod(blockShape),x.shape[1], blockShape[0], ..., x.shape[M], - * blockShape[M-1],x.shape[M+1], ..., x.shape[N-1]]` - * - * 3. Reshape `permuted` to produce `reshapedPermuted` of shape `[batch / - * prod(blockShape),x.shape[1] * blockShape[0], ..., x.shape[M] * - * blockShape[M-1],x.shape[M+1], ..., x.shape[N-1]]` - * - * 4. Crop the start and end of dimensions `[1, ..., M]` of `reshapedPermuted` - * according to `crops` to produce the output of shape: `[batch / - * prod(blockShape),x.shape[1] * blockShape[0] - crops[0,0] - crops[0,1], - * ..., x.shape[M] * blockShape[M-1] - crops[M-1,0] - - * crops[M-1,1],x.shape[M+1], ..., x.shape[N-1]]` - */ - /** @doc {heading: 'Tensors', subheading: 'Transformations'} */ - function batchToSpaceND_(x, blockShape, crops) { - var $x = convertToTensor(x, 'x', 'batchToSpaceND'); - var prod = blockShape.reduce(function (a, b) { return a * b; }); - assert($x.rank >= 1 + blockShape.length, function () { return "input rank is " + $x.rank + " but should be > than blockShape.length " + blockShape.length; }); - assert(crops.length === blockShape.length, function () { return "crops.length is " + crops.length + " but should be equal to blockShape.length " + blockShape.length; }); - assert($x.shape[0] % prod === 0, function () { return "input tensor batch is " + $x.shape[0] + " but is not divisible by the product of " + - ("the elements of blockShape " + blockShape.join(' * ') + " === " + prod); }); - var grad = function (dy) { - return { $x: function () { return dy.spaceToBatchND(blockShape, crops); } }; - }; - return ENGINE.runKernelFunc(function (backend) { return backend.batchToSpaceND($x, blockShape, crops); }, { $x: $x }, grad); - } - /** - * This operation divides "spatial" dimensions `[1, ..., M]` of the input into - * a grid of blocks of shape `blockShape`, and interleaves these blocks with - * the "batch" dimension (0) such that in the output, the spatial - * dimensions `[1, ..., M]` correspond to the position within the grid, - * and the batch dimension combines both the position within a spatial block - * and the original batch position. Prior to division into blocks, - * the spatial dimensions of the input are optionally zero padded - * according to `paddings`. See below for a precise description. - * - * ```js - * const x = tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]); - * const blockShape = [2, 2]; - * const paddings = [[0, 0], [0, 0]]; - * - * x.spaceToBatchND(blockShape, paddings).print(); - * ``` - * - * @param x A `tf.Tensor`. N-D with `x.shape` = `[batch] + spatialShape + - * remainingShape`, where spatialShape has `M` dimensions. - * @param blockShape A 1-D array. Must have shape `[M]`, all values must - * be >= 1. - * @param paddings A 2-D array. Must have shape `[M, 2]`, all values must be >= - * 0. `paddings[i] = [padStart, padEnd]` specifies the amount to zero-pad - * from input dimension `i + 1`, which corresponds to spatial dimension `i`. It - * is required that - * `(inputShape[i + 1] + padStart + padEnd) % blockShape[i] === 0` - * - * This operation is equivalent to the following steps: - * - * 1. Zero-pad the start and end of dimensions `[1, ..., M]` of the input - * according to `paddings` to produce `padded` of shape paddedShape. - * - * 2. Reshape `padded` to `reshapedPadded` of shape: - * `[batch] + [paddedShape[1] / blockShape[0], blockShape[0], ..., - * paddedShape[M] / blockShape[M-1], blockShape[M-1]] + remainingShape` - * - * 3. Permute dimensions of `reshapedPadded` to produce `permutedReshapedPadded` - * of shape: `blockShape + [batch] + [paddedShape[1] / blockShape[0], ..., - * paddedShape[M] / blockShape[M-1]] + remainingShape` - * - * 4. Reshape `permutedReshapedPadded` to flatten `blockShape` into the - * batch dimension, producing an output tensor of shape: - * `[batch * prod(blockShape)] + [paddedShape[1] / blockShape[0], ..., - * paddedShape[M] / blockShape[M-1]] + remainingShape` - */ - /** @doc {heading: 'Tensors', subheading: 'Transformations'} */ - function spaceToBatchND_(x, blockShape, paddings) { - var $x = convertToTensor(x, 'x', 'spaceToBatchND'); - assert($x.rank >= 1 + blockShape.length, function () { return "input rank " + $x.rank + " should be > than [blockShape] " + blockShape.length; }); - assert(paddings.length === blockShape.length, function () { return "paddings.shape[0] " + paddings.length + " must be equal to [blockShape] " + blockShape.length; }); - assert($x.shape.reduce(function (a, b, i) { - if (i > 0 && i <= blockShape.length) { - return a && - ((b + paddings[i - 1][0] + paddings[i - 1][1]) % - blockShape[i - 1] === - 0); - } - return a; - }, true), function () { return "input spatial dimensions " + $x.shape.slice(1) + " with paddings " + paddings.toString() + " must be divisible by blockShapes " + blockShape.toString(); }); - var grad = function (dy) { - return { $x: function () { return dy.batchToSpaceND(blockShape, paddings); } }; - }; - return ENGINE.runKernelFunc(function (backend) { return backend.spaceToBatchND($x, blockShape, paddings); }, { $x: $x }, grad); - } - /** - * Unstacks a `tf.Tensor` of rank-`R` into a list of rank-`(R-1)` `tf.Tensor`s. - * - * ```js - * const a = tf.tensor2d([1, 2, 3, 4], [2, 2]); - * - * tf.unstack(a).forEach(tensor => tensor.print()); - * ``` - * - * @param x A tensor object. - * @param axis The axis to unstack along. Defaults to 0 (the first dim). - */ - /** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ - function unstack_(x, axis) { - if (axis === void 0) { axis = 0; } - axis = axis || 0; - var $x = convertToTensor(x, 'x', 'unstack'); - assert(axis >= -$x.shape.length && axis < $x.shape.length, function () { - return "Axis = " + axis + " is not in [-" + $x.shape.length + ", " + $x.shape.length + ")"; - }); - if (axis < 0) { - axis += $x.shape.length; - } - var grad = function (dy) { - return { x: function () { return stack(dy, axis); } }; - }; - var attrs = { axis: axis }; - return ENGINE.runKernelFunc(function (backend) { return backend.unstack($x, axis); }, { x: $x }, grad, 'Unpack', attrs); - } - /** - * Computes the cumulative sum of a `tf.Tensor` along `axis`. - * - * ```js - * const x = tf.tensor([1, 2, 3, 4]); - * x.cumsum().print(); - * ``` - * ```js - * const x = tf.tensor([[1, 2], [3, 4]]); - * x.cumsum().print(); - * ``` - * - * @param x The input tensor to be summed. - * @param axis The axis along which to sum. Optional. Defaults to 0. - * @param exclusive Whether to perform exclusive cumulative sum. Optional. - * Defaults to false. If set to true then the sum of each tensor entry - * does not include its own value, but only the values previous to it - * along the specified axis. - * @param reverse Whether to sum in the opposite direction. Optional. - * Defaults to false. - */ - /** @doc {heading: 'Operations', subheading: 'Scan'} */ - function cumsum_(x, axis, exclusive, reverse) { - if (axis === void 0) { axis = 0; } - if (exclusive === void 0) { exclusive = false; } - if (reverse === void 0) { reverse = false; } - var $x = convertToTensor(x, 'x', 'cumsum'); - axis = axis | 0; - var permutation = getAxesPermutation([axis], $x.rank); - var permutedX = $x; - if (permutation != null) { - permutedX = $x.transpose(permutation); - } - var permutedAxis = getInnerMostAxes(1, $x.rank)[0]; - var grad = function (dy) { - return { permutedX: function () { return dy.cumsum(axis, exclusive, !reverse); } }; - }; - var value = ENGINE.runKernelFunc(function (backend) { return backend.cumsum(permutedX, permutedAxis, exclusive, reverse); }, { permutedX: permutedX }, grad); - if (permutation != null) { - value = value.transpose(permutation); - } - return value; - } - /** - * Returns a `tf.Tensor` that has expanded rank, by inserting a dimension - * into the tensor's shape. - * - * ```js - * const x = tf.tensor1d([1, 2, 3, 4]); - * const axis = 1; - * x.expandDims(axis).print(); - * ``` - * - * @param x The input tensor whose dimensions to be expanded. - * @param axis The dimension index at which to insert shape of `1`. Defaults - * to 0 (the first dimension). - */ - /** @doc {heading: 'Tensors', subheading: 'Transformations'} */ - function expandDims_(x, axis) { - if (axis === void 0) { axis = 0; } - var parseAs = null; - var $x = convertToTensor(x, 'x', 'expandDims', parseAs); - assert(axis <= $x.rank, function () { return 'Axis must be <= rank of the tensor'; }); - var newShape = $x.shape.slice(); - if (axis < 0) { - // Negative value is counted from the tail of rank. - assert(-($x.rank + 1) <= axis, function () { return "Axis must be in the interval [" + -($x.rank + 1) + ", " + $x.rank + "]"; }); - axis = $x.rank + axis + 1; - } - newShape.splice(axis, 0, 1); - return reshape($x, newShape); - } - /** - * Rearranges data from depth into blocks of spatial data. More specifically, - * this op outputs a copy of the input tensor where values from the `depth` - * dimension are moved in spatial blocks to the `height` and `width` dimensions. - * The attr `blockSize` indicates the input block size and how the data is - * moved. - * - * - Chunks of data of size `blockSize * blockSize` from depth are rearranged - * into non-overlapping blocks of size `blockSize x blockSize` - * - * - The width the output tensor is `inputWidth * blockSize`, whereas the - * height is `inputHeight * blockSize` - * - * - The Y, X coordinates within each block of the output image are determined - * by the high order component of the input channel index - * - * - The depth of the input tensor must be divisible by `blockSize * - * blockSize` - * - * The `dataFormat` attr specifies the layout of the input and output tensors - * with the following options: "NHWC": [ `batch, height, width, channels` ] - * "NCHW": [ `batch, channels, height, width` ] - * - * ```js - * const x = tf.tensor4d([1, 2, 3, 4], [1, 1, 1, 4]); - * const blockSize = 2; - * const dataFormat = "NHWC"; - * - * tf.depthToSpace(x, blockSize, dataFormat).print(); - * ``` - * - * @param x The input tensor of rank 4 - * @param blockSIze An `int` that is `>= 2`. The size of the spatial block - * @param dataFormat An optional string from: "NHWC", "NCHW". Defaults to "NHWC" - */ - /** @doc {heading: 'Tensors', subheading: 'Transformations'} */ - function depthToSpace_(x, blockSize, dataFormat) { - if (dataFormat === void 0) { dataFormat = 'NHWC'; } - var $x = convertToTensor(x, 'x', 'depthToSpace'); - var inputHeight = (dataFormat === 'NHWC') ? $x.shape[1] : $x.shape[2]; - var inputWidth = (dataFormat === 'NHWC') ? $x.shape[2] : $x.shape[3]; - var inputDepth = (dataFormat === 'NHWC') ? $x.shape[3] : $x.shape[1]; - assert(inputHeight * blockSize >= 0, function () { return "Negative dimension size caused by overflow when multiplying\n " + inputHeight + " and " + blockSize + " for depthToSpace with input shape\n " + $x.shape; }); - assert(inputWidth * blockSize >= 0, function () { return "Negative dimension size caused by overflow when multiplying\n " + inputWidth + " and " + blockSize + " for depthToSpace with input shape\n " + $x.shape; }); - assert((inputDepth % (blockSize * blockSize) === 0), function () { return "Dimension size must be evenly divisible by " + blockSize * blockSize + " but is " + inputDepth + " for depthToSpace with input shape " + $x.shape; }); - return ENGINE.runKernelFunc(function (backend) { return backend.depthToSpace($x, blockSize, dataFormat); }, { $x: $x }); - } - /** - * Computes the difference between two lists of numbers. - * - * Given a Tensor `x` and a Tensor `y`, this operation returns a Tensor `out` - * that represents all values that are in `x` but not in `y`. The returned - * Tensor `out` is sorted in the same order that the numbers appear in `x` - * (duplicates are preserved). This operation also returns a Tensor indices that - * represents the position of each out element in `x`. In other words: - * - * `out[i] = x[idx[i]] for i in [0, 1, ..., out.length - 1]` - * - * ```js - * const x = [1, 2, 3, 4, 5, 6]; - * const y = [1, 3, 5]; - * - * const [out, indices] = await tf.setdiff1dAsync(x, y); - * out.print(); // [2, 4, 6] - * indices.print(); // [1, 3, 5] - * ``` - * - * @param x 1-D Tensor. Values to keep. - * @param y 1-D Tensor. Must have the same type as x. Values to exclude in the - * output. - * @returns Promise of Tensor tuple [out, indices]. - * out: Tensor with the same type as x. - * indices: A Tensor of type int32. - */ - /** @doc {heading: 'Tensors', subheading: 'Transformations'} */ - function setdiff1dAsync_(x, y) { - return __awaiter(this, void 0, void 0, function () { - var $x, $y, xVals, yVals, ySet, outputSize, i, buffer, indices, i, p; - return __generator(this, function (_a) { - switch (_a.label) { - case 0: - $x = convertToTensor(x, 'x', 'setdiff1d'); - $y = convertToTensor(y, 'y', 'setdiff1d'); - assert($x.dtype === $y.dtype, function () { return "x and y should have the same dtype, but got x (" + $x.dtype + ") and y (" + $y.dtype + ")."; }); - assert($x.rank === 1, function () { return "x should be 1D tensor, but got x (" + $x.shape + ")."; }); - assert($y.rank === 1, function () { return "y should be 1D tensor, but got y (" + $y.shape + ")."; }); - return [4 /*yield*/, $x.data()]; - case 1: - xVals = _a.sent(); - return [4 /*yield*/, $y.data()]; - case 2: - yVals = _a.sent(); - ySet = new Set(yVals); - outputSize = 0; - for (i = 0; i < xVals.length; i++) { - if (!ySet.has(xVals[i])) { - outputSize++; - } - } - buffer = new TensorBuffer([outputSize], $x.dtype); - indices = new TensorBuffer([outputSize], 'int32'); - for (i = 0, p = 0; i < xVals.length; i++) { - if (!ySet.has(xVals[i])) { - buffer.values[p] = xVals[i]; - indices.values[p] = i; - p++; - } - } - return [2 /*return*/, [buffer.toTensor(), indices.toTensor()]]; - } - }); - }); - } - /** - * Creates an empty `tf.TensorBuffer` with the specified `shape` and `dtype`. - * - * The values are stored in CPU as `TypedArray`. Fill the buffer using - * `buffer.set()`, or by modifying directly `buffer.values`. - * - * When done, call `buffer.toTensor()` to get an immutable `tf.Tensor` with - * those values. - * - * ```js - * // Create a buffer and set values at particular indices. - * const buffer = tf.buffer([2, 2]); - * buffer.set(3, 0, 0); - * buffer.set(5, 1, 0); - * - * // Convert the buffer back to a tensor. - * buffer.toTensor().print(); - * ``` - * - * @param shape An array of integers defining the output tensor shape. - * @param dtype The dtype of the buffer. Defaults to 'float32'. - * @param values The values of the buffer as `TypedArray`. Defaults to - * zeros. - */ - /** @doc {heading: 'Tensors', subheading: 'Creation'} */ - function buffer(shape, dtype, values) { - if (dtype === void 0) { dtype = 'float32'; } - dtype = dtype || 'float32'; - assertNonNegativeIntegerDimensions(shape); - return new TensorBuffer(shape, dtype, values); - } - /** - * Prints information about the `tf.Tensor` including its data. - * - * ```js - * const verbose = true; - * tf.tensor2d([1, 2, 3, 4], [2, 2]).print(verbose); - * ``` - * @param x The tensor to be printed. - * @param verbose Whether to print verbose information about the ` Tensor`, - * including dtype and size. - */ - /** @doc {heading: 'Tensors', subheading: 'Creation'} */ - function print(x, verbose) { - if (verbose === void 0) { verbose = false; } - console.log(x.toString(verbose)); - } - var batchToSpaceND = op({ batchToSpaceND_: batchToSpaceND_ }); - var cast = op({ cast_: cast_ }); - var cumsum = op({ cumsum_: cumsum_ }); - var depthToSpace = op({ depthToSpace_: depthToSpace_ }); - var expandDims = op({ expandDims_: expandDims_ }); - var reshape = op({ reshape_: reshape_ }); - var spaceToBatchND = op({ spaceToBatchND_: spaceToBatchND_ }); - var squeeze = op({ squeeze_: squeeze_ }); - var stack = op({ stack_: stack_ }); - var unstack = op({ unstack_: unstack_ }); - var setdiff1dAsync = setdiff1dAsync_; - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /** - * Gets the new shape of the input Tensor after it's been reshaped - * to: - * [blockShape[0], ..., blockShape[M-1], batch / prod(blockShape), - * inputShape[1], ..., inputShape[N-1]] - * - * See step 1: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd - */ - function getReshaped(inputShape, blockShape, prod, batchToSpace) { - if (batchToSpace === void 0) { batchToSpace = true; } - var reshaped = []; - if (batchToSpace) { - reshaped = reshaped.concat(blockShape.slice(0)); - reshaped.push(inputShape[0] / prod); - reshaped = reshaped.concat(inputShape.slice(1)); - } - else { - reshaped = reshaped.concat(inputShape[0]); - var spatialLength = blockShape.length; - for (var i = 0; i < spatialLength; ++i) { - reshaped = - reshaped.concat([inputShape[i + 1] / blockShape[i], blockShape[i]]); - } - reshaped = reshaped.concat(inputShape.slice(spatialLength + 1)); - } - return reshaped; - } - /** - * Gets the permutation that will transpose the dimensions of the - * reshaped tensor to shape: - * - * [batch / prod(block_shape),inputShape[1], blockShape[0], ..., - * inputShape[M], blockShape[M-1],inputShape[M+1], ..., inputShape[N-1]] - * - * see step 2: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd - */ - function getPermuted(reshapedRank, blockShapeRank, batchToSpace) { - if (batchToSpace === void 0) { batchToSpace = true; } - var permuted = []; - if (batchToSpace) { - permuted.push(blockShapeRank); - for (var i = blockShapeRank + 1; i < reshapedRank; ++i) { - if (i <= 2 * blockShapeRank) { - permuted.push(i); - permuted.push(i - (blockShapeRank + 1)); - } - else { - permuted.push(i); - } - } - } - else { - var permutedBeforeBatch = []; - var permutedAfterBatch = []; - for (var i = 1; i < reshapedRank; ++i) { - if (i >= blockShapeRank * 2 + 1 || i % 2 === 1) { - permutedAfterBatch.push(i); - } - else { - permutedBeforeBatch.push(i); - } - } - permuted.push.apply(permuted, permutedBeforeBatch); - permuted.push(0); - permuted.push.apply(permuted, permutedAfterBatch); - } - return permuted; - } - /** - * Gets the shape of the reshaped and permuted input Tensor before any cropping - * is applied. The new shape will be: - * - * [batch / prod(blockShape),inputShape[1] * blockShape[0], ..., - * inputShape[M] * blockShape[M-1],inputShape[M+1], ..., inputShape[N-1]] - * - * See step 3: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd - */ - function getReshapedPermuted(inputShape, blockShape, prod, batchToSpace) { - if (batchToSpace === void 0) { batchToSpace = true; } - var reshapedPermuted = []; - if (batchToSpace) { - reshapedPermuted.push(inputShape[0] / prod); - } - else { - reshapedPermuted.push(inputShape[0] * prod); - } - for (var i = 1; i < inputShape.length; ++i) { - if (i <= blockShape.length) { - if (batchToSpace) { - reshapedPermuted.push(blockShape[i - 1] * inputShape[i]); - } - else { - reshapedPermuted.push(inputShape[i] / blockShape[i - 1]); - } - } - else { - reshapedPermuted.push(inputShape[i]); - } - } - return reshapedPermuted; - } - /** - * Converts the crops argument into the beginning coordinates of a slice - * operation. - */ - function getSliceBeginCoords(crops, blockShape) { - var sliceBeginCoords = [0]; - for (var i = 0; i < blockShape; ++i) { - sliceBeginCoords.push(crops[i][0]); - } - return sliceBeginCoords; - } - /** - * Converts the crops argument into the size of a slice operation. When - * combined with getSliceBeginCoords this function allows the reshaped and - * permuted Tensor to be cropped to its final output shape of: - * - * inputShape[1] * blockShape[0] - crops[0,0] - crops[0,1], ..., - * inputShape[M] * blockShape[M-1] -crops[M-1,0] - - * crops[M-1,1],inputShape[M+1], ..., inputShape[N-1]] - * - * See step 4: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd - */ - function getSliceSize(uncroppedShape, crops, blockShape) { - var sliceSize = uncroppedShape.slice(0, 1); - for (var i = 0; i < blockShape; ++i) { - sliceSize.push(uncroppedShape[i + 1] - crops[i][0] - crops[i][1]); - } - return sliceSize; - } - - var Add = 'Add'; - var AddN = 'AddN'; - var Div = 'Div'; - var FusedBatchNorm = 'FusedBatchNorm'; - var SquaredDifference = 'SquaredDifference'; - var Square = 'Square'; - var Transpose = 'Transpose'; - var NonMaxSuppressionV5 = 'NonMaxSuppressionV5'; - var BroadcastTo = 'BroadcastTo'; - var OneHot = 'OneHot'; - var Identity = 'Identity'; - var Tile = 'Tile'; - var PadV2 = 'PadV2'; - /** - * TensorFlow.js-only kernels - */ - var FromPixels = 'FromPixels'; - var MaxPoolWithArgmax = 'MaxPoolWithArgmax'; - - /** - * @license - * Copyright 2020 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /** - * Adds two `tf.Tensor`s element-wise, A + B. Supports broadcasting. - * - * We also expose `tf.addStrict` which has the same signature as this op and - * asserts that `a` and `b` are the same shape (does not broadcast). - * - * ```js - * const a = tf.tensor1d([1, 2, 3, 4]); - * const b = tf.tensor1d([10, 20, 30, 40]); - * - * a.add(b).print(); // or tf.add(a, b) - * ``` - * - * ```js - * // Broadcast add a with b. - * const a = tf.scalar(5); - * const b = tf.tensor1d([10, 20, 30, 40]); - * - * a.add(b).print(); // or tf.add(a, b) - * ``` - * @param a The first `tf.Tensor` to add. - * @param b The second `tf.Tensor` to add. Must have the same type as `a`. - */ - /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ - function add_(a, b) { - var _a; - var $a = convertToTensor(a, 'a', 'add'); - var $b = convertToTensor(b, 'b', 'add'); - _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1]; - var forward = function (backend, save) { - var res = backend.add($a, $b); - save([$a, $b]); - return res; - }; - var inputs = { a: $a, b: $b }; - return ENGINE.runKernelFunc(forward, inputs, null /* gradient */, Add); - } - var add = op({ add_: add_ }); - - /** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /** - * Returns the dimensions in the input shape that are broadcasted to - * produce the provided output shape. - * - * The returned dimensions are 0-indexed and sorted. An example: - * inShape = [4, 1, 3] - * outShape = [5, 4, 3, 3] - * result = [1]. Dimension 1 (2nd dimension of input) gets broadcasted 1 => 3. - */ - function getBroadcastDims(inShape, outShape) { - var inRank = inShape.length; - var dims = []; - for (var i = 0; i < inRank; i++) { - var dim = inRank - 1 - i; - var a = inShape[dim] || 1; - var b = outShape[outShape.length - 1 - i] || 1; - if (b > 1 && a === 1) { - dims.unshift(dim); - } - } - return dims; - } - /** - * Returns the axes in the output space that should be reduced to produce - * the input space. - */ - function getReductionAxes(inShape, outShape) { - var result = []; - for (var i = 0; i < outShape.length; i++) { - var inDim = inShape[inShape.length - i - 1]; - var outAxis = outShape.length - i - 1; - var outDim = outShape[outAxis]; - if (inDim == null || (inDim === 1 && outDim > 1)) { - result.unshift(outAxis); - } - } - return result; - } - function assertAndGetBroadcastShape(shapeA, shapeB) { - var result = []; - var l = Math.max(shapeA.length, shapeB.length); - for (var i = 0; i < l; i++) { - var a = shapeA[shapeA.length - i - 1]; - if (a == null) { - a = 1; - } - var b = shapeB[shapeB.length - i - 1]; - if (b == null) { - b = 1; - } - if (a === 1) { - result.unshift(b); - } - else if (b === 1) { - result.unshift(a); - } - else if (a !== b) { - var errMsg = "Operands could not be broadcast together with shapes " + - (shapeA + " and " + shapeB + "."); - throw Error(errMsg); - } - else { - result.unshift(a); - } - } - return result; - } - - /** - * @license - * Copyright 2018 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /** - * Computes `-1 * x` element-wise. - * - * ```js - * const x = tf.tensor2d([1, 2, -2, 0], [2, 2]); - * - * x.neg().print(); // or tf.neg(x) - * ``` - * - * @param x The input tensor. - */ - /** @doc {heading: 'Operations', subheading: 'Basic math'} */ - function neg_(x) { - var $x = convertToTensor(x, 'x', 'neg'); - var grad = function (dy) { - return { x: function () { return dy.neg(); } }; - }; - var attrs = {}; - var inputsToSave = [$x]; - return ENGINE.runKernelFunc(function (backend) { return backend.neg($x); }, { x: $x }, grad, 'Neg', attrs, inputsToSave); - } - /** - * Computes ceiling of input `tf.Tensor` element-wise: `ceil(x)` - * - * ```js - * const x = tf.tensor1d([.6, 1.1, -3.3]); - * - * x.ceil().print(); // or tf.ceil(x) - * ``` - * @param x The input Tensor. - */ - /** @doc {heading: 'Operations', subheading: 'Basic math'} */ - function ceil_(x) { - var $x = convertToTensor(x, 'x', 'ceil'); - // TODO(manrajgrover): Return null for gradients when backprop supports it. - var grad = function (dy) { - return { $x: function () { return zerosLike(dy); } }; - }; - return ENGINE.runKernelFunc(function (backend) { return backend.ceil($x); }, { $x: $x }, grad); - } - /** - * Computes floor of input `tf.Tensor` element-wise: `floor(x)`. - * - * ```js - * const x = tf.tensor1d([.6, 1.1, -3.3]); - * - * x.floor().print(); // or tf.floor(x) - * ``` - * @param x The input tensor. - */ - /** @doc {heading: 'Operations', subheading: 'Basic math'} */ - function floor_(x) { - var $x = convertToTensor(x, 'x', 'floor'); - // TODO(nsthorat): Let gradients be null for cases where we want to stop - // backpropgation. - var grad = function (dy) { - return { $x: function () { return zerosLike(dy); } }; - }; - return ENGINE.runKernelFunc(function (backend) { return backend.floor($x); }, { $x: $x }, grad); - } - /** - * Returns an element-wise indication of the sign of a number. - * - * ```js - * const x = tf.tensor1d([.6, 1.1, -3.3, NaN, 0]); - * - * x.sign().print(); // or tf.sign(x) - * ``` - * @param x The input Tensor. - */ - /** @doc {heading: 'Operations', subheading: 'Basic math'} */ - function sign_(x) { - var $x = convertToTensor(x, 'x', 'sign'); - var grad = function (dy) { - return { $x: function () { return zerosLike(dy); } }; - }; - return ENGINE.runKernelFunc(function (backend) { return backend.sign($x); }, { $x: $x }, grad); - } - /** - * RReturns which elements of x are NaN. - * - * ```js - * const x = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]); - * - * x.isNaN().print(); // or tf.isNaN(x) - * ``` - * @param x The input Tensor. - */ - /** @doc {heading: 'Operations', subheading: 'Basic math'} */ - function isNaN_(x) { - var $x = convertToTensor(x, 'x', 'isNaN'); - // TODO(nsthorat): Let gradients be null for cases where we want to stop - // backpropgation. - var grad = function (dy) { - return { $x: function () { return zerosLike(dy); } }; - }; - return ENGINE.runKernelFunc(function (backend) { return backend.isNaN($x); }, { $x: $x }, grad); - } - /** - * Returns which elements of x are Infinity or -Infinity. - * - * ```js - * const x = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]); - * - * x.isInf().print(); // or tf.isNaN(x) - * ``` - * @param x The input Tensor. - */ - /** @doc {heading: 'Operations', subheading: 'Basic math'} */ - function isInf_(x) { - var $x = convertToTensor(x, 'x', 'isInf'); - // TODO(nsthorat): Let gradients be null for cases where we want to stop - // backpropgation. - var grad = function (dy) { - return { $x: function () { return zerosLike(dy); } }; - }; - return ENGINE.runKernelFunc(function (backend) { return backend.isInf($x); }, { $x: $x }, grad); - } - /** - * Returns which elements of x are finite. - * - * ```js - * const x = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]); - * - * x.isFinite().print(); // or tf.isNaN(x) - * ``` - * @param x The input Tensor. - */ - /** @doc {heading: 'Operations', subheading: 'Basic math'} */ - function isFinite_(x) { - var $x = convertToTensor(x, 'x', 'isFinite'); - // TODO(nsthorat): Let gradients be null for cases where we want to stop - // backpropgation. - var grad = function (dy) { - return { $x: function () { return zerosLike(dy); } }; - }; - return ENGINE.runKernelFunc(function (backend) { return backend.isFinite($x); }, { $x: $x }, grad); - } - /** - * Computes round of input `tf.Tensor` element-wise: `round(x)`. - * It implements banker's rounding. - * - * ```js - * const x = tf.tensor1d([.6, 1.1, -3.3]); - * - * x.round().print(); // or tf.round(x) - * ``` - * @param x The input tensor. - */ - /** @doc {heading: 'Operations', subheading: 'Basic math'} */ - function round_(x) { - var $x = convertToTensor(x, 'x', 'round'); - // TODO(nsthorat): Let gradients be null for cases where we want to stop - // backpropgation. - var grad = function (dy) { - return { $x: function () { return zerosLike(dy); } }; - }; - return ENGINE.runKernelFunc(function (backend) { return backend.round($x); }, { $x: $x }, grad); - } - /** - * Computes exponential of the input `tf.Tensor` element-wise. `e ^ x` - * - * ```js - * const x = tf.tensor1d([1, 2, -3]); - * - * x.exp().print(); // or tf.exp(x) - * ``` - * @param x The input tensor. - */ - /** @doc {heading: 'Operations', subheading: 'Basic math'} */ - function exp_(x) { - var $x = convertToTensor(x, 'x', 'exp'); - var bck = function (dy, saved) { - return { x: function () { return dy.mulStrict(saved[0]); } }; - }; - var attrs = {}; - var inputsToSave = []; - var outputsToSave = [true]; - return ENGINE.runKernelFunc(function (backend, save) { - var y = backend.exp($x); - save([y]); - return y; - }, { x: $x }, bck, 'Exp', attrs, inputsToSave, outputsToSave); - } - /** - * Computes exponential of the input `tf.Tensor` minus one element-wise. - * `e ^ x - 1` - * - * ```js - * const x = tf.tensor1d([1, 2, -3]); - * - * x.expm1().print(); // or tf.expm1(x) - * ``` - * @param x The input tensor. - */ - /** @doc {heading: 'Operations', subheading: 'Basic math'} */ - function expm1_(x) { - var $x = convertToTensor(x, 'x', 'expm1'); - var grad = function (dy, saved) { - var $x = saved[0]; - return { $x: function () { return dy.mul($x.exp()); } }; - }; - return ENGINE.runKernelFunc(function (backend, save) { - var res = backend.expm1($x); - save([$x]); - return res; - }, { $x: $x }, grad); - } - /** - * Computes natural logarithm of the input `tf.Tensor` element-wise: `ln(x)` - * - * ```js - * const x = tf.tensor1d([1, 2, Math.E]); - * - * x.log().print(); // or tf.log(x) - * ``` - * @param x The input tensor. - */ - /** @doc {heading: 'Operations', subheading: 'Basic math'} */ - function log_(x) { - var $x = convertToTensor(x, 'x', 'log'); - var grad = function (dy, saved) { - var $x = saved[0]; - return { x: function () { return dy.div($x.toFloat()); } }; - }; - var attrs = {}; - var inputsToSave = [$x]; - return ENGINE.runKernelFunc(function (backend, save) { - var res = backend.log($x); - save([$x]); - return res; - }, { x: $x }, grad, 'Log', attrs, inputsToSave); - } - /** - * Computes natural logarithm of the input `tf.Tensor` plus one - * element-wise: `ln(1 + x)` - * - * ```js - * const x = tf.tensor1d([1, 2, Math.E - 1]); - * - * x.log1p().print(); // or tf.log1p(x) - * ``` - * @param x The input tensor. - */ - /** @doc {heading: 'Operations', subheading: 'Basic math'} */ - function log1p_(x) { - var $x = convertToTensor(x, 'x', 'log1p'); - var grad = function (dy, saved) { - var $x = saved[0]; - return { $x: function () { return dy.div($x.add(1)); } }; - }; - return ENGINE.runKernelFunc(function (backend, save) { - var res = backend.log1p($x); - save([$x]); - return res; - }, { $x: $x }, grad); - } - /** - * Computes square root of the input `tf.Tensor` element-wise: `y = sqrt(x)` - * - * ```js - * const x = tf.tensor1d([1, 2, 4, -1]); - * - * x.sqrt().print(); // or tf.sqrt(x) - * ``` - * @param x The input tensor. - */ - /** @doc {heading: 'Operations', subheading: 'Basic math'} */ - function sqrt_(x) { - var $x = convertToTensor(x, 'x', 'sqrt'); - var grad = function (dy, saved) { - var $x = saved[0]; - return { $x: function () { return dy.div($x.toFloat().sqrt().mul(2)); } }; - }; - return ENGINE.runKernelFunc(function (backend, save) { - var res = backend.sqrt($x); - save([$x]); - return res; - }, { $x: $x }, grad); - } - /** - * Computes reciprocal of square root of the input `tf.Tensor` element-wise: - * `y = 1 / sqrt(x)` - * - * ```js - * const x = tf.tensor1d([1, 2, 4, -1]); - * - * x.rsqrt().print(); // or tf.rsqrt(x) - * ``` - * @param x The input tensor. - */ - /** @doc {heading: 'Operations', subheading: 'Basic math'} */ - function rsqrt_(x) { - var $x = convertToTensor(x, 'x', 'rsqrt'); - var grad = function (dy, saved) { - var $x = saved[0]; - return { x: function () { return dy.div($x.pow(1.5).mul(2)).neg(); } }; - }; - var inputsToSave = [$x]; - return ENGINE.runKernelFunc(function (backend, save) { - var res = backend.rsqrt($x); - save([$x]); - return res; - }, { x: $x }, grad, 'Rsqrt', {} /* attrs */, inputsToSave); - } - /** - * Computes reciprocal of x element-wise: `1 / x` - * - * ```js - * const x = tf.tensor1d([0, 1, 2]); - * - * x.reciprocal().print(); // or tf.reciprocal(x) - * ``` - * @param x The input tensor. - */ - /** @doc {heading: 'Operations', subheading: 'Basic math'} */ - function reciprocal_(x) { - var $x = convertToTensor(x, 'x', 'reciprocal'); - var grad = function (dy, saved) { - var $x = saved[0]; - return { $x: function () { return dy.div($x.square().neg()); } }; - }; - return ENGINE.runKernelFunc(function (backend, save) { - var res = backend.reciprocal($x); - save([$x]); - return res; - }, { $x: $x }, grad); - } - /** - * Computes absolute value element-wise: `abs(x)` - * - * ```js - * const x = tf.tensor1d([-1, 2, -3, 4]); - * - * x.abs().print(); // or tf.abs(x) - * ``` - * @param x The input `tf.Tensor`. - */ - /** @doc {heading: 'Operations', subheading: 'Basic math'} */ - function abs_(x) { - var $x = convertToTensor(x, 'x', 'abs'); - if ($x.dtype === 'complex64') { - return ENGINE.runKernelFunc(function (backend) { return backend.complexAbs($x); }, { $x: $x }); - } - var grad = function (dy, saved) { - var $x = saved[0]; - return { x: function () { return dy.mul($x.toFloat().step(-1)); } }; - }; - return ENGINE.runKernelFunc(function (backend, save) { - var res = backend.abs($x); - save([$x]); - return res; - }, { x: $x }, grad, 'Abs'); - } - /** - * Clips values element-wise. `max(min(x, clipValueMax), clipValueMin)` - * - * ```js - * const x = tf.tensor1d([-1, 2, -3, 4]); - * - * x.clipByValue(-2, 3).print(); // or tf.clipByValue(x, -2, 3) - * ``` - * @param x The input tensor. - * @param clipValueMin Lower-bound of range to be clipped to. - * @param clipValueMax Upper-bound of range to be clipped to. - */ - /** @doc {heading: 'Operations', subheading: 'Basic math'} */ - function clipByValue_(x, clipValueMin, clipValueMax) { - var $x = convertToTensor(x, 'x', 'clipByValue'); - assert((clipValueMin <= clipValueMax), function () { return "Error in clip: min (" + clipValueMin + ") must be " + - ("less than or equal to max (" + clipValueMax + ")."); }); - var grad = function (dy, saved) { - var $x = saved[0]; - return { - x: function () { return dy.where($x.greaterEqual(clipValueMin) - .logicalAnd($x.lessEqual(clipValueMax)), zerosLike(dy)); }, - }; - }; - var inputsToSave = [$x]; - var attr = { min: clipValueMin, max: clipValueMax }; - return ENGINE.runKernelFunc(function (backend, save) { - var res = backend.clip($x, clipValueMin, clipValueMax); - save([$x]); - return res; - }, { x: $x }, grad, 'ClipByValue', attr, inputsToSave); - } - /** - * Computes sigmoid element-wise, `1 / (1 + exp(-x))` - * - * ```js - * const x = tf.tensor1d([0, -1, 2, -3]); - * - * x.sigmoid().print(); // or tf.sigmoid(x) - * ``` - * @param x The input tensor. - */ - /** @doc {heading: 'Operations', subheading: 'Basic math'} */ - function sigmoid_(x) { - var $x = convertToTensor(x, 'x', 'sigmoid'); - var grad = function (dy, saved) { - var y = saved[0]; - return { x: function () { return dy.mul(y.mul(scalar(1).sub(y))); } }; - }; - return ENGINE.runKernelFunc(function (backend, save) { - var y = backend.sigmoid($x); - save([y]); - return y; - }, { x: $x }, grad, 'Sigmoid'); - } - /** - * Computes log sigmoid of the input `tf.Tensor` element-wise: - * `logSigmoid(x)`. For numerical stability, we use `-tf.softplus(-x)`. - * - * ```js - * const x = tf.tensor1d([0, 1, -1, .7]); - * - * x.logSigmoid().print(); // or tf.logSigmoid(x) - * ``` - * @param x The input tensor. - */ - /** @doc {heading: 'Operations', subheading: 'Basic math'} */ - function logSigmoid_(x) { - var $x = convertToTensor(x, 'x', 'logSigmoid'); - var grad = function (dy, saved) { - var $x = saved[0]; - return { $x: function () { return dy.mul($x.neg().sigmoid()); } }; - }; - return ENGINE.runKernelFunc(function (backend, save) { - var res = backend.softplus($x.neg()).neg(); - save([$x]); - return res; - }, { $x: $x }, grad); - } - /** - * Computes softplus of the input `tf.Tensor` element-wise: `log(exp(x) + 1)` - * - * ```js - * const x = tf.tensor1d([0, 1, -1, .7]); - * - * x.softplus().print(); // or tf.softplus(x) - * ``` - * @param x The input tensor. - */ - /** @doc {heading: 'Operations', subheading: 'Basic math'} */ - function softplus_(x) { - var $x = convertToTensor(x, 'x', 'softplus'); - var grad = function (dy, saved) { - var $x = saved[0]; - return { $x: function () { return dy.mul($x.sigmoid()); } }; - }; - return ENGINE.runKernelFunc(function (backend, save) { - var res = backend.softplus($x); - save([$x]); - return res; - }, { $x: $x }, grad); - } - /** - * Computes sin of the input Tensor element-wise: `sin(x)` - * - * ```js - * const x = tf.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]); - * - * x.sin().print(); // or tf.sin(x) - * ``` - * @param x The input tensor. - */ - /** @doc {heading: 'Operations', subheading: 'Basic math'} */ - function sin_(x) { - var $x = convertToTensor(x, 'x', 'sin'); - var grad = function (dy, saved) { - var $x = saved[0]; - return { x: function () { return $x.toFloat().cos().mul(dy); } }; - }; - var inputsToSave = [$x]; - return ENGINE.runKernelFunc(function (backend, save) { - var res = backend.sin($x); - save([$x]); - return res; - }, { x: $x }, grad, 'Sin', {} /* attrs */, inputsToSave); - } - /** - * Computes cos of the input `tf.Tensor` element-wise: `cos(x)` - * - * ```js - * const x = tf.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]); - * - * x.cos().print(); // or tf.cos(x) - * ``` - * @param x The input tensor. - */ - /** @doc {heading: 'Operations', subheading: 'Basic math'} */ - function cos_(x) { - var $x = convertToTensor(x, 'x', 'cos'); - var grad = function (dy, saved) { - var $x = saved[0]; - return { x: function () { return $x.toFloat().sin().neg().mul(dy); } }; - }; - var inputsToSave = [$x]; - return ENGINE.runKernelFunc(function (backend, save) { - var res = backend.cos($x); - save([$x]); - return res; - }, { x: $x }, grad, 'Cos', {} /* attrs */, inputsToSave); - } - /** - * Computes tan of the input `tf.Tensor` element-wise, `tan(x)` - * - * ```js - * const x = tf.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]); - * - * x.tan().print(); // or tf.tan(x) - * ``` - * @param x The input tensor. - */ - /** @doc {heading: 'Operations', subheading: 'Basic math'} */ - function tan_(x) { - var $x = convertToTensor(x, 'x', 'tan'); - var grad = function (dy, saved) { - var $x = saved[0]; - return { $x: function () { return dy.div($x.cos().square()); } }; - }; - return ENGINE.runKernelFunc(function (backend, save) { - var res = backend.tan($x); - save([$x]); - return res; - }, { $x: $x }, grad); - } - /** - * Computes asin of the input `tf.Tensor` element-wise: `asin(x)` - * - * ```js - * const x = tf.tensor1d([0, 1, -1, .7]); - * - * x.asin().print(); // or tf.asin(x) - * ``` - * @param x The input tensor. - */ - /** @doc {heading: 'Operations', subheading: 'Basic math'} */ - function asin_(x) { - var $x = convertToTensor(x, 'x', 'asin'); - var grad = function (dy, saved) { - var $x = saved[0]; - return { - $x: function () { return dy.divStrict(scalar(1).sub($x.toFloat().square()).sqrt()); } - }; - }; - return ENGINE.runKernelFunc(function (backend, save) { - var res = backend.asin($x); - save([$x]); - return res; - }, { $x: $x }, grad); - } - /** - * Computes acos of the input `tf.Tensor` element-wise: `acos(x)` - * - * ```js - * const x = tf.tensor1d([0, 1, -1, .7]); - * - * x.acos().print(); // or tf.acos(x) - * ``` - * @param x The input tensor. - */ - /** @doc {heading: 'Operations', subheading: 'Basic math'} */ - function acos_(x) { - var $x = convertToTensor(x, 'x', 'acos'); - var grad = function (dy, saved) { - var $x = saved[0]; - return { - $x: function () { - return dy.divStrict(scalar(1).sub($x.toFloat().square()).sqrt()).neg(); - } - }; - }; - return ENGINE.runKernelFunc(function (backend, save) { - var res = backend.acos($x); - save([$x]); - return res; - }, { $x: $x }, grad); - } - /** - * Computes atan of the input `tf.Tensor` element-wise: `atan(x)` - * - * ```js - * const x = tf.tensor1d([0, 1, -1, .7]); - * - * x.atan().print(); // or tf.atan(x) - * ``` - * @param x The input tensor. - */ - /** @doc {heading: 'Operations', subheading: 'Basic math'} */ - function atan_(x) { - var $x = convertToTensor(x, 'x', 'atan'); - var grad = function (dy, saved) { - var $x = saved[0]; - return { $x: function () { return dy.div($x.toFloat().square().add(1)); } }; - }; - return ENGINE.runKernelFunc(function (backend, save) { - var res = backend.atan($x); - save([$x]); - return res; - }, { $x: $x }, grad); - } - /** - * Computes hyperbolic sin of the input `tf.Tensor` element-wise: `sinh(x)` - * - * ```js - * const x = tf.tensor1d([0, 1, -1, .7]); - * - * x.sinh().print(); // or tf.sinh(x) - * ``` - * @param x The input tensor. - */ - /** @doc {heading: 'Operations', subheading: 'Basic math'} */ - function sinh_(x) { - var $x = convertToTensor(x, 'x', 'sinh'); - var grad = function (dy, saved) { - var $x = saved[0]; - return { $x: function () { return $x.toFloat().cosh().mulStrict(dy); } }; - }; - return ENGINE.runKernelFunc(function (backend, save) { - var res = backend.sinh($x); - save([$x]); - return res; - }, { $x: $x }, grad); - } - /** - * Computes hyperbolic cos of the input `tf.Tensor` element-wise: `cosh(x)` - * - * ```js - * const x = tf.tensor1d([0, 1, -1, .7]); - * - * x.cosh().print(); // or tf.cosh(x) - * ``` - * @param x The input tensor. - */ - /** @doc {heading: 'Operations', subheading: 'Basic math'} */ - function cosh_(x) { - var $x = convertToTensor(x, 'x', 'cosh'); - var grad = function (dy, saved) { - var $x = saved[0]; - return { $x: function () { return $x.toFloat().sinh().mulStrict(dy); } }; - }; - return ENGINE.runKernelFunc(function (backend, save) { - var res = backend.cosh($x); - save([$x]); - return res; - }, { $x: $x }, grad); - } - /** - * Computes hyperbolic tangent of the input `tf.Tensor` element-wise: `tanh(x)` - * - * ```js - * const x = tf.tensor1d([0, 1, -1, 70]); - * - * x.tanh().print(); // or tf.tanh(x) - * ``` - * @param x The input tensor. - */ - /** @doc {heading: 'Operations', subheading: 'Basic math'} */ - function tanh_(x) { - var $x = convertToTensor(x, 'x', 'tanh'); - var grad = function (dy, saved) { - var y = saved[0]; - return { x: function () { return scalar(1).sub(y.square()).mulStrict(dy); } }; - }; - var outputsToSave = [true]; - return ENGINE.runKernelFunc(function (backend, save) { - var y = backend.tanh($x); - save([y]); - return y; - }, { x: $x }, grad, 'Tanh', {} /* attrs */, null /* inputsToSave */, outputsToSave); - } - /** - * Computes inverse hyperbolic sin of the input `tf.Tensor` element-wise: - * `asinh(x)` - * - * ```js - * const x = tf.tensor1d([0, 1, -1, .7]); - * - * x.asinh().print(); // or tf.asinh(x) - * ``` - * @param x The input tensor. - */ - /** @doc {heading: 'Operations', subheading: 'Basic math'} */ - function asinh_(x) { - var $x = convertToTensor(x, 'x', 'asinh'); - var grad = function (dy, saved) { - var $x = saved[0]; - return { - $x: function () { return dy.divStrict(scalar(1).add($x.toFloat().square()).sqrt()); } - }; - }; - return ENGINE.runKernelFunc(function (backend, save) { - var res = backend.asinh($x); - save([$x]); - return res; - }, { $x: $x }, grad); - } - /** - * Computes the inverse hyperbolic cos of the input `tf.Tensor` element-wise: - * `acosh(x)` - * - * ```js - * const x = tf.tensor1d([10, 1, 3, 5.7]); - * - * x.acosh().print(); // or tf.acosh(x) - * ``` - * @param x The input tensor. - */ - /** @doc {heading: 'Operations', subheading: 'Basic math'} */ - function acosh_(x) { - var $x = convertToTensor(x, 'x', 'acosh'); - var grad = function (dy, saved) { - var $x = saved[0]; - return { $x: function () { return dy.divStrict($x.toFloat().square().sub(1).sqrt()); } }; - }; - return ENGINE.runKernelFunc(function (backend, save) { - var res = backend.acosh($x); - save([$x]); - return res; - }, { $x: $x }, grad); - } - /** - * Computes inverse hyperbolic tan of the input `tf.Tensor` element-wise: - * `atanh(x)` - * - * ```js - * const x = tf.tensor1d([0, .1, -.1, .7]); - * - * x.atanh().print(); // or tf.atanh(x) - * ``` - * @param x The input tensor. - */ - /** @doc {heading: 'Operations', subheading: 'Basic math'} */ - function atanh_(x) { - var $x = convertToTensor(x, 'x', 'atanh'); - var grad = function (dy, saved) { - var $x = saved[0]; - return { $x: function () { return dy.div(scalar(1).sub($x.toFloat().square())); } }; - }; - return ENGINE.runKernelFunc(function (backend, save) { - var res = backend.atanh($x); - save([$x]); - return res; - }, { $x: $x }, grad); - } - /** - * Computes gause error function of the input `tf.Tensor` element-wise: - * `erf(x)` - * - * ```js - * const x = tf.tensor1d([0, .1, -.1, .7]); - * - * x.erf().print(); // or tf.erf(x); - * ``` - * @param x The input tensor. - */ - /** @doc {heading: 'Operations', subheading: 'Basic math'} */ - function erf_(x) { - var $x = convertToTensor(x, 'x', 'erf'); - assert($x.dtype === 'int32' || $x.dtype === 'float32', function () { return 'Input dtype must be `int32` or `float32`.'; }); - if ($x.dtype === 'int32') { - $x = $x.toFloat(); - } - var grad = function (dy, saved) { - var $x = saved[0]; - return { - $x: function () { return dy.mul($x.square().neg().exp().mul(2 / Math.sqrt(Math.PI))); } - }; - }; - return ENGINE.runKernelFunc(function (backend, save) { - var res = backend.erf($x); - save([$x]); - return res; - }, { $x: $x }, grad); - } - /** - * Computes step of the input `tf.Tensor` element-wise: `x > 0 ? 1 : alpha * x` - * - * ```js - * const x = tf.tensor1d([0, 2, -1, -3]); - * - * x.step(.5).print(); // or tf.step(x, .5) - * ``` - * @param x The input tensor. - * @param alpha The gradient when input is negative. - */ - /** @doc {heading: 'Operations', subheading: 'Basic math'} */ - function step_(x, alpha) { - if (alpha === void 0) { alpha = 0.0; } - var $x = convertToTensor(x, 'x', 'step'); - // TODO(manrajgrover): Return null for gradients when backprop supports - // it. - var grad = function (dy) { - return { $x: function () { return zerosLike(dy); } }; - }; - return ENGINE.runKernelFunc(function (backend) { return backend.step($x, alpha); }, { $x: $x }, grad); - } - var abs = op({ abs_: abs_ }); - var acos = op({ acos_: acos_ }); - var acosh = op({ acosh_: acosh_ }); - var asin = op({ asin_: asin_ }); - var asinh = op({ asinh_: asinh_ }); - var atan = op({ atan_: atan_ }); - var atanh = op({ atanh_: atanh_ }); - var ceil = op({ ceil_: ceil_ }); - var clipByValue = op({ clipByValue_: clipByValue_ }); - var cos = op({ cos_: cos_ }); - var cosh = op({ cosh_: cosh_ }); - var erf = op({ erf_: erf_ }); - var exp = op({ exp_: exp_ }); - var expm1 = op({ expm1_: expm1_ }); - var floor = op({ floor_: floor_ }); - var log = op({ log_: log_ }); - var log1p = op({ log1p_: log1p_ }); - var logSigmoid = op({ logSigmoid_: logSigmoid_ }); - var neg = op({ neg_: neg_ }); - var reciprocal = op({ reciprocal_: reciprocal_ }); - var round = op({ round_: round_ }); - var rsqrt = op({ rsqrt_: rsqrt_ }); - var sigmoid = op({ sigmoid_: sigmoid_ }); - var sign = op({ sign_: sign_ }); - var isNaN$1 = op({ isNaN_: isNaN_ }); - var isInf = op({ isInf_: isInf_ }); - var isFinite$1 = op({ isFinite_: isFinite_ }); - var sin = op({ sin_: sin_ }); - var sinh = op({ sinh_: sinh_ }); - var softplus = op({ softplus_: softplus_ }); - var sqrt = op({ sqrt_: sqrt_ }); - var step = op({ step_: step_ }); - var tan = op({ tan_: tan_ }); - var tanh$1 = op({ tanh_: tanh_ }); - - /** - * @license - * Copyright 2018 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /** - * Adds two `tf.Tensor`s element-wise, A + B. - * - * Inputs must be the same shape. For broadcasting support, use add() instead. - * - * @param a The first Tensor to add element-wise. - * @param b The second Tensor to add element-wise. - */ - function addStrict_(a, b) { - var $a = convertToTensor(a, 'a', 'addStrict'); - var $b = convertToTensor(b, 'b', 'addStrict'); - assertShapesMatch($a.shape, $b.shape, 'Error in addStrict: '); - return $a.add($b); - } - /** - * Subtracts two `tf.Tensor`s element-wise, A - B. Supports broadcasting. - * - * We also expose `tf.subStrict` which has the same signature as this op and - * asserts that `a` and `b` are the same shape (does not broadcast). - * - * ```js - * const a = tf.tensor1d([10, 20, 30, 40]); - * const b = tf.tensor1d([1, 2, 3, 4]); - * - * a.sub(b).print(); // or tf.sub(a, b) - * ``` - * - * ```js - * // Broadcast subtract a with b. - * const a = tf.tensor1d([10, 20, 30, 40]); - * const b = tf.scalar(5); - * - * a.sub(b).print(); // or tf.sub(a, b) - * ``` - * @param a The first `tf.Tensor` to subtract from. - * @param b The second `tf.Tensor` to be subtracted. Must have the same dtype as - * `a`. - */ - /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ - function sub_(a, b) { - var _a; - var $a = convertToTensor(a, 'a', 'sub'); - var $b = convertToTensor(b, 'b', 'sub'); - _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1]; - var outShape = assertAndGetBroadcastShape($a.shape, $b.shape); - var der = function (dy) { - var derA = function () { - var res = dy; - var reduceAxes = getReductionAxes($a.shape, outShape); - if (reduceAxes.length > 0) { - res = res.sum(reduceAxes); - } - return res.reshape($a.shape); - }; - var derB = function () { - var res = dy; - var reduceAxes = getReductionAxes($b.shape, outShape); - if (reduceAxes.length > 0) { - res = res.sum(reduceAxes); - } - return res.neg().reshape($b.shape); - }; - return { a: derA, b: derB }; - }; - return ENGINE.runKernelFunc(function (backend) { return backend.subtract($a, $b); }, { a: $a, b: $b }, der, 'Sub'); - } - /** - * Subtracts two `tf.Tensor`s element-wise, A - B. Inputs must - * be the same shape. - * - * For broadcasting support, use `tf.sub` instead. - * - * @param a The first Tensor to subtract element-wise. - * @param b The second Tensor to subtract element-wise. - */ - function subStrict_(a, b) { - var $a = convertToTensor(a, 'a', 'subStrict'); - var $b = convertToTensor(b, 'b', 'subStrict'); - assertShapesMatch($a.shape, $b.shape, 'Error in subStrict: '); - return $a.sub($b); - } - /** - * Computes the power of one `tf.Tensor` to another. Supports broadcasting. - * - * Given a `tf.Tensor` x and a `tf.Tensor` y, this operation computes x^y for - * corresponding elements in x and y. The result's dtype will be the upcasted - * type of the `base` and `exp` dtypes. - * - * ```js - * const a = tf.tensor([[2, 3], [4, 5]]) - * const b = tf.tensor([[1, 2], [3, 0]]).toInt(); - * - * a.pow(b).print(); // or tf.pow(a, b) - * ``` - * - * ```js - * const a = tf.tensor([[1, 2], [3, 4]]) - * const b = tf.tensor(2).toInt(); - * - * a.pow(b).print(); // or tf.pow(a, b) - * ``` - * We also expose `powStrict` which has the same signature as this op and - * asserts that `base` and `exp` are the same shape (does not broadcast). - * - * @param base The base `tf.Tensor` to pow element-wise. - * @param exp The exponent `tf.Tensor` to pow element-wise. - */ - /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ - function pow_(base, exp) { - var _a; - var $base = convertToTensor(base, 'base', 'pow'); - var $exp = convertToTensor(exp, 'exp', 'pow'); - _a = makeTypesMatch($base, $exp), $base = _a[0], $exp = _a[1]; - var outShape = assertAndGetBroadcastShape($base.shape, $exp.shape); - var grad = function (dy, saved) { - var $base = saved[0], $exp = saved[1], y = saved[2]; - var derBase = function () { - var expFloat = $exp.toFloat(); - var res = dy.mul(expFloat.mul($base.pow(expFloat.sub(scalar(1))))); - var reduceAxes = getReductionAxes($base.shape, outShape); - if (reduceAxes.length > 0) { - res = res.sum(reduceAxes); - } - return res.reshape($base.shape); - }; - var derExp = function () { - var condition = $base.greater(0); - var logBase = $base.log().where(condition, zerosLike($base)); - var res = dy.mul(y.mul(logBase)); - var reduceAxes = getReductionAxes($exp.shape, outShape); - if (reduceAxes.length > 0) { - res = res.sum(reduceAxes); - } - return res.reshape($exp.shape); - }; - return { a: derBase, b: derExp }; - }; - var attrs = {}; - var inputsToSave = [$base, $exp]; - var outputsToSave = [true]; - return ENGINE.runKernelFunc(function (backend, save) { - var y = backend.pow($base, $exp); - save([$base, $exp, y]); - return y; - }, { a: $base, b: $exp }, grad, 'Pow', attrs, inputsToSave, outputsToSave); - } - /** - * Computes the power of one `tf.Tensor` to another. Inputs must - * be the same shape. - * - * For broadcasting support, use `tf.pow` instead. - * - * @param base The base tensor to pow element-wise. - * @param exp The exponent tensor to pow element-wise. - */ - function powStrict_(base, exp) { - assertShapesMatch(base.shape, exp.shape, 'Error in powStrict: '); - return base.pow(exp); - } - /** - * Multiplies two `tf.Tensor`s element-wise, A * B. Supports broadcasting. - * - * We also expose `tf.mulStrict` which has the same signature as this op and - * asserts that `a` and `b` are the same shape (does not broadcast). - * - * ```js - * const a = tf.tensor1d([1, 2, 3, 4]); - * const b = tf.tensor1d([2, 3, 4, 5]); - * - * a.mul(b).print(); // or tf.mul(a, b) - * ``` - * - * ```js - * // Broadcast mul a with b. - * const a = tf.tensor1d([1, 2, 3, 4]); - * const b = tf.scalar(5); - * - * a.mul(b).print(); // or tf.mul(a, b) - * ``` - * @param a The first tensor to multiply. - * @param b The second tensor to multiply. Must have the same dtype as `a`. - */ - /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ - function mul_(a, b) { - var _a; - var $a = convertToTensor(a, 'a', 'mul'); - var $b = convertToTensor(b, 'b', 'mul'); - _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1]; - var outShape = assertAndGetBroadcastShape($a.shape, $b.shape); - var der = function (dy, saved) { - var $a = saved[0], $b = saved[1]; - var derA = function () { - var res = dy.mul($b.toFloat()); - var reduceAxes = getReductionAxes($a.shape, outShape); - if (reduceAxes.length > 0) { - return res.sum(reduceAxes).reshape($a.shape); - } - return res; - }; - var derB = function () { - var res = dy.mul($a.toFloat()); - var reduceAxes = getReductionAxes($b.shape, outShape); - if (reduceAxes.length > 0) { - return res.sum(reduceAxes).reshape($b.shape); - } - return res; - }; - return { a: derA, b: derB }; - }; - return ENGINE.runKernelFunc(function (backend, save) { - var res = backend.multiply($a, $b); - save([$a, $b]); - return res; - }, { a: $a, b: $b }, der, 'Mul'); - } - /** - * Multiplies two `tf.Tensor`s element-wise, A * B. - * - * Inputs must be the same shape. For broadcasting support, use `tf.mul`. - * - * @param a The first tensor to multiply. - * @param b The first tensor to multiply. Must have the same - * dtype as `a`. - */ - function mulStrict_(a, b) { - var $a = convertToTensor(a, 'a', 'mul'); - var $b = convertToTensor(b, 'b', 'mul'); - assertShapesMatch($a.shape, $b.shape, 'Error in multiplyStrict: '); - return $a.mul($b); - } - /** - * Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting. - * The result is rounded with floor function. - * - * - * ```js - * const a = tf.tensor1d([1, 4, 9, 16]); - * const b = tf.tensor1d([1, 2, 3, 4]); - * - * a.floorDiv(b).print(); // or tf.div(a, b) - * ``` - * - * ```js - * // Broadcast div a with b. - * const a = tf.tensor1d([2, 4, 6, 8]); - * const b = tf.scalar(2); - * - * a.floorDiv(b).print(); // or tf.floorDiv(a, b) - * ``` - * - * @param a The first tensor as the numerator. - * @param b The second tensor as the denominator. Must have the same dtype as - * `a`. - */ - /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ - function floorDiv_(a, b) { - var _a; - var $a = convertToTensor(a, 'a', 'floorDiv'); - var $b = convertToTensor(b, 'b', 'floorDiv'); - _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1]; - var outShape = assertAndGetBroadcastShape($a.shape, $b.shape); - var der = function (dy, saved) { - var $a = saved[0], $b = saved[1]; - var derA = function () { - var res = dy.div($b.toFloat()); - var reduceAxes = getReductionAxes($a.shape, outShape); - if (reduceAxes.length > 0) { - return res.sum(reduceAxes).reshape($a.shape); - } - return res; - }; - var derB = function () { - var res = dy.mul($a.toFloat()); - var reduceAxes = getReductionAxes($b.shape, outShape); - if (reduceAxes.length > 0) { - res = res.sum(reduceAxes).reshape($b.shape); - } - var tmp = $b.square(); - return res.div(tmp.toFloat()).neg(); - }; - return { a: derA, b: derB }; - }; - return ENGINE.runKernelFunc(function (backend, save) { - var res = backend.floorDiv($a, $b); - save([$a, $b]); - return res; - }, { a: $a, b: $b }, der, 'FloorDiv'); - } - /** - * Divides two `tf.Tensor`s element-wise, A / B. Inputs must - * be the same shape. - * - * @param a The first tensor as the numerator for element-wise division. - * @param b The second tensor as the denominator for element-wise division. - */ - function divStrict_(a, b) { - var $a = convertToTensor(a, 'a', 'div'); - var $b = convertToTensor(b, 'b', 'div'); - assertShapesMatch($a.shape, $b.shape, 'Error in divideStrict: '); - return $a.div($b); - } - /** - * Returns the mod of a and b element-wise. - * `floor(x / y) * y + mod(x, y) = x` - * Supports broadcasting. - * - * We also expose `tf.modStrict` which has the same signature as this op and - * asserts that `a` and `b` are the same shape (does not broadcast). - * - * ```js - * const a = tf.tensor1d([1, 4, 3, 16]); - * const b = tf.tensor1d([1, 2, 9, 4]); - * - * a.mod(b).print(); // or tf.mod(a, b) - * ``` - * - * ```js - * // Broadcast a mod b. - * const a = tf.tensor1d([2, 4, 6, 8]); - * const b = tf.scalar(5); - * - * a.mod(b).print(); // or tf.mod(a, b) - * ``` - * - * @param a The first tensor. - * @param b The second tensor. Must have the same type as `a`. - */ - /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ - function mod_(a, b) { - var _a; - var $a = convertToTensor(a, 'a', 'mod'); - var $b = convertToTensor(b, 'b', 'mod'); - _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1]; - var outShape = assertAndGetBroadcastShape($a.shape, $b.shape); - var der = function (dy, saved) { - var $a = saved[0], $b = saved[1]; - var derA = function () { - var reduceAxes = getReductionAxes($a.shape, outShape); - if (reduceAxes.length > 0) { - return dy.sum(reduceAxes).reshape($a.shape); - } - return dy; - }; - var derB = function () { - var res = dy.mul($a.div($b).floor().neg()); - var reduceAxes = getReductionAxes($b.shape, outShape); - if (reduceAxes.length > 0) { - return res.sum(reduceAxes).reshape($b.shape); - } - return res; - }; - return { $a: derA, $b: derB }; - }; - return ENGINE.runKernelFunc(function (backend, save) { - var res = backend.mod($a, $b); - save([$a, $b]); - return res; - }, { $a: $a, $b: $b }, der); - } - /** - * Returns the mod of a and b (`a < b ? a : b`) element-wise. Inputs must - * be the same shape. For broadcasting support, use mod(). - * - * @param a The first tensor. - * @param b The second tensor. Must have the same dtype as `a`. - */ - function modStrict_(a, b) { - var $a = convertToTensor(a, 'a', 'modStrict'); - var $b = convertToTensor(b, 'b', 'modStrict'); - assertShapesMatch($a.shape, $b.shape, 'Error in modStrict: '); - return $a.mod($b); - } - /** - * Returns the min of a and b (`a < b ? a : b`) element-wise. - * Supports broadcasting. - * - * We also expose `minimumStrict` which has the same signature as this op and - * asserts that `a` and `b` are the same shape (does not broadcast). - * - * ```js - * const a = tf.tensor1d([1, 4, 3, 16]); - * const b = tf.tensor1d([1, 2, 9, 4]); - * - * a.minimum(b).print(); // or tf.minimum(a, b) - * ``` - * - * ```js - * // Broadcast minimum a with b. - * const a = tf.tensor1d([2, 4, 6, 8]); - * const b = tf.scalar(5); - * - * a.minimum(b).print(); // or tf.minimum(a, b) - * ``` - * - * @param a The first tensor. - * @param b The second tensor. Must have the same type as `a`. - */ - /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ - function minimum_(a, b) { - var _a; - var $a = convertToTensor(a, 'a', 'minimum'); - var $b = convertToTensor(b, 'b', 'minimum'); - _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1]; - if ($a.dtype === 'bool') { - $a = $a.toInt(); - $b = $b.toInt(); - } - assertAndGetBroadcastShape($a.shape, $b.shape); - var der = function (dy, saved) { - var $a = saved[0], $b = saved[1]; - var derA = function () { return dy.mul($a.lessEqual($b).toFloat()); }; - var derB = function () { return dy.mul($a.greater($b).toFloat()); }; - return { a: derA, b: derB }; - }; - return ENGINE.runKernelFunc(function (backend, save) { - var res = backend.minimum($a, $b); - save([$a, $b]); - return res; - }, { a: $a, b: $b }, der, 'Minimum'); - } - /** - * Returns the min of a and b (`a < b ? a : b`) element-wise. Inputs must - * be the same shape. For broadcasting support, use minimum(). - * - * @param a The first tensor. - * @param b The second tensor. Must have the same dtype as `a`. - */ - function minimumStrict_(a, b) { - var $a = convertToTensor(a, 'a', 'minimumStrict'); - var $b = convertToTensor(b, 'b', 'minimumStrict'); - assertShapesMatch($a.shape, $b.shape, 'Error in minimumStrict: '); - return $a.minimum($b); - } - /** - * Returns the max of a and b (`a > b ? a : b`) element-wise. - * Supports broadcasting. - * - * We also expose `tf.maximumStrict` which has the same signature as this op and - * asserts that `a` and `b` are the same shape (does not broadcast). - * - * ```js - * const a = tf.tensor1d([1, 4, 3, 16]); - * const b = tf.tensor1d([1, 2, 9, 4]); - * - * a.maximum(b).print(); // or tf.maximum(a, b) - * ``` - * - * ```js - * // Broadcast maximum a with b. - * const a = tf.tensor1d([2, 4, 6, 8]); - * const b = tf.scalar(5); - * - * a.maximum(b).print(); // or tf.maximum(a, b) - * ``` - * - * @param a The first tensor. - * @param b The second tensor. Must have the same type as `a`. - */ - /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ - function maximum_(a, b) { - var _a; - var $a = convertToTensor(a, 'a', 'maximum'); - var $b = convertToTensor(b, 'b', 'maximum'); - _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1]; - if ($a.dtype === 'bool') { - $a = $a.toInt(); - $b = $b.toInt(); - } - assertAndGetBroadcastShape($a.shape, $b.shape); - var der = function (dy, saved) { - var $a = saved[0], $b = saved[1]; - var derA = function () { return dy.mul($a.greaterEqual($b).toFloat()); }; - var derB = function () { return dy.mul($a.less($b).toFloat()); }; - return { a: derA, b: derB }; - }; - return ENGINE.runKernelFunc(function (backend, save) { - var res = backend.maximum($a, $b); - save([$a, $b]); - return res; - }, { a: $a, b: $b }, der, 'Maximum'); - } - /** - * Returns the max of a and b (`a > b ? a : b`) element-wise. Inputs must - * be the same shape. For broadcasting support, use maximum(). - * - * @param a The first tensor. - * @param b The second tensor. Must have the same dtype as `a`. - */ - function maximumStrict_(a, b) { - var $a = convertToTensor(a, 'a', 'maximumStrict'); - var $b = convertToTensor(b, 'b', 'maximumStrict'); - assertShapesMatch($a.shape, $b.shape, 'Error in maximumStrict: '); - return $a.maximum($b); - } - /** - * Returns (a - b) * (a - b) element-wise. - * - * Inputs must be the same shape. For broadcasting support, use - * `tf.squaredDifference` instead. - * - * @param a The first tensor. - * @param b The second tensor. Must have the same type as `a`. - */ - function squaredDifferenceStrict_(a, b) { - var $a = convertToTensor(a, 'a', 'squaredDifferenceStrict'); - var $b = convertToTensor(b, 'b', 'squaredDifferenceStrict'); - assertShapesMatch($a.shape, $b.shape, 'Error in squaredDifferenceStrict: '); - return $a.squaredDifference($b); - } - /** - * Computes arctangent of `tf.Tensor`s a / b element-wise: `atan2(a, b)`. - * Supports broadcasting. - * - * ```js - * const a = tf.tensor1d([1.0, 1.0, -1.0, .7]); - * const b = tf.tensor1d([2.0, 13.0, 3.5, .21]); - * - * tf.atan2(a, b).print() - * ``` - * - * @param a The first tensor. - * @param b The second tensor. Must have the same dtype as `a`. - * - */ - /** @doc {heading: 'Operations', subheading: 'Basic math'} */ - function atan2_(a, b) { - var _a; - var $a = convertToTensor(a, 'a', 'atan2'); - var $b = convertToTensor(b, 'b', 'atan2'); - _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1]; - var outShape = assertAndGetBroadcastShape($a.shape, $b.shape); - var der = function (dy, saved) { - var $a = saved[0], $b = saved[1]; - var derA = function () { - var d = add($a.square(), $b.square()); - var res = dy.mul($b.div(d)); - var reduceAxes = getReductionAxes($a.shape, outShape); - if (reduceAxes.length > 0) { - res = res.sum(reduceAxes); - } - return res.reshape($a.shape); - }; - var derB = function () { - var d = add($a.square(), $b.square()); - var res = neg(dy.mul($a.div(d))); - var reduceAxes = getReductionAxes($b.shape, outShape); - if (reduceAxes.length > 0) { - res = res.sum(reduceAxes); - } - return res.reshape($b.shape); - }; - return { $a: derA, $b: derB }; - }; - return ENGINE.runKernelFunc(function (backend, save) { - var res = backend.atan2($a, $b); - save([$a, $b]); - return res; - }, { $a: $a, $b: $b }, der); - } - var addStrict = op({ addStrict_: addStrict_ }); - var atan2 = op({ atan2_: atan2_ }); - var divStrict = op({ divStrict_: divStrict_ }); - var floorDiv = op({ floorDiv_: floorDiv_ }); - var maximum = op({ maximum_: maximum_ }); - var maximumStrict = op({ maximumStrict_: maximumStrict_ }); - var minimum = op({ minimum_: minimum_ }); - var minimumStrict = op({ minimumStrict_: minimumStrict_ }); - var mod = op({ mod_: mod_ }); - var modStrict = op({ modStrict_: modStrict_ }); - var mul = op({ mul_: mul_ }); - var mulStrict = op({ mulStrict_: mulStrict_ }); - var pow = op({ pow_: pow_ }); - var powStrict = op({ powStrict_: powStrict_ }); - var squaredDifferenceStrict = op({ squaredDifferenceStrict_: squaredDifferenceStrict_ }); - var sub = op({ sub_: sub_ }); - var subStrict = op({ subStrict_: subStrict_ }); - - /** - * @license - * Copyright 2020 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /** - * Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting. - * - * We also expose `tf.divStrict` which has the same signature as this op and - * asserts that `a` and `b` are the same shape (does not broadcast). - * - * ```js - * const a = tf.tensor1d([1, 4, 9, 16]); - * const b = tf.tensor1d([1, 2, 3, 4]); - * - * a.div(b).print(); // or tf.div(a, b) - * ``` - * - * ```js - * // Broadcast div a with b. - * const a = tf.tensor1d([2, 4, 6, 8]); - * const b = tf.scalar(2); - * - * a.div(b).print(); // or tf.div(a, b) - * ``` - * - * @param a The first tensor as the numerator. - * @param b The second tensor as the denominator. Must have the same dtype as - * `a`. - */ - /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ - function div_(a, b) { - var _a; - var $a = convertToTensor(a, 'a', 'div'); - var $b = convertToTensor(b, 'b', 'div'); - _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1]; - if ($a.dtype === 'int32' && $b.dtype === 'int32') { - return floorDiv($a, $b); - } - var forward = function (backend, save) { - var res = backend.realDivide($a, $b); - save([$a, $b]); - return res; - }; - var inputs = { a: $a, b: $b }; - var attrs = {}; - return ENGINE.runKernelFunc(forward, inputs, null /* gradient */, Div, attrs); - } - var div = op({ div_: div_ }); - - /** - * Validate gather nd inputs. - * - * @param tensor The tensor contains the source values. - * @param indices The tensor contains the indices to slice the source. - * - * @returns [resultShape, numUpdates, sliceSize, strides] - */ - function prepareAndValidate(tensor, indices) { - if (tensor.rank < 1) { - throw new Error('tf.gatherND() expects the input to be rank 1 or higher,' + - (" but the rank was " + tensor.rank + ".")); - } - if (indices.rank < 1) { - throw new Error('tf.gatherND() expects the indices to be rank 1 or higher,' + - (" but the rank was " + indices.rank + ".")); - } - if (indices.dtype !== 'int32') { - throw new Error('tf.gatherND() expects the indices to be int32 type,' + - (" but the dtype was " + indices.dtype + ".")); - } - if (indices.shape[indices.rank - 1] > tensor.rank) { - throw new Error('index innermost dimension length must be <= tensor rank; saw: ' + - (indices.shape[indices.rank - 1] + " vs. " + tensor.rank)); - } - if (tensor.size === 0) { - throw new Error('Requested more than 0 entries, but input is empty.' + - (" Input shape: " + tensor.shape + ".")); - } - var indicesShape = indices.shape; - var sliceRank = indicesShape[indicesShape.length - 1]; - // The result shape is - // indices.shape[:-1] + params.shape[indices.shape[-1]:] - var nResult = 1; - for (var i = 0; i < indicesShape.length - 1; ++i) { - nResult *= indicesShape[i]; - } - var inputShape = tensor.shape; - var resultShape = indicesShape.slice(); - resultShape.pop(); - var sliceSize = 1; - for (var i = sliceRank; i < tensor.rank; ++i) { - sliceSize *= inputShape[i]; - resultShape.push(inputShape[i]); - } - var strides = computeStrides(tensor.shape).map(function (stride) { return stride / sliceSize; }).concat([1]).slice(0, sliceRank); - return [resultShape, nResult, sliceSize, strides]; - } - - var gather_nd_util = /*#__PURE__*/Object.freeze({ - prepareAndValidate: prepareAndValidate - }); - - /** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var PARALLELIZE_THRESHOLD = 30; - function computeOptimalWindowSize(inSize) { - if (inSize <= PARALLELIZE_THRESHOLD) { - return inSize; - } - return nearestDivisor(inSize, Math.floor(Math.sqrt(inSize))); - } - - /** - * Check whether updates.shape = indices.shape[:batchDim] + - * shape[sliceDim:] - * - * @param x The input tensor. - */ - function validateUpdateShape(shape, indices, updates) { - var sliceDim = (indices.rank > 1) ? indices.shape[indices.rank - 1] : 1; - var batchDim = (indices.rank > 1) ? indices.rank - 1 : 1; - var shapeError = 'Must have updates.shape = indices.shape[:batchDim] + ' + - ("shape[sliceDim:], got updates.shape: " + updates.shape) + - (", indices.shape: " + indices.shape + ", shape: " + shape) + - (", sliceDim: " + sliceDim + ", and batchDim: " + batchDim + "."); - if (updates.rank < batchDim) { - throw new Error(shapeError + (" update.rank < " + batchDim + ". ")); - } - if (shape.length < sliceDim + (updates.rank - batchDim)) { - throw new Error(shapeError + - (" Output shape length < " + (sliceDim + (updates.rank - batchDim)))); - } - if (updates.rank !== batchDim + shape.length - sliceDim) { - throw new Error(shapeError + (" update.rank != " + (batchDim + shape.length - sliceDim))); - } - for (var d = 0; d < batchDim; ++d) { - if (updates.shape[d] !== indices.shape[d]) { - throw new Error(shapeError + - (" updates.shape[" + d + "] (" + updates.shape[d] + ") != indices.shape[" + d + "] (" + indices.shape[d] + ").")); - } - } - for (var d = 0; d < updates.rank - batchDim; ++d) { - if (updates.shape[d + batchDim] !== shape[d + sliceDim]) { - throw new Error(shapeError + - (" updates.shape[" + (d + batchDim) + "] (" + updates.shape[d + batchDim] + ") != shape[" + (d + batchDim) + "] (" + shape[d + batchDim] + ")")); - } - } - } - /** - * Validate scatter nd inputs. - * - * @param update The tensor contains the update values. - * @param indices The tensor contains the indices for the update values. - * @param shape The shape of the output tensor. - */ - function validateInput(updates, indices, shape) { - if (indices.rank < 1) { - throw new Error('tf.scatterND() expects the indices to be rank 1 or higher,' + - (" but the rank was " + indices.rank + ".")); - } - if (updates.rank < 1) { - throw new Error('tf.scatterND() expects the updates to be rank 1 or higher,' + - (" but the rank was " + updates.rank + ".")); - } - if (indices.dtype !== 'int32') { - throw new Error("The dtype of 'indices' should be int32, but got dtype: " + indices.dtype); - } - if (shape.length < 1) { - throw new Error("Output rank must be greater or equal to 1, but got shape: " + shape); - } - if (shape.length === 0) { - if (indices.size === 0) { - throw new Error("Indices specified for empty output. indices shape: " + indices.shape); - } - if (updates.size === 0) { - throw new Error("Updates specified for empty output. updates shape: " + updates.shape); - } - } - validateUpdateShape(shape, indices, updates); - } - /** - * Calculate the shape information for the output. - * - * @param update The tensor contains the update values. - * @param indices The tensor contains the indices for the update values. - * @param shape The shape of the output tensor. - * - * @returns ScatterShapeInfo - */ - function calculateShapes(updates, indices, shape) { - // Calculate the number of dimensions in indices - var indicesRank = indices.shape.length; - var sliceRank = (indicesRank > 1) ? indices.shape[indicesRank - 1] : 1; - // Calculate the number of elements that make up each slice of our updated - // tensor. This allows us to work with flattened tensors and copy over whole - // slices at a time. - var totalNd = shape.length; - var sliceSize = 1; - for (var i = sliceRank; i < totalNd; ++i) { - sliceSize *= shape[i]; - } - var safeSliceDim = (sliceRank < 1) ? 1 : sliceRank; - var numUpdates = sizeFromShape(indices.shape) / safeSliceDim; - var strides = computeStrides(shape.slice(0, sliceRank)).concat([1]); - var outputSize = sizeFromShape(shape); - return { sliceRank: sliceRank, numUpdates: numUpdates, sliceSize: sliceSize, strides: strides, outputSize: outputSize }; - } - - var scatter_nd_util = /*#__PURE__*/Object.freeze({ - validateUpdateShape: validateUpdateShape, - validateInput: validateInput, - calculateShapes: calculateShapes - }); - - /** - * @license - * Copyright 2018 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - function segOpComputeOptimalWindowSize(inSize, numSegments) { - var done = false; - var res; - if (inSize <= PARALLELIZE_THRESHOLD) { - res = inSize; - done = true; - } - else { - res = nearestDivisor(inSize, Math.floor(Math.sqrt(inSize))); - } - while (!done) { - if (res > numSegments || res === inSize) { - done = true; - } - else { - res = nearestDivisor(inSize, res + 1); - } - } - return res; - } - function computeOutShape$1(aShape, axis, numSegments) { - var outShape = []; - var rank = aShape.length; - for (var dim = 0; dim < rank; dim++) { - if (dim !== axis) { - outShape.push(aShape[dim]); - } - else { - outShape.push(numSegments); - } - } - return outShape; - } - function collectGatherOpShapeInfo(x, indices, axis) { - var dimSize = x.shape[axis]; - var outputShape = []; - var batchSize = 1; - var sliceSize = 1; - for (var i = 0; i < axis; i++) { - outputShape.push(x.shape[i]); - batchSize *= x.shape[i]; - } - for (var i = 0; i < indices.rank; i++) { - outputShape.push(indices.shape[i]); - } - for (var i = axis + 1; i < x.rank; i++) { - outputShape.push(x.shape[i]); - sliceSize *= x.shape[i]; - } - return { batchSize: batchSize, sliceSize: sliceSize, dimSize: dimSize, outputShape: outputShape }; - } - - /** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - function assertParamsValid(input, begin, size) { - assert(input.rank === begin.length, function () { return "Error in slice" + input.rank + "D: Length of begin " + begin + " must " + - ("match the rank of the array (" + input.rank + ")."); }); - assert(input.rank === size.length, function () { return "Error in slice" + input.rank + "D: Length of size " + size + " must " + - ("match the rank of the array (" + input.rank + ")."); }); - var _loop_1 = function (i) { - assert(begin[i] + size[i] <= input.shape[i], function () { return "Error in slice" + input.rank + "D: begin[" + i + "] + size[" + i + "] " + - ("(" + (begin[i] + size[i]) + ") would overflow input.shape[" + i + "] (" + input.shape[i] + ")"); }); - }; - for (var i = 0; i < input.rank; ++i) { - _loop_1(i); - } - } - /** Converts a binary mask to an array of axes. Used in stridedSlice(). */ - function maskToAxes(mask) { - var axes = []; - var axis = 0; - while (mask > 0) { - if (mask & 1) { - axes.push(axis); - } - mask /= 2; - axis++; - } - return axes; - } - /** Computes the output shape given the strided slice params. */ - function computeOutShape$2(begin, end, strides) { - var size = []; - for (var axis = 0; axis < begin.length; axis++) { - size[axis] = Math.ceil((end[axis] - begin[axis]) / strides[axis]); - } - return size; - } - function startForAxis(beginMask, startIndices, strides, inputShape, axis) { - // Begin with the specified index - var start = startIndices[axis]; - var stride = strides[axis] || 1; - // Check the axis bit from right of beginMask or the begin index is not set - // for the axis. - if (beginMask & 1 << axis || start == null) { - if (stride > 0) { - // Forward iteration - use the first element. These values will get - // clamped below (Note: We could have set them to 0 and axis_size-1, but - // use lowest() and max() to maintain symmetry with StopForAxis()) - start = Number.MIN_SAFE_INTEGER; - } - else { - // Backward iteration - use the last element. - start = Number.MAX_SAFE_INTEGER; - } - } - // Handle negative indices - var axisSize = inputShape[axis]; - if (start < 0) { - start += axisSize; - } - // Clamping - start = clamp(0, start, axisSize - 1); - return start; - } - function stopForAxis(endMask, stopIndices, strides, inputShape, axis) { - // Begin with the specified index - var stop = stopIndices[axis]; - var stride = strides[axis] || 1; - // Check the axis bit from right of endMask or if the stop index is not set - // for this axis. - if (endMask & (1 << axis) || stop == null) { - if (stride > 0) { - // Forward iteration - use the last element. These values will get - // clamped below - stop = Number.MAX_SAFE_INTEGER; - } - else { - // Backward iteration - use the first element. - stop = Number.MIN_SAFE_INTEGER; - } - } - // Handle negative indices - var axisSize = inputShape[axis]; - if (stop < 0) { - stop += axisSize; - } - // Clamping - // Because the end index points one past the last element, we need slightly - // different clamping ranges depending on the direction. - if (stride > 0) { - // Forward iteration - stop = clamp(0, stop, axisSize); - } - else { - // Backward iteration - stop = clamp(-1, stop, axisSize - 1); - } - return stop; - } - /** - * Returns true if the slice occupies a continous set of elements in the - * 'flat' space. - */ - function isSliceContinous(shape, begin, size) { - // Index of the first axis that has size > 1. - var firstNonOneAxis = size.length; - for (var i = 0; i < size.length; i++) { - if (size[i] > 1) { - firstNonOneAxis = i; - break; - } - } - for (var i = firstNonOneAxis + 1; i < size.length; i++) { - if (begin[i] > 0 || size[i] !== shape[i]) { - return false; - } - } - return true; - } - function computeFlatOffset(begin, strides) { - var flatOffset = begin.length > 0 ? begin[begin.length - 1] : 1; - for (var i = 0; i < begin.length - 1; i++) { - flatOffset += begin[i] * strides[i]; - } - return flatOffset; - } - - var slice_util = /*#__PURE__*/Object.freeze({ - assertParamsValid: assertParamsValid, - maskToAxes: maskToAxes, - computeOutShape: computeOutShape$2, - startForAxis: startForAxis, - stopForAxis: stopForAxis, - isSliceContinous: isSliceContinous, - computeFlatOffset: computeFlatOffset - }); - - /** - * @license - * Copyright 2018 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /** - * Provided `f(x)`, returns another function `g(x, dy?)`, which gives the - * gradient of `f(x)` with respect to `x`. - * - * If `dy` is provided, the gradient of `f(x).mul(dy).sum()` with respect to - * `x` is computed instead. `f(x)` must take a single tensor `x` and return a - * single tensor `y`. If `f()` takes multiple inputs, use `tf.grads` instead. - * - * ```js - * // f(x) = x ^ 2 - * const f = x => x.square(); - * // f'(x) = 2x - * const g = tf.grad(f); - * - * const x = tf.tensor1d([2, 3]); - * g(x).print(); - * ``` - * - * ```js - * // f(x) = x ^ 3 - * const f = x => x.pow(tf.scalar(3, 'int32')); - * // f'(x) = 3x ^ 2 - * const g = tf.grad(f); - * // f''(x) = 6x - * const gg = tf.grad(g); - * - * const x = tf.tensor1d([2, 3]); - * gg(x).print(); - * ``` - * - * @param f The function f(x), to compute gradient for. - */ - /** @doc {heading: 'Training', subheading: 'Gradients'} */ - function grad(f) { - assert(isFunction(f), function () { return 'The f passed in grad(f) must be a function'; }); - return function (x, dy) { - // x can be of any dtype, thus null as the last argument. - var $x = convertToTensor(x, 'x', 'tf.grad', null); - var $dy = (dy != null) ? convertToTensor(dy, 'dy', 'tf.grad') : null; - return ENGINE.tidy(function () { - var _a = ENGINE.gradients(function () { return f($x); }, [$x], $dy), value = _a.value, grads = _a.grads; - if ($dy != null) { - assertShapesMatch(value.shape, $dy.shape, 'The shape of dy passed in grad(f)(x, dy) must match the shape ' + - 'returned by f(x)'); - } - checkGrads(grads); - return grads[0]; - }); - }; - } - /** - * Provided `f(x1, x2,...)`, returns another function `g([x1, x2,...], dy?)`, - * which gives an array of gradients of `f()` with respect to each input - * [`x1`,`x2`,...]. - * - * If `dy` is passed when calling `g()`, the gradient of - * `f(x1,...).mul(dy).sum()` with respect to each input is computed instead. - * The provided `f` must take one or more tensors and return a single tensor - * `y`. If `f()` takes a single input, we recommend using `tf.grad` instead. - * - * ```js - * // f(a, b) = a * b - * const f = (a, b) => a.mul(b); - * // df / da = b, df / db = a - * const g = tf.grads(f); - * - * const a = tf.tensor1d([2, 3]); - * const b = tf.tensor1d([-2, -3]); - * const [da, db] = g([a, b]); - * console.log('da'); - * da.print(); - * console.log('db'); - * db.print(); - * ``` - * - * @param f The function `f(x1, x2,...)` to compute gradients for. - */ - /** @doc {heading: 'Training', subheading: 'Gradients'} */ - function grads(f) { - assert(isFunction(f), function () { return 'The f passed in grads(f) must be a function'; }); - return function (args, dy) { - assert(Array.isArray(args), function () { return 'The args passed in grads(f)(args) must be an array ' + - 'of `Tensor`s or `TensorLike`s'; }); - // args can be of any dtype, thus null as the last argument. - var $args = convertToTensorArray(args, 'args', 'tf.grads', null); - var $dy = (dy != null) ? convertToTensor(dy, 'dy', 'tf.grads') : null; - return ENGINE.tidy(function () { - var _a = ENGINE.gradients(function () { return f.apply(void 0, $args); }, $args, $dy), value = _a.value, grads = _a.grads; - if ($dy != null) { - assertShapesMatch(value.shape, $dy.shape, 'The shape of dy passed in grads(f)([x1,...], dy) must ' + - 'match the shape returned by f([x1,...])'); - } - checkGrads(grads); - return grads; - }); - }; - } - /** - * Like `tf.grad`, but also returns the value of `f()`. Useful when `f()` - * returns a metric you want to show. - * - * The result is a rich object with the following properties: - * - grad: The gradient of `f(x)` w.r.t `x` (result of `tf.grad`). - * - value: The value returned by `f(x)`. - * - * ```js - * // f(x) = x ^ 2 - * const f = x => x.square(); - * // f'(x) = 2x - * const g = tf.valueAndGrad(f); - * - * const x = tf.tensor1d([2, 3]); - * const {value, grad} = g(x); - * - * console.log('value'); - * value.print(); - * console.log('grad'); - * grad.print(); - * ``` - */ - /** @doc {heading: 'Training', subheading: 'Gradients'} */ - function valueAndGrad(f) { - assert(isFunction(f), function () { return 'The f passed in valueAndGrad(f) must be a function'; }); - return function (x, dy) { - assert(x instanceof Tensor, function () { return 'The x passed in valueAndGrad(f)(x) must be a tensor'; }); - assert(dy == null || dy instanceof Tensor, function () { return 'The dy passed in valueAndGrad(f)(x, dy) must be a tensor'; }); - var _a = ENGINE.gradients(function () { return f(x); }, [x], dy), grads = _a.grads, value = _a.value; - checkGrads(grads); - return { grad: grads[0], value: value }; - }; - } - /** - * Like `tf.grads`, but returns also the value of `f()`. Useful when `f()` - * returns a metric you want to show. - * - * The result is a rich object with the following properties: - * - grads: The gradients of `f()` w.r.t each input (result of `tf.grads`). - * - value: The value returned by `f(x)`. - * - * ```js - * // f(a, b) = a * b - * const f = (a, b) => a.mul(b); - * // df/da = b, df/db = a - * const g = tf.valueAndGrads(f); - * - * const a = tf.tensor1d([2, 3]); - * const b = tf.tensor1d([-2, -3]); - * const {value, grads} = g([a, b]); - * - * const [da, db] = grads; - * - * console.log('value'); - * value.print(); - * - * console.log('da'); - * da.print(); - * console.log('db'); - * db.print(); - * ``` - */ - /** @doc {heading: 'Training', subheading: 'Gradients'} */ - function valueAndGrads(f) { - assert(isFunction(f), function () { return 'The f passed in valueAndGrads(f) must be a function'; }); - return function (args, dy) { - assert(Array.isArray(args) && args.every(function (arg) { return arg instanceof Tensor; }), function () { return 'The args passed in valueAndGrads(f)(args) must be array of ' + - 'tensors'; }); - assert(dy == null || dy instanceof Tensor, function () { return 'The dy passed in valueAndGrads(f)(args, dy) must be a tensor'; }); - var res = ENGINE.gradients(function () { return f.apply(void 0, args); }, args, dy); - if (dy != null) { - assertShapesMatch(res.value.shape, dy.shape, 'The shape of dy passed in valueAndGrads(f)([x1,...], dy) must ' + - 'match the shape returned by f([x1,...])'); - } - checkGrads(res.grads); - return res; - }; - } - /** - * Computes and returns the gradient of f(x) with respect to the list of - * trainable variables provided by `varList`. If no list is provided, it - * defaults to all trainable variables. - * - * ```js - * const a = tf.variable(tf.tensor1d([3, 4])); - * const b = tf.variable(tf.tensor1d([5, 6])); - * const x = tf.tensor1d([1, 2]); - * - * // f(a, b) = a * x ^ 2 + b * x - * const f = () => a.mul(x.square()).add(b.mul(x)).sum(); - * // df/da = x ^ 2, df/db = x - * const {value, grads} = tf.variableGrads(f); - * - * Object.keys(grads).forEach(varName => grads[varName].print()); - * ``` - * - * @param f The function to execute. f() should return a scalar. - * @param varList The list of variables to compute the gradients with respect - * to. Defaults to all trainable variables. - * @returns An object with the following keys and values: - * - `value`: The value of the function `f`. - * - `grads`: A map from the names of the variables to the gradients. - * If the `varList` argument is provided explicitly and contains a subset of - * non-trainable variables, this map in the return value will contain keys - * that map the names of the non-trainable variables to `null`. - */ - /** @doc {heading: 'Training', subheading: 'Gradients'} */ - function variableGrads(f, varList) { - assert(isFunction(f), function () { return 'The f passed in variableGrads(f) must be a function'; }); - assert(varList == null || - Array.isArray(varList) && varList.every(function (v) { return v instanceof Variable; }), function () { - return 'The varList passed in variableGrads(f, varList) must be an array ' + - 'of variables'; - }); - var specifiedVarList = varList != null; - if (!specifiedVarList) { - // Get all of the trainable variables. - varList = []; - for (var varName in ENGINE.registeredVariables) { - varList.push(ENGINE.registeredVariables[varName]); - } - } - var specifiedNonTrainable = specifiedVarList ? varList.filter(function (variable) { return !variable.trainable; }) : null; - // Prune non-trainable variables. - var originalVarCount = varList.length; - varList = varList.filter(function (variable) { return variable.trainable; }); - assert(varList.length > 0, function () { return "variableGrads() expects at least one of the input variables to " + - ("be trainable, but none of the " + originalVarCount + " variables is ") + - "trainable."; }); - var allowNoGradients = true; - var _a = ENGINE.gradients(f, varList, null, allowNoGradients), value = _a.value, grads = _a.grads; - assert(grads.some(function (g) { return g != null; }), function () { return 'Cannot find a connection between any variable and the result of ' + - 'the loss function y=f(x). Please make sure the operations that ' + - 'use variables are inside the function f passed to minimize().'; }); - assert(value.rank === 0, function () { return "The f passed in variableGrads(f) must return a scalar, but it " + - ("returned a rank-" + value.rank + " tensor"); }); - var namedGrads = {}; - varList.forEach(function (v, i) { - if (grads[i] != null) { - namedGrads[v.name] = grads[i]; - } - }); - if (specifiedNonTrainable != null) { - // If varList is explicitly provided and contains non-trainable values, - // add them to the returned gradients with `null` values. - specifiedNonTrainable.forEach(function (v) { return namedGrads[v.name] = null; }); - } - return { value: value, grads: namedGrads }; - } - /** - * Overrides the gradient computation of a function `f`. - * - * Takes a function - * `f(...inputs, save) => {value: Tensor, gradFunc: (dy, saved) => Tensor[]}` - * and returns another function `g(...inputs)` which takes the same inputs as - * `f`. When called, `g` returns `f().value`. In backward mode, custom gradients - * with respect to each input of `f` are computed using `f().gradFunc`. - * - * The `save` function passsed to `f` should be used for saving tensors needed - * in the gradient. And the `saved` passed to the `gradFunc` is a - * `NamedTensorMap`, which contains those saved tensor. - * - * ```js - * const customOp = tf.customGrad((x, save) => { - * // Save x to make sure it's available later for the gradient. - * save([x]); - * // Override gradient of our custom x ^ 2 op to be dy * abs(x); - * return { - * value: x.square(), - * // Note `saved.x` which points to the `x` we saved earlier. - * gradFunc: (dy, saved) => [dy.mul(saved[0].abs())] - * }; - * }); - * - * const x = tf.tensor1d([-1, -2, 3]); - * const dx = tf.grad(x => customOp(x)); - * - * console.log(`f(x):`); - * customOp(x).print(); - * console.log(`f'(x):`); - * dx(x).print(); - * ``` - * - * @param f The function to evaluate in forward mode, which should return - * `{value: Tensor, gradFunc: (dy, saved) => Tensor[]}`, where `gradFunc` - * returns the custom gradients of `f` with respect to its inputs. - */ - /** @doc {heading: 'Training', subheading: 'Gradients'} */ - function customGrad(f) { - return ENGINE.customGrad(f); - } - function checkGrads(grads) { - var numNullGradients = grads.filter(function (g) { return g == null; }).length; - if (numNullGradients > 0) { - throw new Error("Cannot compute gradient of y=f(x) with respect to x. Make sure that\n the f you passed encloses all operations that lead from x to y."); - } - } - - /** - * @license - * Copyright 2018 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /** - * Computes the softmax normalized vector given the logits. - * - * ```js - * const a = tf.tensor1d([1, 2, 3]); - * - * a.softmax().print(); // or tf.softmax(a) - * ``` - * - * ```js - * const a = tf.tensor2d([2, 4, 6, 1, 2, 3], [2, 3]); - * - * a.softmax().print(); // or tf.softmax(a) - * ``` - * - * @param logits The logits array. - * @param dim The dimension softmax would be performed on. Defaults to `-1` - * which indicates the last dimension. - */ - /** @doc {heading: 'Operations', subheading: 'Normalization'} */ - function softmax_(logits, dim) { - if (dim === void 0) { dim = -1; } - var $logits = convertToTensor(logits, 'logits', 'softmax', 'float32'); - if (dim === -1) { - dim = $logits.rank - 1; - } - if (dim !== $logits.rank - 1) { - throw Error('Softmax along a non-last dimension is not yet supported. ' + - ("Logits was rank " + $logits.rank + " and dim was " + dim)); - } - var inputsToSave = []; - var outputsToSave = [true]; - return ENGINE.runKernelFunc(function (backend, save) { - var y = backend.softmax($logits, dim); - save([y]); - return y; - }, { logits: $logits }, function (dy, saved) { - var y = saved[0]; - var dyTimesY = dy.mul(y); - var keepDims = true; - return { - logits: function () { return dyTimesY.sub(dyTimesY.sum([dim], keepDims).mul(y)); } - }; - }, 'Softmax', { dim: dim }, inputsToSave, outputsToSave); - } - /** - * Computes the log softmax. - * - * ```js - * const a = tf.tensor1d([1, 2, 3]); - * - * a.logSoftmax().print(); // or tf.logSoftmax(a) - * ``` - * - * ```js - * const a = tf.tensor2d([2, 4, 6, 1, 2, 3], [2, 3]); - * - * a.logSoftmax().print(); // or tf.logSoftmax(a) - * ``` - * - * @param logits The logits array. - * @param axis The dimension softmax would be performed on. Defaults to `-1` - * which indicates the last dimension. - */ - /** @doc {heading: 'Operations', subheading: 'Normalization'} */ - function logSoftmax_(logits, axis) { - if (axis === void 0) { axis = -1; } - var $logits = convertToTensor(logits, 'logits', 'logSoftmax'); - if (axis === -1) { - axis = $logits.rank - 1; - } - if (axis !== $logits.rank - 1) { - throw Error('Log Softmax along a non-last dimension is not yet supported. ' + - ("Logits was rank " + $logits.rank + " and axis was " + axis)); - } - var customOp = customGrad(function (logits, save) { - var keepDims = true; - var xMax = logits.max(axis, true); - var shifted = logits.sub(xMax); - var value = shifted.toFloat().sub(shifted.exp().sum(axis, keepDims).log()); - save([value]); - var gradFunc = function (dy, saved) { - var value = saved[0]; - var softmax = value.exp(); - return dy.sub(dy.sum(axis, keepDims).mul(softmax)); - }; - return { value: value, gradFunc: gradFunc }; - }); - return customOp($logits); - } - var softmax = op({ softmax_: softmax_ }); - var logSoftmax = op({ logSoftmax_: logSoftmax_ }); - - /** - * @license - * Copyright 2018 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /** - * Transposes the `tf.Tensor`. Permutes the dimensions according to `perm`. - * - * The returned `tf.Tensor`'s dimension `i` will correspond to the input - * dimension `perm[i]`. If `perm` is not given, it is set to `[n-1...0]`, - * where `n` is the rank of the input `tf.Tensor`. Hence by default, this - * operation performs a regular matrix transpose on 2-D input `tf.Tensor`s. - * - * ```js - * const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); - * - * a.transpose().print(); // or tf.transpose(a) - * ``` - * - * @param x The tensor to transpose. - * @param perm The permutation of the dimensions of a. - */ - /** @doc {heading: 'Operations', subheading: 'Matrices'} */ - function transpose_(x, perm) { - var $x = convertToTensor(x, 'x', 'transpose'); - if (perm == null) { - perm = $x.shape.map(function (s, i) { return i; }).reverse(); - } - assert($x.rank === perm.length, function () { return "Error in transpose: rank of input " + $x.rank + " " + - ("must match length of perm " + perm + "."); }); - perm.forEach(function (axis) { - assert(axis >= 0 && axis < $x.rank, function () { return "All entries in 'perm' must be between 0 and " + ($x.rank - 1) + - (" but got " + perm); }); - }); - if ($x.rank <= 1) { - return $x.clone(); - } - var attrs = { perm: perm }; - return ENGINE.runKernelFunc(function (backend) { return backend.transpose($x, perm); }, { x: $x }, null /* gradient */, 'Transpose', attrs); - } - var transpose = op({ transpose_: transpose_ }); - - /** - * @license - * Copyright 2018 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var EPSILON_FLOAT32 = 1e-7; - var EPSILON_FLOAT16 = 1e-4; - /** Convenient class for storing tensor-related data. */ - var DataStorage = /** @class */ (function () { - function DataStorage(backend, dataMover) { - this.backend = backend; - this.dataMover = dataMover; - this.data = new WeakMap(); - this.dataIdsCount = 0; - } - DataStorage.prototype.get = function (dataId) { - if (!this.data.has(dataId)) { - this.dataMover.moveData(this.backend, dataId); - } - return this.data.get(dataId); - }; - DataStorage.prototype.set = function (dataId, value) { - this.dataIdsCount++; - this.data.set(dataId, value); - }; - DataStorage.prototype.has = function (dataId) { - return this.data.has(dataId); - }; - DataStorage.prototype.delete = function (dataId) { - this.dataIdsCount--; - return this.data.delete(dataId); - }; - DataStorage.prototype.numDataIds = function () { - return this.dataIdsCount; - }; - return DataStorage; - }()); - /** - * The interface that defines the kernels that should be implemented when - * adding a new backend. New backends don't need to implement every one of the - * methods, this can be done gradually (throw an error for unimplemented - * methods). - */ - var KernelBackend = /** @class */ (function () { - function KernelBackend() { - } - KernelBackend.prototype.time = function (f) { - return notYetImplemented('time'); - }; - KernelBackend.prototype.read = function (dataId) { - return notYetImplemented('read'); - }; - KernelBackend.prototype.readSync = function (dataId) { - return notYetImplemented('readSync'); - }; - KernelBackend.prototype.numDataIds = function () { - return notYetImplemented('numDataIds'); - }; - KernelBackend.prototype.disposeData = function (dataId) { - return notYetImplemented('disposeData'); - }; - KernelBackend.prototype.write = function (values, shape, dtype) { - return notYetImplemented('write'); - }; - KernelBackend.prototype.move = function (dataId, values, shape, dtype) { - return notYetImplemented('move'); - }; - KernelBackend.prototype.memory = function () { - return notYetImplemented('memory'); - }; - /** Returns the highest precision for floats in bits (e.g. 16 or 32) */ - KernelBackend.prototype.floatPrecision = function () { - return notYetImplemented('floatPrecision'); - }; - /** Returns the smallest representable number. */ - KernelBackend.prototype.epsilon = function () { - return this.floatPrecision() === 32 ? EPSILON_FLOAT32 : EPSILON_FLOAT16; - }; - KernelBackend.prototype.batchMatMul = function (a, b, transposeA, transposeB) { - return notYetImplemented('batchMatMul'); - }; - KernelBackend.prototype.fusedBatchMatMul = function (_a) { - var a = _a.a, b = _a.b, transposeA = _a.transposeA, transposeB = _a.transposeB, bias = _a.bias, activation = _a.activation, preluActivationWeights = _a.preluActivationWeights; - return notYetImplemented('fusedBatchMatMul'); - }; - KernelBackend.prototype.slice = function (x, begin, size) { - return notYetImplemented('slice'); - }; - KernelBackend.prototype.stridedSlice = function (x, begin, end, strides) { - return notYetImplemented('stridedSlice'); - }; - KernelBackend.prototype.unstack = function (x, axis) { - return notYetImplemented('unstack'); - }; - KernelBackend.prototype.reverse = function (a, axis) { - return notYetImplemented('reverse'); - }; - KernelBackend.prototype.concat = function (tensors, axis) { - return notYetImplemented('concat'); - }; - KernelBackend.prototype.neg = function (a) { - return notYetImplemented('neg'); - }; - KernelBackend.prototype.add = function (a, b) { - return notYetImplemented('add'); - }; - KernelBackend.prototype.addN = function (tensors) { - return notYetImplemented('addN'); - }; - KernelBackend.prototype.subtract = function (a, b) { - return notYetImplemented('subtract'); - }; - KernelBackend.prototype.multiply = function (a, b) { - return notYetImplemented('multiply'); - }; - KernelBackend.prototype.realDivide = function (a, b) { - return notYetImplemented('realDivide'); - }; - KernelBackend.prototype.floorDiv = function (a, b) { - return notYetImplemented('floorDiv'); - }; - KernelBackend.prototype.sum = function (x, axes) { - return notYetImplemented('sum'); - }; - KernelBackend.prototype.prod = function (x, axes) { - return notYetImplemented('prod'); - }; - KernelBackend.prototype.unsortedSegmentSum = function (x, segmentIds, numSegments) { - return notYetImplemented('unsortedSegmentSum'); - }; - KernelBackend.prototype.argMin = function (x, axis) { - return notYetImplemented('argMin'); - }; - KernelBackend.prototype.argMax = function (x, axis) { - return notYetImplemented('argMax'); - }; - KernelBackend.prototype.equal = function (a, b) { - return notYetImplemented('equal'); - }; - KernelBackend.prototype.notEqual = function (a, b) { - return notYetImplemented('notEqual'); - }; - KernelBackend.prototype.less = function (a, b) { - return notYetImplemented('less'); - }; - KernelBackend.prototype.lessEqual = function (a, b) { - return notYetImplemented('lessEqual'); - }; - KernelBackend.prototype.greater = function (a, b) { - return notYetImplemented('greater'); - }; - KernelBackend.prototype.greaterEqual = function (a, b) { - return notYetImplemented('greaterEqual'); - }; - KernelBackend.prototype.logicalNot = function (a) { - return notYetImplemented('logicalNot'); - }; - KernelBackend.prototype.logicalAnd = function (a, b) { - return notYetImplemented('logicalAnd'); - }; - KernelBackend.prototype.logicalOr = function (a, b) { - return notYetImplemented('logicalOr'); - }; - KernelBackend.prototype.where = function (condition) { - return notYetImplemented('where'); - }; - KernelBackend.prototype.select = function (condition, a, b) { - return notYetImplemented('select'); - }; - KernelBackend.prototype.topk = function (x, k, sorted) { - return notYetImplemented('topk'); - }; - KernelBackend.prototype.min = function (x, axes) { - return notYetImplemented('min'); - }; - KernelBackend.prototype.minimum = function (a, b) { - return notYetImplemented('minimum'); - }; - KernelBackend.prototype.mod = function (a, b) { - return notYetImplemented('mod'); - }; - KernelBackend.prototype.max = function (x, axes) { - return notYetImplemented('max'); - }; - KernelBackend.prototype.maximum = function (a, b) { - return notYetImplemented('maximum'); - }; - KernelBackend.prototype.all = function (x, axes) { - return notYetImplemented('all'); - }; - KernelBackend.prototype.any = function (x, axes) { - return notYetImplemented('any'); - }; - KernelBackend.prototype.squaredDifference = function (a, b) { - return notYetImplemented('squaredDifference'); - }; - KernelBackend.prototype.ceil = function (x) { - return notYetImplemented('ceil'); - }; - KernelBackend.prototype.floor = function (x) { - return notYetImplemented('floor'); - }; - KernelBackend.prototype.round = function (x) { - return notYetImplemented('round'); - }; - KernelBackend.prototype.sign = function (x) { - return notYetImplemented('sign'); - }; - KernelBackend.prototype.isNaN = function (x) { - return notYetImplemented('isNaN'); - }; - KernelBackend.prototype.isInf = function (x) { - return notYetImplemented('isInf'); - }; - KernelBackend.prototype.isFinite = function (x) { - return notYetImplemented('isFinite'); - }; - KernelBackend.prototype.pow = function (a, b) { - return notYetImplemented('pow'); - }; - KernelBackend.prototype.exp = function (x) { - return notYetImplemented('exp'); - }; - KernelBackend.prototype.expm1 = function (x) { - return notYetImplemented('expm1'); - }; - KernelBackend.prototype.softmax = function (x, dim) { - return notYetImplemented('softmax'); - }; - KernelBackend.prototype.log = function (x) { - return notYetImplemented('log'); - }; - KernelBackend.prototype.log1p = function (x) { - return notYetImplemented('log1p'); - }; - KernelBackend.prototype.sqrt = function (x) { - return notYetImplemented('sqrt'); - }; - KernelBackend.prototype.rsqrt = function (x) { - return notYetImplemented('rsqrt'); - }; - KernelBackend.prototype.square = function (x) { - return notYetImplemented('square'); - }; - KernelBackend.prototype.reciprocal = function (x) { - return notYetImplemented('reciprocal'); - }; - KernelBackend.prototype.relu = function (x) { - return notYetImplemented('relu'); - }; - KernelBackend.prototype.relu6 = function (x) { - return notYetImplemented('relu6'); - }; - KernelBackend.prototype.prelu = function (x, a) { - return notYetImplemented('prelu'); - }; - KernelBackend.prototype.elu = function (x) { - return notYetImplemented('elu'); - }; - KernelBackend.prototype.eluDer = function (dy, y) { - return notYetImplemented('eluDer'); - }; - KernelBackend.prototype.selu = function (x) { - return notYetImplemented('selu'); - }; - KernelBackend.prototype.int = function (x) { - return notYetImplemented('int'); - }; - KernelBackend.prototype.clip = function (x, min, max) { - return notYetImplemented('clip'); - }; - KernelBackend.prototype.abs = function (x) { - return notYetImplemented('abs'); - }; - KernelBackend.prototype.complexAbs = function (x) { - return notYetImplemented('complexAbs'); - }; - KernelBackend.prototype.sigmoid = function (x) { - return notYetImplemented('sigmoid'); - }; - KernelBackend.prototype.softplus = function (x) { - return notYetImplemented('softplus'); - }; - KernelBackend.prototype.sin = function (x) { - return notYetImplemented('sin'); - }; - KernelBackend.prototype.cos = function (x) { - return notYetImplemented('cos'); - }; - KernelBackend.prototype.tan = function (x) { - return notYetImplemented('tan'); - }; - KernelBackend.prototype.asin = function (x) { - return notYetImplemented('asin'); - }; - KernelBackend.prototype.acos = function (x) { - return notYetImplemented('acos'); - }; - KernelBackend.prototype.atan = function (x) { - return notYetImplemented('atan'); - }; - KernelBackend.prototype.atan2 = function (a, b) { - return notYetImplemented('atan2'); - }; - KernelBackend.prototype.sinh = function (x) { - return notYetImplemented('sinh'); - }; - KernelBackend.prototype.cosh = function (x) { - return notYetImplemented('cosh'); - }; - KernelBackend.prototype.tanh = function (x) { - return notYetImplemented('tanh'); - }; - KernelBackend.prototype.asinh = function (x) { - return notYetImplemented('asinh'); - }; - KernelBackend.prototype.acosh = function (x) { - return notYetImplemented('acosh'); - }; - KernelBackend.prototype.atanh = function (x) { - return notYetImplemented('atanh'); - }; - KernelBackend.prototype.erf = function (x) { - return notYetImplemented('erf'); - }; - KernelBackend.prototype.step = function (x, alpha) { - return notYetImplemented('step'); - }; - KernelBackend.prototype.fusedConv2d = function (_a) { - var input = _a.input, filter = _a.filter, convInfo = _a.convInfo, bias = _a.bias, activation = _a.activation, preluActivationWeights = _a.preluActivationWeights; - return notYetImplemented('fusedConv2d'); - }; - KernelBackend.prototype.conv2d = function (x, filter, convInfo) { - return notYetImplemented('conv2d'); - }; - KernelBackend.prototype.conv2dDerInput = function (dy, filter, convInfo) { - return notYetImplemented('conv2dDerInput'); - }; - KernelBackend.prototype.conv2dDerFilter = function (x, dY, convInfo) { - return notYetImplemented('conv2dDerFilter'); - }; - KernelBackend.prototype.fusedDepthwiseConv2D = function (_a) { - var input = _a.input, filter = _a.filter, convInfo = _a.convInfo, bias = _a.bias, activation = _a.activation, preluActivationWeights = _a.preluActivationWeights; - return notYetImplemented('fusedDepthwiseConv2D'); - }; - KernelBackend.prototype.depthwiseConv2D = function (input, filter, convInfo) { - return notYetImplemented('depthwiseConv2D'); - }; - KernelBackend.prototype.depthwiseConv2DDerInput = function (dy, filter, convInfo) { - return notYetImplemented('depthwiseConv2DDerInput'); - }; - KernelBackend.prototype.depthwiseConv2DDerFilter = function (x, dY, convInfo) { - return notYetImplemented('depthwiseConv2DDerFilter'); - }; - KernelBackend.prototype.conv3d = function (x, filter, convInfo) { - return notYetImplemented('conv3d'); - }; - KernelBackend.prototype.conv3dDerInput = function (dy, filter, convInfo) { - return notYetImplemented('conv3dDerInput'); - }; - KernelBackend.prototype.conv3dDerFilter = function (x, dY, convInfo) { - return notYetImplemented('conv3dDerFilter'); - }; - KernelBackend.prototype.maxPool = function (x, convInfo) { - return notYetImplemented('maxPool'); - }; - KernelBackend.prototype.maxPoolBackprop = function (dy, x, y, convInfo) { - return notYetImplemented('maxPoolBackprop'); - }; - KernelBackend.prototype.avgPool = function (x, convInfo) { - return notYetImplemented('avgPool'); - }; - KernelBackend.prototype.avgPoolBackprop = function (dy, x, convInfo) { - return notYetImplemented('avgPoolBackprop'); - }; - KernelBackend.prototype.avgPool3d = function (x, convInfo) { - return notYetImplemented('avgPool3d'); - }; - KernelBackend.prototype.avgPool3dBackprop = function (dy, x, convInfo) { - return notYetImplemented('avgPool3dBackprop'); - }; - KernelBackend.prototype.maxPool3d = function (x, convInfo) { - return notYetImplemented('maxPool3d'); - }; - KernelBackend.prototype.maxPool3dBackprop = function (dy, x, y, convInfo) { - return notYetImplemented('maxPool3dBackprop'); - }; - KernelBackend.prototype.reshape = function (x, shape) { - return notYetImplemented('reshape'); - }; - KernelBackend.prototype.cast = function (x, dtype) { - return notYetImplemented('cast'); - }; - KernelBackend.prototype.tile = function (x, reps) { - return notYetImplemented('tile'); - }; - KernelBackend.prototype.pad = function (x, paddings, constantValue) { - return notYetImplemented('pad'); - }; - KernelBackend.prototype.transpose = function (x, perm) { - return notYetImplemented('transpose'); - }; - KernelBackend.prototype.gather = function (x, indices, axis) { - return notYetImplemented('gather'); - }; - KernelBackend.prototype.gatherND = function (x, indices) { - return notYetImplemented('gatherND'); - }; - KernelBackend.prototype.scatterND = function (indices, updates, shape) { - return notYetImplemented('scatterND'); - }; - KernelBackend.prototype.batchToSpaceND = function (x, blockShape, crops) { - return notYetImplemented('batchToSpaceND'); - }; - KernelBackend.prototype.spaceToBatchND = function (x, blockShape, paddings) { - return notYetImplemented('spaceToBatchND'); - }; - KernelBackend.prototype.resizeBilinear = function (x, newHeight, newWidth, alignCorners) { - return notYetImplemented('resizeBilinear'); - }; - KernelBackend.prototype.resizeBilinearBackprop = function (dy, x, alignCorners) { - return notYetImplemented('resizeBilinearBackprop'); - }; - KernelBackend.prototype.resizeNearestNeighbor = function (x, newHEight, newWidth, alignCorners) { - return notYetImplemented('resizeNearestNeighbor'); - }; - KernelBackend.prototype.resizeNearestNeighborBackprop = function (dy, x, alignCorners) { - return notYetImplemented('resizeNearestNeighborBackprop'); - }; - KernelBackend.prototype.batchNormalization = function (x, mean, variance, varianceEpsilon, scale, offset) { - return notYetImplemented('batchNormalization'); - }; - KernelBackend.prototype.localResponseNormalization4D = function (x, radius, bias, alpha, beta) { - return notYetImplemented('localResponseNormalization4D'); - }; - KernelBackend.prototype.LRNGrad = function (dy, inputImage, outputImage, radius, bias, alpha, beta) { - return notYetImplemented('LRNGrad'); - }; - KernelBackend.prototype.multinomial = function (logits, normalized, numSamples, seed) { - return notYetImplemented('multinomial'); - }; - KernelBackend.prototype.oneHot = function (indices, depth, onValue, offValue) { - return notYetImplemented('oneHot'); - }; - KernelBackend.prototype.cumsum = function (x, axis, exclusive, reverse) { - return notYetImplemented('cumsum'); - }; - KernelBackend.prototype.nonMaxSuppression = function (boxes, scores, maxOutputSize, iouThreshold, scoreThreshold) { - return notYetImplemented('nonMaxSuppression'); - }; - KernelBackend.prototype.fft = function (x) { - return notYetImplemented('fft'); - }; - KernelBackend.prototype.ifft = function (x) { - return notYetImplemented('ifft'); - }; - KernelBackend.prototype.complex = function (real, imag) { - return notYetImplemented('complex'); - }; - KernelBackend.prototype.real = function (input) { - return notYetImplemented('real'); - }; - KernelBackend.prototype.imag = function (input) { - return notYetImplemented('imag'); - }; - KernelBackend.prototype.cropAndResize = function (image, boxes, boxIndex, cropSize, method, extrapolationValue) { - return notYetImplemented('cropAndResize'); - }; - KernelBackend.prototype.depthToSpace = function (x, blockSize, dataFormat) { - return notYetImplemented('depthToSpace'); - }; - // Aligns with the "SplitV" kernel in TensorFlow. - KernelBackend.prototype.split = function (value, sizeSplits, axis) { - return notYetImplemented('split'); - }; - KernelBackend.prototype.sparseToDense = function (sparseIndices, sparseValues, outputShape, defaultValue) { - return notYetImplemented('sparseToDense'); - }; - KernelBackend.prototype.diag = function (x) { - return notYetImplemented('diag'); - }; - KernelBackend.prototype.fill = function (shape, value, dtype) { - return notYetImplemented('fill'); - }; - KernelBackend.prototype.onesLike = function (x) { - return notYetImplemented('onesLike'); - }; - KernelBackend.prototype.zerosLike = function (x) { - return notYetImplemented('zerosLike'); - }; - KernelBackend.prototype.linspace = function (start, stop, num) { - return notYetImplemented('linspace'); - }; - KernelBackend.prototype.dispose = function () { - return notYetImplemented('dispose'); - }; - return KernelBackend; - }()); - function notYetImplemented(kernelName) { - throw new Error("'" + kernelName + "' not yet implemented or not found in the registry. " + - "Did you forget to import the kernel?"); - } - - /** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - function computePool2DInfo(inShape, filterSize, strides, dilations, pad, roundingMode, dataFormat) { - if (dataFormat === void 0) { dataFormat = 'channelsLast'; } - var _a = parseTupleParam(filterSize), filterHeight = _a[0], filterWidth = _a[1]; - var filterShape; - if (dataFormat === 'channelsLast') { - filterShape = [filterHeight, filterWidth, inShape[3], inShape[3]]; - } - else if (dataFormat === 'channelsFirst') { - filterShape = [filterHeight, filterWidth, inShape[1], inShape[1]]; - } - else { - throw new Error("Unknown dataFormat " + dataFormat); - } - return computeConv2DInfo(inShape, filterShape, strides, dilations, pad, roundingMode, false, dataFormat); - } - /** - * Computes the information for a forward pass of a pooling3D operation. - */ - function computePool3DInfo(inShape, filterSize, strides, dilations, pad, roundingMode, dataFormat) { - if (dataFormat === void 0) { dataFormat = 'NDHWC'; } - var _a = parse3TupleParam(filterSize), filterDepth = _a[0], filterHeight = _a[1], filterWidth = _a[2]; - var filterShape; - var $dataFormat; - if (dataFormat === 'NDHWC') { - $dataFormat = 'channelsLast'; - filterShape = - [filterDepth, filterHeight, filterWidth, inShape[4], inShape[4]]; - } - else if (dataFormat === 'NCDHW') { - $dataFormat = 'channelsFirst'; - filterShape = - [filterDepth, filterHeight, filterWidth, inShape[1], inShape[1]]; - } - else { - throw new Error("Unknown dataFormat " + dataFormat); - } - return computeConv3DInfo(inShape, filterShape, strides, dilations, pad, false, $dataFormat, roundingMode); - } - /** - * Computes the information for a forward pass of a convolution/pooling - * operation. - */ - function computeConv2DInfo(inShape, filterShape, strides, dilations, pad, roundingMode, depthwise, dataFormat) { - if (depthwise === void 0) { depthwise = false; } - if (dataFormat === void 0) { dataFormat = 'channelsLast'; } - var _a = [-1, -1, -1, -1], batchSize = _a[0], inHeight = _a[1], inWidth = _a[2], inChannels = _a[3]; - if (dataFormat === 'channelsLast') { - batchSize = inShape[0], inHeight = inShape[1], inWidth = inShape[2], inChannels = inShape[3]; - } - else if (dataFormat === 'channelsFirst') { - batchSize = inShape[0], inChannels = inShape[1], inHeight = inShape[2], inWidth = inShape[3]; - } - else { - throw new Error("Unknown dataFormat " + dataFormat); - } - var filterHeight = filterShape[0], filterWidth = filterShape[1], filterChannels = filterShape[3]; - var _b = parseTupleParam(strides), strideHeight = _b[0], strideWidth = _b[1]; - var _c = parseTupleParam(dilations), dilationHeight = _c[0], dilationWidth = _c[1]; - var effectiveFilterHeight = getEffectiveFilterSize(filterHeight, dilationHeight); - var effectiveFilterWidth = getEffectiveFilterSize(filterWidth, dilationWidth); - var _d = getPadAndOutInfo(pad, inHeight, inWidth, strideHeight, strideWidth, effectiveFilterHeight, effectiveFilterWidth, roundingMode), padInfo = _d.padInfo, outHeight = _d.outHeight, outWidth = _d.outWidth; - var outChannels = depthwise ? filterChannels * inChannels : filterChannels; - var outShape; - if (dataFormat === 'channelsFirst') { - outShape = [batchSize, outChannels, outHeight, outWidth]; - } - else if (dataFormat === 'channelsLast') { - outShape = [batchSize, outHeight, outWidth, outChannels]; - } - return { - batchSize: batchSize, - dataFormat: dataFormat, - inHeight: inHeight, - inWidth: inWidth, - inChannels: inChannels, - outHeight: outHeight, - outWidth: outWidth, - outChannels: outChannels, - padInfo: padInfo, - strideHeight: strideHeight, - strideWidth: strideWidth, - filterHeight: filterHeight, - filterWidth: filterWidth, - effectiveFilterHeight: effectiveFilterHeight, - effectiveFilterWidth: effectiveFilterWidth, - dilationHeight: dilationHeight, - dilationWidth: dilationWidth, - inShape: inShape, - outShape: outShape, - filterShape: filterShape - }; - } - /** - * Computes the information for a forward pass of a 3D convolution/pooling - * operation. - */ - function computeConv3DInfo(inShape, filterShape, strides, dilations, pad, depthwise, dataFormat, roundingMode) { - if (depthwise === void 0) { depthwise = false; } - if (dataFormat === void 0) { dataFormat = 'channelsLast'; } - var _a = [-1, -1, -1, -1, -1], batchSize = _a[0], inDepth = _a[1], inHeight = _a[2], inWidth = _a[3], inChannels = _a[4]; - if (dataFormat === 'channelsLast') { - batchSize = inShape[0], inDepth = inShape[1], inHeight = inShape[2], inWidth = inShape[3], inChannels = inShape[4]; - } - else if (dataFormat === 'channelsFirst') { - batchSize = inShape[0], inChannels = inShape[1], inDepth = inShape[2], inHeight = inShape[3], inWidth = inShape[4]; - } - else { - throw new Error("Unknown dataFormat " + dataFormat); - } - var filterDepth = filterShape[0], filterHeight = filterShape[1], filterWidth = filterShape[2], filterChannels = filterShape[4]; - var _b = parse3TupleParam(strides), strideDepth = _b[0], strideHeight = _b[1], strideWidth = _b[2]; - var _c = parse3TupleParam(dilations), dilationDepth = _c[0], dilationHeight = _c[1], dilationWidth = _c[2]; - var effectiveFilterDepth = getEffectiveFilterSize(filterDepth, dilationDepth); - var effectiveFilterHeight = getEffectiveFilterSize(filterHeight, dilationHeight); - var effectiveFilterWidth = getEffectiveFilterSize(filterWidth, dilationWidth); - var _d = get3DPadAndOutInfo(pad, inDepth, inHeight, inWidth, strideDepth, strideHeight, strideWidth, effectiveFilterDepth, effectiveFilterHeight, effectiveFilterWidth, roundingMode), padInfo = _d.padInfo, outDepth = _d.outDepth, outHeight = _d.outHeight, outWidth = _d.outWidth; - var outChannels = depthwise ? filterChannels * inChannels : filterChannels; - var outShape; - if (dataFormat === 'channelsFirst') { - outShape = [batchSize, outChannels, outDepth, outHeight, outWidth]; - } - else if (dataFormat === 'channelsLast') { - outShape = [batchSize, outDepth, outHeight, outWidth, outChannels]; - } - return { - batchSize: batchSize, - dataFormat: dataFormat, - inDepth: inDepth, - inHeight: inHeight, - inWidth: inWidth, - inChannels: inChannels, - outDepth: outDepth, - outHeight: outHeight, - outWidth: outWidth, - outChannels: outChannels, - padInfo: padInfo, - strideDepth: strideDepth, - strideHeight: strideHeight, - strideWidth: strideWidth, - filterDepth: filterDepth, - filterHeight: filterHeight, - filterWidth: filterWidth, - effectiveFilterDepth: effectiveFilterDepth, - effectiveFilterHeight: effectiveFilterHeight, - effectiveFilterWidth: effectiveFilterWidth, - dilationDepth: dilationDepth, - dilationHeight: dilationHeight, - dilationWidth: dilationWidth, - inShape: inShape, - outShape: outShape, - filterShape: filterShape - }; - } - function computeOutputShape2D(inShape, fieldSize, stride, zeroPad, roundingMode) { - if (zeroPad == null) { - zeroPad = computeDefaultPad(inShape, fieldSize, stride); - } - var inputRows = inShape[0]; - var inputCols = inShape[1]; - var outputRows = conditionalRound((inputRows - fieldSize + 2 * zeroPad) / stride + 1, roundingMode); - assert(isInt(outputRows), function () { return "The output # of rows (" + outputRows + ") must be an integer. " + - "Change the stride and/or zero pad parameters"; }); - var outputCols = conditionalRound((inputCols - fieldSize + 2 * zeroPad) / stride + 1, roundingMode); - assert(isInt(outputCols), function () { return "The output # of columns (" + outputCols + ") must be an integer. " + - "Change the stride and/or zero pad parameters"; }); - return [outputRows, outputCols]; - } - function computeOutputShape4D(inShape, fieldSize, outChannels, stride, zeroPad, roundingMode) { - if (zeroPad == null) { - zeroPad = computeDefaultPad(inShape, fieldSize, stride); - } - var inputDepth = inShape[0]; - var inputRows = inShape[1]; - var inputCols = inShape[2]; - var outputDepths = conditionalRound((inputDepth - fieldSize + 2 * zeroPad) / stride + 1, roundingMode); - assert(isInt(outputDepths), function () { return "The output # of depths (" + outputDepths + ") must be an integer. " + - "Change the stride and/or zero pad parameters"; }); - var outputRows = conditionalRound((inputRows - fieldSize + 2 * zeroPad) / stride + 1, roundingMode); - assert(isInt(outputRows), function () { return "The output # of rows (" + outputRows + ") must be an integer. " + - "Change the stride and/or zero pad parameters"; }); - var outputCols = conditionalRound((inputCols - fieldSize + 2 * zeroPad) / stride + 1, roundingMode); - assert(isInt(outputCols), function () { return "The output # of columns (" + outputCols + ") must be an integer. " + - "Change the stride and/or zero pad parameters"; }); - return [outputDepths, outputRows, outputCols, outChannels]; - } - function computeDefaultPad(inputShape, fieldSize, stride, dilation) { - if (dilation === void 0) { dilation = 1; } - var effectiveFieldSize = getEffectiveFilterSize(fieldSize, dilation); - return Math.floor((inputShape[0] * (stride - 1) - stride + effectiveFieldSize) / 2); - } - function parseTupleParam(param) { - if (typeof param === 'number') { - return [param, param, param]; - } - if (param.length === 2) { - return [param[0], param[1], 1]; - } - return param; - } - function parse3TupleParam(param) { - return typeof param === 'number' ? [param, param, param] : param; - } - /* See https://www.tensorflow.org/api_docs/python/tf/nn/atrous_conv2d - * Atrous convolution is equivalent to standard convolution with upsampled - * filters with effective_filter_height = - * filter_height + (filter_height - 1) * (dilation - 1) - * and effective_filter_width = - * filter_width + (filter_width - 1) * (dilation - 1), - * produced by inserting dilation - 1 zeros along consecutive elements across - * the filters' spatial dimensions. - * When there is a dilation, this converts a filter dimension to the - * effective filter dimension, so it can be used in a standard convolution. - */ - function getEffectiveFilterSize(filterSize, dilation) { - if (dilation <= 1) { - return filterSize; - } - return filterSize + (filterSize - 1) * (dilation - 1); - } - function getPadAndOutInfo(pad, inHeight, inWidth, strideHeight, strideWidth, filterHeight, filterWidth, roundingMode) { - var padInfo; - var outHeight; - var outWidth; - if (typeof pad === 'number') { - var padType = (pad === 0) ? 'VALID' : 'NUMBER'; - padInfo = { top: pad, bottom: pad, left: pad, right: pad, type: padType }; - var outShape = computeOutputShape2D([inHeight, inWidth], filterHeight, strideHeight, pad, roundingMode); - outHeight = outShape[0]; - outWidth = outShape[1]; - } - else if (pad === 'same') { - outHeight = Math.ceil(inHeight / strideHeight); - outWidth = Math.ceil(inWidth / strideWidth); - var padAlongHeight = Math.max(0, (outHeight - 1) * strideHeight + filterHeight - inHeight); - var padAlongWidth = Math.max(0, (outWidth - 1) * strideWidth + filterWidth - inWidth); - var top_1 = Math.floor(padAlongHeight / 2); - var bottom = padAlongHeight - top_1; - var left = Math.floor(padAlongWidth / 2); - var right = padAlongWidth - left; - padInfo = { top: top_1, bottom: bottom, left: left, right: right, type: 'SAME' }; - } - else if (pad === 'valid') { - padInfo = { top: 0, bottom: 0, left: 0, right: 0, type: 'VALID' }; - outHeight = Math.ceil((inHeight - filterHeight + 1) / strideHeight); - outWidth = Math.ceil((inWidth - filterWidth + 1) / strideWidth); - } - else { - throw Error("Unknown padding parameter: " + pad); - } - return { padInfo: padInfo, outHeight: outHeight, outWidth: outWidth }; - } - function get3DPadAndOutInfo(pad, inDepth, inHeight, inWidth, strideDepth, strideHeight, strideWidth, filterDepth, filterHeight, filterWidth, roundingMode) { - var padInfo; - var outDepth; - var outHeight; - var outWidth; - if (typeof pad === 'number') { - var padType = (pad === 0) ? 'VALID' : 'NUMBER'; - padInfo = { - top: pad, - bottom: pad, - left: pad, - right: pad, - front: pad, - back: pad, - type: padType - }; - var outShape = computeOutputShape4D([inDepth, inHeight, inWidth, 1], filterDepth, 1, strideDepth, pad, roundingMode); - outDepth = outShape[0]; - outHeight = outShape[1]; - outWidth = outShape[2]; - } - else if (pad === 'same') { - outDepth = Math.ceil(inDepth / strideDepth); - outHeight = Math.ceil(inHeight / strideHeight); - outWidth = Math.ceil(inWidth / strideWidth); - var padAlongDepth = (outDepth - 1) * strideDepth + filterDepth - inDepth; - var padAlongHeight = (outHeight - 1) * strideHeight + filterHeight - inHeight; - var padAlongWidth = (outWidth - 1) * strideWidth + filterWidth - inWidth; - var front = Math.floor(padAlongDepth / 2); - var back = padAlongDepth - front; - var top_2 = Math.floor(padAlongHeight / 2); - var bottom = padAlongHeight - top_2; - var left = Math.floor(padAlongWidth / 2); - var right = padAlongWidth - left; - padInfo = { top: top_2, bottom: bottom, left: left, right: right, front: front, back: back, type: 'SAME' }; - } - else if (pad === 'valid') { - padInfo = { - top: 0, - bottom: 0, - left: 0, - right: 0, - front: 0, - back: 0, - type: 'VALID' - }; - outDepth = Math.ceil((inDepth - filterDepth + 1) / strideDepth); - outHeight = Math.ceil((inHeight - filterHeight + 1) / strideHeight); - outWidth = Math.ceil((inWidth - filterWidth + 1) / strideWidth); - } - else { - throw Error("Unknown padding parameter: " + pad); - } - return { padInfo: padInfo, outDepth: outDepth, outHeight: outHeight, outWidth: outWidth }; - } - /** - * Rounds a value depending on the rounding mode - * @param value - * @param roundingMode - */ - function conditionalRound(value, roundingMode) { - if (!roundingMode) { - return value; - } - switch (roundingMode) { - case 'round': - // used for Caffe Conv - return Math.round(value); - case 'ceil': - // used for Caffe Pool - return Math.ceil(value); - case 'floor': - return Math.floor(value); - default: - throw new Error("Unknown roundingMode " + roundingMode); - } - } - function tupleValuesAreOne(param) { - var _a = parseTupleParam(param), dimA = _a[0], dimB = _a[1], dimC = _a[2]; - return dimA === 1 && dimB === 1 && dimC === 1; - } - function eitherStridesOrDilationsAreOne(strides, dilations) { - return tupleValuesAreOne(strides) || tupleValuesAreOne(dilations); - } - /** - * Convert Conv2D dataFormat from 'NHWC'|'NCHW' to - * 'channelsLast'|'channelsFirst' - * @param dataFormat in 'NHWC'|'NCHW' mode - * @return dataFormat in 'channelsLast'|'channelsFirst' mode - * @throws unknown dataFormat - */ - function convertConv2DDataFormat(dataFormat) { - if (dataFormat === 'NHWC') { - return 'channelsLast'; - } - else if (dataFormat === 'NCHW') { - return 'channelsFirst'; - } - else { - throw new Error("Unknown dataFormat " + dataFormat); - } - } - - /** - * @license - * Copyright 2018 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - function castTensor(x, dtype, backend) { - if (dtype === 'complex64') { - if (x.dtype === 'complex64') { - return x.clone(); - } - var zerosTensor = zeros(x.shape); - var floatX = x.toFloat(); - var result = backend.complex(floatX, zerosTensor); - zerosTensor.dispose(); - floatX.dispose(); - return result; - } - if (!hasEncodingLoss(x.dtype, dtype)) { - // We don't change the underlying data, since we cast to higher - // precision. - return ENGINE.makeTensorFromDataId(x.dataId, x.shape, dtype); - } - if (x.dtype === 'complex64') { - var real = backend.real(x); - var result = real.cast(dtype); - real.dispose(); - return result; - } - if (dtype === 'int32') { - return backend.int(x); - } - else if (dtype === 'bool') { - var zero = scalar(0, x.dtype); - var result = backend.notEqual(x, zero); - zero.dispose(); - return result; - } - else { - throw new Error("Error in Cast: failed to cast " + x.dtype + " to " + dtype); - } - } - function reshapeTensor(x, shape) { - return ENGINE.makeTensorFromDataId(x.dataId, shape, x.dtype); - } - function linspaceImpl(start, stop, num) { - var step = (stop - start) / (num - 1); - var values = makeZerosTypedArray(num, 'float32'); - values[0] = start; - for (var i = 1; i < values.length; i++) { - values[i] = values[i - 1] + step; - } - return tensor1d(values, 'float32'); - } - - var backend_util = /*#__PURE__*/Object.freeze({ - castTensor: castTensor, - reshapeTensor: reshapeTensor, - linspaceImpl: linspaceImpl, - upcastType: upcastType, - axesAreInnerMostDims: axesAreInnerMostDims, - combineLocations: combineLocations, - computeOutAndReduceShapes: computeOutAndReduceShapes, - expandShapeToKeepDim: expandShapeToKeepDim, - assertAxesAreInnerMostDims: assertAxesAreInnerMostDims, - getAxesPermutation: getAxesPermutation, - getUndoAxesPermutation: getUndoAxesPermutation, - getInnerMostAxes: getInnerMostAxes, - getBroadcastDims: getBroadcastDims, - getReductionAxes: getReductionAxes, - assertAndGetBroadcastShape: assertAndGetBroadcastShape, - assertParamsConsistent: assertParamsConsistent, - computeOutShape: computeOutShape, - computePool2DInfo: computePool2DInfo, - computePool3DInfo: computePool3DInfo, - computeConv2DInfo: computeConv2DInfo, - computeConv3DInfo: computeConv3DInfo, - computeDefaultPad: computeDefaultPad, - tupleValuesAreOne: tupleValuesAreOne, - eitherStridesOrDilationsAreOne: eitherStridesOrDilationsAreOne, - convertConv2DDataFormat: convertConv2DDataFormat, - PARALLELIZE_THRESHOLD: PARALLELIZE_THRESHOLD, - computeOptimalWindowSize: computeOptimalWindowSize - }); - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /** - * Merges real and imaginary Float32Arrays into a single complex Float32Array. - * - * The memory layout is interleaved as follows: - * real: [r0, r1, r2] - * imag: [i0, i1, i2] - * complex: [r0, i0, r1, i1, r2, i2] - * - * This is the inverse of splitRealAndImagArrays. - * - * @param real The real values of the complex tensor values. - * @param imag The imag values of the complex tensor values. - * @returns A complex tensor as a Float32Array with merged values. - */ - function mergeRealAndImagArrays(real, imag) { - if (real.length !== imag.length) { - throw new Error("Cannot merge real and imag arrays of different lengths. real:" + - (real.length + ", imag: " + imag.length + ".")); - } - var result = new Float32Array(real.length * 2); - for (var i = 0; i < result.length; i += 2) { - result[i] = real[i / 2]; - result[i + 1] = imag[i / 2]; - } - return result; - } - /** - * Splits a complex Float32Array into real and imag parts. - * - * The memory layout is interleaved as follows: - * complex: [r0, i0, r1, i1, r2, i2] - * real: [r0, r1, r2] - * imag: [i0, i1, i2] - * - * This is the inverse of mergeRealAndImagArrays. - * - * @param complex The complex tensor values. - * @returns An object with real and imag Float32Array components of the complex - * tensor. - */ - function splitRealAndImagArrays(complex) { - var real = new Float32Array(complex.length / 2); - var imag = new Float32Array(complex.length / 2); - for (var i = 0; i < complex.length; i += 2) { - real[i / 2] = complex[i]; - imag[i / 2] = complex[i + 1]; - } - return { real: real, imag: imag }; - } - /** - * Extracts even indexed complex values in the given array. - * @param complex The complex tensor values - */ - function complexWithEvenIndex(complex) { - var len = Math.ceil(complex.length / 4); - var real = new Float32Array(len); - var imag = new Float32Array(len); - for (var i = 0; i < complex.length; i += 4) { - real[Math.floor(i / 4)] = complex[i]; - imag[Math.floor(i / 4)] = complex[i + 1]; - } - return { real: real, imag: imag }; - } - /** - * Extracts odd indexed comple values in the given array. - * @param complex The complex tensor values - */ - function complexWithOddIndex(complex) { - var len = Math.floor(complex.length / 4); - var real = new Float32Array(len); - var imag = new Float32Array(len); - for (var i = 2; i < complex.length; i += 4) { - real[Math.floor(i / 4)] = complex[i]; - imag[Math.floor(i / 4)] = complex[i + 1]; - } - return { real: real, imag: imag }; - } - /** - * Get the map representing a complex value in the given array. - * @param complex The complex tensor values. - * @param index An index of the target complex value. - */ - function getComplexWithIndex(complex, index) { - var real = complex[index * 2]; - var imag = complex[index * 2 + 1]; - return { real: real, imag: imag }; - } - /** - * Insert a given complex value into the TypedArray. - * @param data The array in which the complex value is inserted. - * @param c The complex value to be inserted. - * @param index An index of the target complex value. - */ - function assignToTypedArray(data, real, imag, index) { - data[index * 2] = real; - data[index * 2 + 1] = imag; - } - /** - * Make the list of exponent terms used by FFT. - */ - function exponents(n, inverse) { - var real = new Float32Array(n / 2); - var imag = new Float32Array(n / 2); - for (var i = 0; i < Math.ceil(n / 2); i++) { - var x = (inverse ? 2 : -2) * Math.PI * (i / n); - real[i] = Math.cos(x); - imag[i] = Math.sin(x); - } - return { real: real, imag: imag }; - } - /** - * Make the exponent term used by FFT. - */ - function exponent(k, n, inverse) { - var x = (inverse ? 2 : -2) * Math.PI * (k / n); - var real = Math.cos(x); - var imag = Math.sin(x); - return { real: real, imag: imag }; - } - - /** - * @license - * Copyright 2019 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /** - * Inserts a value into a sorted array. This method allows duplicate, meaning it - * allows inserting duplicate value, in which case, the element will be inserted - * at the lowest index of the value. - * @param arr The array to modify. - * @param element The element to insert. - * @param comparator Optional. If no comparator is specified, elements are - * compared using array_util.defaultComparator, which is suitable for Strings - * and Numbers in ascending arrays. If the array contains multiple instances of - * the target value, the left-most instance will be returned. To provide a - * comparator, it should take 2 arguments to compare and return a negative, - * zero, or a positive number. - */ - function binaryInsert(arr, element, comparator) { - var index = binarySearch(arr, element, comparator); - var insertionPoint = index < 0 ? -(index + 1) : index; - arr.splice(insertionPoint, 0, element); - } - /** - * Searches the array for the target using binary search, returns the index - * of the found element, or position to insert if element not found. If no - * comparator is specified, elements are compared using array_ - * util.defaultComparator, which is suitable for Strings and Numbers in - * ascending arrays. If the array contains multiple instances of the target - * value, the left-most instance will be returned. - * @param arr The array to be searched in. - * @param target The target to be searched for. - * @param comparator Should take 2 arguments to compare and return a negative, - * zero, or a positive number. - * @return Lowest index of the target value if found, otherwise the insertion - * point where the target should be inserted, in the form of - * (-insertionPoint - 1). - */ - function binarySearch(arr, target, comparator) { - return binarySearch_(arr, target, comparator || defaultComparator); - } - /** - * Compares its two arguments for order. - * @param a The first element to be compared. - * @param b The second element to be compared. - * @return A negative number, zero, or a positive number as the first - * argument is less than, equal to, or greater than the second. - */ - function defaultComparator(a, b) { - return a > b ? 1 : a < b ? -1 : 0; - } - function binarySearch_(arr, target, comparator) { - var left = 0; - var right = arr.length; - var middle = 0; - var found = false; - while (left < right) { - middle = left + ((right - left) >>> 1); - var compareResult = comparator(target, arr[middle]); - if (compareResult > 0) { - left = middle + 1; - } - else { - right = middle; - // If compareResult is 0, the value is found. We record it is found, - // and then keep looking because there may be duplicate. - found = !compareResult; - } - } - return found ? left : -left - 1; - } - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - function nonMaxSuppressionV3(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold) { - var dummySoftNmsSigma = 0.0; - return nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, dummySoftNmsSigma) - .selectedIndices; - } - function nonMaxSuppressionV5(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma) { - // For NonMaxSuppressionV5Op, we always return a second output holding - // corresponding scores. - var returnScoresTensor = true; - var result = nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma, returnScoresTensor); - result.numValidOutputs.dispose(); - return { - selectedIndices: result.selectedIndices, - selectedScores: result.selectedScores - }; - } - function nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma, returnScoresTensor, padToMaxOutputSize) { - if (returnScoresTensor === void 0) { returnScoresTensor = false; } - if (padToMaxOutputSize === void 0) { padToMaxOutputSize = false; } - // The list is sorted in ascending order, so that we can always pop the - // candidate with the largest score in O(1) time. - var candidates = Array.from(scores) - .map(function (score, boxIndex) { return ({ score: score, boxIndex: boxIndex, suppressBeginIndex: 0 }); }) - .filter(function (c) { return c.score > scoreThreshold; }) - .sort(ascendingComparator); - // If softNmsSigma is 0, the outcome of this algorithm is exactly same as - // before. - var scale = softNmsSigma > 0 ? (-0.5 / softNmsSigma) : 0.0; - var selectedIndices = []; - var selectedScores = []; - while (selectedIndices.length < maxOutputSize && candidates.length > 0) { - var candidate = candidates.pop(); - var originalScore = candidate.score, boxIndex = candidate.boxIndex, suppressBeginIndex = candidate.suppressBeginIndex; - if (originalScore < scoreThreshold) { - break; - } - // Overlapping boxes are likely to have similar scores, therefore we - // iterate through the previously selected boxes backwards in order to - // see if candidate's score should be suppressed. We use - // suppressBeginIndex to track and ensure a candidate can be suppressed - // by a selected box no more than once. Also, if the overlap exceeds - // iouThreshold, we simply ignore the candidate. - var ignoreCandidate = false; - for (var j = selectedIndices.length - 1; j >= suppressBeginIndex; --j) { - var iou = intersectionOverUnion(boxes, boxIndex, selectedIndices[j]); - if (iou >= iouThreshold) { - ignoreCandidate = true; - break; - } - candidate.score = - candidate.score * suppressWeight(iouThreshold, scale, iou); - if (candidate.score <= scoreThreshold) { - break; - } - } - // At this point, if `candidate.score` has not dropped below - // `scoreThreshold`, then we know that we went through all of the - // previous selections and can safely update `suppressBeginIndex` to the - // end of the selected array. Then we can re-insert the candidate with - // the updated score and suppressBeginIndex back in the candidate list. - // If on the other hand, `candidate.score` has dropped below the score - // threshold, we will not add it back to the candidates list. - candidate.suppressBeginIndex = selectedIndices.length; - if (!ignoreCandidate) { - // Candidate has passed all the tests, and is not suppressed, so - // select the candidate. - if (candidate.score === originalScore) { - selectedIndices.push(boxIndex); - selectedScores.push(candidate.score); - } - else if (candidate.score > scoreThreshold) { - // Candidate's score is suppressed but is still high enough to be - // considered, so add back to the candidates list. - binaryInsert(candidates, candidate, ascendingComparator); - } - } - } - // NonMaxSuppressionV4 feature: padding output to maxOutputSize. - var numValidOutputs = selectedIndices.length; - if (padToMaxOutputSize) { - selectedIndices.fill(0, numValidOutputs); - selectedScores.fill(0.0, numValidOutputs); - } - return { - selectedIndices: tensor1d(selectedIndices, 'int32'), - selectedScores: tensor1d(selectedScores, 'float32'), - numValidOutputs: scalar(numValidOutputs, 'int32') - }; - } - function intersectionOverUnion(boxes, i, j) { - var iCoord = boxes.subarray(i * 4, i * 4 + 4); - var jCoord = boxes.subarray(j * 4, j * 4 + 4); - var yminI = Math.min(iCoord[0], iCoord[2]); - var xminI = Math.min(iCoord[1], iCoord[3]); - var ymaxI = Math.max(iCoord[0], iCoord[2]); - var xmaxI = Math.max(iCoord[1], iCoord[3]); - var yminJ = Math.min(jCoord[0], jCoord[2]); - var xminJ = Math.min(jCoord[1], jCoord[3]); - var ymaxJ = Math.max(jCoord[0], jCoord[2]); - var xmaxJ = Math.max(jCoord[1], jCoord[3]); - var areaI = (ymaxI - yminI) * (xmaxI - xminI); - var areaJ = (ymaxJ - yminJ) * (xmaxJ - xminJ); - if (areaI <= 0 || areaJ <= 0) { - return 0.0; - } - var intersectionYmin = Math.max(yminI, yminJ); - var intersectionXmin = Math.max(xminI, xminJ); - var intersectionYmax = Math.min(ymaxI, ymaxJ); - var intersectionXmax = Math.min(xmaxI, xmaxJ); - var intersectionArea = Math.max(intersectionYmax - intersectionYmin, 0.0) * - Math.max(intersectionXmax - intersectionXmin, 0.0); - return intersectionArea / (areaI + areaJ - intersectionArea); - } - // A Gaussian penalty function, this method always returns values in [0, 1]. - // The weight is a function of similarity, the more overlap two boxes are, the - // smaller the weight is, meaning highly overlapping boxe will be significantly - // penalized. On the other hand, a non-overlapping box will not be penalized. - function suppressWeight(iouThreshold, scale, iou) { - var weight = Math.exp(scale * iou * iou); - return iou <= iouThreshold ? weight : 0.0; - } - function ascendingComparator(c1, c2) { - // For objects with same scores, we make the object with the larger index go - // first. In an array that pops from the end, this means that the object with - // the smaller index will be popped first. This ensures the same output as - // the TensorFlow python version. - return (c1.score - c2.score) || - ((c1.score === c2.score) && (c2.boxIndex - c1.boxIndex)); - } - - /** - * @license - * Copyright 2018 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /** Shared implementation of the split kernel across WebGL and CPU. */ - function split$1(x, sizeSplits, axis) { - var begin = new Array(x.rank).fill(0); - var size = x.shape.slice(); - return sizeSplits.map(function (s) { - size[axis] = s; - var slice = x.slice(begin, size); - begin[axis] += s; - return slice; - }); - } - - /** - * @license - * Copyright 2019 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - function tile(xBuf, reps) { - var newShape = new Array(xBuf.rank); - for (var i = 0; i < newShape.length; i++) { - newShape[i] = xBuf.shape[i] * reps[i]; - } - var result = buffer(newShape, xBuf.dtype); - for (var i = 0; i < result.values.length; ++i) { - var newLoc = result.indexToLoc(i); - var originalLoc = new Array(xBuf.rank); - for (var j = 0; j < originalLoc.length; j++) { - originalLoc[j] = newLoc[j] % xBuf.shape[j]; - } - var originalIndex = xBuf.locToIndex(originalLoc); - result.values[i] = xBuf.values[originalIndex]; - } - return result.toTensor(); - } - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - function topkImpl(x, xShape, xDtype, k, sorted) { - // Reshape into a 2d tensor [batch, lastDim] and compute topk along lastDim. - var lastDim = xShape[xShape.length - 1]; - var _a = [x.length / lastDim, lastDim], batch = _a[0], size = _a[1]; - var allTopKVals = getTypedArrayFromDType(xDtype, batch * k); - var allTopKIndices = getTypedArrayFromDType('int32', batch * k); - for (var b = 0; b < batch; b++) { - var offset = b * size; - var vals = x.subarray(offset, offset + size); - var valAndInd = []; - for (var i = 0; i < vals.length; i++) { - valAndInd.push({ value: vals[i], index: i }); - } - valAndInd.sort(function (a, b) { return b.value - a.value; }); - var outOffset = b * k; - var topKVals = allTopKVals.subarray(outOffset, outOffset + k); - var topKIndices = allTopKIndices.subarray(outOffset, outOffset + k); - for (var i = 0; i < k; i++) { - topKVals[i] = valAndInd[i].value; - topKIndices[i] = valAndInd[i].index; - } - } - // Reshape back to the original input shape, except that the last - // dimension is k. - var outputShape = xShape.slice(); - outputShape[outputShape.length - 1] = k; - return [ - tensor(allTopKVals, outputShape, xDtype), - tensor(allTopKIndices, outputShape, 'int32') - ]; - } - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - function whereImpl(condShape, condVals) { - var indices = []; - for (var i = 0; i < condVals.length; i++) { - if (condVals[i]) { - indices.push(i); - } - } - var inBuffer = buffer(condShape, 'int32'); - var out = buffer([indices.length, condShape.length], 'int32'); - for (var i = 0; i < indices.length; i++) { - var loc = inBuffer.indexToLoc(indices[i]); - var offset = i * condShape.length; - out.values.set(loc, offset); - } - return out.toTensor(); - } - - /** - * @license - * Copyright 2019 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var AddNProgram = /** @class */ (function () { - function AddNProgram(outputShape, shapes) { - this.outputShape = []; - this.outputShape = outputShape; - this.variableNames = shapes.map(function (_, i) { return "T" + i; }); - var snippets = []; - // Get target elements from every input tensor. - this.variableNames.forEach(function (variable) { - snippets.push("float v" + variable + " = get" + variable + "AtOutCoords();"); - }); - // Calculate the sum of all elements. - var operation = this.variableNames - .map(function (variable) { - return "v" + variable; - }) - .join(' + '); - this.userCode = "\n void main() {\n " + snippets.join('\n ') + "\n\n float result = " + operation + ";\n setOutput(result);\n }\n "; - } - return AddNProgram; - }()); - - /** - * @license - * Copyright 2019 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var AddNPackedProgram = /** @class */ (function () { - function AddNPackedProgram(outputShape, shapes) { - this.outputShape = []; - this.packedInputs = true; - this.packedOutput = true; - this.outputShape = outputShape; - this.variableNames = shapes.map(function (_, i) { return "T" + i; }); - var snippets = []; - // Get target elements from every input tensor. - this.variableNames.forEach(function (variable) { - snippets.push("vec4 v" + variable + " = get" + variable + "AtOutCoords();"); - }); - // Calculate the sum of all elements. - var operation = this.variableNames - .map(function (variable) { - return "v" + variable; - }) - .join(' + '); - this.userCode = "\n void main() {\n " + snippets.join('\n ') + "\n\n vec4 result = " + operation + ";\n setOutput(result);\n }\n "; - } - return AddNPackedProgram; - }()); - - /** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var ArgMinMaxProgram = /** @class */ (function () { - function ArgMinMaxProgram(reduceInfo, op, firstPass) { - this.variableNames = ['A']; - var windowSize = reduceInfo.windowSize; - var batchSize = reduceInfo.batchSize; - var inSize = reduceInfo.inSize; - var outSize = Math.ceil(inSize / windowSize); - if (!firstPass) { - this.variableNames.push('bestIndicesA'); - } - this.outputShape = [batchSize, outSize]; - var compOp = (op === 'max') ? '>' : '<'; - var indexSnippet = firstPass ? - 'inOffset + i;' : - 'round(getBestIndicesA(batch, inOffset + i));'; - this.userCode = "\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n int outIdx = coords[1];\n int inOffset = outIdx * " + windowSize + ";\n\n int bestIndex = inOffset;\n float bestValue = getA(batch, bestIndex);\n\n for (int i = 0; i < " + windowSize + "; i++) {\n int inIdx = " + indexSnippet + ";\n float candidate = getA(batch, inIdx);\n if (candidate " + compOp + " bestValue) {\n bestValue = candidate;\n bestIndex = inIdx;\n }\n }\n setOutput(float(bestIndex));\n }\n "; - } - return ArgMinMaxProgram; - }()); - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - function getVecChannels(name, rank) { - return ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, rank).map(function (d) { return name + "." + d; }); - } - function getChannels(name, rank) { - if (rank === 1) { - return [name]; - } - return getVecChannels(name, rank); - } - function getSourceCoords(rank, dims) { - if (rank === 1) { - return 'rc'; - } - var coords = ''; - for (var i = 0; i < rank; i++) { - coords += dims[i]; - if (i < rank - 1) { - coords += ','; - } - } - return coords; - } - - /** - * @license - * Copyright 2018 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - function getGlslDifferences() { - var version; - var attribute; - var varyingVs; - var varyingFs; - var texture2D; - var output; - var defineOutput; - var defineSpecialNaN; - var defineSpecialInf; - var defineRound; - if (env().getNumber('WEBGL_VERSION') === 2) { - version = '#version 300 es'; - attribute = 'in'; - varyingVs = 'out'; - varyingFs = 'in'; - texture2D = 'texture'; - output = 'outputColor'; - defineOutput = 'out vec4 outputColor;'; - // Use custom isnan definition to work across differences between - // implementations on various platforms. While this should happen in ANGLE - // we still see differences between android and windows (on chrome) when - // using isnan directly. - defineSpecialNaN = "\n bool isnan_custom(float val) {\n return (val > 0.0 || val < 0.0) ? false : val != 0.0;\n }\n\n bvec4 isnan_custom(vec4 val) {\n return bvec4(isnan_custom(val.x),\n isnan_custom(val.y), isnan_custom(val.z), isnan_custom(val.w));\n }\n\n #define isnan(value) isnan_custom(value)\n "; - // In webgl 2 we do not need to specify a custom isinf so there is no - // need for a special INFINITY constant. - defineSpecialInf = ""; - defineRound = "\n #define round(value) newRound(value)\n int newRound(float value) {\n return int(floor(value + 0.5));\n }\n\n ivec4 newRound(vec4 value) {\n return ivec4(floor(value + vec4(0.5)));\n }\n "; - } - else { - version = ''; - attribute = 'attribute'; - varyingVs = 'varying'; - varyingFs = 'varying'; - texture2D = 'texture2D'; - output = 'gl_FragColor'; - defineOutput = ''; - // WebGL1 has no built in isnan so we define one here. - defineSpecialNaN = "\n #define isnan(value) isnan_custom(value)\n bool isnan_custom(float val) {\n return (val > 0. || val < 1. || val == 0.) ? false : true;\n }\n bvec4 isnan_custom(vec4 val) {\n return bvec4(isnan(val.x), isnan(val.y), isnan(val.z), isnan(val.w));\n }\n "; - defineSpecialInf = "\n uniform float INFINITY;\n\n bool isinf(float val) {\n return abs(val) == INFINITY;\n }\n bvec4 isinf(vec4 val) {\n return equal(abs(val), vec4(INFINITY));\n }\n "; - defineRound = "\n int round(float value) {\n return int(floor(value + 0.5));\n }\n\n ivec4 round(vec4 value) {\n return ivec4(floor(value + vec4(0.5)));\n }\n "; - } - return { - version: version, - attribute: attribute, - varyingVs: varyingVs, - varyingFs: varyingFs, - texture2D: texture2D, - output: output, - defineOutput: defineOutput, - defineSpecialNaN: defineSpecialNaN, - defineSpecialInf: defineSpecialInf, - defineRound: defineRound - }; - } - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /** - * Produces GLSL code that derives logical coordinates from a flat - * index. The code performs integer division with each stride and decrements - * the index until the index equals the final dimension coordinate. - */ - function getLogicalCoordinatesFromFlatIndex(coords, shape, index) { - if (index === void 0) { index = 'index'; } - var strides = computeStrides(shape); - return strides - .map(function (stride, i) { - var line1 = "int " + coords[i] + " = " + index + " / " + stride; - var line2 = i === strides.length - 1 ? - "int " + coords[i + 1] + " = " + index + " - " + coords[i] + " * " + stride : - "index -= " + coords[i] + " * " + stride; - return line1 + "; " + line2 + ";"; - }) - .join(''); - } - /** - * Produces GLSL that computes the flat index from 3D coordinates. - */ - function getFlatIndexFrom3D(shape) { - var strides = computeStrides(shape).map(function (d) { return d.toString(); }); - return "\n int getFlatIndex(ivec3 coords) {\n return coords.x * " + strides[0] + " + coords.y * " + strides[1] + " + coords.z;\n }\n"; - } - var ENCODE_FLOAT_SNIPPET = "\n const float FLOAT_MAX = 1.70141184e38;\n const float FLOAT_MIN = 1.17549435e-38;\n\n lowp vec4 encode_float(highp float v) {\n if (isnan(v)) {\n return vec4(255, 255, 255, 255);\n }\n\n highp float av = abs(v);\n\n if(av < FLOAT_MIN) {\n return vec4(0.0, 0.0, 0.0, 0.0);\n } else if(v > FLOAT_MAX) {\n return vec4(0.0, 0.0, 128.0, 127.0) / 255.0;\n } else if(v < -FLOAT_MAX) {\n return vec4(0.0, 0.0, 128.0, 255.0) / 255.0;\n }\n\n highp vec4 c = vec4(0,0,0,0);\n\n highp float e = floor(log2(av));\n highp float m = exp2(fract(log2(av))) - 1.0;\n\n c[2] = floor(128.0 * m);\n m -= c[2] / 128.0;\n c[1] = floor(32768.0 * m);\n m -= c[1] / 32768.0;\n c[0] = floor(8388608.0 * m);\n\n highp float ebias = e + 127.0;\n c[3] = floor(ebias / 2.0);\n ebias -= c[3] * 2.0;\n c[2] += floor(ebias) * 128.0;\n\n c[3] += 128.0 * step(0.0, -v);\n\n return c / 255.0;\n }\n"; - - /** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - function makeShader(inputsInfo, outputShape, userCode, usesPackedTextures) { - var prefixSnippets = []; - inputsInfo.forEach(function (x) { - var size = sizeFromShape(x.shapeInfo.logicalShape); - // Snippet when we decided to upload the values as uniform. - if (x.shapeInfo.isUniform) { - prefixSnippets.push("uniform float " + x.name + (size > 1 ? "[" + size + "]" : '') + ";"); - } - else { - prefixSnippets.push("uniform sampler2D " + x.name + ";"); - prefixSnippets.push("uniform int offset" + x.name + ";"); - } - }); - var inputPrefixSnippet = prefixSnippets.join('\n'); - var inputSamplingSnippet = inputsInfo - .map(function (x) { return getInputSamplingSnippet(x, outputShape, usesPackedTextures); }) - .join('\n'); - var outTexShape = outputShape.texShape; - var glsl = getGlslDifferences(); - var floatTextureSampleSnippet = getFloatTextureSampleSnippet(glsl); - var outputSamplingSnippet; - var floatTextureSetOutputSnippet; - var shaderPrefix = getShaderPrefix(glsl); - if (outputShape.isPacked) { - outputSamplingSnippet = - getPackedOutputSamplingSnippet(outputShape.logicalShape, outTexShape); - floatTextureSetOutputSnippet = getFloatTextureSetRGBASnippet(glsl); - } - else { - outputSamplingSnippet = - getOutputSamplingSnippet(outputShape.logicalShape, outTexShape); - floatTextureSetOutputSnippet = getFloatTextureSetRSnippet(glsl); - } - if (usesPackedTextures) { - shaderPrefix += SHADER_PACKED_PREFIX; - } - var source = [ - shaderPrefix, floatTextureSampleSnippet, floatTextureSetOutputSnippet, - inputPrefixSnippet, outputSamplingSnippet, inputSamplingSnippet, userCode - ].join('\n'); - return source; - } - function getSamplerFromInInfo(inInfo) { - var shape = inInfo.shapeInfo.logicalShape; - switch (shape.length) { - case 0: - return getSamplerScalar(inInfo); - case 1: - return getSampler1D(inInfo); - case 2: - return getSampler2D(inInfo); - case 3: - return getSampler3D(inInfo); - case 4: - return getSampler4D(inInfo); - case 5: - return getSampler5D(inInfo); - case 6: - return getSampler6D(inInfo); - default: - throw new Error(shape.length + "-D input sampling" + - " is not yet supported"); - } - } - function getPackedSamplerFromInInfo(inInfo) { - var shape = inInfo.shapeInfo.logicalShape; - switch (shape.length) { - case 0: - return getPackedSamplerScalar(inInfo); - case 1: - return getPackedSampler1D(inInfo); - case 2: - return getPackedSampler2D(inInfo); - case 3: - return getPackedSampler3D(inInfo); - default: - return getPackedSamplerND(inInfo); - } - } - function getInputSamplingSnippet(inInfo, outShapeInfo, usesPackedTextures) { - if (usesPackedTextures === void 0) { usesPackedTextures = false; } - var res = ''; - if (usesPackedTextures) { - res += getPackedSamplerFromInInfo(inInfo); - } - else { - res += getSamplerFromInInfo(inInfo); - } - var inShape = inInfo.shapeInfo.logicalShape; - var outShape = outShapeInfo.logicalShape; - if (inShape.length <= outShape.length) { - if (usesPackedTextures) { - res += getPackedSamplerAtOutputCoords(inInfo, outShapeInfo); - } - else { - res += getSamplerAtOutputCoords(inInfo, outShapeInfo); - } - } - return res; - } - function getPackedOutputSamplingSnippet(outShape, outTexShape) { - switch (outShape.length) { - case 0: - return getOutputScalarCoords(); - case 1: - return getOutputPacked1DCoords(outShape, outTexShape); - case 2: - return getOutputPacked2DCoords(outShape, outTexShape); - case 3: - return getOutputPacked3DCoords(outShape, outTexShape); - default: - return getOutputPackedNDCoords(outShape, outTexShape); - } - } - function getOutputSamplingSnippet(outShape, outTexShape) { - switch (outShape.length) { - case 0: - return getOutputScalarCoords(); - case 1: - return getOutput1DCoords(outShape, outTexShape); - case 2: - return getOutput2DCoords(outShape, outTexShape); - case 3: - return getOutput3DCoords(outShape, outTexShape); - case 4: - return getOutput4DCoords(outShape, outTexShape); - case 5: - return getOutput5DCoords(outShape, outTexShape); - case 6: - return getOutput6DCoords(outShape, outTexShape); - default: - throw new Error(outShape.length + "-D output sampling is not yet supported"); - } - } - function getFloatTextureSampleSnippet(glsl) { - return "\n float sampleTexture(sampler2D textureSampler, vec2 uv) {\n return " + glsl.texture2D + "(textureSampler, uv).r;\n }\n "; - } - function getFloatTextureSetRSnippet(glsl) { - return "\n void setOutput(float val) {\n " + glsl.output + " = vec4(val, 0, 0, 0);\n }\n "; - } - function getFloatTextureSetRGBASnippet(glsl) { - return "\n void setOutput(vec4 val) {\n " + glsl.output + " = val;\n }\n "; - } - function getShaderPrefix(glsl) { - var SHADER_PREFIX = glsl.version + "\n precision highp float;\n precision highp int;\n precision highp sampler2D;\n " + glsl.varyingFs + " vec2 resultUV;\n " + glsl.defineOutput + "\n const vec2 halfCR = vec2(0.5, 0.5);\n\n struct ivec5\n {\n int x;\n int y;\n int z;\n int w;\n int u;\n };\n\n struct ivec6\n {\n int x;\n int y;\n int z;\n int w;\n int u;\n int v;\n };\n\n uniform float NAN;\n " + glsl.defineSpecialNaN + "\n " + glsl.defineSpecialInf + "\n " + glsl.defineRound + "\n\n int imod(int x, int y) {\n return x - y * (x / y);\n }\n\n int idiv(int a, int b, float sign) {\n int res = a / b;\n int mod = imod(a, b);\n if (sign < 0. && mod != 0) {\n res -= 1;\n }\n return res;\n }\n\n //Based on the work of Dave Hoskins\n //https://www.shadertoy.com/view/4djSRW\n #define HASHSCALE1 443.8975\n float random(float seed){\n vec2 p = resultUV * seed;\n vec3 p3 = fract(vec3(p.xyx) * HASHSCALE1);\n p3 += dot(p3, p3.yzx + 19.19);\n return fract((p3.x + p3.y) * p3.z);\n }\n\n " + SAMPLE_1D_SNIPPET + "\n " + SAMPLE_2D_SNIPPET + "\n " + SAMPLE_3D_SNIPPET + "\n "; - return SHADER_PREFIX; - } - var SAMPLE_1D_SNIPPET = "\nvec2 uvFromFlat(int texNumR, int texNumC, int index) {\n int texR = index / texNumC;\n int texC = index - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\nvec2 packedUVfrom1D(int texNumR, int texNumC, int index) {\n int texelIndex = index / 2;\n int texR = texelIndex / texNumC;\n int texC = texelIndex - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\n"; - var SAMPLE_2D_SNIPPET = "\nvec2 packedUVfrom2D(int texelsInLogicalRow, int texNumR,\n int texNumC, int row, int col) {\n int texelIndex = (row / 2) * texelsInLogicalRow + (col / 2);\n int texR = texelIndex / texNumC;\n int texC = texelIndex - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\n"; - var SAMPLE_3D_SNIPPET = "\nvec2 packedUVfrom3D(int texNumR, int texNumC,\n int texelsInBatch, int texelsInLogicalRow, int b,\n int row, int col) {\n int index = b * texelsInBatch + (row / 2) * texelsInLogicalRow + (col / 2);\n int texR = index / texNumC;\n int texC = index - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\n"; - var SHADER_PACKED_PREFIX = "\n float getChannel(vec4 frag, vec2 innerDims) {\n vec2 modCoord = mod(innerDims, 2.);\n return modCoord.x == 0. ?\n (modCoord.y == 0. ? frag.r : frag.g) :\n (modCoord.y == 0. ? frag.b : frag.a);\n }\n float getChannel(vec4 frag, int dim) {\n float modCoord = mod(float(dim), 2.);\n return modCoord == 0. ? frag.r : frag.g;\n }\n"; - function getOutputScalarCoords() { - return "\n int getOutputCoords() {\n return 0;\n }\n "; - } - function getOutputPacked1DCoords(shape, texShape) { - var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)]; - if (packedTexShape[0] === 1) { - return "\n int getOutputCoords() {\n return 2 * int(resultUV.x * " + packedTexShape[1] + ".0);\n }\n "; - } - if (packedTexShape[1] === 1) { - return "\n int getOutputCoords() {\n return 2 * int(resultUV.y * " + packedTexShape[0] + ".0);\n }\n "; - } - return "\n int getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + packedTexShape[0] + ", " + packedTexShape[1] + "));\n return 2 * (resTexRC.x * " + packedTexShape[1] + " + resTexRC.y);\n }\n "; - } - function getOutput1DCoords(shape, texShape) { - if (texShape[0] === 1) { - return "\n int getOutputCoords() {\n return int(resultUV.x * " + texShape[1] + ".0);\n }\n "; - } - if (texShape[1] === 1) { - return "\n int getOutputCoords() {\n return int(resultUV.y * " + texShape[0] + ".0);\n }\n "; - } - return "\n int getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n return resTexRC.x * " + texShape[1] + " + resTexRC.y;\n }\n "; - } - function getOutputPacked3DCoords(shape, texShape) { - var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)]; - var texelsInLogicalRow = Math.ceil(shape[2] / 2); - var texelsInBatch = texelsInLogicalRow * Math.ceil(shape[1] / 2); - return "\n ivec3 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + packedTexShape[0] + ", " + packedTexShape[1] + "));\n int index = resTexRC.x * " + packedTexShape[1] + " + resTexRC.y;\n\n int b = index / " + texelsInBatch + ";\n index -= b * " + texelsInBatch + ";\n\n int r = 2 * (index / " + texelsInLogicalRow + ");\n int c = imod(index, " + texelsInLogicalRow + ") * 2;\n\n return ivec3(b, r, c);\n }\n "; - } - function getOutput3DCoords(shape, texShape) { - var coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], shape); - return "\n ivec3 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n " + coordsFromIndexSnippet + "\n return ivec3(r, c, d);\n }\n "; - } - function getOutputPackedNDCoords(shape, texShape) { - var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)]; - var texelsInLogicalRow = Math.ceil(shape[shape.length - 1] / 2); - var texelsInBatch = texelsInLogicalRow * Math.ceil(shape[shape.length - 2] / 2); - var texelsInBatchN = texelsInBatch; - var batches = ""; - var coords = 'b, r, c'; - for (var b = 2; b < shape.length - 1; b++) { - texelsInBatchN *= shape[shape.length - b - 1]; - batches = "\n int b" + b + " = index / " + texelsInBatchN + ";\n index -= b" + b + " * " + texelsInBatchN + ";\n " + batches; - coords = "b" + b + ", " + coords; - } - return "\n ivec" + shape.length + " getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + packedTexShape[0] + ", " + packedTexShape[1] + "));\n int index = resTexRC.x * " + packedTexShape[1] + " + resTexRC.y;\n\n " + batches + "\n\n int b = index / " + texelsInBatch + ";\n index -= b * " + texelsInBatch + ";\n\n int r = 2 * (index / " + texelsInLogicalRow + ");\n int c = imod(index, " + texelsInLogicalRow + ") * 2;\n\n return ivec" + shape.length + "(" + coords + ");\n }\n "; - } - function getOutput4DCoords(shape, texShape) { - var coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd', 'd2'], shape); - return "\n ivec4 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n " + coordsFromIndexSnippet + "\n return ivec4(r, c, d, d2);\n }\n "; - } - function getOutput5DCoords(shape, texShape) { - var coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd', 'd2', 'd3'], shape); - return "\n ivec5 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx * vec2(" + texShape[0] + ",\n " + texShape[1] + "));\n\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n\n " + coordsFromIndexSnippet + "\n\n ivec5 outShape = ivec5(r, c, d, d2, d3);\n return outShape;\n }\n "; - } - function getOutput6DCoords(shape, texShape) { - var coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd', 'd2', 'd3', 'd4'], shape); - return "\n ivec6 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n\n " + coordsFromIndexSnippet + "\n\n ivec6 result = ivec6(r, c, d, d2, d3, d4);\n return result;\n }\n "; - } - function getOutputPacked2DCoords(shape, texShape) { - var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)]; - if (arraysEqual(shape, texShape)) { - return "\n ivec2 getOutputCoords() {\n return 2 * ivec2(resultUV.yx * vec2(" + packedTexShape[0] + ", " + packedTexShape[1] + "));\n }\n "; - } - // texels needed to accommodate a logical row - var texelsInLogicalRow = Math.ceil(shape[1] / 2); - /** - * getOutputCoords - * - * resTexRC: The rows and columns of the texels. If you move over one - * texel to the right in the packed texture, you are moving over one column - * (not two). - * - * index: The texel index - */ - return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + packedTexShape[0] + ", " + packedTexShape[1] + "));\n\n int index = resTexRC.x * " + packedTexShape[1] + " + resTexRC.y;\n int r = 2 * (index / " + texelsInLogicalRow + ");\n int c = imod(index, " + texelsInLogicalRow + ") * 2;\n\n return ivec2(r, c);\n }\n "; - } - function getOutput2DCoords(shape, texShape) { - if (arraysEqual(shape, texShape)) { - return "\n ivec2 getOutputCoords() {\n return ivec2(resultUV.yx * vec2(" + texShape[0] + ", " + texShape[1] + "));\n }\n "; - } - if (shape[1] === 1) { - return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n return ivec2(index, 0);\n }\n "; - } - if (shape[0] === 1) { - return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n return ivec2(0, index);\n }\n "; - } - return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n int r = index / " + shape[1] + ";\n int c = index - r * " + shape[1] + ";\n return ivec2(r, c);\n }\n "; - } - function getFlatOffsetUniformName(texName) { - return "offset" + texName; - } - function getPackedSamplerScalar(inputInfo) { - var texName = inputInfo.name; - var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); - var glsl = getGlslDifferences(); - return "\n vec4 " + funcName + "() {\n return " + glsl.texture2D + "(" + texName + ", halfCR);\n }\n "; - } - function getSamplerScalar(inputInfo) { - var texName = inputInfo.name; - var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); - if (inputInfo.shapeInfo.isUniform) { - return "float " + funcName + "() {return " + texName + ";}"; - } - var _a = inputInfo.shapeInfo.texShape, texNumR = _a[0], texNumC = _a[1]; - if (texNumR === 1 && texNumC === 1) { - return "\n float " + funcName + "() {\n return sampleTexture(" + texName + ", halfCR);\n }\n "; - } - var _b = inputInfo.shapeInfo.texShape, tNumR = _b[0], tNumC = _b[1]; - var offset = getFlatOffsetUniformName(texName); - return "\n float " + funcName + "() {\n vec2 uv = uvFromFlat(" + tNumR + ", " + tNumC + ", " + offset + ");\n return sampleTexture(" + texName + ", uv);\n }\n "; - } - function getPackedSampler1D(inputInfo) { - var texName = inputInfo.name; - var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); - var texShape = inputInfo.shapeInfo.texShape; - var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)]; - var glsl = getGlslDifferences(); - return "\n vec4 " + funcName + "(int index) {\n vec2 uv = packedUVfrom1D(\n " + packedTexShape[0] + ", " + packedTexShape[1] + ", index);\n return " + glsl.texture2D + "(" + texName + ", uv);\n }\n "; - } - function getSampler1D(inputInfo) { - var texName = inputInfo.name; - var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); - if (inputInfo.shapeInfo.isUniform) { - // Uniform arrays will be less than 65505 (no risk of float16 overflow). - return "\n float " + funcName + "(int index) {\n " + getUniformSampler(inputInfo) + "\n }\n "; - } - var texShape = inputInfo.shapeInfo.texShape; - var tNumR = texShape[0]; - var tNumC = texShape[1]; - if (tNumC === 1 && tNumR === 1) { - return "\n float " + funcName + "(int index) {\n return sampleTexture(" + texName + ", halfCR);\n }\n "; - } - var offset = getFlatOffsetUniformName(texName); - if (tNumC === 1) { - return "\n float " + funcName + "(int index) {\n vec2 uv = vec2(0.5, (float(index + " + offset + ") + 0.5) / " + tNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n "; - } - if (tNumR === 1) { - return "\n float " + funcName + "(int index) {\n vec2 uv = vec2((float(index + " + offset + ") + 0.5) / " + tNumC + ".0, 0.5);\n return sampleTexture(" + texName + ", uv);\n }\n "; - } - return "\n float " + funcName + "(int index) {\n vec2 uv = uvFromFlat(" + tNumR + ", " + tNumC + ", index + " + offset + ");\n return sampleTexture(" + texName + ", uv);\n }\n "; - } - function getPackedSampler2D(inputInfo) { - var shape = inputInfo.shapeInfo.logicalShape; - var texName = inputInfo.name; - var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); - var texShape = inputInfo.shapeInfo.texShape; - var texNumR = texShape[0]; - var texNumC = texShape[1]; - var glsl = getGlslDifferences(); - if (texShape != null && arraysEqual(shape, texShape)) { - return "\n vec4 " + funcName + "(int row, int col) {\n vec2 uv = (vec2(col, row) + halfCR) / vec2(" + texNumC + ".0, " + texNumR + ".0);\n\n return " + glsl.texture2D + "(" + texName + ", uv);\n }\n "; - } - var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)]; - var valuesPerRow = Math.ceil(shape[1] / 2); - return "\n vec4 " + funcName + "(int row, int col) {\n vec2 uv = packedUVfrom2D(" + valuesPerRow + ", " + packedTexShape[0] + ", " + packedTexShape[1] + ", row, col);\n return " + glsl.texture2D + "(" + texName + ", uv);\n }\n "; - } - function getSampler2D(inputInfo) { - var shape = inputInfo.shapeInfo.logicalShape; - var texName = inputInfo.name; - var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); - var texShape = inputInfo.shapeInfo.texShape; - if (texShape != null && arraysEqual(shape, texShape)) { - var texNumR_1 = texShape[0]; - var texNumC_1 = texShape[1]; - return "\n float " + funcName + "(int row, int col) {\n vec2 uv = (vec2(col, row) + halfCR) / vec2(" + texNumC_1 + ".0, " + texNumR_1 + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n "; - } - var _a = squeezeShape(shape), newShape = _a.newShape, keptDims = _a.keptDims; - var squeezedShape = newShape; - if (squeezedShape.length < shape.length) { - var newInputInfo = squeezeInputInfo(inputInfo, squeezedShape); - var params = ['row', 'col']; - return "\n " + getSamplerFromInInfo(newInputInfo) + "\n float " + funcName + "(int row, int col) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n "; - } - if (inputInfo.shapeInfo.isUniform) { - // Uniform arrays will be less than 65505 (no risk of float16 overflow). - return "\n float " + funcName + "(int row, int col) {\n int index = round(dot(vec2(row, col), vec2(" + shape[1] + ", 1)));\n " + getUniformSampler(inputInfo) + "\n }\n "; - } - var texNumR = texShape[0]; - var texNumC = texShape[1]; - var offset = getFlatOffsetUniformName(texName); - if (texNumC === 1) { - // index is used directly as physical (no risk of float16 overflow). - return "\n float " + funcName + "(int row, int col) {\n float index = dot(vec3(row, col, " + offset + "), vec3(" + shape[1] + ", 1, 1));\n vec2 uv = vec2(0.5, (index + 0.5) / " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n "; - } - if (texNumR === 1) { - // index is used directly as physical (no risk of float16 overflow). - return "\n float " + funcName + "(int row, int col) {\n float index = dot(vec3(row, col, " + offset + "), vec3(" + shape[1] + ", 1, 1));\n vec2 uv = vec2((index + 0.5) / " + texNumC + ".0, 0.5);\n return sampleTexture(" + texName + ", uv);\n }\n "; - } - return "\n float " + funcName + "(int row, int col) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * " + shape[1] + " + col + " + offset + ";\n vec2 uv = uvFromFlat(" + texNumR + ", " + texNumC + ", index);\n return sampleTexture(" + texName + ", uv);\n }\n"; - } - function getPackedSampler3D(inputInfo) { - var shape = inputInfo.shapeInfo.logicalShape; - var texName = inputInfo.name; - var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); - var texShape = inputInfo.shapeInfo.texShape; - var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)]; - if (shape[0] === 1) { - var squeezedShape = shape.slice(1); - var keptDims = [1, 2]; - var newInputInfo = squeezeInputInfo(inputInfo, squeezedShape); - var params = ['b', 'row', 'col']; - return "\n " + getPackedSamplerFromInInfo(newInputInfo) + "\n vec4 " + funcName + "(int b, int row, int col) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n "; - } - var texNumR = packedTexShape[0]; - var texNumC = packedTexShape[1]; - var valuesPerRow = Math.ceil(shape[2] / 2); - var texelsInBatch = valuesPerRow * Math.ceil(shape[1] / 2); - var glsl = getGlslDifferences(); - return "\n vec4 " + funcName + "(int b, int row, int col) {\n vec2 uv = packedUVfrom3D(\n " + texNumR + ", " + texNumC + ", " + texelsInBatch + ", " + valuesPerRow + ", b, row, col);\n return " + glsl.texture2D + "(" + texName + ", uv);\n }\n "; - } - function getSampler3D(inputInfo) { - var shape = inputInfo.shapeInfo.logicalShape; - var texName = inputInfo.name; - var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); - var stride0 = shape[1] * shape[2]; - var stride1 = shape[2]; - var _a = squeezeShape(shape), newShape = _a.newShape, keptDims = _a.keptDims; - var squeezedShape = newShape; - if (squeezedShape.length < shape.length) { - var newInputInfo = squeezeInputInfo(inputInfo, squeezedShape); - var params = ['row', 'col', 'depth']; - return "\n " + getSamplerFromInInfo(newInputInfo) + "\n float " + funcName + "(int row, int col, int depth) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n "; - } - if (inputInfo.shapeInfo.isUniform) { - // Uniform arrays will be less than 65505 (no risk of float16 overflow). - return "\n float " + funcName + "(int row, int col, int depth) {\n int index = round(dot(vec3(row, col, depth),\n vec3(" + stride0 + ", " + stride1 + ", 1)));\n " + getUniformSampler(inputInfo) + "\n }\n "; - } - var texShape = inputInfo.shapeInfo.texShape; - var texNumR = texShape[0]; - var texNumC = texShape[1]; - var flatOffset = inputInfo.shapeInfo.flatOffset; - if (texNumC === stride0 && flatOffset == null) { - // texC is used directly as physical (no risk of float16 overflow). - return "\n float " + funcName + "(int row, int col, int depth) {\n float texR = float(row);\n float texC = dot(vec2(col, depth), vec2(" + stride1 + ", 1));\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n "; - } - if (texNumC === stride1 && flatOffset == null) { - // texR is used directly as physical (no risk of float16 overflow). - return "\n float " + funcName + "(int row, int col, int depth) {\n float texR = dot(vec2(row, col), vec2(" + shape[1] + ", 1));\n float texC = float(depth);\n vec2 uv = (vec2(texC, texR) + halfCR) / vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n "; - } - var offset = getFlatOffsetUniformName(texName); - return "\n float " + funcName + "(int row, int col, int depth) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * " + stride0 + " + col * " + stride1 + " + depth + " + offset + ";\n vec2 uv = uvFromFlat(" + texNumR + ", " + texNumC + ", index);\n return sampleTexture(" + texName + ", uv);\n }\n "; - } - function getPackedSamplerND(inputInfo) { - var shape = inputInfo.shapeInfo.logicalShape; - var rank = shape.length; - var texName = inputInfo.name; - var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); - var texShape = inputInfo.shapeInfo.texShape; - var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)]; - var texNumR = packedTexShape[0]; - var texNumC = packedTexShape[1]; - var valuesPerRow = Math.ceil(shape[rank - 1] / 2); - var texelsInBatch = valuesPerRow * Math.ceil(shape[rank - 2] / 2); - var params = "int b, int row, int col"; - var index = "b * " + texelsInBatch + " + (row / 2) * " + valuesPerRow + " + (col / 2)"; - for (var b = 2; b < rank - 1; b++) { - params = "int b" + b + ", " + params; - texelsInBatch *= shape[rank - b - 1]; - index = "b" + b + " * " + texelsInBatch + " + " + index; - } - var glsl = getGlslDifferences(); - return "\n vec4 " + funcName + "(" + params + ") {\n int index = " + index + ";\n int texR = index / " + texNumC + ";\n int texC = index - texR * " + texNumC + ";\n vec2 uv = (vec2(texC, texR) + halfCR) / vec2(" + texNumC + ", " + texNumR + ");\n return " + glsl.texture2D + "(" + texName + ", uv);\n }\n "; - } - function getSampler4D(inputInfo) { - var shape = inputInfo.shapeInfo.logicalShape; - var texName = inputInfo.name; - var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); - var stride2 = shape[3]; - var stride1 = shape[2] * stride2; - var stride0 = shape[1] * stride1; - var _a = squeezeShape(shape), newShape = _a.newShape, keptDims = _a.keptDims; - if (newShape.length < shape.length) { - var newInputInfo = squeezeInputInfo(inputInfo, newShape); - var params = ['row', 'col', 'depth', 'depth2']; - return "\n " + getSamplerFromInInfo(newInputInfo) + "\n float " + funcName + "(int row, int col, int depth, int depth2) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n "; - } - if (inputInfo.shapeInfo.isUniform) { - // Uniform arrays will be less than 65505 (no risk of float16 overflow). - return "\n float " + funcName + "(int row, int col, int depth, int depth2) {\n int index = round(dot(vec4(row, col, depth, depth2),\n vec4(" + stride0 + ", " + stride1 + ", " + stride2 + ", 1)));\n " + getUniformSampler(inputInfo) + "\n }\n "; - } - var flatOffset = inputInfo.shapeInfo.flatOffset; - var texShape = inputInfo.shapeInfo.texShape; - var texNumR = texShape[0]; - var texNumC = texShape[1]; - if (texNumC === stride0 && flatOffset == null) { - // texC is used directly as physical (no risk of float16 overflow). - return "\n float " + funcName + "(int row, int col, int depth, int depth2) {\n float texR = float(row);\n float texC =\n dot(vec3(col, depth, depth2),\n vec3(" + stride1 + ", " + stride2 + ", 1));\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n "; - } - if (texNumC === stride2 && flatOffset == null) { - // texR is used directly as physical (no risk of float16 overflow). - return "\n float " + funcName + "(int row, int col, int depth, int depth2) {\n float texR = dot(vec3(row, col, depth),\n vec3(" + shape[1] * shape[2] + ", " + shape[2] + ", 1));\n float texC = float(depth2);\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n "; - } - var offset = getFlatOffsetUniformName(texName); - return "\n float " + funcName + "(int row, int col, int depth, int depth2) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * " + stride0 + " + col * " + stride1 + " +\n depth * " + stride2 + " + depth2;\n vec2 uv = uvFromFlat(" + texNumR + ", " + texNumC + ", index + " + offset + ");\n return sampleTexture(" + texName + ", uv);\n }\n "; - } - function getSampler5D(inputInfo) { - var shape = inputInfo.shapeInfo.logicalShape; - var texName = inputInfo.name; - var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); - var stride3 = shape[4]; - var stride2 = shape[3] * stride3; - var stride1 = shape[2] * stride2; - var stride0 = shape[1] * stride1; - var _a = squeezeShape(shape), newShape = _a.newShape, keptDims = _a.keptDims; - if (newShape.length < shape.length) { - var newInputInfo = squeezeInputInfo(inputInfo, newShape); - var params = ['row', 'col', 'depth', 'depth2', 'depth3']; - return "\n " + getSamplerFromInInfo(newInputInfo) + "\n float " + funcName + "(int row, int col, int depth, int depth2, int depth3) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n "; - } - if (inputInfo.shapeInfo.isUniform) { - // Uniform arrays will be less than 65505 (no risk of float16 overflow). - return "\n float " + funcName + "(int row, int col, int depth, int depth2, int depth3) {\n float index = dot(\n vec4(row, col, depth, depth2),\n vec4(" + stride0 + ", " + stride1 + ", " + stride2 + ", " + stride3 + ")) +\n depth3;\n " + getUniformSampler(inputInfo) + "\n }\n "; - } - var flatOffset = inputInfo.shapeInfo.flatOffset; - var texShape = inputInfo.shapeInfo.texShape; - var texNumR = texShape[0]; - var texNumC = texShape[1]; - if (texNumC === stride0 && flatOffset == null) { - // texC is used directly as physical (no risk of float16 overflow). - return "\n float " + funcName + "(int row, int col, int depth, int depth2, int depth3) {\n int texR = row;\n float texC = dot(vec4(col, depth, depth2, depth3),\n vec4(" + stride1 + ", " + stride2 + ", " + stride3 + ", 1));\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n "; - } - if (texNumC === stride3 && flatOffset == null) { - // texR is used directly as physical (no risk of float16 overflow). - return "\n float " + funcName + "(int row, int col, int depth, int depth2, int depth3) {\n float texR = dot(\n vec4(row, col, depth, depth2),\n vec4(" + shape[1] * shape[2] * shape[3] + ",\n " + shape[2] * shape[3] + ", " + shape[3] + ", 1));\n int texC = depth3;\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n "; - } - var offset = getFlatOffsetUniformName(texName); - return "\n float " + funcName + "(int row, int col, int depth, int depth2, int depth3) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * " + stride0 + " + col * " + stride1 + " + depth * " + stride2 + " +\n depth2 * " + stride3 + " + depth3 + " + offset + ";\n vec2 uv = uvFromFlat(" + texNumR + ", " + texNumC + ", index);\n return sampleTexture(" + texName + ", uv);\n }\n "; - } - function getSampler6D(inputInfo) { - var shape = inputInfo.shapeInfo.logicalShape; - var texName = inputInfo.name; - var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); - var _a = squeezeShape(shape), newShape = _a.newShape, keptDims = _a.keptDims; - if (newShape.length < shape.length) { - var newInputInfo = squeezeInputInfo(inputInfo, newShape); - var params = ['row', 'col', 'depth', 'depth2', 'depth3', 'depth4']; - return "\n " + getSamplerFromInInfo(newInputInfo) + "\n float " + funcName + "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n "; - } - var stride4 = shape[5]; - var stride3 = shape[4] * stride4; - var stride2 = shape[3] * stride3; - var stride1 = shape[2] * stride2; - var stride0 = shape[1] * stride1; - if (inputInfo.shapeInfo.isUniform) { - // Uniform arrays will be less than 65505 (no risk of float16 overflow). - return "\n float " + funcName + "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n int index = round(dot(\n vec4(row, col, depth, depth2),\n vec4(" + stride0 + ", " + stride1 + ", " + stride2 + ", " + stride3 + ")) +\n dot(\n vec2(depth3, depth4),\n vec2(" + stride4 + ", 1)));\n " + getUniformSampler(inputInfo) + "\n }\n "; - } - var flatOffset = inputInfo.shapeInfo.flatOffset; - var texShape = inputInfo.shapeInfo.texShape; - var texNumR = texShape[0]; - var texNumC = texShape[1]; - if (texNumC === stride0 && flatOffset == null) { - // texC is used directly as physical (no risk of float16 overflow). - return "\n float " + funcName + "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n int texR = row;\n float texC = dot(vec4(col, depth, depth2, depth3),\n vec4(" + stride1 + ", " + stride2 + ", " + stride3 + ", " + stride4 + ")) +\n float(depth4);\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n "; - } - if (texNumC === stride4 && flatOffset == null) { - // texR is used directly as physical (no risk of float16 overflow). - return "\n float " + funcName + "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n float texR = dot(vec4(row, col, depth, depth2),\n vec4(" + shape[1] * shape[2] * shape[3] * shape[4] + ",\n " + shape[2] * shape[3] * shape[4] + ",\n " + shape[3] * shape[4] + ",\n " + shape[4] + ")) + float(depth3);\n int texC = depth4;\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n "; - } - var offset = getFlatOffsetUniformName(texName); - return "\n float " + funcName + "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * " + stride0 + " + col * " + stride1 + " + depth * " + stride2 + " +\n depth2 * " + stride3 + " + depth3 * " + stride4 + " + depth4 + " + offset + ";\n vec2 uv = uvFromFlat(" + texNumR + ", " + texNumC + ", index);\n return sampleTexture(" + texName + ", uv);\n }\n "; - } - function getUniformSampler(inputInfo) { - var texName = inputInfo.name; - var inSize = sizeFromShape(inputInfo.shapeInfo.logicalShape); - if (inSize < 2) { - return "return " + texName + ";"; - } - return "\n for (int i = 0; i < " + inSize + "; i++) {\n if (i == index) {\n return " + texName + "[i];\n }\n }\n "; - } - function getPackedSamplerAtOutputCoords(inputInfo, outShapeInfo) { - var texName = inputInfo.name; - var texFuncSnippet = texName.charAt(0).toUpperCase() + texName.slice(1); - var funcName = 'get' + texFuncSnippet + 'AtOutCoords'; - var inRank = inputInfo.shapeInfo.logicalShape.length; - var outRank = outShapeInfo.logicalShape.length; - var broadcastDims = getBroadcastDims(inputInfo.shapeInfo.logicalShape, outShapeInfo.logicalShape); - var type = getCoordsDataType(outRank); - var rankDiff = outRank - inRank; - var coordsSnippet; - var fields = ['x', 'y', 'z', 'w', 'u', 'v']; - if (inRank === 0) { - coordsSnippet = ''; - } - else if (outRank < 2 && broadcastDims.length >= 1) { - coordsSnippet = 'coords = 0;'; - } - else { - coordsSnippet = - broadcastDims.map(function (d) { return "coords." + fields[d + rankDiff] + " = 0;"; }) - .join('\n'); - } - var unpackedCoordsSnippet = ''; - if (outRank < 2 && inRank > 0) { - unpackedCoordsSnippet = 'coords'; - } - else { - unpackedCoordsSnippet = inputInfo.shapeInfo.logicalShape - .map(function (s, i) { return "coords." + fields[i + rankDiff]; }) - .join(', '); - } - var output = "return outputValue;"; - var inSize = sizeFromShape(inputInfo.shapeInfo.logicalShape); - var isInputScalar = inSize === 1; - var outSize = sizeFromShape(outShapeInfo.logicalShape); - var isOutputScalar = outSize === 1; - if (inRank === 1 && !isInputScalar && !isOutputScalar) { - output = "\n return vec4(outputValue.xy, outputValue.xy);\n "; - } - else if (isInputScalar && !isOutputScalar) { - if (outRank === 1) { - output = "\n return vec4(outputValue.x, outputValue.x, 0., 0.);\n "; - } - else { - output = "\n return vec4(outputValue.x);\n "; - } - } - else if (broadcastDims.length) { - var rows = inRank - 2; - var cols = inRank - 1; - if (broadcastDims.indexOf(rows) > -1 && broadcastDims.indexOf(cols) > -1) { - output = "return vec4(outputValue.x);"; - } - else if (broadcastDims.indexOf(rows) > -1) { - output = "return vec4(outputValue.x, outputValue.y, " + - "outputValue.x, outputValue.y);"; - } - else if (broadcastDims.indexOf(cols) > -1) { - output = "return vec4(outputValue.xx, outputValue.zz);"; - } - } - return "\n vec4 " + funcName + "() {\n " + type + " coords = getOutputCoords();\n " + coordsSnippet + "\n vec4 outputValue = get" + texFuncSnippet + "(" + unpackedCoordsSnippet + ");\n " + output + "\n }\n "; - } - function getSamplerAtOutputCoords(inputInfo, outShapeInfo) { - var texName = inputInfo.name; - var texFuncSnippet = texName.charAt(0).toUpperCase() + texName.slice(1); - var funcName = 'get' + texFuncSnippet + 'AtOutCoords'; - var outTexShape = outShapeInfo.texShape; - var inTexShape = inputInfo.shapeInfo.texShape; - var inRank = inputInfo.shapeInfo.logicalShape.length; - var outRank = outShapeInfo.logicalShape.length; - if (!inputInfo.shapeInfo.isUniform && inRank === outRank && - inputInfo.shapeInfo.flatOffset == null && - arraysEqual(inTexShape, outTexShape)) { - return "\n float " + funcName + "() {\n return sampleTexture(" + texName + ", resultUV);\n }\n "; - } - var type = getCoordsDataType(outRank); - var broadcastDims = getBroadcastDims(inputInfo.shapeInfo.logicalShape, outShapeInfo.logicalShape); - var rankDiff = outRank - inRank; - var coordsSnippet; - var fields = ['x', 'y', 'z', 'w', 'u', 'v']; - if (inRank === 0) { - coordsSnippet = ''; - } - else if (outRank < 2 && broadcastDims.length >= 1) { - coordsSnippet = 'coords = 0;'; - } - else { - coordsSnippet = - broadcastDims.map(function (d) { return "coords." + fields[d + rankDiff] + " = 0;"; }) - .join('\n'); - } - var unpackedCoordsSnippet = ''; - if (outRank < 2 && inRank > 0) { - unpackedCoordsSnippet = 'coords'; - } - else { - unpackedCoordsSnippet = inputInfo.shapeInfo.logicalShape - .map(function (s, i) { return "coords." + fields[i + rankDiff]; }) - .join(', '); - } - return "\n float " + funcName + "() {\n " + type + " coords = getOutputCoords();\n " + coordsSnippet + "\n return get" + texFuncSnippet + "(" + unpackedCoordsSnippet + ");\n }\n "; - } - function getCoordsDataType(rank) { - if (rank <= 1) { - return 'int'; - } - else if (rank === 2) { - return 'ivec2'; - } - else if (rank === 3) { - return 'ivec3'; - } - else if (rank === 4) { - return 'ivec4'; - } - else if (rank === 5) { - return 'ivec5'; - } - else if (rank === 6) { - return 'ivec6'; - } - else { - throw Error("GPU for rank " + rank + " is not yet supported"); - } - } - /** Returns a new input info (a copy) that has a squeezed logical shape. */ - function squeezeInputInfo(inInfo, squeezedShape) { - // Deep copy. - var newInputInfo = JSON.parse(JSON.stringify(inInfo)); - newInputInfo.shapeInfo.logicalShape = squeezedShape; - return newInputInfo; - } - function getSqueezedParams(params, keptDims) { - return keptDims.map(function (d) { return params[d]; }).join(', '); - } - - /** - * @license - * Copyright 2019 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var ArgMinMaxPackedProgram = /** @class */ (function () { - function ArgMinMaxPackedProgram(shape, windowSize, op, firstPass) { - this.variableNames = ['A']; - this.packedInputs = true; - this.packedOutput = true; - assert(shape.length > 2, function () { return "Packed arg" + (op.charAt(0).toUpperCase() + - op.slice(1)) + " supports only inputs with rank above 2."; }); - var inSize = shape[shape.length - 1]; - var outSize = Math.ceil(inSize / windowSize); - this.outputShape = shape.slice(0, -1); - if (outSize > 1) { - this.outputShape.push(outSize); - } - if (!firstPass) { - this.variableNames.push('bestIndicesA'); - } - var outShape = this.outputShape; - var rank = outShape.length; - var dtype = getCoordsDataType(rank); - var coords = getChannels('coords', rank); - var sourceLocSetup; - var sourceRank; - if (outSize === 1) { - sourceRank = rank + 1; - var sourceLocDType = getCoordsDataType(sourceRank); - sourceLocSetup = "\n " + sourceLocDType + " sourceLocR = " + sourceLocDType + "(" + coords.join() + ", 0);\n ++" + coords[rank - 1] + ";\n " + sourceLocDType + " sourceLocG = " + sourceLocDType + "(" + coords.join() + ", 0);\n ++" + coords[rank - 2] + ";\n " + sourceLocDType + " sourceLocA = " + sourceLocDType + "(" + coords.join() + ", 0);\n --" + coords[rank - 1] + ";\n " + sourceLocDType + " sourceLocB = " + sourceLocDType + "(" + coords.join() + ", 0);\n --" + coords[rank - 2] + ";"; - } - else { - sourceRank = rank; - sourceLocSetup = "\n " + dtype + " sourceLocR = coords;\n ++" + coords[rank - 1] + ";\n " + dtype + " sourceLocG = coords;\n ++" + coords[rank - 2] + ";\n " + dtype + " sourceLocA = coords;\n --" + coords[rank - 1] + ";\n " + dtype + " sourceLocB = coords;\n --" + coords[rank - 2] + ";"; - } - var channels = ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, sourceRank); - var inChannel = '.' + channels[sourceRank - 1]; // e.g. ".b" for rank 3. - var intChannels = channels.map(function (x) { return 'int ' + x; }); - var srcRCoords = getChannels('sourceLocR', sourceRank - 1).concat('inIdx.r'); - var srcGCoords = getChannels('sourceLocG', sourceRank - 1).concat('inIdx.g'); - var srcBCoords = getChannels('sourceLocB', sourceRank - 1).concat('inIdx.b'); - var srcACoords = getChannels('sourceLocA', sourceRank - 1).concat('inIdx.a'); - var compOp = (op === 'max') ? 'greaterThan' : 'lessThan'; - var fetchCandidateIdx = firstPass ? '' : "\n inIdx = round(vec4(getBestIndicesAChannel(" + srcRCoords.join() + "),\n getBestIndicesAChannel(" + srcGCoords.join() + "),\n getBestIndicesAChannel(" + srcBCoords.join() + "),\n getBestIndicesAChannel(" + srcACoords.join() + ")));"; - var fetchValue = "vec4(\n getAChannel(" + srcRCoords.join() + "),\n hasNextCol ? getAChannel(" + srcGCoords.join() + ") : 0.,\n hasNextRow ? getAChannel(" + srcBCoords.join() + ") : 0.,\n hasNextRow && hasNextCol ? getAChannel(" + srcACoords.join() + ") : 0.)"; - var getBestIndicesAChannelSnippet = firstPass ? '' : "\n float getBestIndicesAChannel(" + intChannels.join() + ") {\n return getChannel(getBestIndicesA(" + channels.join() + "),\n vec2(" + channels.slice(-2).join() + "));\n }"; - this.userCode = "\n float getAChannel(" + intChannels.join() + ") {\n return getChannel(getA(" + channels.join() + "),\n vec2(" + channels.slice(-2).join() + "));\n }\n " + getBestIndicesAChannelSnippet + "\n void main() {\n " + dtype + " coords = getOutputCoords();\n bool hasNextCol = " + coords[rank - 1] + " < " + (outShape[rank - 1] - 1) + ";\n bool hasNextRow = " + coords[rank - 2] + " < " + (outShape[rank - 2] - 1) + ";\n " + sourceLocSetup + "\n ivec4 srcIdx = ivec4(sourceLocR" + inChannel + ", sourceLocG" + inChannel + ",\n sourceLocB" + inChannel + ", sourceLocA" + inChannel + ") * " + windowSize + ";\n ivec4 inIdx = srcIdx;\n vec4 bestIndex = vec4(inIdx);\n vec4 bestValue = " + fetchValue + ";\n\n for (int i = 0; i < " + windowSize + "; i++) {\n inIdx = srcIdx;\n " + fetchCandidateIdx + "\n vec4 candidate = " + fetchValue + ";\n bvec4 nan = isnan(candidate);\n bvec4 replace = bvec4(\n vec4(" + compOp + "(candidate, bestValue)) * (vec4(1.0) - vec4(nan)));\n\n bestValue = vec4(replace.x ? candidate.x : bestValue.x,\n replace.y ? candidate.y : bestValue.y,\n replace.z ? candidate.z : bestValue.z,\n replace.w ? candidate.w : bestValue.w);\n bestIndex = mix(bestIndex, vec4(inIdx), vec4(replace));\n srcIdx++;\n }\n setOutput(bestIndex);\n }\n "; - } - return ArgMinMaxPackedProgram; - }()); - - /** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var AvgPool2DBackpropProgram = /** @class */ (function () { - function AvgPool2DBackpropProgram(convInfo) { - this.variableNames = ['dy']; - this.outputShape = convInfo.inShape; - var filterHeight = convInfo.filterHeight; - var filterWidth = convInfo.filterWidth; - var strideHeight = convInfo.strideHeight; - var strideWidth = convInfo.strideWidth; - var dilationHeight = convInfo.dilationHeight; - var dilationWidth = convInfo.dilationWidth; - var effectiveFilterHeight = convInfo.effectiveFilterHeight; - var effectiveFilterWidth = convInfo.effectiveFilterWidth; - var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top; - var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left; - var avgMultiplier = 1 / (filterHeight * filterWidth); - this.userCode = "\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n const float avgMultiplier = float(" + avgMultiplier + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n\n ivec2 dyRCCorner = coords.yz - pads;\n int dyRCorner = dyRCCorner.x;\n int dyCCorner = dyRCCorner.y;\n\n // Convolve dy(?, ?, d) with pos mask(:, :, d) to get dx(xR, xC, d).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 || fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n for (int wC = 0; wC < " + effectiveFilterWidth + ";\n wC+= " + dilationWidth + ") {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n float dyValue = getDy(b, idyR, idyC, d);\n\n dotProd += dyValue * avgMultiplier;\n }\n }\n setOutput(dotProd);\n }\n "; - } - return AvgPool2DBackpropProgram; - }()); - var AvgPool3DBackpropProgram = /** @class */ (function () { - function AvgPool3DBackpropProgram(convInfo) { - this.variableNames = ['dy']; - this.outputShape = convInfo.inShape; - var filterDepth = convInfo.filterDepth; - var filterHeight = convInfo.filterHeight; - var filterWidth = convInfo.filterWidth; - var strideDepth = convInfo.strideDepth; - var strideHeight = convInfo.strideHeight; - var strideWidth = convInfo.strideWidth; - var dilationDepth = convInfo.dilationDepth; - var dilationHeight = convInfo.dilationHeight; - var dilationWidth = convInfo.dilationWidth; - var effectiveFilterDepth = convInfo.effectiveFilterDepth; - var effectiveFilterHeight = convInfo.effectiveFilterHeight; - var effectiveFilterWidth = convInfo.effectiveFilterWidth; - var padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front; - var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top; - var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left; - var avgMultiplier = 1 / (filterDepth * filterHeight * filterWidth); - this.userCode = "\n const ivec3 pads = ivec3(" + padFront + ", " + padTop + ", " + padLeft + ");\n const float avgMultiplier = float(" + avgMultiplier + ");\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int ch = coords.u;\n\n ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;\n int dyDCorner = dyCorner.x;\n int dyRCorner = dyCorner.y;\n int dyCCorner = dyCorner.z;\n\n // Convolve dy(?, ?, ?, d) with pos mask(:, :, :, ch) to get\n // dx(xD, xR, xC, ch).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n\n for (int wD = 0; wD < " + effectiveFilterDepth + ";\n wD += " + dilationDepth + ") {\n float dyD = float(dyDCorner + wD) / " + strideDepth + ".0;\n\n if (dyD < 0.0 || dyD >= " + convInfo.outDepth + ".0 || fract(dyD) > 0.0) {\n continue;\n }\n int idyD = int(dyD);\n\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 ||\n fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n for (int wC = 0; wC < " + effectiveFilterWidth + ";\n wC += " + dilationWidth + ") {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n float dyValue = getDy(batch, idyD, idyR, idyC, ch);\n\n dotProd += dyValue * avgMultiplier;\n }\n }\n }\n setOutput(dotProd);\n }\n "; - } - return AvgPool3DBackpropProgram; - }()); - - /** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var BatchNormProgram = /** @class */ (function () { - function BatchNormProgram(xShape, meanShape, varianceShape, offsetShape, scaleShape, varianceEpsilon) { - this.outputShape = []; - this.variableNames = ['x', 'mean', 'variance']; - assertAndGetBroadcastShape(xShape, meanShape); - assertAndGetBroadcastShape(xShape, varianceShape); - var offsetSnippet = '0.0'; - if (offsetShape != null) { - assertAndGetBroadcastShape(xShape, offsetShape); - this.variableNames.push('offset'); - offsetSnippet = 'getOffsetAtOutCoords()'; - } - var scaleSnippet = '1.0'; - if (scaleShape != null) { - assertAndGetBroadcastShape(xShape, scaleShape); - this.variableNames.push('scale'); - scaleSnippet = 'getScaleAtOutCoords()'; - } - this.outputShape = xShape; - this.userCode = "\n void main() {\n float x = getXAtOutCoords();\n float mean = getMeanAtOutCoords();\n float variance = getVarianceAtOutCoords();\n float offset = " + offsetSnippet + ";\n float scale = " + scaleSnippet + ";\n float inv = scale * inversesqrt(variance + float(" + varianceEpsilon + "));\n setOutput(dot(vec3(x, -mean, offset), vec3(inv, inv, 1)));\n }\n "; - } - return BatchNormProgram; - }()); - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var BatchNormPackedProgram = /** @class */ (function () { - function BatchNormPackedProgram(xShape, meanShape, varianceShape, offsetShape, scaleShape, varianceEpsilon) { - this.packedInputs = true; - this.packedOutput = true; - this.variableNames = ['x', 'mean', 'variance']; - assertAndGetBroadcastShape(xShape, meanShape); - assertAndGetBroadcastShape(xShape, varianceShape); - var offsetSnippet = 'vec4(0.0)'; - if (offsetShape != null) { - assertAndGetBroadcastShape(xShape, offsetShape); - this.variableNames.push('offset'); - offsetSnippet = 'getOffsetAtOutCoords()'; - } - var scaleSnippet = 'vec4(1.0)'; - if (scaleShape != null) { - assertAndGetBroadcastShape(xShape, scaleShape); - this.variableNames.push('scale'); - scaleSnippet = 'getScaleAtOutCoords()'; - } - this.outputShape = xShape; - this.userCode = "\n void main() {\n vec4 offset = " + offsetSnippet + ";\n vec4 scale = " + scaleSnippet + ";\n\n vec4 x = getXAtOutCoords();\n vec4 mean = getMeanAtOutCoords();\n vec4 variance = getVarianceAtOutCoords();\n\n vec4 inv = scale * inversesqrt(variance + vec4(" + varianceEpsilon + "));\n\n setOutput((x - mean) * inv + offset);\n }\n "; - } - return BatchNormPackedProgram; - }()); - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - // (Ar + Ai)(Br + Bi) = - // ArBr + ArBi + AiBr + AiBi = ArBr - AB + ArBi + AiBr - // Yr = ArBr - AB - // Yi = ArBi + AiBr - var COMPLEX_MULTIPLY = { - REAL: 'return areal * breal - aimag * bimag;', - IMAG: 'return areal * bimag + aimag * breal;' - }; - var BinaryOpComplexProgram = /** @class */ (function () { - function BinaryOpComplexProgram(op, aShape, bShape) { - this.variableNames = ['AReal', 'AImag', 'BReal', 'BImag']; - this.outputShape = - assertAndGetBroadcastShape(aShape, bShape); - this.userCode = "\n float binaryOpComplex(\n float areal, float aimag, float breal, float bimag) {\n " + op + "\n }\n\n void main() {\n float areal = getARealAtOutCoords();\n float aimag = getAImagAtOutCoords();\n float breal = getBRealAtOutCoords();\n float bimag = getBImagAtOutCoords();\n setOutput(binaryOpComplex(areal, aimag, breal, bimag));\n }\n "; - } - return BinaryOpComplexProgram; - }()); - - /** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var CHECK_NAN_SNIPPET = "\n if (isnan(a)) return a;\n if (isnan(b)) return b;\n"; - var ADD = 'return a + b;'; - var SUB = 'return a - b;'; - var MUL = 'return a * b;'; - // Without the equality check div produces 0.9999 for a = b, which when - // floored can cause errors. - var DIV = "\nif (a == b) {\n return 1.0;\n};\nreturn a / b;"; - // We use native integer division to deal with floating point imprecision. Since - // we implement floor division and glsl implements truncated division, we - // correct for this by subtracting 1 from result when the result is negative and - // there is a remainder. - var INT_DIV = "\n float s = sign(a) * sign(b);\n int ia = round(a);\n int ib = round(b);\n if (ib != 0) {\n // Windows (D3D) wants guaranteed non-zero int division at compile-time.\n return float(idiv(ia, ib, s));\n } else {\n return NAN;\n }\n"; - var POW = "\nif(a < 0.0 && floor(b) < b){\n return NAN;\n}\nif (b == 0.0) {\n return 1.0;\n}\nreturn (round(mod(b, 2.0)) != 1) ?\n pow(abs(a), b) : sign(a) * pow(abs(a), b);\n"; - var EQUAL = "return float(a == b);"; - var NOT_EQUAL = "return float(a != b);"; - var LESS = "return float(a < b);"; - var LESS_EQUAL = "return float(a <= b);"; - var GREATER = "return float(a > b);"; - var GREATER_EQUAL = "return float(a >= b);"; - var LOGICAL_AND = "return float(a >= 1.0 && b >= 1.0);"; - var LOGICAL_OR = "return float(a >= 1.0 || b >= 1.0);"; - var MAX = CHECK_NAN_SNIPPET + "\n return max(a, b);\n"; - var MIN = CHECK_NAN_SNIPPET + "\n return min(a, b);\n"; - var MOD = "if (b == 0.0) return NAN;\n return mod(a, b);"; - var ATAN2 = CHECK_NAN_SNIPPET + "\n return atan(a, b);\n"; - var ELU_DER = "return (b >= 1.0) ? a : a * (b + 1.0);"; - var PRELU = "return (a < 0.) ? b * a : a;"; - var BinaryOpProgram = /** @class */ (function () { - function BinaryOpProgram(op, aShape, bShape) { - this.variableNames = ['A', 'B']; - this.outputShape = - assertAndGetBroadcastShape(aShape, bShape); - this.userCode = "\n float binaryOperation(float a, float b) {\n " + op + "\n }\n\n void main() {\n float a = getAAtOutCoords();\n float b = getBAtOutCoords();\n setOutput(binaryOperation(a, b));\n }\n "; - } - return BinaryOpProgram; - }()); - - /** - * @license - * Copyright 2018 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var CHECK_NAN_SNIPPET$1 = "\n result.r = isNaN.r > 0. ? NAN : result.r;\n result.g = isNaN.g > 0. ? NAN : result.g;\n result.b = isNaN.b > 0. ? NAN : result.b;\n result.a = isNaN.a > 0. ? NAN : result.a;\n"; - // We do the same as in ./binaryop_gpu, with vec4 and ivec4. - // On Linux, the vectorized implementation produces NaNs when a and b are 0. - var DIV$1 = "\n // vec4 one = vec4(equal(a, b));\n // return one + (vec4(1.0) - one) * a / b;\n vec4 result = a / b;\n if(a.x == b.x) {\n result.x = 1.;\n }\n if(a.y == b.y) {\n result.y = 1.;\n }\n if(a.z == b.z) {\n result.z = 1.;\n }\n if(a.w == b.w) {\n result.w = 1.;\n }\n\n return result;\n"; - var INT_DIV$1 = "\n ivec4 ia = round(a);\n ivec4 ib = round(b);\n bvec4 cond = notEqual(ib, ivec4(0));\n ivec4 result = ivec4(0);\n vec4 s = sign(a) * sign(b);\n\n // Windows (D3D) wants guaranteed non-zero int division at compile-time.\n if (cond[0]) {\n result[0] = idiv(ia[0], ib[0], s[0]);\n }\n if (cond[1]) {\n result[1] = idiv(ia[1], ib[1], s[1]);\n }\n if (cond[2]) {\n result[2] = idiv(ia[2], ib[2], s[2]);\n }\n if (cond[3]) {\n result[3] = idiv(ia[3], ib[3], s[3]);\n }\n return vec4(result);\n"; - var POW$1 = "\n // isModRound1 has 1 for components with round(mod(b, 2.0)) == 1, 0 otherwise.\n vec4 isModRound1 = vec4(equal(round(mod(b, 2.0)), ivec4(1)));\n vec4 multiplier = sign(a) * isModRound1 + (vec4(1.0) - isModRound1);\n vec4 result = multiplier * pow(abs(a), b);\n\n // Ensure that a^0 = 1, including 0^0 = 1 as this correspond to TF and JS\n bvec4 isExpZero = equal(b, vec4(0.0));\n result.r = isExpZero.r ? 1.0 : result.r;\n result.g = isExpZero.g ? 1.0 : result.g;\n result.b = isExpZero.b ? 1.0 : result.b;\n result.a = isExpZero.a ? 1.0 : result.a;\n\n vec4 isNaN = vec4(lessThan(a, vec4(0.0))) * vec4(lessThan(floor(b), b));\n " + - CHECK_NAN_SNIPPET$1 + "\n return result;\n"; - var PRELU$1 = "\n vec4 aLessThanZero = vec4(lessThan(a, vec4(0.)));\n return (aLessThanZero * (b * a)) + ((vec4(1.0) - aLessThanZero) * a);\n"; - var ELU_DER$1 = "\n vec4 bGTEZero = vec4(greaterThanEqual(b, vec4(0.)));\n return (bGTEZero * a) + ((vec4(1.0) - bGTEZero) * (a * (b + vec4(1.0))));\n"; - var ATAN2$1 = "\n vec4 result = atan(a, b);\n vec4 isNaN = min(vec4(isnan(a)) + vec4(isnan(b)), vec4(1.0));\n " + - CHECK_NAN_SNIPPET$1 + "\n return result;\n"; - var EQUAL$1 = "\n return vec4(equal(a, b));\n"; - var NOT_EQUAL$1 = "\n return vec4(notEqual(a, b));\n"; - var LESS$1 = "\n return vec4(lessThan(a, b));\n"; - var LESS_EQUAL$1 = "\n return vec4(lessThanEqual(a, b));\n"; - var GREATER$1 = "\n return vec4(greaterThan(a, b));\n"; - var GREATER_EQUAL$1 = "\n return vec4(greaterThanEqual(a, b));\n"; - var LOGICAL_AND$1 = "\n return vec4(\n vec4(greaterThanEqual(a, vec4(1.0))) *\n vec4(greaterThanEqual(b, vec4(1.0))));\n"; - var LOGICAL_OR$1 = "\n return min(\n vec4(greaterThanEqual(a, vec4(1.0))) +\n vec4(greaterThanEqual(b, vec4(1.0))),\n vec4(1.0));\n"; - var MAX$1 = "\n vec4 result = vec4(max(a, b));\n vec4 isNaN = min(vec4(isnan(a)) + vec4(isnan(b)), vec4(1.0));\n " + - CHECK_NAN_SNIPPET$1 + "\n return result;\n"; - var MIN$1 = "\n vec4 result = vec4(min(a, b));\n vec4 isNaN = min(vec4(isnan(a)) + vec4(isnan(b)), vec4(1.0));\n " + - CHECK_NAN_SNIPPET$1 + "\n return result;\n"; - var MOD$1 = "\n vec4 result = mod(a, b);\n vec4 isNaN = vec4(equal(b, vec4(0.0)));\n " + - CHECK_NAN_SNIPPET$1 + "\n return result;\n"; - var BinaryOpPackedProgram = /** @class */ (function () { - function BinaryOpPackedProgram(op, aShape, bShape, checkOutOfBounds) { - if (checkOutOfBounds === void 0) { checkOutOfBounds = false; } - this.variableNames = ['A', 'B']; - this.supportsBroadcasting = true; - this.packedInputs = true; - this.packedOutput = true; - this.outputShape = - assertAndGetBroadcastShape(aShape, bShape); - var rank = this.outputShape.length; - var checkOutOfBoundsString = ''; - if (checkOutOfBounds) { - if (rank === 0 || sizeFromShape(this.outputShape) === 1) { - checkOutOfBoundsString = "\n result.y = 0.;\n result.z = 0.;\n result.w = 0.;\n "; - } - else { - var dtype = getCoordsDataType(rank); - checkOutOfBoundsString = "\n " + dtype + " coords = getOutputCoords();\n "; - if (rank === 1) { - checkOutOfBoundsString += "\n result.y = (coords + 1) >= " + this.outputShape[0] + " ? 0. : result.y;\n result.z = 0.;\n result.w = 0.;\n "; - } - else { - var channels = getChannels('coords', rank); - checkOutOfBoundsString += "\n bool nextRowOutOfBounds =\n (" + channels[rank - 2] + " + 1) >= " + this.outputShape[rank - 2] + ";\n bool nextColOutOfBounds =\n (" + channels[rank - 1] + " + 1) >= " + this.outputShape[rank - 1] + ";\n result.y = nextColOutOfBounds ? 0. : result.y;\n result.z = nextRowOutOfBounds ? 0. : result.z;\n result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w;\n "; - } - } - } - this.userCode = "\n vec4 binaryOperation(vec4 a, vec4 b) {\n " + op + "\n }\n\n void main() {\n vec4 a = getAAtOutCoords();\n vec4 b = getBAtOutCoords();\n\n vec4 result = binaryOperation(a, b);\n " + checkOutOfBoundsString + "\n\n setOutput(result);\n }\n "; - } - return BinaryOpPackedProgram; - }()); - - /** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var ClipProgram = /** @class */ (function () { - function ClipProgram(aShape) { - this.variableNames = ['A']; - this.outputShape = aShape; - this.userCode = "\n uniform float minVal;\n uniform float maxVal;\n\n void main() {\n float value = getAAtOutCoords();\n if (isnan(value)) {\n setOutput(value);\n return;\n }\n\n setOutput(clamp(value, minVal, maxVal));\n }\n "; - } - ClipProgram.prototype.getCustomSetupFunc = function (min, max) { - var _this = this; - return function (gpgpu, webGLProgram) { - if (_this.minLoc == null) { - _this.minLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'minVal'); - _this.maxLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'maxVal'); - } - gpgpu.gl.uniform1f(_this.minLoc, min); - gpgpu.gl.uniform1f(_this.maxLoc, max); - }; - }; - return ClipProgram; - }()); - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var ClipPackedProgram = /** @class */ (function () { - function ClipPackedProgram(aShape) { - this.variableNames = ['A']; - this.packedInputs = true; - this.packedOutput = true; - this.outputShape = aShape; - this.userCode = "\n uniform float minVal;\n uniform float maxVal;\n\n void main() {\n vec4 value = getAAtOutCoords();\n\n if (any(isnan(value))) {\n setOutput(value);\n return;\n }\n\n setOutput(clamp(value, vec4(minVal), vec4(maxVal)));\n }\n "; - } - ClipPackedProgram.prototype.getCustomSetupFunc = function (min, max) { - var _this = this; - return function (gpgpu, webGLProgram) { - if (_this.minLoc == null) { - _this.minLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'minVal'); - _this.maxLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'maxVal'); - } - gpgpu.gl.uniform1f(_this.minLoc, min); - gpgpu.gl.uniform1f(_this.maxLoc, max); - }; - }; - return ClipPackedProgram; - }()); - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var ComplexAbsProgram = /** @class */ (function () { - function ComplexAbsProgram(shape) { - this.variableNames = ['real', 'imag']; - this.outputShape = shape; - this.userCode = "\n void main() {\n float re = abs(getRealAtOutCoords());\n float im = abs(getImagAtOutCoords());\n float mx = max(re, im);\n\n // sadly the length function in glsl is not underflow-safe\n // (at least not on Intel GPUs). So the safe solution is\n // to ensure underflow-safety in all cases.\n setOutput(\n mx == 0.0 ? 0.0 : mx * length(vec2(1, min(re, im)/mx))\n );\n }\n "; - } - return ComplexAbsProgram; - }()); - - /** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var ConcatProgram = /** @class */ (function () { - // Concats 2d tensors along axis=1. See comments in MathBackendWebGL.concat(). - function ConcatProgram(shapes) { - this.outputShape = []; - this.outputShape = computeOutShape(shapes, 1 /* axis */); - this.variableNames = shapes.map(function (_, i) { return "T" + i; }); - var offsets = new Array(shapes.length - 1); - offsets[0] = shapes[0][1]; - for (var i = 1; i < offsets.length; i++) { - offsets[i] = offsets[i - 1] + shapes[i][1]; - } - var snippets = ["if (yC < " + offsets[0] + ") setOutput(getT0(yR, yC));"]; - for (var i = 1; i < offsets.length; i++) { - var shift = offsets[i - 1]; - snippets.push("else if (yC < " + offsets[i] + ") " + - ("setOutput(getT" + i + "(yR, yC-" + shift + "));")); - } - var lastIndex = offsets.length; - var lastShift = offsets[offsets.length - 1]; - snippets.push("else setOutput(getT" + lastIndex + "(yR, yC-" + lastShift + "));"); - this.userCode = "\n void main() {\n ivec2 coords = getOutputCoords();\n int yR = coords.x;\n int yC = coords.y;\n\n " + snippets.join('\n ') + "\n }\n "; - } - return ConcatProgram; - }()); - - /** - * @license - * Copyright 2019 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var ConcatPackedProgram = /** @class */ (function () { - function ConcatPackedProgram(shapes, axis) { - this.packedInputs = true; - this.packedOutput = true; - this.outputShape = []; - this.outputShape = computeOutShape(shapes, axis); - var shape = this.outputShape; - var rank = shape.length; - var dtype = getCoordsDataType(rank); - var coords = getChannels('coords', rank); - var channels = ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, rank); - this.variableNames = shapes.map(function (_, i) { return "T" + i; }); - var offsets = new Array(shapes.length - 1); - offsets[0] = shapes[0][axis]; - for (var i = 1; i < offsets.length; i++) { - offsets[i] = offsets[i - 1] + shapes[i][axis]; - } - var channel = channels[axis]; - var lastChannels = channels.slice(-2); - var allChannels = channels.join(); - var getValueSnippet = "if (" + channel + " < " + offsets[0] + ") {\n return getChannel(\n getT0(" + allChannels + "), vec2(" + lastChannels.join() + "));\n }"; - for (var i = 1; i < offsets.length; i++) { - var shift_1 = offsets[i - 1]; - // Note: the >= comparison below may seem unnecessary given the check - // above but is needed to workaround branch execution issues on some - // devices. It makes all the conditions exclusive without relying on - // execution order. - getValueSnippet += "\n if (" + channel + " < " + offsets[i] + " && " + channel + " >= " + offsets[i - 1] + ") {\n return getChannel(\n getT" + i + "(" + shiftedChannels(channels, channel, shift_1) + "),\n vec2(" + shiftedChannels(lastChannels, channel, shift_1) + "));\n }"; - } - var lastIndex = offsets.length; - var shift = offsets[offsets.length - 1]; - getValueSnippet += "\n return getChannel(\n getT" + lastIndex + "(" + shiftedChannels(channels, channel, shift) + "),\n vec2(" + shiftedChannels(lastChannels, channel, shift) + "));"; - this.userCode = "\n float getValue(" + channels.map(function (x) { return 'int ' + x; }) + ") {\n " + getValueSnippet + "\n }\n\n void main() {\n " + dtype + " coords = getOutputCoords();\n vec4 result = vec4(getValue(" + coords + "), 0., 0., 0.);\n\n " + coords[rank - 1] + " = " + coords[rank - 1] + " + 1;\n if (" + coords[rank - 1] + " < " + shape[rank - 1] + ") {\n result.g = getValue(" + coords + ");\n }\n\n " + coords[rank - 2] + " = " + coords[rank - 2] + " + 1;\n if (" + coords[rank - 2] + " < " + shape[rank - 2] + ") {\n result.a = getValue(" + coords + ");\n }\n\n " + coords[rank - 1] + " = " + coords[rank - 1] + " - 1;\n if (" + coords[rank - 2] + " < " + shape[rank - 2] + " &&\n " + coords[rank - 1] + " < " + shape[rank - 1] + ") {\n result.b = getValue(" + coords + ");\n }\n setOutput(result);\n }\n "; - } - return ConcatPackedProgram; - }()); - /** - * Return an expression for coordinates into a vector where a given channel - * will be offset by [shift]. - * - * @param channels the channels to consider - * @param channel the channel we want shifted - * @param shift the amount to subtract from the channel. - * - * @returns a string of the form 'x, y-[shift], z' where any one channel can - * have the shift applied. - */ - function shiftedChannels(channels, channel, shift) { - var channelIdx = channels.indexOf(channel); - var res = channels.map(function (c, idx) { - if (idx === channelIdx) { - return c + " - " + shift; - } - else { - return c; - } - }); - return res.join(); - } - - /** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var Conv2DDerFilterProgram = /** @class */ (function () { - function Conv2DDerFilterProgram(convInfo) { - this.variableNames = ['x', 'dy']; - this.outputShape = convInfo.filterShape; - var strideHeight = convInfo.strideHeight; - var strideWidth = convInfo.strideWidth; - var padTop = convInfo.padInfo.top; - var padLeft = convInfo.padInfo.left; - var isChannelsLast = convInfo.dataFormat === 'channelsLast'; - this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int wR = coords.x;\n int wC = coords.y;\n int d1 = coords.z;\n int d2 = coords.w;\n\n // Convolve x(?, ?, d1) with dy(:, :, d2) to get dw(wR, wC, d1, d2).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n\n for (int b = 0; b < " + convInfo.batchSize + "; b++) {\n for (int yR = 0; yR < " + convInfo.outHeight + "; yR++) {\n int xR = wR + yR * " + strideHeight + " - " + padTop + ";\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int yC = 0; yC < " + convInfo.outWidth + "; yC++) {\n int xC = wC + yC * " + strideWidth + " - " + padLeft + ";\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n if (" + isChannelsLast + ") {\n float dyValue = getDy(b, yR, yC, d2);\n float xValue = getX(b, xR, xC, d1);\n dotProd += (xValue * dyValue);\n } else {\n float dyValue = getDy(b, d2, yR, yC);\n float xValue = getX(b, d1, xR, xC);\n dotProd += (xValue * dyValue);\n }\n\n }\n }\n }\n setOutput(dotProd);\n }\n "; - } - return Conv2DDerFilterProgram; - }()); - var Conv2DDerInputProgram = /** @class */ (function () { - function Conv2DDerInputProgram(convInfo) { - this.variableNames = ['dy', 'W']; - this.outputShape = convInfo.inShape; - var filterHeight = convInfo.filterHeight; - var filterWidth = convInfo.filterWidth; - var strideHeight = convInfo.strideHeight; - var strideWidth = convInfo.strideWidth; - var isChannelsLast = convInfo.dataFormat === 'channelsLast'; - var padTop = filterHeight - 1 - convInfo.padInfo.top; - var padLeft = filterWidth - 1 - convInfo.padInfo.left; - var rowDim = isChannelsLast ? 1 : 2; - var colDim = isChannelsLast ? 2 : 3; - var channelDim = isChannelsLast ? 3 : 1; - this.userCode = "\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d1 = coords[" + channelDim + "];\n\n ivec2 dyCorner = ivec2(coords[" + rowDim + "], coords[" + colDim + "]) - pads;\n int dyRCorner = dyCorner.x;\n int dyCCorner = dyCorner.y;\n\n // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n for (int wR = 0; wR < " + filterHeight + "; wR++) {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 || fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n int wRPerm = " + filterHeight + " - 1 - wR;\n\n for (int wC = 0; wC < " + filterWidth + "; wC++) {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n int wCPerm = " + filterWidth + " - 1 - wC;\n\n for (int d2 = 0; d2 < " + convInfo.outChannels + "; d2++) {\n\n if (" + isChannelsLast + ") {\n float xValue = getDy(batch, idyR, idyC, d2);\n float wValue = getW(wRPerm, wCPerm, d1, d2);\n dotProd += xValue * wValue;\n } else {\n float xValue = getDy(batch, d2, idyR, idyC);\n float wValue = getW(wRPerm, wCPerm, d1, d2);\n dotProd += xValue * wValue;\n }\n\n }\n }\n }\n setOutput(dotProd);\n }\n "; - } - return Conv2DDerInputProgram; - }()); - var Conv3DDerFilterProgram = /** @class */ (function () { - function Conv3DDerFilterProgram(convInfo) { - this.variableNames = ['x', 'dy']; - this.outputShape = convInfo.filterShape; - var strideDepth = convInfo.strideDepth; - var strideHeight = convInfo.strideHeight; - var strideWidth = convInfo.strideWidth; - var padFront = convInfo.padInfo.front; - var padTop = convInfo.padInfo.top; - var padLeft = convInfo.padInfo.left; - this.userCode = "\n void main() {\n ivec5 coords = getOutputCoords();\n int wF = coords.x;\n int wR = coords.y;\n int wC = coords.z;\n int d1 = coords.w;\n int d2 = coords.u;\n\n float dotProd = 0.0;\n\n for (int b = 0; b < " + convInfo.batchSize + "; b++) {\n for (int yF = 0; yF < " + convInfo.outDepth + "; yF++) {\n int xF = wF + yF * " + strideDepth + " - " + padFront + ";\n\n if (xF < 0 || xF >= " + convInfo.inDepth + ") {\n continue;\n }\n\n for (int yR = 0; yR < " + convInfo.outHeight + "; yR++) {\n int xR = wR + yR * " + strideHeight + " - " + padTop + ";\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int yC = 0; yC < " + convInfo.outWidth + "; yC++) {\n int xC = wC + yC * " + strideWidth + " - " + padLeft + ";\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n float dyValue = getDy(b, yF, yR, yC, d2);\n float xValue = getX(b, xF, xR, xC, d1);\n dotProd += (xValue * dyValue);\n }\n }\n }\n }\n setOutput(dotProd);\n }\n "; - } - return Conv3DDerFilterProgram; - }()); - var Conv3DDerInputProgram = /** @class */ (function () { - function Conv3DDerInputProgram(convInfo) { - this.variableNames = ['dy', 'W']; - this.outputShape = convInfo.inShape; - var filterDepth = convInfo.filterDepth; - var filterHeight = convInfo.filterHeight; - var filterWidth = convInfo.filterWidth; - var strideDepth = convInfo.strideDepth; - var strideHeight = convInfo.strideHeight; - var strideWidth = convInfo.strideWidth; - var padFront = filterDepth - 1 - convInfo.padInfo.front; - var padTop = filterHeight - 1 - convInfo.padInfo.top; - var padLeft = filterWidth - 1 - convInfo.padInfo.left; - this.userCode = "\n const ivec3 pads = ivec3(" + padFront + ", " + padTop + ", " + padLeft + ");\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int d1 = coords.u;\n\n\n ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;\n int dyFCorner = dyCorner.x;\n int dyRCorner = dyCorner.y;\n int dyCCorner = dyCorner.z;\n\n float dotProd = 0.0;\n for (int wF = 0; wF < " + filterDepth + "; wF++) {\n float dyF = float(dyFCorner + wF) / " + strideDepth + ".0;\n\n if (dyF < 0.0 || dyF >= " + convInfo.outDepth + ".0 || fract(dyF) > 0.0) {\n continue;\n }\n int idyF = int(dyF);\n\n int wFPerm = " + filterDepth + " - 1 - wF;\n\n for (int wR = 0; wR < " + filterHeight + "; wR++) {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 ||\n fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n int wRPerm = " + filterHeight + " - 1 - wR;\n\n for (int wC = 0; wC < " + filterWidth + "; wC++) {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n int wCPerm = " + filterWidth + " - 1 - wC;\n\n for (int d2 = 0; d2 < " + convInfo.outChannels + "; d2++) {\n float xValue = getDy(batch, idyF, idyR, idyC, d2);\n float wValue = getW(wFPerm, wRPerm, wCPerm, d1, d2);\n dotProd += xValue * wValue;\n }\n }\n }\n }\n setOutput(dotProd);\n }\n "; - } - return Conv3DDerInputProgram; - }()); - - /** - * @license - * Copyright 2018 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var DepthwiseConv2DDerFilterProgram = /** @class */ (function () { - function DepthwiseConv2DDerFilterProgram(convInfo) { - this.variableNames = ['x', 'dy']; - this.outputShape = convInfo.filterShape; - var strideHeight = convInfo.strideHeight; - var strideWidth = convInfo.strideWidth; - var padTop = convInfo.padInfo.top; - var padLeft = convInfo.padInfo.left; - var channelMul = convInfo.outChannels / convInfo.inChannels; - this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int wR = coords.x;\n int wC = coords.y;\n int d1 = coords.z;\n int dm = coords.w;\n int d2 = d1 * " + channelMul + " + dm;\n\n float dotProd = 0.0;\n\n // TO DO: Vec4 over the batch size\n for (int b = 0; b < " + convInfo.batchSize + "; b++) {\n for (int yR = 0; yR < " + convInfo.outHeight + "; yR++) {\n int xR = wR + yR * " + strideHeight + " - " + padTop + ";\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int yC = 0; yC < " + convInfo.outWidth + "; yC++) {\n int xC = wC + yC * " + strideWidth + " - " + padLeft + ";\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n float dyValue = getDy(b, yR, yC, d2);\n float xValue = getX(b, xR, xC, d1);\n dotProd += (xValue * dyValue);\n }\n }\n }\n setOutput(dotProd);\n }\n "; - } - return DepthwiseConv2DDerFilterProgram; - }()); - var DepthwiseConv2DDerInputProgram = /** @class */ (function () { - function DepthwiseConv2DDerInputProgram(convInfo) { - this.variableNames = ['dy', 'W']; - this.outputShape = convInfo.inShape; - var filterHeight = convInfo.filterHeight; - var filterWidth = convInfo.filterWidth; - var strideHeight = convInfo.strideHeight; - var strideWidth = convInfo.strideWidth; - var padTop = filterHeight - 1 - convInfo.padInfo.top; - var padLeft = filterWidth - 1 - convInfo.padInfo.left; - var channelMul = convInfo.outChannels / convInfo.inChannels; - this.userCode = "\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d1 = coords[3];\n ivec2 dyCorner = coords.yz - pads;\n int dyRCorner = dyCorner.x;\n int dyCCorner = dyCorner.y;\n\n float dotProd = 0.0;\n\n for (int wR = 0; wR < " + filterHeight + "; wR++) {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 || fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n int wRPerm = " + filterHeight + " - 1 - wR;\n\n for (int wC = 0; wC < " + filterWidth + "; wC++) {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n int wCPerm = " + filterWidth + " - 1 - wC;\n\n // TO DO: Vec4 over the channelMul\n for (int dm = 0; dm < " + channelMul + "; dm++) {\n int d2 = d1 * " + channelMul + " + dm;\n float xValue = getDy(batch, idyR, idyC, d2);\n float wValue = getW(wRPerm, wCPerm, d1, dm);\n dotProd += xValue * wValue;\n }\n }\n }\n setOutput(dotProd);\n }\n "; - } - return DepthwiseConv2DDerInputProgram; - }()); - - /** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var Conv2DProgram = /** @class */ (function () { - function Conv2DProgram(convInfo, addBias, activation, hasPreluActivationWeights) { - if (addBias === void 0) { addBias = false; } - if (activation === void 0) { activation = null; } - if (hasPreluActivationWeights === void 0) { hasPreluActivationWeights = false; } - this.variableNames = ['x', 'W']; - this.outputShape = convInfo.outShape; - var padTop = convInfo.padInfo.top; - var padLeft = convInfo.padInfo.left; - var strideHeight = convInfo.strideHeight; - var strideWidth = convInfo.strideWidth; - var dilationHeight = convInfo.dilationHeight; - var dilationWidth = convInfo.dilationWidth; - var filterHeight = convInfo.filterHeight; - var filterWidth = convInfo.filterWidth; - var inputDepthNearestVec4 = Math.floor(convInfo.inChannels / 4) * 4; - var inputDepthVec4Remainder = convInfo.inChannels % 4; - var isChannelsLast = convInfo.dataFormat === 'channelsLast'; - var rowDim = isChannelsLast ? 1 : 2; - var colDim = isChannelsLast ? 2 : 3; - var channelDim = isChannelsLast ? 3 : 1; - var activationSnippet = '', applyActivationSnippet = ''; - if (activation) { - if (hasPreluActivationWeights) { - activationSnippet = "float activation(float a) {\n float b = getPreluActivationWeightsAtOutCoords();\n " + activation + "\n }"; - } - else { - activationSnippet = "\n float activation(float x) {\n " + activation + "\n }\n "; - } - applyActivationSnippet = "result = activation(result);"; - } - var addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : ''; - if (addBias) { - this.variableNames.push('bias'); - } - if (hasPreluActivationWeights) { - this.variableNames.push('preluActivationWeights'); - } - this.userCode = "\n " + activationSnippet + "\n\n const ivec2 strides = ivec2(" + strideHeight + ", " + strideWidth + ");\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d2 = coords[" + channelDim + "];\n\n ivec2 xRCCorner =\n ivec2(coords[" + rowDim + "], coords[" + colDim + "]) * strides - pads;\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n // Convolve x(?, ?, d1) with w(:, :, d1, d2) to get y(yR, yC, d2).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n for (int wR = 0; wR < " + filterHeight + "; wR++) {\n int xR = xRCorner + wR * " + dilationHeight + ";\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + filterWidth + "; wC++) {\n int xC = xCCorner + wC * " + dilationWidth + ";\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n for (int d1 = 0; d1 < " + inputDepthNearestVec4 + "; d1 += 4) {\n vec4 wValues = vec4(\n getW(wR, wC, d1, d2),\n getW(wR, wC, d1 + 1, d2),\n getW(wR, wC, d1 + 2, d2),\n getW(wR, wC, d1 + 3, d2)\n );\n\n if (" + isChannelsLast + ") {\n vec4 xValues = vec4(\n getX(batch, xR, xC, d1),\n getX(batch, xR, xC, d1 + 1),\n getX(batch, xR, xC, d1 + 2),\n getX(batch, xR, xC, d1 + 3)\n );\n dotProd += dot(xValues, wValues);\n } else {\n vec4 xValues = vec4(\n getX(batch, d1, xR, xC),\n getX(batch, d1 + 1, xR, xC),\n getX(batch, d1 + 2, xR, xC),\n getX(batch, d1 + 3, xR, xC)\n );\n dotProd += dot(xValues, wValues);\n }\n }\n\n if (" + (inputDepthVec4Remainder === 1) + ") {\n\n if (" + isChannelsLast + ") {\n dotProd +=\n getX(batch, xR, xC, " + inputDepthNearestVec4 + ") *\n getW(wR, wC, " + inputDepthNearestVec4 + ", d2);\n } else {\n dotProd +=\n getX(batch, " + inputDepthNearestVec4 + ", xR, xC) *\n getW(wR, wC, " + inputDepthNearestVec4 + ", d2);\n }\n\n } else if (" + (inputDepthVec4Remainder === 2) + ") {\n vec2 wValues = vec2(\n getW(wR, wC, " + inputDepthNearestVec4 + ", d2),\n getW(wR, wC, " + inputDepthNearestVec4 + " + 1, d2)\n );\n\n if (" + isChannelsLast + ") {\n vec2 xValues = vec2(\n getX(batch, xR, xC, " + inputDepthNearestVec4 + "),\n getX(batch, xR, xC, " + inputDepthNearestVec4 + " + 1)\n );\n dotProd += dot(xValues, wValues);\n } else {\n vec2 xValues = vec2(\n getX(batch, " + inputDepthNearestVec4 + ", xR, xC),\n getX(batch, " + inputDepthNearestVec4 + " + 1, xR, xC)\n );\n dotProd += dot(xValues, wValues);\n }\n\n } else if (" + (inputDepthVec4Remainder === 3) + ") {\n vec3 wValues = vec3(\n getW(wR, wC, " + inputDepthNearestVec4 + ", d2),\n getW(wR, wC, " + inputDepthNearestVec4 + " + 1, d2),\n getW(wR, wC, " + inputDepthNearestVec4 + " + 2, d2)\n );\n\n if (" + isChannelsLast + ") {\n vec3 xValues = vec3(\n getX(batch, xR, xC, " + inputDepthNearestVec4 + "),\n getX(batch, xR, xC, " + inputDepthNearestVec4 + " + 1),\n getX(batch, xR, xC, " + inputDepthNearestVec4 + " + 2)\n );\n dotProd += dot(xValues, wValues);\n } else {\n vec3 xValues = vec3(\n getX(batch, " + inputDepthNearestVec4 + ", xR, xC),\n getX(batch, " + inputDepthNearestVec4 + " + 1, xR, xC),\n getX(batch, " + inputDepthNearestVec4 + " + 2, xR, xC)\n );\n dotProd += dot(xValues, wValues);\n }\n\n }\n }\n }\n\n float result = dotProd;\n " + addBiasSnippet + "\n " + applyActivationSnippet + "\n setOutput(result);\n }\n "; - } - return Conv2DProgram; - }()); - var Conv3DProgram = /** @class */ (function () { - function Conv3DProgram(convInfo) { - this.variableNames = ['x', 'W']; - this.outputShape = convInfo.outShape; - var padFront = convInfo.padInfo.front; - var padTop = convInfo.padInfo.top; - var padLeft = convInfo.padInfo.left; - var strideDepth = convInfo.strideDepth; - var strideHeight = convInfo.strideHeight; - var strideWidth = convInfo.strideWidth; - var dilationDepth = convInfo.dilationDepth; - var dilationHeight = convInfo.dilationHeight; - var dilationWidth = convInfo.dilationWidth; - var filterDepth = convInfo.filterDepth; - var filterHeight = convInfo.filterHeight; - var filterWidth = convInfo.filterWidth; - var inputDepthNearestVec4 = Math.floor(convInfo.inChannels / 4) * 4; - var inputDepthVec4Remainder = convInfo.inChannels % 4; - this.userCode = "\n const ivec3 strides = ivec3(" + strideDepth + ", " + strideHeight + ", " + strideWidth + ");\n const ivec3 pads = ivec3(" + padFront + ", " + padTop + ", " + padLeft + ");\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int d2 = coords.u;\n\n ivec3 xFRCCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;\n int xFCorner = xFRCCorner.x;\n int xRCorner = xFRCCorner.y;\n int xCCorner = xFRCCorner.z;\n\n // Convolve x(?, ?, ?, d1) with w(:, :, :, d1, d2) to get\n // y(yF, yR, yC, d2). ? = to be determined. : = across all\n // values in that axis.\n float dotProd = 0.0;\n for (int wF = 0; wF < " + filterDepth + "; wF++) {\n int xF = xFCorner + wF * " + dilationDepth + ";\n\n if (xF < 0 || xF >= " + convInfo.inDepth + ") {\n continue;\n }\n\n for (int wR = 0; wR < " + filterHeight + "; wR++) {\n int xR = xRCorner + wR * " + dilationHeight + ";\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + filterWidth + "; wC++) {\n int xC = xCCorner + wC * " + dilationWidth + ";\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n for (int d1 = 0; d1 < " + inputDepthNearestVec4 + "; d1 += 4) {\n vec4 xValues = vec4(\n getX(batch, xF, xR, xC, d1),\n getX(batch, xF, xR, xC, d1 + 1),\n getX(batch, xF, xR, xC, d1 + 2),\n getX(batch, xF, xR, xC, d1 + 3)\n );\n vec4 wValues = vec4(\n getW(wF, wR, wC, d1, d2),\n getW(wF, wR, wC, d1 + 1, d2),\n getW(wF, wR, wC, d1 + 2, d2),\n getW(wF, wR, wC, d1 + 3, d2)\n );\n\n dotProd += dot(xValues, wValues);\n }\n\n if (" + (inputDepthVec4Remainder === 1) + ") {\n dotProd +=\n getX(batch, xF, xR, xC, " + inputDepthNearestVec4 + ") *\n getW(wF, wR, wC, " + inputDepthNearestVec4 + ", d2);\n } else if (" + (inputDepthVec4Remainder === 2) + ") {\n vec2 xValues = vec2(\n getX(batch, xF, xR, xC, " + inputDepthNearestVec4 + "),\n getX(batch, xF, xR, xC, " + inputDepthNearestVec4 + " + 1)\n );\n vec2 wValues = vec2(\n getW(wF, wR, wC, " + inputDepthNearestVec4 + ", d2),\n getW(wF, wR, wC, " + inputDepthNearestVec4 + " + 1, d2)\n );\n dotProd += dot(xValues, wValues);\n } else if (" + (inputDepthVec4Remainder === 3) + ") {\n vec3 xValues = vec3(\n getX(batch, xF, xR, xC, " + inputDepthNearestVec4 + "),\n getX(batch, xF, xR, xC, " + inputDepthNearestVec4 + " + 1),\n getX(batch, xF, xR, xC, " + inputDepthNearestVec4 + " + 2)\n );\n vec3 wValues = vec3(\n getW(wF, wR, wC, " + inputDepthNearestVec4 + ", d2),\n getW(wF, wR, wC, " + inputDepthNearestVec4 + " + 1, d2),\n getW(wF, wR, wC, " + inputDepthNearestVec4 + " + 2, d2)\n );\n dotProd += dot(xValues, wValues);\n }\n }\n }\n }\n setOutput(dotProd);\n }\n "; - } - return Conv3DProgram; - }()); - - /** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var DepthwiseConv2DProgram = /** @class */ (function () { - function DepthwiseConv2DProgram(convInfo, addBias, activation, hasPreluActivation) { - if (addBias === void 0) { addBias = false; } - if (activation === void 0) { activation = null; } - if (hasPreluActivation === void 0) { hasPreluActivation = false; } - this.variableNames = ['x', 'W']; - this.outputShape = convInfo.outShape; - var xNumRows = convInfo.inHeight; - var xNumCols = convInfo.inWidth; - var padTop = convInfo.padInfo.top; - var padLeft = convInfo.padInfo.left; - var strideHeight = convInfo.strideHeight; - var strideWidth = convInfo.strideWidth; - var dilationHeight = convInfo.dilationHeight; - var dilationWidth = convInfo.dilationWidth; - var filterHeight = convInfo.filterHeight; - var filterWidth = convInfo.filterWidth; - var channelMul = convInfo.outChannels / convInfo.inChannels; - var activationSnippet = '', applyActivationSnippet = ''; - if (activation) { - if (hasPreluActivation) { - activationSnippet = "float activation(float a) {\n float b = getPreluActivationWeightsAtOutCoords();\n " + activation + "\n }"; - } - else { - activationSnippet = "\n float activation(float x) {\n " + activation + "\n }\n "; - } - applyActivationSnippet = "result = activation(result);"; - } - var addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : ''; - if (addBias) { - this.variableNames.push('bias'); - } - if (hasPreluActivation) { - this.variableNames.push('preluActivationWeights'); - } - this.userCode = "\n " + activationSnippet + "\n\n const ivec2 strides = ivec2(" + strideHeight + ", " + strideWidth + ");\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords.x;\n ivec2 xRCCorner = coords.yz * strides - pads;\n int d2 = coords.w;\n int d1 = d2 / " + channelMul + ";\n int q = d2 - d1 * " + channelMul + ";\n\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n // Convolve x(?, ?, d1) with w(:, :, d1, q) to get y(yR, yC, d2).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n // TO DO(dsmilkov): Flatten the two for loops and vec4 the operations.\n for (int wR = 0; wR < " + filterHeight + "; wR++) {\n int xR = xRCorner + wR * " + dilationHeight + ";\n\n if (xR < 0 || xR >= " + xNumRows + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + filterWidth + "; wC++) {\n int xC = xCCorner + wC * " + dilationWidth + ";\n\n if (xC < 0 || xC >= " + xNumCols + ") {\n continue;\n }\n\n float xVal = getX(batch, xR, xC, d1);\n float wVal = getW(wR, wC, d1, q);\n dotProd += xVal * wVal;\n }\n }\n\n float result = dotProd;\n " + addBiasSnippet + "\n " + applyActivationSnippet + "\n setOutput(result);\n }\n "; - } - return DepthwiseConv2DProgram; - }()); - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var DepthwiseConvPacked2DProgram = /** @class */ (function () { - function DepthwiseConvPacked2DProgram(convInfo, addBias, activation, hasPreluActivation) { - if (addBias === void 0) { addBias = false; } - if (activation === void 0) { activation = null; } - if (hasPreluActivation === void 0) { hasPreluActivation = false; } - this.variableNames = ['x', 'W']; - this.packedInputs = true; - this.packedOutput = true; - this.outputShape = convInfo.outShape; - var xNumRows = convInfo.inHeight; - var xNumCols = convInfo.inWidth; - var padTop = convInfo.padInfo.top; - var padLeft = convInfo.padInfo.left; - var strideHeight = convInfo.strideHeight; - var strideWidth = convInfo.strideWidth; - var dilationHeight = convInfo.dilationHeight; - var dilationWidth = convInfo.dilationWidth; - var filterHeight = convInfo.filterHeight; - var filterWidth = convInfo.filterWidth; - var texelsAcross = filterWidth; - var mainLoop = "int xR; int xC; int xCOffset;"; - for (var r = 0; r < filterHeight; r++) { - for (var c = 0; c < filterWidth; c++) { - mainLoop += "\n vec4 xTexelR" + r + "C" + c * 2 + " = vec4(0.);\n vec4 wR" + r + "C" + c + " = vec4(0.);\n vec4 xR" + r + "C" + c + " = vec4(0.);"; - } - } - /** - * This vectorized implementation works by gathering the values needed for - * each output channel's dot product into vec4's and then multiplying them - * all together (this happens in the final double for-loop below). Most of - * the main loop consists of constructing these vec4's with the minimum - * number of texture2D calls, which means making use of all four returned - * values from a texture2D call at once. - */ - for (var r = 0; r < filterHeight; r++) { - for (var texelC = 0; texelC < texelsAcross; texelC++) { - var c = texelC * 2; - mainLoop += "\n xR = xRCorner + " + r * dilationHeight + ";\n xC = xCCorner + " + c * dilationWidth + ";\n "; - if (strideWidth === 1) { - if (c < filterWidth) { - // If padding is odd, the outer texels have to be composed. - if (padLeft % 2 === 1) { - // TODO: Ensure vec4 previous does not result in redundant sample, - // and avoid setting xTexelRC's that exceed the boundary in the - // first place rather than resetting them to vec4(0)). - // To compute xCOffset: - // - If padding is odd, we must add 1 to ensure we ask for an - // even-numbered row. - // - We subtract 2 to access the previous texel. - mainLoop += "\n xCOffset = xC + 1;\n if(xR >= 0 && xR < " + xNumRows + " && xCOffset >= 0 && xCOffset < " + xNumCols + ") {\n xTexelR" + r + "C" + c + " = getX(batch, xR, xCOffset, d1);\n\n // Need to manually clear unused channels in case\n // we're reading from recycled texture.\n if(xCOffset + 1 >= " + xNumCols + ") {\n xTexelR" + r + "C" + c + ".zw = vec2(0.);\n }\n } else {\n xTexelR" + r + "C" + c + " = vec4(0.);\n }\n\n xCOffset = xC + 1 - 2;\n if(xR >= 0 && xR < " + xNumRows + " && xCOffset >= 0 && xCOffset < " + xNumCols + ") {\n vec4 previous = getX(batch, xR, xCOffset, d1);\n\n // Need to manually clear unused channels in case\n // we're reading from recycled texture.\n if(xCOffset + 1 >= " + xNumCols + ") {\n previous.zw = vec2(0.);\n }\n\n xR" + r + "C" + c + " = vec4(previous.zw, xTexelR" + r + "C" + c + ".xy);\n } else {\n xR" + r + "C" + c + " = vec4(0, 0, xTexelR" + r + "C" + c + ".xy);\n }\n "; - } - else { - // Padding is even, so xRC corresponds to a single texel. - mainLoop += "\n if(xR >= 0 && xR < " + xNumRows + " && xC >= 0 && xC < " + xNumCols + ") {\n xTexelR" + r + "C" + c + " = getX(batch, xR, xC, d1);\n } else {\n xTexelR" + r + "C" + c + " = vec4(0.);\n }\n\n xR" + r + "C" + c + " = xTexelR" + r + "C" + c + ";\n "; - } - if (c + 1 < filterWidth) { - // If dilation is even, the second entry should match the first - // (either both are composed or both are single samples). But if - // dilation is odd, then the second entry should be the opposite - // of the first (if the first is composed, the second is a single - // sample, and vice versa.) - var nextTexelOffset = padLeft % 2 === 0 ? - nearestLargerEven(dilationWidth) : - dilationWidth; - if ((dilationWidth % 2 === 0 && padLeft % 2 === 1) || - (dilationWidth % 2 !== 0 && padLeft % 2 !== 1)) { - mainLoop += "\n xCOffset = xC + " + padLeft % 2 + " + " + nextTexelOffset + ";\n\n if(xR >= 0 && xR < " + xNumRows + " &&\n xCOffset >= 0 && xCOffset < " + xNumCols + ") {\n xTexelR" + r + "C" + (c + 2) + " = getX(batch, xR, xCOffset, d1);\n }\n "; - // If dilation > 1 then the xRC's will not be able to share any - // values, so each xRC will require two unique calls to getX. - if (dilationWidth > 1) { - mainLoop += "\n xCOffset -= 2;\n if(xR >= 0 && xR < " + xNumRows + " &&\n xCOffset >= 0 && xCOffset < " + xNumCols + ") {\n xTexelR" + r + "C" + c + " = getX(batch, xR, xCOffset, d1);\n } else {\n xTexelR" + r + "C" + c + " = vec4(0.);\n }\n "; - } - mainLoop += "\n xR" + r + "C" + (c + 1) + " = vec4(\n xTexelR" + r + "C" + c + ".zw, xTexelR" + r + "C" + (c + 2) + ".xy);\n "; - } - else { - mainLoop += "\n xCOffset = xC + " + nextTexelOffset + ";\n\n if(xR >= 0 && xR < " + xNumRows + " &&\n xCOffset >= 0 && xCOffset < " + xNumCols + ") {\n xTexelR" + r + "C" + (c + 2) + " = getX(batch, xR, xCOffset, d1);\n }\n\n xR" + r + "C" + (c + 1) + " = xTexelR" + r + "C" + (c + 2) + ";\n "; - } - } - } - } - else { // stride > 1 - if (c < filterWidth) { - mainLoop += "\n if(xR >= 0 && xR < " + xNumRows + ") {\n "; - // Depending on whether padLeft is even or odd, we want either the - // xy or zw channels from X texels for xR${r}C${c}. If padLeft is - // even, xR${r}C${c + 1} is simply the zw channels of texels we've - // already sampled. But if padLeft is odd, xR${r}C{$c + 1}.zw will - // need to come from the xy channels of a new texel, hence the `vec4 - // final` initialized below. - if (padLeft % 2 === 1) { - mainLoop += "\n xCOffset = xC + 1 - " + strideWidth + ";\n if(xCOffset >= 0 && xCOffset < " + xNumCols + ") {\n xTexelR" + r + "C" + c + " = getX(batch, xR, xCOffset, d1);\n } else {\n xTexelR" + r + "C" + c + " = vec4(0.);\n }\n\n if(xC + 1 >= 0 && xC + 1 < " + xNumCols + ") {\n xTexelR" + r + "C" + (c + 2) + " = getX(batch, xR, xC + 1, d1);\n } else {\n xTexelR" + r + "C" + (c + 2) + " = vec4(0.);\n }\n\n xR" + r + "C" + c + " = vec4(\n xTexelR" + r + "C" + c + ".zw, xTexelR" + r + "C" + (c + 2) + ".zw);\n "; - if (c + 1 < filterWidth) { - mainLoop += "\n vec4 final = vec4(0.);\n xCOffset = xC + 1 + " + strideWidth + ";\n if(xCOffset >= 0 && xCOffset < " + xNumCols + ") {\n final = getX(batch, xR, xCOffset, d1);\n }\n xR" + r + "C" + (c + 1) + " = vec4(xTexelR" + r + "C" + (c + 2) + ".xy, final.xy);\n "; - } - } - else { - mainLoop += "\n if(xC >= 0 && xC < " + xNumCols + ") {\n xTexelR" + r + "C" + c + " = getX(batch, xR, xC, d1);\n } else {\n xTexelR" + r + "C" + c + " = vec4(0.);\n }\n\n xCOffset = xC + " + strideWidth + ";\n if(xCOffset >= 0 && xCOffset < " + xNumCols + ") {\n xTexelR" + r + "C" + (c + 2) + " = getX(batch, xR, xCOffset, d1);\n } else {\n xTexelR" + r + "C" + (c + 2) + " = vec4(0.);\n }\n\n xR" + r + "C" + c + " = vec4(\n xTexelR" + r + "C" + c + ".xy, xTexelR" + r + "C" + (c + 2) + ".xy);\n "; - if (c + 1 < filterWidth) { - mainLoop += "\n xR" + r + "C" + (c + 1) + " = vec4(\n xTexelR" + r + "C" + c + ".zw, xTexelR" + r + "C" + (c + 2) + ".zw);\n "; - } - } - mainLoop += "}"; - } - } - if (c < filterWidth) { - mainLoop += "\n vec4 wTexelR" + r + "C" + c + " = getW(" + r + ", " + c + ", d1, q);\n wR" + r + "C" + c + " = vec4(wTexelR" + r + "C" + c + ".xz, wTexelR" + r + "C" + c + ".xz);\n "; - if (c + 1 < filterWidth) { - mainLoop += "\n vec4 wTexelR" + r + "C" + (c + 1) + " = getW(" + r + ", " + (c + 1) + ", d1, q);\n wR" + r + "C" + (c + 1) + " =\n vec4(wTexelR" + r + "C" + (c + 1) + ".xz, wTexelR" + r + "C" + (c + 1) + ".xz);"; - } - } - } - } - for (var r = 0; r < filterHeight; r++) { - for (var c = 0; c < filterWidth; c++) { - mainLoop += "dotProd += xR" + r + "C" + c + " * wR" + r + "C" + c + ";"; - } - } - var activationSnippet = '', applyActivationSnippet = ''; - if (activation) { - if (hasPreluActivation) { - activationSnippet = "vec4 activation(vec4 a) {\n vec4 b = getPreluActivationWeightsAtOutCoords();\n " + activation + "\n }"; - } - else { - activationSnippet = "vec4 activation(vec4 x) {\n " + activation + "\n }"; - } - applyActivationSnippet = "result = activation(result);"; - } - var addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : ''; - if (addBias) { - this.variableNames.push('bias'); - } - if (hasPreluActivation) { - this.variableNames.push('preluActivationWeights'); - } - this.userCode = "\n " + activationSnippet + "\n\n const ivec2 strides = ivec2(" + strideHeight + ", " + strideWidth + ");\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n\n ivec4 coords = getOutputCoords();\n int batch = coords.x;\n ivec2 xRCCorner = coords.yz * strides - pads;\n int d2 = coords.w;\n int d1 = d2;\n int q = 0;\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n vec4 dotProd = vec4(0.);\n\n " + mainLoop + "\n\n vec4 result = dotProd;\n " + addBiasSnippet + "\n " + applyActivationSnippet + "\n setOutput(result);\n }\n "; - } - return DepthwiseConvPacked2DProgram; - }()); - - /** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var CropAndResizeProgram = /** @class */ (function () { - function CropAndResizeProgram(imageShape, boxShape, cropSize, method, extrapolationValue) { - this.variableNames = ['Image', 'Boxes', 'BoxInd']; - this.outputShape = []; - var batch = imageShape[0], imageHeight = imageShape[1], imageWidth = imageShape[2], depth = imageShape[3]; - var numBoxes = boxShape[0]; - var cropHeight = cropSize[0], cropWidth = cropSize[1]; - this.outputShape = [numBoxes, cropHeight, cropWidth, depth]; - var methodId = method === 'bilinear' ? 1 : 0; - var _a = [imageHeight - 1 + ".0", imageWidth - 1 + ".0"], inputHeightFloat = _a[0], inputWidthFloat = _a[1]; - var _b = cropHeight > 1 ? - [ - "" + (imageHeight - 1) / (cropHeight - 1), - '(y2-y1) * height_ratio', - "y1*" + inputHeightFloat + " + float(y)*(height_scale)", - ] : - [ - '0.0', - '0.0', - "0.5 * (y1+y2) * " + inputHeightFloat, - ], heightRatio = _b[0], heightScale = _b[1], inY = _b[2]; - var _c = cropWidth > 1 ? - [ - "" + (imageWidth - 1) / (cropWidth - 1), - '(x2-x1) * width_ratio', - "x1*" + inputWidthFloat + " + float(x)*(width_scale)", - ] : - [ - '0.0', - '0.0', - "0.5 * (x1+x2) * " + inputWidthFloat, - ], widthRatio = _c[0], widthScale = _c[1], inX = _c[2]; - // Reference implementation - // tslint:disable-next-line:max-line-length - // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc - this.userCode = "\n const float height_ratio = float(" + heightRatio + ");\n const float width_ratio = float(" + widthRatio + ");\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int y = coords[1];\n int x = coords[2];\n int d = coords[3];\n\n // get box vals\n float y1 = getBoxes(b,0);\n float x1 = getBoxes(b,1);\n float y2 = getBoxes(b,2);\n float x2 = getBoxes(b,3);\n\n // get image in batch index\n int bInd = round(getBoxInd(b));\n if(bInd < 0 || bInd >= " + batch + ") {\n return;\n }\n\n float height_scale = " + heightScale + ";\n float width_scale = " + widthScale + ";\n\n float in_y = " + inY + ";\n if( in_y < 0.0 || in_y > " + inputHeightFloat + " ) {\n setOutput(float(" + extrapolationValue + "));\n return;\n }\n float in_x = " + inX + ";\n if( in_x < 0.0 || in_x > " + inputWidthFloat + " ) {\n setOutput(float(" + extrapolationValue + "));\n return;\n }\n\n vec2 sourceFracIndexCR = vec2(in_x,in_y);\n if(" + methodId + " == 1) {\n // Compute the four integer indices.\n ivec2 sourceFloorCR = ivec2(sourceFracIndexCR);\n ivec2 sourceCeilCR = ivec2(ceil(sourceFracIndexCR));\n\n float topLeft = getImage(b, sourceFloorCR.y, sourceFloorCR.x, d);\n float bottomLeft = getImage(b, sourceCeilCR.y, sourceFloorCR.x, d);\n float topRight = getImage(b, sourceFloorCR.y, sourceCeilCR.x, d);\n float bottomRight = getImage(b, sourceCeilCR.y, sourceCeilCR.x, d);\n\n vec2 fracCR = sourceFracIndexCR - vec2(sourceFloorCR);\n\n float top = topLeft + (topRight - topLeft) * fracCR.x;\n float bottom = bottomLeft + (bottomRight - bottomLeft) * fracCR.x;\n float newValue = top + (bottom - top) * fracCR.y;\n setOutput(newValue);\n } else {\n // Compute the coordinators of nearest neighbor point.\n ivec2 sourceNearestCR = ivec2(floor(\n sourceFracIndexCR + vec2(0.5,0.5)));\n float newValue = getImage(b, sourceNearestCR.y, sourceNearestCR.x, d);\n setOutput(newValue);\n }\n }\n "; - } - return CropAndResizeProgram; - }()); - - /** - * @license - * Copyright 2018 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var CumSumProgram = /** @class */ (function () { - function CumSumProgram(shape, exclusive, reverse) { - this.variableNames = ['x']; - this.outputShape = shape; - var rank = shape.length; - var finalDim = shape[shape.length - 1]; - var comparator = reverse ? '<' : '>'; - this.userCode = "\n int getIndex(int i) {\n " + (reverse ? "return " + finalDim + " -i - 1;" : 'return i;') + "\n }\n\n void main() {\n " + getCoordsDataType(rank) + " coords = getOutputCoords();\n int end = " + getFinalCoord(rank, 'coords') + ";\n float val = 0.0;\n for (int i = " + finalDim + " - 1; i >= 0; i -= 1) {\n int idx = getIndex(i);\n if (idx " + comparator + " end) {\n continue;\n }\n if (idx == end && " + exclusive + ") {\n continue;\n }\n " + getFinalCoord(rank, 'coords') + " = idx;\n val += getX(" + getCoords(rank, 'coords') + ");\n }\n setOutput(val);\n }\n "; - } - return CumSumProgram; - }()); - function getCoords(rank, name) { - if (rank === 1) { - return "" + name; - } - else if (rank === 2) { - return name + ".x, " + name + ".y"; - } - else if (rank === 3) { - return name + ".x, " + name + ".y, " + name + ".z"; - } - else if (rank === 4) { - return name + ".x, " + name + ".y, " + name + ".z, " + name + ".w"; - } - else { - throw Error("Cumulative sum for rank " + rank + " is not yet supported"); - } - } - function getFinalCoord(rank, name) { - if (rank === 1) { - return "" + name; - } - else if (rank === 2) { - return name + ".y"; - } - else if (rank === 3) { - return name + ".z"; - } - else if (rank === 4) { - return name + ".w"; - } - else { - throw Error("Cumulative sum for rank " + rank + " is not yet supported"); - } - } - - /** - * @license - * Copyright 2019 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var DecodeMatrixProgram = /** @class */ (function () { - function DecodeMatrixProgram(outputShape) { - this.variableNames = ['A']; - this.packedInputs = false; - this.packedOutput = true; - this.outPackingScheme = PackingScheme.DENSE; - var texShape = getDenseTexShape(outputShape); - var glsl = getGlslDifferences(); - this.outputShape = outputShape; - this.userCode = "\n ivec3 outCoordsFromFlatIndex(int index) {\n " + getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], outputShape) + "\n return ivec3(r, c, d);\n }\n\n void main() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = 4 * (resTexRC.x * " + texShape[1] + " + resTexRC.y);\n\n vec4 result = vec4(0.);\n\n for (int i=0; i<4; i++) {\n int flatIndex = index + i;\n ivec3 rc = outCoordsFromFlatIndex(flatIndex);\n result[i] = getA(rc.x, rc.y, rc.z);\n }\n\n " + glsl.output + " = result;\n }\n "; - } - return DecodeMatrixProgram; - }()); - - /** - * @license - * Copyright 2019 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var DecodeMatrixPackedProgram = /** @class */ (function () { - function DecodeMatrixPackedProgram(outputShape) { - this.variableNames = ['A']; - this.packedInputs = true; - this.packedOutput = true; - this.outPackingScheme = PackingScheme.DENSE; - var texShape = getDenseTexShape(outputShape); - var glsl = getGlslDifferences(); - this.outputShape = outputShape; - this.userCode = "\n ivec3 outCoordsFromFlatIndex(int index) {\n " + getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], outputShape) + "\n return ivec3(r, c, d);\n }\n\n void main() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = 4 * (resTexRC.x * " + texShape[1] + " + resTexRC.y);\n\n vec4 result = vec4(0.);\n\n for (int i=0; i<4; i++) {\n int flatIndex = index + i;\n ivec3 rc = outCoordsFromFlatIndex(flatIndex);\n result[i] = getChannel(getA(rc.x, rc.y, rc.z), vec2(rc.y, rc.z));\n }\n\n " + glsl.output + " = result;\n }\n "; - } - return DecodeMatrixPackedProgram; - }()); - - /** - * @license - * Copyright 2018 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var DepthToSpaceProgram = /** @class */ (function () { - function DepthToSpaceProgram(outputShape, blockSize, dataFormat) { - this.variableNames = ['x']; - this.outputShape = []; - this.outputShape = outputShape; - this.blockSize = blockSize; - this.dataFormat = dataFormat; - this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int h = " + this.getHeightCoordString() + ";\n int w = " + this.getWidthCoordString() + ";\n int d = " + this.getDepthCoordString() + ";\n\n int in_h = h / " + blockSize + ";\n int offset_h = imod(h, " + blockSize + ");\n int in_w = w / " + blockSize + ";\n int offset_w = imod(w, " + blockSize + ");\n int offset_d = (offset_h * " + blockSize + " + offset_w) *\n " + this.getOutputDepthSize() + ";\n int in_d = d + offset_d;\n\n float result = " + this.getInputSamplingString() + ";\n setOutput(result);\n }\n "; - } - DepthToSpaceProgram.prototype.getHeightCoordString = function () { - if (this.dataFormat === 'NHWC') { - return "coords[1]"; - } - else { - return "coords[2]"; - } - }; - DepthToSpaceProgram.prototype.getWidthCoordString = function () { - if (this.dataFormat === 'NHWC') { - return "coords[2]"; - } - else { - return "coords[3]"; - } - }; - DepthToSpaceProgram.prototype.getDepthCoordString = function () { - if (this.dataFormat === 'NHWC') { - return "coords[3]"; - } - else { - return "coords[1]"; - } - }; - DepthToSpaceProgram.prototype.getOutputDepthSize = function () { - if (this.dataFormat === 'NHWC') { - return this.outputShape[3]; - } - else { - return this.outputShape[1]; - } - }; - DepthToSpaceProgram.prototype.getInputSamplingString = function () { - if (this.dataFormat === 'NHWC') { - return "getX(b, in_h, in_w, in_d)"; - } - else { - return "getX(b, in_d, in_h, in_w)"; - } - }; - return DepthToSpaceProgram; - }()); - - /** - * @license - * Copyright 2019 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var DiagProgram = /** @class */ (function () { - function DiagProgram(size) { - this.variableNames = ['X']; - this.outputShape = [size, size]; - this.userCode = "\n void main() {\n ivec2 coords = getOutputCoords();\n float val = coords[0] == coords[1] ? getX(coords[0]) : 0.0;\n setOutput(val);\n }\n "; - } - return DiagProgram; - }()); - - /** - * @license - * Copyright 2018 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var EncodeFloatProgram = /** @class */ (function () { - function EncodeFloatProgram(outputShape) { - this.variableNames = ['A']; - this.outTexUsage = TextureUsage.DOWNLOAD; - var glsl = getGlslDifferences(); - this.outputShape = outputShape; - this.userCode = "\n " + ENCODE_FLOAT_SNIPPET + "\n\n void main() {\n float x = getAAtOutCoords();\n " + glsl.output + " = encode_float(x);\n }\n "; - } - return EncodeFloatProgram; - }()); - - /** - * @license - * Copyright 2018 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var EncodeFloatPackedProgram = /** @class */ (function () { - function EncodeFloatPackedProgram(outputShape) { - this.variableNames = ['A']; - this.packedInputs = true; - this.packedOutput = false; - this.outTexUsage = TextureUsage.DOWNLOAD; - var glsl = getGlslDifferences(); - this.outputShape = outputShape; - this.userCode = "\n " + ENCODE_FLOAT_SNIPPET + "\n\n void main() {\n ivec3 coords = getOutputCoords();\n float x = getChannel(getAAtOutCoords(), vec2(coords.y, coords.z));\n " + glsl.output + " = encode_float(x);\n }\n "; - } - return EncodeFloatPackedProgram; - }()); - - /** - * @license - * Copyright 2018 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var EncodeMatrixProgram = /** @class */ (function () { - function EncodeMatrixProgram(outputShape, texShape, inputIsUnsignedByte) { - if (inputIsUnsignedByte === void 0) { inputIsUnsignedByte = false; } - this.variableNames = ['A']; - var glsl = getGlslDifferences(); - var height = texShape[0], width = texShape[1]; - this.outputShape = outputShape; - var output = "result"; - if (inputIsUnsignedByte) { - output = "floor(result * 255. + 0.5)"; - } - this.userCode = "\n " + getFlatIndexFrom3D(outputShape) + "\n\n void main() {\n ivec3 coords = getOutputCoords();\n\n int flatIndex = getFlatIndex(coords);\n int offset = imod(flatIndex, 4);\n\n flatIndex = idiv(flatIndex, 4, 1.);\n \n int r = flatIndex / " + width + ";\n int c = imod(flatIndex, " + width + ");\n vec2 uv = (vec2(c, r) + halfCR) / vec2(" + width + ".0, " + height + ".0);\n vec4 values = " + glsl.texture2D + "(A, uv);\n\n float result;\n\n if(offset == 0) {\n result = values[0];\n } else if(offset == 1) {\n result = values[1];\n } else if(offset == 2) {\n result = values[2];\n } else {\n result = values[3];\n }\n\n " + glsl.output + " = vec4(" + output + ", 0., 0., 0.);\n }\n "; - } - return EncodeMatrixProgram; - }()); - - /** - * @license - * Copyright 2018 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - /* - This is how the shader encodes a tensor with shape = [2, 3, 5] - (indices are [batch, row, col]). - - 000|001 002|003 004|xxx 020|021 022|023 024|xxx - ------- ------- ------- ------- ------- ------- - 010|011 012|013 014|xxx xxx|xxx xxx|xxx xxx|xxx - - 100|101 102|103 104|xxx 120|121 122|123 124|xxx - ------- ------- ------- ------- ------- ------- - 110|111 112|113 114|xxx xxx|xxx xxx|xxx xxx|xxx - - Single texels contain only values from the same batch, and from adjacent rows - and columns. - */ - var EncodeMatrixPackedProgram = /** @class */ (function () { - function EncodeMatrixPackedProgram(outputShape, texShape, inputIsUnsignedByte) { - if (inputIsUnsignedByte === void 0) { inputIsUnsignedByte = false; } - this.variableNames = ['A']; - this.packedInputs = false; - this.packedOutput = true; - var glsl = getGlslDifferences(); - var height = texShape[0], width = texShape[1]; - this.outputShape = outputShape; - var mainLoop = ''; - var output = 'result'; - if (inputIsUnsignedByte) { - output = 'floor(result * 255. + 0.5)'; - } - for (var row = 0; row <= 1; row++) { - for (var col = 0; col <= 1; col++) { - var channel = row * 2 + col; - mainLoop += "\n localCoords = coords;\n if(localCoords[2] + " + col + " < " + outputShape[2] + ") {\n localCoords[2] += " + col + ";\n if(localCoords[1] + " + row + " < " + outputShape[1] + ") {\n localCoords[1] += " + row + ";\n\n flatIndex = getFlatIndex(localCoords);\n offset = imod(flatIndex, 4);\n\n flatIndex = idiv(flatIndex, 4, 1.);\n\n r = flatIndex / " + width + ";\n c = imod(flatIndex, " + width + ");\n uv = (vec2(c, r) + halfCR) / vec2(" + width + ".0, " + height + ".0);\n values = " + glsl.texture2D + "(A, uv);\n\n if(offset == 0) {\n result[" + channel + "] = values[0];\n } else if(offset == 1) {\n result[" + channel + "] = values[1];\n } else if(offset == 2) {\n result[" + channel + "] = values[2];\n } else {\n result[" + channel + "] = values[3];\n }\n }\n }\n "; - } - } - this.userCode = "\n " + getFlatIndexFrom3D(outputShape) + "\n\n void main() {\n ivec3 coords = getOutputCoords();\n\n vec4 result = vec4(0.);\n int flatIndex, r, c, offset;\n ivec3 localCoords;\n vec2 uv;\n vec4 values;\n\n " + mainLoop + "\n\n " + glsl.output + " = " + output + ";\n }\n "; - } - return EncodeMatrixPackedProgram; - }()); - - /** - * @license - * Copyright 2018 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var COMPLEX_FFT = { - REAL: 'return real * expR - imag * expI;', - IMAG: 'return real * expI + imag * expR;' - }; - var FFTProgram = /** @class */ (function () { - function FFTProgram(op, inputShape, inverse) { - this.variableNames = ['real', 'imag']; - var innerDim = inputShape[1]; - this.outputShape = inputShape; - var exponentMultiplierSnippet = inverse ? "2.0 * " + Math.PI : "-2.0 * " + Math.PI; - var resultDenominator = inverse ? innerDim + ".0" : '1.0'; - this.userCode = "\n const float exponentMultiplier = " + exponentMultiplierSnippet + ";\n\n float unaryOpComplex(float real, float expR, float imag, float expI) {\n " + op + "\n }\n\n float mulMatDFT(int batch, int index) {\n float indexRatio = float(index) / float(" + innerDim + ");\n float exponentMultiplierTimesIndexRatio =\n exponentMultiplier * indexRatio;\n\n float result = 0.0;\n\n for (int i = 0; i < " + innerDim + "; i++) {\n // x = (-2|2 * PI / N) * index * i;\n float x = exponentMultiplierTimesIndexRatio * float(i);\n float expR = cos(x);\n float expI = sin(x);\n float real = getReal(batch, i);\n float imag = getImag(batch, i);\n\n result +=\n unaryOpComplex(real, expR, imag, expI) / " + resultDenominator + ";\n }\n\n return result;\n }\n\n void main() {\n ivec2 coords = getOutputCoords();\n setOutput(mulMatDFT(coords[0], coords[1]));\n }\n "; - } - return FFTProgram; - }()); - - /** - * @license - * Copyright 2019 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var FillProgram = /** @class */ (function () { - function FillProgram(shape, value) { - this.outputShape = []; - this.variableNames = ['x']; - this.outputShape = shape; - this.userCode = "\n uniform float value;\n void main() {\n // Input can be obtained from uniform value.\n setOutput(value);\n }\n "; - } - FillProgram.prototype.getCustomSetupFunc = function (value) { - var _this = this; - return function (gpgpu, webGLProgram) { - if (_this.valueLoc == null) { - _this.valueLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'value'); - } - gpgpu.gl.uniform1f(_this.valueLoc, value); - }; - }; - return FillProgram; - }()); - - /** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var GatherProgram = /** @class */ (function () { - function GatherProgram(aShape, indicesLength, axis) { - this.variableNames = ['A', 'indices']; - var outputShape = aShape.slice(); - outputShape[axis] = indicesLength; - this.outputShape = outputShape; - this.rank = outputShape.length; - var dtype = getCoordsDataType(this.rank); - var sourceCoords = getSourceCoords$1(aShape, axis); - this.userCode = "\n void main() {\n " + dtype + " resRC = getOutputCoords();\n setOutput(getA(" + sourceCoords + "));\n }\n "; - } - return GatherProgram; - }()); - function getSourceCoords$1(aShape, axis) { - var rank = aShape.length; - if (rank > 4) { - throw Error("Gather for rank " + rank + " is not yet supported"); - } - if (rank === 1) { - return "int(getIndices(resRC))"; - } - var currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w']; - var sourceCoords = []; - for (var i = 0; i < aShape.length; i++) { - if (i === axis) { - sourceCoords.push("int(getIndices(" + currentCoords[i] + "))"); - } - else { - sourceCoords.push("" + currentCoords[i]); - } - } - return sourceCoords.join(); - } - - var GatherNDProgram = /** @class */ (function () { - function GatherNDProgram(sliceDim, strides, shape) { - this.sliceDim = sliceDim; - this.strides = strides; - this.variableNames = ['x', 'indices']; - this.outputShape = shape; - var stridesType = getCoordsDataType(strides.length); - var dtype = getCoordsDataType(shape.length); - var strideString = this.sliceDim > 1 ? 'strides[j]' : 'strides'; - this.userCode = "\n " + stridesType + " strides = " + stridesType + "(" + this.strides + ");\n void main() {\n " + dtype + " coords = getOutputCoords();\n int flattenIndex = 0;\n for (int j = 0; j < " + this.sliceDim + "; j++) {\n int index = round(getIndices(coords[0], j));\n flattenIndex += index * " + strideString + ";\n }\n setOutput(getX(flattenIndex, coords[1]));\n }\n "; - } - return GatherNDProgram; - }()); - - /** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - function createVertexShader$1(gl, debug) { - var glsl = getGlslDifferences(); - var vertexShaderSource = glsl.version + "\n precision highp float;\n " + glsl.attribute + " vec3 clipSpacePos;\n " + glsl.attribute + " vec2 uv;\n " + glsl.varyingVs + " vec2 resultUV;\n\n void main() {\n gl_Position = vec4(clipSpacePos, 1);\n resultUV = uv;\n }"; - return createVertexShader(gl, debug, vertexShaderSource); - } - function createVertexBuffer(gl, debug) { - // [x y z u v] * [upper-left, lower-left, upper-right, lower-right] - var vertexArray = new Float32Array([-1, 1, 0, 0, 1, -1, -1, 0, 0, 0, 1, 1, 0, 1, 1, 1, -1, 0, 1, 0]); - return createStaticVertexBuffer(gl, debug, vertexArray); - } - function createIndexBuffer(gl, debug) { - // OpenGL (and WebGL) have "CCW == front" winding - var triangleVertexIndices = new Uint16Array([0, 1, 2, 2, 1, 3]); - return createStaticIndexBuffer(gl, debug, triangleVertexIndices); - } - function createAndConfigureTexture(gl, debug, width, height, internalFormat, textureFormat, textureType) { - validateTextureSize(width, height); - var texture = createTexture(gl, debug); - var tex2d = gl.TEXTURE_2D; - callAndCheck(gl, debug, function () { return gl.bindTexture(tex2d, texture); }); - callAndCheck(gl, debug, function () { return gl.texParameteri(tex2d, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE); }); - callAndCheck(gl, debug, function () { return gl.texParameteri(tex2d, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE); }); - callAndCheck(gl, debug, function () { return gl.texParameteri(tex2d, gl.TEXTURE_MIN_FILTER, gl.NEAREST); }); - callAndCheck(gl, debug, function () { return gl.texParameteri(tex2d, gl.TEXTURE_MAG_FILTER, gl.NEAREST); }); - callAndCheck(gl, debug, function () { return gl.texImage2D(tex2d, 0, internalFormat, width, height, 0, textureFormat, textureType, null); }); - callAndCheck(gl, debug, function () { return gl.bindTexture(gl.TEXTURE_2D, null); }); - return texture; - } - function createFloat32MatrixTexture(gl, debug, rows, columns, textureConfig) { - var _a = getUnpackedMatrixTextureShapeWidthHeight(rows, columns), width = _a[0], height = _a[1]; - return createAndConfigureTexture(gl, debug, width, height, textureConfig.internalFormatFloat, textureConfig.textureFormatFloat, gl.FLOAT); - } - function createFloat16MatrixTexture(gl, debug, rows, columns, textureConfig) { - var _a = getUnpackedMatrixTextureShapeWidthHeight(rows, columns), width = _a[0], height = _a[1]; - return createAndConfigureTexture(gl, debug, width, height, textureConfig.internalFormatHalfFloat, textureConfig.textureFormatFloat, textureConfig.textureTypeHalfFloat); - } - function createUnsignedBytesMatrixTexture(gl, debug, rows, columns, textureConfig) { - var _a = getUnpackedMatrixTextureShapeWidthHeight(rows, columns), width = _a[0], height = _a[1]; - return createAndConfigureTexture(gl, debug, width, height, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE); - } - function createPackedMatrixTexture(gl, debug, rows, columns, textureConfig) { - var _a = getPackedMatrixTextureShapeWidthHeight(rows, columns), width = _a[0], height = _a[1]; - return createAndConfigureTexture(gl, debug, width, height, textureConfig.internalFormatPackedFloat, gl.RGBA, gl.FLOAT); - } - function createFloat16PackedMatrixTexture(gl, debug, rows, columns, textureConfig) { - var _a = getPackedMatrixTextureShapeWidthHeight(rows, columns), width = _a[0], height = _a[1]; - return createAndConfigureTexture(gl, debug, width, height, textureConfig.internalFormatPackedHalfFloat, gl.RGBA, textureConfig.textureTypeHalfFloat); - } - function bindVertexProgramAttributeStreams(gl, debug, program, vertexBuffer) { - var posOffset = 0; // x is the first buffer element - var uvOffset = 3 * 4; // uv comes after [x y z] - var stride = (3 * 4) + (2 * 4); // xyz + uv, each entry is 4-byte float. - callAndCheck(gl, debug, function () { return gl.bindBuffer(gl.ARRAY_BUFFER, vertexBuffer); }); - var success = bindVertexBufferToProgramAttribute(gl, debug, program, 'clipSpacePos', vertexBuffer, 3, stride, posOffset); - return success && - bindVertexBufferToProgramAttribute(gl, debug, program, 'uv', vertexBuffer, 2, stride, uvOffset); - } - function uploadDenseMatrixToTexture(gl, debug, texture, width, height, data, textureConfig) { - callAndCheck(gl, debug, function () { return gl.bindTexture(gl.TEXTURE_2D, texture); }); - var dataForUpload, texelDataType, internalFormat; - if (data instanceof Uint8Array) { - dataForUpload = new Uint8Array(width * height * 4); - texelDataType = gl.UNSIGNED_BYTE; - internalFormat = gl.RGBA; - } - else { - dataForUpload = new Float32Array(width * height * 4); - texelDataType = gl.FLOAT; - internalFormat = textureConfig.internalFormatPackedFloat; - } - dataForUpload.set(data); - callAndCheck(gl, debug, function () { return gl.texImage2D(gl.TEXTURE_2D, 0, internalFormat, width, height, 0, gl.RGBA, texelDataType, dataForUpload); }); - callAndCheck(gl, debug, function () { return gl.bindTexture(gl.TEXTURE_2D, null); }); - } - function uploadPixelDataToTexture(gl, debug, texture, pixels) { - callAndCheck(gl, debug, function () { return gl.bindTexture(gl.TEXTURE_2D, texture); }); - if (pixels.data instanceof Uint8Array) { - callAndCheck(gl, debug, function () { return gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA, pixels.width, pixels.height, 0, gl.RGBA, gl.UNSIGNED_BYTE, pixels.data); }); - } - else { - callAndCheck(gl, debug, function () { return gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, pixels); }); - } - callAndCheck(gl, debug, function () { return gl.bindTexture(gl.TEXTURE_2D, null); }); - } - function createBufferFromOutputTexture(gl2, debug, rows, columns, textureConfig) { - // Create and bind the buffer. - var buffer = gl2.createBuffer(); - callAndCheck(gl2, debug, function () { return gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer); }); - // Initialize the buffer to the size of the texture in bytes. - var bytesPerFloat = 4; - var valuesPerTexel = 4; - var bufferSizeBytes = bytesPerFloat * valuesPerTexel * rows * columns; - callAndCheck(gl2, debug, function () { return gl2.bufferData(gl2.PIXEL_PACK_BUFFER, bufferSizeBytes, gl2.STREAM_READ); }); - // Enqueue a command on the GPU command queue to copy of texture into the - // buffer. - callAndCheck(gl2, debug, function () { return gl2.readPixels(0, 0, columns, rows, gl2.RGBA, gl2.FLOAT, 0); }); - callAndCheck(gl2, debug, function () { return gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null); }); - return buffer; - } - function downloadFloat32MatrixFromBuffer(gl, buffer, size) { - var gl2 = gl; - var downloadTarget = new Float32Array(size); - gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer); - gl2.getBufferSubData(gl2.PIXEL_PACK_BUFFER, 0, downloadTarget); - gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null); - return downloadTarget; - } - function downloadByteEncodedFloatMatrixFromOutputTexture(gl, debug, rows, columns, textureConfig) { - var _a = getUnpackedMatrixTextureShapeWidthHeight(rows, columns), w = _a[0], h = _a[1]; - var numChannels = 4; - var downloadTarget = new Uint8Array(getUnpackedArraySizeFromMatrixSize(rows * columns, numChannels)); - callAndCheck(gl, debug, function () { return gl.readPixels(0, 0, w, h, textureConfig.downloadTextureFormat, gl.UNSIGNED_BYTE, downloadTarget); }); - // By wrapping the buffer in a Float32Array, we use native browser IEEE 754 - // decoding of the 4 bytes that back each 32 bit float. - return new Float32Array(downloadTarget.buffer); - } - function downloadPackedMatrixFromBuffer(gl, buffer, batch, rows, cols, physicalRows, physicalCols, textureConfig) { - var gl2 = gl; - var downloadTarget = new Float32Array(getPackedRGBAArraySizeFromMatrixShape(physicalRows, physicalCols)); - gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer); - gl2.getBufferSubData(gl2.PIXEL_PACK_BUFFER, 0, downloadTarget); - gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null); - return downloadTarget; - } - function downloadMatrixFromPackedOutputTexture(gl, debug, physicalRows, physicalCols) { - var packedRGBA = new Float32Array(physicalRows * physicalCols * 4); - callAndCheck(gl, debug, function () { return gl.readPixels(0, 0, physicalCols, physicalRows, gl.RGBA, gl.FLOAT, packedRGBA); }); - return packedRGBA; - } - - var gpgpu_util = /*#__PURE__*/Object.freeze({ - createVertexShader: createVertexShader$1, - createVertexBuffer: createVertexBuffer, - createIndexBuffer: createIndexBuffer, - createFloat32MatrixTexture: createFloat32MatrixTexture, - createFloat16MatrixTexture: createFloat16MatrixTexture, - createUnsignedBytesMatrixTexture: createUnsignedBytesMatrixTexture, - createPackedMatrixTexture: createPackedMatrixTexture, - createFloat16PackedMatrixTexture: createFloat16PackedMatrixTexture, - bindVertexProgramAttributeStreams: bindVertexProgramAttributeStreams, - uploadDenseMatrixToTexture: uploadDenseMatrixToTexture, - uploadPixelDataToTexture: uploadPixelDataToTexture, - createBufferFromOutputTexture: createBufferFromOutputTexture, - downloadFloat32MatrixFromBuffer: downloadFloat32MatrixFromBuffer, - downloadByteEncodedFloatMatrixFromOutputTexture: downloadByteEncodedFloatMatrixFromOutputTexture, - downloadPackedMatrixFromBuffer: downloadPackedMatrixFromBuffer, - downloadMatrixFromPackedOutputTexture: downloadMatrixFromPackedOutputTexture - }); - - /** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var GPGPUContext = /** @class */ (function () { - function GPGPUContext(gl) { - this.outputTexture = null; - this.program = null; - this.disposed = false; - this.vertexAttrsAreBound = false; - this.itemsToPoll = []; - var glVersion = env().getNumber('WEBGL_VERSION'); - if (gl != null) { - this.gl = gl; - setWebGLContext(glVersion, gl); - } - else { - this.gl = getWebGLContext(glVersion); - } - // WebGL 2.0 enables texture floats without an extension. - var COLOR_BUFFER_FLOAT = 'WEBGL_color_buffer_float'; - var COLOR_BUFFER_HALF_FLOAT = 'EXT_color_buffer_half_float'; - if (env().getNumber('WEBGL_VERSION') === 1) { - var TEXTURE_FLOAT = 'OES_texture_float'; - var TEXTURE_HALF_FLOAT = 'OES_texture_half_float'; - this.textureFloatExtension = - getExtensionOrThrow(this.gl, this.debug, TEXTURE_FLOAT); - if (hasExtension(this.gl, TEXTURE_HALF_FLOAT)) { - this.textureHalfFloatExtension = getExtensionOrThrow(this.gl, this.debug, TEXTURE_HALF_FLOAT); - } - else if (env().get('WEBGL_FORCE_F16_TEXTURES')) { - throw new Error('GL context does not support half float textures, yet the ' + - 'environment flag WEBGL_FORCE_F16_TEXTURES is set to true.'); - } - this.colorBufferFloatExtension = this.gl.getExtension(COLOR_BUFFER_FLOAT); - if (hasExtension(this.gl, COLOR_BUFFER_HALF_FLOAT)) { - this.colorBufferHalfFloatExtension = getExtensionOrThrow(this.gl, this.debug, COLOR_BUFFER_HALF_FLOAT); - } - else if (env().get('WEBGL_FORCE_F16_TEXTURES')) { - throw new Error('GL context does not support color renderable half floats, yet ' + - 'the environment flag WEBGL_FORCE_F16_TEXTURES is set to true.'); - } - } - else { - COLOR_BUFFER_FLOAT = 'EXT_color_buffer_float'; - if (hasExtension(this.gl, COLOR_BUFFER_FLOAT)) { - this.colorBufferFloatExtension = - this.gl.getExtension(COLOR_BUFFER_FLOAT); - } - else if (hasExtension(this.gl, COLOR_BUFFER_HALF_FLOAT)) { - this.colorBufferHalfFloatExtension = - this.gl.getExtension(COLOR_BUFFER_HALF_FLOAT); - } - else { - throw new Error('GL context does not support color renderable floats'); - } - } - this.vertexBuffer = createVertexBuffer(this.gl, this.debug); - this.indexBuffer = createIndexBuffer(this.gl, this.debug); - this.framebuffer = createFramebuffer(this.gl, this.debug); - this.textureConfig = - getTextureConfig(this.gl, this.textureHalfFloatExtension); - } - Object.defineProperty(GPGPUContext.prototype, "debug", { - get: function () { - return env().getBool('DEBUG'); - }, - enumerable: true, - configurable: true - }); - GPGPUContext.prototype.dispose = function () { - var _this = this; - if (this.disposed) { - return; - } - if (this.program != null) { - console.warn('Disposing a GPGPUContext that still has a bound WebGLProgram.' + - ' This is probably a resource leak, delete the program with ' + - 'GPGPUContext.deleteProgram before disposing.'); - } - if (this.outputTexture != null) { - console.warn('Disposing a GPGPUContext that still has a bound output matrix ' + - 'texture. This is probably a resource leak, delete the output ' + - 'matrix texture with GPGPUContext.deleteMatrixTexture before ' + - 'disposing.'); - } - var gl = this.gl; - callAndCheck(gl, this.debug, function () { return gl.finish(); }); - callAndCheck(gl, this.debug, function () { return gl.bindFramebuffer(gl.FRAMEBUFFER, null); }); - callAndCheck(gl, this.debug, function () { return gl.deleteFramebuffer(_this.framebuffer); }); - callAndCheck(gl, this.debug, function () { return gl.bindBuffer(gl.ARRAY_BUFFER, null); }); - callAndCheck(gl, this.debug, function () { return gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, null); }); - callAndCheck(gl, this.debug, function () { return gl.deleteBuffer(_this.indexBuffer); }); - this.disposed = true; - }; - GPGPUContext.prototype.createFloat32MatrixTexture = function (rows, columns) { - this.throwIfDisposed(); - return createFloat32MatrixTexture(this.gl, this.debug, rows, columns, this.textureConfig); - }; - GPGPUContext.prototype.createFloat16MatrixTexture = function (rows, columns) { - this.throwIfDisposed(); - return createFloat16MatrixTexture(this.gl, this.debug, rows, columns, this.textureConfig); - }; - GPGPUContext.prototype.createUnsignedBytesMatrixTexture = function (rows, columns) { - this.throwIfDisposed(); - return createUnsignedBytesMatrixTexture(this.gl, this.debug, rows, columns, this.textureConfig); - }; - GPGPUContext.prototype.uploadPixelDataToTexture = function (texture, pixels) { - this.throwIfDisposed(); - uploadPixelDataToTexture(this.gl, this.debug, texture, pixels); - }; - GPGPUContext.prototype.uploadDenseMatrixToTexture = function (texture, width, height, data) { - this.throwIfDisposed(); - uploadDenseMatrixToTexture(this.gl, this.debug, texture, width, height, data, this.textureConfig); - }; - GPGPUContext.prototype.createFloat16PackedMatrixTexture = function (rows, columns) { - this.throwIfDisposed(); - return createFloat16PackedMatrixTexture(this.gl, this.debug, rows, columns, this.textureConfig); - }; - GPGPUContext.prototype.createPackedMatrixTexture = function (rows, columns) { - this.throwIfDisposed(); - return createPackedMatrixTexture(this.gl, this.debug, rows, columns, this.textureConfig); - }; - GPGPUContext.prototype.deleteMatrixTexture = function (texture) { - var _this = this; - this.throwIfDisposed(); - if (this.outputTexture === texture) { - unbindColorTextureFromFramebuffer(this.gl, this.debug, this.framebuffer); - this.outputTexture = null; - } - callAndCheck(this.gl, this.debug, function () { return _this.gl.deleteTexture(texture); }); - }; - GPGPUContext.prototype.downloadByteEncodedFloatMatrixFromOutputTexture = function (texture, rows, columns) { - var _this = this; - return this.downloadMatrixDriver(texture, function () { return downloadByteEncodedFloatMatrixFromOutputTexture(_this.gl, _this.debug, rows, columns, _this.textureConfig); }); - }; - GPGPUContext.prototype.downloadPackedMatrixFromBuffer = function (buffer, batch, rows, columns, physicalRows, physicalCols) { - return downloadPackedMatrixFromBuffer(this.gl, buffer, batch, rows, columns, physicalRows, physicalCols, this.textureConfig); - }; - GPGPUContext.prototype.downloadFloat32MatrixFromBuffer = function (buffer, size) { - return downloadFloat32MatrixFromBuffer(this.gl, buffer, size); - }; - GPGPUContext.prototype.createBufferFromTexture = function (texture, rows, columns) { - this.bindTextureToFrameBuffer(texture); - var result = createBufferFromOutputTexture(this.gl, this.debug, rows, columns, this.textureConfig); - this.unbindTextureToFrameBuffer(); - return result; - }; - GPGPUContext.prototype.createAndWaitForFence = function () { - var fenceContext = this.createFence(this.gl); - return this.pollFence(fenceContext); - }; - GPGPUContext.prototype.createFence = function (gl) { - var _this = this; - var query; - var isFencePassed; - if (env().getBool('WEBGL_FENCE_API_ENABLED')) { - var gl2_1 = gl; - var sync_1 = gl2_1.fenceSync(gl2_1.SYNC_GPU_COMMANDS_COMPLETE, 0); - gl.flush(); - isFencePassed = function () { - var status = gl2_1.clientWaitSync(sync_1, 0, 0); - return status === gl2_1.ALREADY_SIGNALED || - status === gl2_1.CONDITION_SATISFIED; - }; - query = sync_1; - } - else if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) { - query = this.beginQuery(); - this.endQuery(); - isFencePassed = function () { return _this.isQueryAvailable(query, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION')); }; - } - else { - // If we have no way to fence, return true immediately. This will fire in - // WebGL 1.0 when there is no disjoint query timer. In this case, because - // the fence passes immediately, we'll immediately ask for a download of - // the texture, which will cause the UI thread to hang. - isFencePassed = function () { return true; }; - } - return { query: query, isFencePassed: isFencePassed }; - }; - GPGPUContext.prototype.downloadMatrixFromPackedTexture = function (texture, physicalRows, physicalCols) { - var _this = this; - return this.downloadMatrixDriver(texture, function () { return downloadMatrixFromPackedOutputTexture(_this.gl, _this.debug, physicalRows, physicalCols); }); - }; - GPGPUContext.prototype.createProgram = function (fragmentShaderSource) { - this.throwIfDisposed(); - var gl = this.gl; - var fragmentShader = createFragmentShader(gl, this.debug, fragmentShaderSource); - var vertexShader = createVertexShader$1(gl, this.debug); - var program = createProgram(gl, this.debug); - callAndCheck(gl, this.debug, function () { return gl.attachShader(program, vertexShader); }); - callAndCheck(gl, this.debug, function () { return gl.attachShader(program, fragmentShader); }); - linkProgram(gl, this.debug, program); - if (this.debug) { - validateProgram(gl, this.debug, program); - } - if (!this.vertexAttrsAreBound) { - this.setProgram(program); - this.vertexAttrsAreBound = bindVertexProgramAttributeStreams(gl, this.debug, this.program, this.vertexBuffer); - } - return program; - }; - GPGPUContext.prototype.deleteProgram = function (program) { - var _this = this; - this.throwIfDisposed(); - if (program === this.program) { - this.program = null; - } - if (program != null) { - callAndCheck(this.gl, this.debug, function () { return _this.gl.deleteProgram(program); }); - } - }; - GPGPUContext.prototype.setProgram = function (program) { - var _this = this; - this.throwIfDisposed(); - this.program = program; - if ((this.program != null) && this.debug) { - validateProgram(this.gl, this.debug, this.program); - } - callAndCheck(this.gl, this.debug, function () { return _this.gl.useProgram(program); }); - }; - GPGPUContext.prototype.getUniformLocation = function (program, uniformName, shouldThrow) { - if (shouldThrow === void 0) { shouldThrow = true; } - this.throwIfDisposed(); - if (shouldThrow) { - return getProgramUniformLocationOrThrow(this.gl, this.debug, program, uniformName); - } - else { - return getProgramUniformLocation(this.gl, program, uniformName); - } - }; - GPGPUContext.prototype.getAttributeLocation = function (program, attribute) { - var _this = this; - this.throwIfDisposed(); - return callAndCheck(this.gl, this.debug, function () { return _this.gl.getAttribLocation(program, attribute); }); - }; - GPGPUContext.prototype.getUniformLocationNoThrow = function (program, uniformName) { - this.throwIfDisposed(); - return this.gl.getUniformLocation(program, uniformName); - }; - GPGPUContext.prototype.setInputMatrixTexture = function (inputMatrixTexture, uniformLocation, textureUnit) { - this.throwIfDisposed(); - this.throwIfNoProgram(); - bindTextureToProgramUniformSampler(this.gl, this.debug, this.program, inputMatrixTexture, uniformLocation, textureUnit); - }; - GPGPUContext.prototype.setOutputMatrixTexture = function (outputMatrixTexture, rows, columns) { - this.setOutputMatrixTextureDriver(outputMatrixTexture, columns, rows); - }; - GPGPUContext.prototype.setOutputPackedMatrixTexture = function (outputPackedMatrixTexture, rows, columns) { - this.throwIfDisposed(); - var _a = getPackedMatrixTextureShapeWidthHeight(rows, columns), width = _a[0], height = _a[1]; - this.setOutputMatrixTextureDriver(outputPackedMatrixTexture, width, height); - }; - GPGPUContext.prototype.setOutputMatrixWriteRegion = function (startRow, numRows, startColumn, numColumns) { - this.setOutputMatrixWriteRegionDriver(startColumn, startRow, numColumns, numRows); - }; - GPGPUContext.prototype.setOutputPackedMatrixWriteRegion = function (startRow, numRows, startColumn, numColumns) { - throw new Error('setOutputPackedMatrixWriteRegion not implemented.'); - }; - GPGPUContext.prototype.debugValidate = function () { - if (this.program != null) { - validateProgram(this.gl, this.debug, this.program); - } - validateFramebuffer(this.gl); - }; - GPGPUContext.prototype.executeProgram = function () { - this.throwIfDisposed(); - this.throwIfNoProgram(); - var gl = this.gl; - if (this.debug) { - this.debugValidate(); - } - callAndCheck(gl, this.debug, function () { return gl.drawElements(gl.TRIANGLES, 6, gl.UNSIGNED_SHORT, 0); }); - }; - GPGPUContext.prototype.blockUntilAllProgramsCompleted = function () { - var _this = this; - this.throwIfDisposed(); - callAndCheck(this.gl, this.debug, function () { return _this.gl.finish(); }); - }; - GPGPUContext.prototype.getQueryTimerExtension = function () { - if (this.disjointQueryTimerExtension == null) { - this.disjointQueryTimerExtension = - getExtensionOrThrow(this.gl, this.debug, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2 ? - 'EXT_disjoint_timer_query_webgl2' : - 'EXT_disjoint_timer_query'); - } - return this.disjointQueryTimerExtension; - }; - GPGPUContext.prototype.getQueryTimerExtensionWebGL2 = function () { - return this.getQueryTimerExtension(); - }; - GPGPUContext.prototype.getQueryTimerExtensionWebGL1 = function () { - return this.getQueryTimerExtension(); - }; - GPGPUContext.prototype.beginQuery = function () { - if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2) { - var gl2 = this.gl; - var ext_1 = this.getQueryTimerExtensionWebGL2(); - var query_1 = gl2.createQuery(); - gl2.beginQuery(ext_1.TIME_ELAPSED_EXT, query_1); - return query_1; - } - var ext = this.getQueryTimerExtensionWebGL1(); - var query = ext.createQueryEXT(); - ext.beginQueryEXT(ext.TIME_ELAPSED_EXT, query); - return query; - }; - GPGPUContext.prototype.endQuery = function () { - if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2) { - var gl2 = this.gl; - var ext_2 = this.getQueryTimerExtensionWebGL2(); - gl2.endQuery(ext_2.TIME_ELAPSED_EXT); - return; - } - var ext = this.getQueryTimerExtensionWebGL1(); - ext.endQueryEXT(ext.TIME_ELAPSED_EXT); - }; - GPGPUContext.prototype.waitForQueryAndGetTime = function (query) { - return __awaiter(this, void 0, void 0, function () { - var _this = this; - return __generator(this, function (_a) { - switch (_a.label) { - case 0: return [4 /*yield*/, repeatedTry(function () { return _this.disposed || // while testing contexts are created / disposed - // in rapid succession, so without this check we - // may poll for the query timer indefinitely - _this.isQueryAvailable(query, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION')); })]; - case 1: - _a.sent(); - return [2 /*return*/, this.getQueryTime(query, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION'))]; - } - }); - }); - }; - GPGPUContext.prototype.getQueryTime = function (query, queryTimerVersion) { - if (queryTimerVersion === 0) { - return null; - } - if (queryTimerVersion === 2) { - var gl2 = this.gl; - var timeElapsedNanos = gl2.getQueryParameter(query, gl2.QUERY_RESULT); - // Return milliseconds. - return timeElapsedNanos / 1000000; - } - else { - var ext = this.getQueryTimerExtensionWebGL1(); - var timeElapsedNanos = ext.getQueryObjectEXT(query, ext.QUERY_RESULT_EXT); - // Return milliseconds. - return timeElapsedNanos / 1000000; - } - }; - GPGPUContext.prototype.isQueryAvailable = function (query, queryTimerVersion) { - if (queryTimerVersion === 0) { - return true; - } - if (queryTimerVersion === 2) { - var gl2 = this.gl; - var ext = this.getQueryTimerExtensionWebGL2(); - var available = gl2.getQueryParameter(query, gl2.QUERY_RESULT_AVAILABLE); - if (this.disjoint == null) { - this.disjoint = this.gl.getParameter(ext.GPU_DISJOINT_EXT); - } - return available && !this.disjoint; - } - else { - var ext = this.getQueryTimerExtensionWebGL1(); - var available = ext.getQueryObjectEXT(query, ext.QUERY_RESULT_AVAILABLE_EXT); - if (this.disjoint == null) { - this.disjoint = this.gl.getParameter(ext.GPU_DISJOINT_EXT); - } - return available && !this.disjoint; - } - }; - GPGPUContext.prototype.pollFence = function (fenceContext) { - var _this = this; - return new Promise(function (resolve) { - _this.addItemToPoll(function () { return fenceContext.isFencePassed(); }, function () { return resolve(); }); - }); - }; - GPGPUContext.prototype.pollItems = function () { - // Find the last query that has finished. - var index = linearSearchLastTrue(this.itemsToPoll.map(function (x) { return x.isDoneFn; })); - for (var i = 0; i <= index; ++i) { - var resolveFn = this.itemsToPoll[i].resolveFn; - resolveFn(); - } - this.itemsToPoll = this.itemsToPoll.slice(index + 1); - }; - GPGPUContext.prototype.addItemToPoll = function (isDoneFn, resolveFn) { - var _this = this; - this.itemsToPoll.push({ isDoneFn: isDoneFn, resolveFn: resolveFn }); - if (this.itemsToPoll.length > 1) { - // We already have a running loop that polls. - return; - } - // Start a new loop that polls. - repeatedTry(function () { - _this.pollItems(); - // End the loop if no more items to poll. - return _this.itemsToPoll.length === 0; - }); - }; - GPGPUContext.prototype.bindTextureToFrameBuffer = function (texture) { - this.throwIfDisposed(); - bindColorTextureToFramebuffer(this.gl, this.debug, texture, this.framebuffer); - if (this.debug) { - validateFramebuffer(this.gl); - } - }; - GPGPUContext.prototype.unbindTextureToFrameBuffer = function () { - if (this.outputTexture != null) { - bindColorTextureToFramebuffer(this.gl, this.debug, this.outputTexture, this.framebuffer); - if (this.debug) { - validateFramebuffer(this.gl); - } - } - else { - unbindColorTextureFromFramebuffer(this.gl, this.debug, this.framebuffer); - } - }; - GPGPUContext.prototype.downloadMatrixDriver = function (texture, downloadAndDecode) { - this.bindTextureToFrameBuffer(texture); - var result = downloadAndDecode(); - this.unbindTextureToFrameBuffer(); - return result; - }; - GPGPUContext.prototype.setOutputMatrixTextureDriver = function (outputMatrixTextureMaybePacked, width, height) { - this.throwIfDisposed(); - var gl = this.gl; - bindColorTextureToFramebuffer(gl, this.debug, outputMatrixTextureMaybePacked, this.framebuffer); - if (this.debug) { - validateFramebuffer(gl); - } - this.outputTexture = outputMatrixTextureMaybePacked; - callAndCheck(gl, this.debug, function () { return gl.viewport(0, 0, width, height); }); - callAndCheck(gl, this.debug, function () { return gl.scissor(0, 0, width, height); }); - }; - GPGPUContext.prototype.setOutputMatrixWriteRegionDriver = function (x, y, width, height) { - var _this = this; - this.throwIfDisposed(); - callAndCheck(this.gl, this.debug, function () { return _this.gl.scissor(x, y, width, height); }); - }; - GPGPUContext.prototype.throwIfDisposed = function () { - if (this.disposed) { - throw new Error('Attempted to use disposed GPGPUContext.'); - } - }; - GPGPUContext.prototype.throwIfNoProgram = function () { - if (this.program == null) { - throw new Error('No GPU program is currently set.'); - } - }; - return GPGPUContext; - }()); - /** - * Finds the index of the last true element using linear search. - * Note: We can't do binary search because Chrome expects us to explicitly - * test all fences before download: - * https://github.com/tensorflow/tfjs/issues/1145 - */ - function linearSearchLastTrue(arr) { - var i = 0; - for (; i < arr.length; ++i) { - var isDone = arr[i](); - if (!isDone) { - break; - } - } - return i - 1; - } - - /** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - function compileProgram(gpgpu, program, inputs, output) { - var userCode = program.userCode; - var inputInfos = inputs.map(function (input, i) { - var shapeInfo = { - logicalShape: input.shape, - texShape: input.isUniform ? null : input.texData.texShape, - isUniform: input.isUniform, - isPacked: input.isUniform ? false : input.texData.isPacked, - flatOffset: null - }; - if (input.texData != null && input.texData.slice != null && - input.texData.slice.flatOffset > 0) { - shapeInfo.flatOffset = input.texData.slice.flatOffset; - } - return { name: program.variableNames[i], shapeInfo: shapeInfo }; - }); - var inShapeInfos = inputInfos.map(function (x) { return x.shapeInfo; }); - var outShapeInfo = { - logicalShape: output.shape, - texShape: output.texData.texShape, - isUniform: false, - isPacked: output.texData.isPacked, - flatOffset: null - }; - var source = makeShader(inputInfos, outShapeInfo, userCode, program.packedInputs); - var webGLProgram = gpgpu.createProgram(source); - // Add special uniforms (NAN, INFINITY) - var infLoc = null; - var nanLoc = gpgpu.getUniformLocation(webGLProgram, 'NAN', false); - if (env().getNumber('WEBGL_VERSION') === 1) { - infLoc = gpgpu.getUniformLocation(webGLProgram, 'INFINITY', false); - } - // Add user-defined uniforms - var uniformLocations = {}; - for (var i = 0; i < program.variableNames.length; i++) { - var varName = program.variableNames[i]; - var shouldThrow = false; - uniformLocations[varName] = - gpgpu.getUniformLocation(webGLProgram, varName, shouldThrow); - uniformLocations["offset" + varName] = - gpgpu.getUniformLocation(webGLProgram, "offset" + varName, shouldThrow); - } - return { - program: program, - source: source, - webGLProgram: webGLProgram, - uniformLocations: uniformLocations, - inShapeInfos: inShapeInfos, - outShapeInfo: outShapeInfo, - infLoc: infLoc, - nanLoc: nanLoc, - }; - } - function validateBinaryAndProgram(shapeInfos, inputs) { - if (shapeInfos.length !== inputs.length) { - throw Error("Binary was compiled with " + shapeInfos.length + " inputs, but " + - ("was executed with " + inputs.length + " inputs")); - } - shapeInfos.forEach(function (s, i) { - var shapeA = s.logicalShape; - var input = inputs[i]; - var shapeB = input.shape; - if (!arraysEqual(shapeA, shapeB)) { - throw Error("Binary was compiled with different shapes than " + - ("the current args. Shapes " + shapeA + " and " + shapeB + " must match")); - } - // The input is uploaded as uniform. - if (s.isUniform && input.isUniform) { - return; - } - var texShapeA = s.texShape; - var texShapeB = input.isUniform ? null : input.texData.texShape; - if (!arraysEqual(texShapeA, texShapeB)) { - throw Error("Binary was compiled with different texture shapes than the" + - (" current args. Shape " + texShapeA + " and " + texShapeB + " must match")); - } - }); - } - function runProgram(gpgpu, binary, inputs, output, customSetup) { - validateBinaryAndProgram(binary.inShapeInfos, inputs); - validateBinaryAndProgram([binary.outShapeInfo], [output]); - var outTex = output.texData.texture; - var outTexShape = output.texData.texShape; - if (output.texData.isPacked) { - gpgpu.setOutputPackedMatrixTexture(outTex, outTexShape[0], outTexShape[1]); - } - else { - gpgpu.setOutputMatrixTexture(outTex, outTexShape[0], outTexShape[1]); - } - gpgpu.setProgram(binary.webGLProgram); - // Set special uniforms (NAN, INFINITY) - if (env().getNumber('WEBGL_VERSION') === 1) { - if (binary.infLoc !== null) { - gpgpu.gl.uniform1f(binary.infLoc, Infinity); - } - } - if (binary.nanLoc !== null) { - gpgpu.gl.uniform1f(binary.nanLoc, NaN); - } - // Set user-defined inputs - inputs.forEach(function (input, i) { - var varName = binary.program.variableNames[i]; - var varLoc = binary.uniformLocations[varName]; - var varOffsetLoc = binary.uniformLocations["offset" + varName]; - if (varLoc == null) { - // The compiler inferred that this variable is not used in this shader. - return; - } - if (input.isUniform) { - // Upload the values of the tensor as uniform. - if (sizeFromShape(input.shape) < 2) { - gpgpu.gl.uniform1f(varLoc, input.uniformValues[0]); - } - else { - var vals = input.uniformValues; - if (!(vals instanceof Float32Array)) { - vals = new Float32Array(vals); - } - gpgpu.gl.uniform1fv(varLoc, vals); - } - return; - } - // If the input was sliced, upload the flat offset index. - if (input.texData.slice != null && varOffsetLoc != null) { - gpgpu.gl.uniform1i(varOffsetLoc, input.texData.slice.flatOffset); - } - gpgpu.setInputMatrixTexture(input.texData.texture, varLoc, i); - }); - if (customSetup != null) { - customSetup(gpgpu, binary.webGLProgram); - } - gpgpu.executeProgram(); - } - function makeShaderKey(program, inputs, output) { - var keyInputs = ''; - inputs.concat(output).forEach(function (x) { - var hasOffset = x.texData != null && x.texData.slice != null && - x.texData.slice.flatOffset > 0; - var texShape = x.isUniform ? 'uniform' : x.texData.texShape; - keyInputs += x.shape + "_" + texShape + "_" + hasOffset; - }); - var keyUserCode = program.userCode; - var key = program.constructor.name; - // Fast string concat. See https://jsperf.com/string-concatenation/14. - key += '_' + keyInputs + '_' + keyUserCode; - return key; - } - - /** - * @license - * Copyright 2019 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var Im2ColPackedProgram = /** @class */ (function () { - function Im2ColPackedProgram(outputShape, inputShape, convInfo) { - this.variableNames = ['A']; - this.packedInputs = true; - this.packedOutput = true; - this.outputShape = outputShape; - var filterWidth = convInfo.filterWidth, inChannels = convInfo.inChannels, strideWidth = convInfo.strideWidth, strideHeight = convInfo.strideHeight, padInfo = convInfo.padInfo, outWidth = convInfo.outWidth, dilationWidth = convInfo.dilationWidth, dilationHeight = convInfo.dilationHeight, dataFormat = convInfo.dataFormat; - var left = padInfo.left, top = padInfo.top; - var itemsPerBlockRow = inChannels * filterWidth; - var glsl = getGlslDifferences(); - var isChannelsLast = dataFormat === 'channelsLast'; - var rowDim = isChannelsLast ? 0 : 1; - var colDim = isChannelsLast ? 1 : 2; - var unrolled = ""; - for (var row = 0; row <= 1; row++) { - for (var col = 0; col <= 1; col++) { - unrolled += "\n blockIndex = rc.y + " + col + ";\n pos = rc.x + " + row + ";\n\n if(blockIndex < " + outputShape[1] + " && pos < " + outputShape[0] + ") {\n offsetY = int(blockIndex / (" + outWidth + ")) * " + strideHeight + " - " + top + ";\n d0 = offsetY + " + dilationHeight + " * (pos / " + itemsPerBlockRow + ");\n\n if(d0 < " + inputShape[rowDim] + " && d0 >= 0) {\n\n offsetX = int(mod(float(blockIndex), " + outWidth + ".) * " + strideWidth + ". - " + left + ".);\n d1 = offsetX + " + dilationWidth + " * (int(mod(float(pos), " + itemsPerBlockRow + ".) / " + inChannels + ".));\n\n if(d1 < " + inputShape[colDim] + " && d1 >= 0) {\n\n ch = int(mod(float(pos), " + inChannels + ".));\n\n if (" + isChannelsLast + ") {\n innerDims = vec2(d1, ch);\n result[" + (row * 2 + col) + "] = getChannel(\n getA(d0, int(innerDims.x),\n int(innerDims.y)), innerDims);\n } else {\n innerDims = vec2(d0, d1);\n result[" + (row * 2 + col) + "] = getChannel(\n getA(ch, int(innerDims.x),\n int(innerDims.y)), innerDims);\n }\n }\n }\n }\n "; - } - } - this.userCode = "\n void main() {\n ivec2 rc = getOutputCoords();\n\n vec4 result = vec4(0);\n\n int blockIndex, pos, offsetY, d0, offsetX, d1, ch;\n vec2 innerDims;\n\n " + unrolled + "\n\n " + glsl.output + " = result;\n }\n "; - } - return Im2ColPackedProgram; - }()); - - /** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var LRNProgram = /** @class */ (function () { - function LRNProgram(xShape, radius, bias, alpha, beta) { - this.variableNames = ['x']; - this.outputShape = []; - var rad = radius; - var maxD = xShape[3] - 1; - this.outputShape = xShape; - // optimize pow(bias + alpha * sum, -beta) - // src: https://github.com/tensorflow/tensorflow/.. - // blob/26033a1644a9c4a5fbe3170ab2e864b6a4ccd4ca/.. - // tensorflow/core/kernels/mkl_lrn_op.cc#L320 - var powOperator; - var basis = "float(" + bias + ") + float(" + alpha + ") * sum"; - if (beta === 0.5) { - powOperator = "inversesqrt(" + basis + ")"; - } - else if (beta === 1.0) { - powOperator = "1.0/(" + basis + ")"; - } - else { - powOperator = "exp(log(" + basis + ") * float(-" + beta + "));"; - } - this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int r = coords[1];\n int c = coords[2];\n int d = coords[3];\n float x = getX(b, r, c, d);\n float sum = 0.0;\n for (int j = -" + rad + "; j <= " + rad + "; j++) {\n int idx = d + j;\n if (idx >= 0 && idx <= " + maxD + ") {\n float z = getX(b, r, c, idx);\n sum += z * z;\n }\n }\n float val = x * " + powOperator + ";\n setOutput(val);\n }\n "; - } - return LRNProgram; - }()); - - /** - * @license - * Copyright 2018 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var LRNGradProgram = /** @class */ (function () { - function LRNGradProgram(inputShape, depthRadius, bias, alpha, beta) { - this.variableNames = ['inputImage', 'outputImage', 'dy']; - this.outputShape = []; - this.outputShape = inputShape; - this.depth = inputShape[3]; - this.depthRadius = depthRadius; - this.bias = bias; - this.alpha = alpha; - this.beta = beta; - this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int r = coords[1];\n int c = coords[2];\n\n float result = 0.0;\n for (int d = 0; d < " + this.depth + "; ++d) {\n int depthBegin = int(max(0.0, float(d - " + depthRadius + ")));\n int depthEnd = int(min(float(" + this.depth + "),\n float(d + " + depthRadius + " + 1)));\n\n const int MIN_DEPTH_BEGIN = 0;\n const int MAX_DEPTH_END = " + this.depth + ";\n\n float norm = 0.0;\n for (int k = MIN_DEPTH_BEGIN; k < MAX_DEPTH_END; ++k) {\n if (k < depthBegin){\n continue;\n }\n else if (k >= depthBegin && k < depthEnd) {\n norm += getInputImage(b, r, c, k) * getInputImage(b, r, c, k);\n }\n else {\n break;\n }\n }\n\n norm = float(" + alpha + ") * norm + float(" + bias + ");\n\n for(int k = MIN_DEPTH_BEGIN; k < MAX_DEPTH_END; ++k){\n if (k < depthBegin){\n continue;\n }\n else if (k >= depthBegin && k < depthEnd){\n float dyi = -2.0 * float(" + alpha + ")\n * float(" + beta + ")\n * getInputImage(b ,r ,c, k) * getOutputImage(b, r, c, d)\n / norm;\n if (k == d) {\n dyi += pow(norm, -1.0 * " + beta + ");\n }\n if (k == coords[3]) {\n dyi *= getDy(b, r, c, d);\n result += dyi;\n }\n }\n else {\n break;\n }\n }\n }\n setOutput(result);\n }\n "; - } - return LRNGradProgram; - }()); - - /** - * @license - * Copyright 2019 Google LLC All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var LRNPackedProgram = /** @class */ (function () { - function LRNPackedProgram(xShape, radius, bias, alpha, beta) { - this.variableNames = ['x']; - this.outputShape = []; - this.packedInputs = true; - this.packedOutput = true; - var rad = radius; - var maxD = xShape[3] - 1; - this.outputShape = xShape; - // optimize pow(bias + alpha * sum, -beta) - // src: https://github.com/tensorflow/tensorflow/.. - // blob/26033a1644a9c4a5fbe3170ab2e864b6a4ccd4ca/.. - // tensorflow/core/kernels/mkl_lrn_op.cc#L320 - var powOperator; - var basis = "float(" + bias + ") + float(" + alpha + ") * sum"; - if (beta === 0.5) { - powOperator = "inversesqrt(" + basis + ")"; - } - else if (beta === 1.0) { - powOperator = "1.0/(" + basis + ")"; - } - else { - powOperator = "exp(log(" + basis + ") * float(-" + beta + "));"; - } - this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords.x;\n int r = coords.y;\n int c = coords.z;\n int d = coords.w;\n\n bool hasNextCol = d < " + this.outputShape[3] + ";\n bool hasNextRow = c < " + this.outputShape[2] + ";\n\n vec4 sum = vec4(0.);\n vec4 xFragAtOutputCoords = getX(b, r, c, d);\n\n vec4 xAtOutputCoords = vec4(\n getChannel(xFragAtOutputCoords, vec2(c, d)),\n hasNextCol ?\n getChannel(xFragAtOutputCoords, vec2(c, d + 1)) : 0.0,\n hasNextRow ?\n getChannel(xFragAtOutputCoords , vec2(c + 1, d)) : 0.0,\n (hasNextRow && hasNextCol) ?\n getChannel(xFragAtOutputCoords, vec2(c + 1, d + 1)) : 0.0\n );\n\n int firstChannel = d - " + rad + ";\n vec2 cache = vec2(0.);\n if(firstChannel >= 0){\n vec4 firstChannelFrag = getX(b, r, c, firstChannel);\n cache.x = getChannel(firstChannelFrag, vec2(c, firstChannel));\n if(hasNextRow){\n cache.y = getChannel(firstChannelFrag, vec2(c + 1, firstChannel));\n }\n }\n\n ivec2 depth = ivec2(d, d + 1);\n for (int j = - " + rad + "; j <= " + rad + "; j++) {\n ivec2 idx = depth + j;\n bvec2 aboveLowerBound = greaterThanEqual(idx, ivec2(0));\n bvec2 belowUpperBound = lessThanEqual(idx, ivec2(" + maxD + "));\n\n bool depthInRange = aboveLowerBound.x && belowUpperBound.x;\n bool depthPlusOneInRange = aboveLowerBound.y && belowUpperBound.y;\n\n if(depthInRange || depthPlusOneInRange){\n vec4 z = vec4(0.);\n vec4 xFragAtCurrentDepth;\n z.xz = cache.xy;\n if(depthPlusOneInRange && hasNextCol){\n xFragAtCurrentDepth = idx.y != d ?\n getX(b, r, c, idx.y) : xFragAtOutputCoords;\n z.y = getChannel(xFragAtCurrentDepth, vec2(c, idx.y));\n if(hasNextRow){\n z.w = getChannel(xFragAtCurrentDepth, vec2(c + 1, idx.y));\n }\n }\n cache.xy = z.yw;\n sum += z * z;\n }\n }\n vec4 result = xAtOutputCoords * " + powOperator + ";\n setOutput(result);\n }\n "; - } - return LRNPackedProgram; - }()); - - /** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var MaxPool2DBackpropProgram = /** @class */ (function () { - function MaxPool2DBackpropProgram(convInfo) { - this.variableNames = ['dy', 'maxPos']; - this.outputShape = convInfo.inShape; - var strideHeight = convInfo.strideHeight; - var strideWidth = convInfo.strideWidth; - var dilationHeight = convInfo.dilationHeight; - var effectiveFilterHeight = convInfo.effectiveFilterHeight; - var effectiveFilterWidth = convInfo.effectiveFilterWidth; - var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top; - var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left; - var lastIndex = effectiveFilterHeight * effectiveFilterWidth - 1; - this.userCode = "\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n\n ivec2 dyRCCorner = coords.yz - pads;\n int dyRCorner = dyRCCorner.x;\n int dyCCorner = dyRCCorner.y;\n\n // Convolve dy(?, ?, d) with pos mask(:, :, d) to get dx(xR, xC, d).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 || fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n for (int wC = 0; wC < " + effectiveFilterWidth + "; wC++) {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n float dyValue = getDy(b, idyR, idyC, d);\n int maxPosValue = " + lastIndex + " - int(getMaxPos(b, idyR, idyC, d));\n\n // Get the current value, check it against the value from the\n // position matrix.\n int curPosValue = wR * " + effectiveFilterWidth + " + wC;\n float mask = float(maxPosValue == curPosValue ? 1.0 : 0.0);\n\n dotProd += dyValue * mask;\n }\n }\n setOutput(dotProd);\n }\n "; - } - return MaxPool2DBackpropProgram; - }()); - var MaxPool3DBackpropProgram = /** @class */ (function () { - function MaxPool3DBackpropProgram(convInfo) { - this.variableNames = ['dy', 'maxPos']; - this.outputShape = convInfo.inShape; - var strideDepth = convInfo.strideDepth; - var strideHeight = convInfo.strideHeight; - var strideWidth = convInfo.strideWidth; - var dilationDepth = convInfo.dilationDepth; - var dilationHeight = convInfo.dilationHeight; - var dilationWidth = convInfo.dilationWidth; - var effectiveFilterDepth = convInfo.effectiveFilterDepth; - var effectiveFilterHeight = convInfo.effectiveFilterHeight; - var effectiveFilterWidth = convInfo.effectiveFilterWidth; - var padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front; - var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top; - var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left; - var lastIndex = effectiveFilterDepth * effectiveFilterHeight * effectiveFilterWidth - 1; - this.userCode = "\n const ivec3 pads = ivec3(" + padFront + ", " + padTop + ", " + padLeft + ");\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int ch = coords.u;\n\n ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;\n int dyDCorner = dyCorner.x;\n int dyRCorner = dyCorner.y;\n int dyCCorner = dyCorner.z;\n\n // Convolve dy(?, ?, ?, ch) with pos mask(:, :, :, d) to get\n // dx(xD, xR, xC, ch).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n\n for (int wD = 0; wD < " + effectiveFilterDepth + ";\n wD += " + dilationDepth + ") {\n float dyD = float(dyDCorner + wD) / " + strideDepth + ".0;\n\n if (dyD < 0.0 || dyD >= " + convInfo.outDepth + ".0 || fract(dyD) > 0.0) {\n continue;\n }\n int idyD = int(dyD);\n\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 ||\n fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n for (int wC = 0; wC < " + effectiveFilterWidth + ";\n wC += " + dilationWidth + ") {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n float dyValue = getDy(batch, idyD, idyR, idyC, ch);\n int maxPosValue = " + lastIndex + " -\n int(getMaxPos(batch, idyD, idyR, idyC, ch));\n\n // Get the current value, check it against the value from the\n // position matrix.\n int curPosValue =\n wD * " + effectiveFilterHeight + " * " + effectiveFilterWidth + " +\n wR * " + effectiveFilterWidth + " + wC;\n float mask = float(maxPosValue == curPosValue ? 1.0 : 0.0);\n\n dotProd += dyValue * mask;\n }\n }\n }\n setOutput(dotProd);\n }\n "; - } - return MaxPool3DBackpropProgram; - }()); - - /** - * @license - * Copyright 2018 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var MatMulPackedProgram = /** @class */ (function () { - function MatMulPackedProgram(aShape, outputShape, transposeA, transposeB, addBias, activation, hasPreluActivation) { - if (transposeA === void 0) { transposeA = false; } - if (transposeB === void 0) { transposeB = false; } - if (addBias === void 0) { addBias = false; } - if (activation === void 0) { activation = null; } - if (hasPreluActivation === void 0) { hasPreluActivation = false; } - this.variableNames = ['matrixA', 'matrixB']; - this.packedInputs = true; - this.packedOutput = true; - this.outputShape = outputShape; - var sharedDim = transposeA ? aShape[1] : aShape[2]; - var sharedDimensionPacked = Math.ceil(sharedDim / 2); - var aSample = transposeA ? 'i * 2, rc.y' : 'rc.y, i * 2'; - var bSample = transposeB ? 'rc.z, i * 2' : 'i * 2, rc.z'; - var aSwizzle = transposeA ? ['a.xxyy', 'a.zzww'] : ['a.xxzz', 'a.yyww']; - var bSwizzle = transposeB ? ['b.xzxz', 'b.ywyw'] : ['b.xyxy', 'b.zwzw']; - var activationSnippet = '', applyActivationSnippet = ''; - if (activation) { - if (hasPreluActivation) { - activationSnippet = "vec4 activation(vec4 a) {\n vec4 b = getPreluActivationWeightsAtOutCoords();\n " + activation + "\n }"; - } - else { - activationSnippet = "vec4 activation(vec4 x) {\n " + activation + "\n }"; - } - applyActivationSnippet = "result = activation(result);"; - } - var addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : ''; - if (addBias) { - this.variableNames.push('bias'); - } - if (hasPreluActivation) { - this.variableNames.push('preluActivationWeights'); - } - this.userCode = "\n " + activationSnippet + "\n\n const float sharedDimension = " + sharedDimensionPacked + ".0;\n\n vec4 dot2x2ARowBCol(ivec3 rc) {\n vec4 result = vec4(0);\n for (int i = 0; i < " + sharedDimensionPacked + "; i++) {\n vec4 a = getMatrixA(rc.x, " + aSample + ");\n vec4 b = getMatrixB(rc.x, " + bSample + ");\n\n // These swizzled products need to be separately added.\n // See: https://github.com/tensorflow/tfjs/issues/1735\n result += (" + aSwizzle[0] + " * " + bSwizzle[0] + ");\n result += (" + aSwizzle[1] + " * " + bSwizzle[1] + ");\n }\n return result;\n }\n\n void main() {\n ivec3 rc = getOutputCoords();\n vec4 result = dot2x2ARowBCol(rc);\n\n " + addBiasSnippet + "\n\n " + applyActivationSnippet + "\n\n setOutput(result);\n }\n "; - } - return MatMulPackedProgram; - }()); - - /** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var MultinomialProgram = /** @class */ (function () { - function MultinomialProgram(batchSize, numOutcomes, numSamples) { - this.variableNames = ['probs']; - this.outputShape = [batchSize, numSamples]; - this.userCode = "\n uniform float seed;\n\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n\n float r = random(seed);\n float cdf = 0.0;\n\n for (int i = 0; i < " + (numOutcomes - 1) + "; i++) {\n cdf += getProbs(batch, i);\n\n if (r < cdf) {\n setOutput(float(i));\n return;\n }\n }\n\n // If no other event happened, last event happened.\n setOutput(float(" + (numOutcomes - 1) + "));\n }\n "; - } - MultinomialProgram.prototype.getCustomSetupFunc = function (seed) { - var _this = this; - return function (gpgpu, webGLProgram) { - if (_this.seedLoc == null) { - _this.seedLoc = gpgpu.getUniformLocation(webGLProgram, 'seed'); - } - gpgpu.gl.uniform1f(_this.seedLoc, seed); - }; - }; - return MultinomialProgram; - }()); - - /** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var OneHotProgram = /** @class */ (function () { - function OneHotProgram(numIndices, depth, onValue, offValue) { - this.variableNames = ['indices']; - this.outputShape = [numIndices, depth]; - this.userCode = "\n void main() {\n ivec2 coords = getOutputCoords();\n int index = round(getIndices(coords.x));\n setOutput(mix(float(" + offValue + "), float(" + onValue + "),\n float(index == coords.y)));\n }\n "; - } - return OneHotProgram; - }()); - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var PackProgram = /** @class */ (function () { - function PackProgram(outputShape) { - this.variableNames = ['A']; - this.packedInputs = false; - this.packedOutput = true; - // Only input / output 3D tensors. - this.outputShape = outputShape; - var rank = outputShape.length; - if (rank === 0) { - this.userCode = "\n void main() {\n setOutput(vec4(getA(), 0., 0., 0.));\n }\n "; - } - else { - var channels = getChannels('rc', rank); - var dtype = getCoordsDataType(rank); - var outOfBoundsCondition = getOutOfBoundsCondition(rank, outputShape, channels); - var setup = getSetup(rank, outputShape[outputShape.length - 1], outputShape[outputShape.length - 2], channels); - var output = getOutput(outputShape, channels); - this.userCode = "\n void main() {\n " + dtype + " rc = getOutputCoords();\n\n if(" + outOfBoundsCondition + ") {\n setOutput(vec4(0));\n } else {\n " + setup + "\n\n setOutput(vec4(" + output + "));\n }\n }\n "; - } - } - return PackProgram; - }()); - function getSourceCoordsArr(rank, dims) { - var coords = []; - for (var row = 0; row <= 1; row++) { - for (var col = 0; col <= 1; col++) { - var coord = (row === 0 ? 'r' : 'rp1') + ", " + (col === 0 ? 'c' : 'cp1'); - for (var d = 2; d < rank; d++) { - coord = dims[dims.length - 1 - d] + "," + coord; - } - coords.push(coord); - } - } - return coords; - } - function getOutOfBoundsCondition(rank, shape, dims) { - if (rank === 1) { - return "rc > " + shape[0]; - } - var cond = ''; - for (var i = rank - 2; i < rank; i++) { - cond += dims[i] + " >= " + shape[i]; - if (i < rank - 1) { - cond += '||'; - } - } - return cond; - } - function getSetup(rank, cols, rows, dims) { - if (rank === 1) { - return ''; - } - var innerDims = dims.slice(-2); - return "\n int r = " + innerDims[0] + ";\n int c = " + innerDims[1] + ";\n int rp1 = r + 1;\n int cp1 = c + 1;\n\n bool cEdge = cp1 >= " + cols + ";\n bool rEdge = rp1 >= " + rows + ";\n "; - } - function getOutput(shape, dims) { - var rank = shape.length; - var sourceCoords = getSourceCoordsArr(rank, dims); - if (rank === 1) { - return "getA(rc),\n rc + 1 >= " + shape[0] + " ? 0. : getA(rc + 1),\n 0, 0"; - } - return "getA(" + sourceCoords[0] + "),\n cEdge ? 0. : getA(" + sourceCoords[1] + "),\n rEdge ? 0. : getA(" + sourceCoords[2] + "),\n rEdge || cEdge ? 0. : getA(" + sourceCoords[3] + ")"; - } - - /** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var PadProgram = /** @class */ (function () { - function PadProgram(xShape, paddings, constantValue) { - this.variableNames = ['x']; - this.outputShape = paddings.map(function (p, i) { return p[0] /* beforePad */ + xShape[i] + p[1]; } /* afterPad */); - var rank = xShape.length; - var type = getCoordsDataType(rank); - var start = paddings.map(function (p) { return p[0]; }).join(','); - var end = paddings.map(function (p, i) { return p[0] + xShape[i]; }).join(','); - var unpackedCoords = ['coords[0]', 'coords[1]', 'coords[2]', 'coords[3]'].slice(0, rank); - if (rank === 1) { - this.userCode = "\n int start = " + start + ";\n int end = " + end + ";\n\n void main() {\n int outC = getOutputCoords();\n if (outC < start || outC >= end) {\n setOutput(float(" + constantValue + "));\n } else {\n setOutput(getX(outC - start));\n }\n }\n "; - return; - } - this.userCode = "\n " + type + " start = " + type + "(" + start + ");\n " + type + " end = " + type + "(" + end + ");\n\n void main() {\n " + type + " outC = getOutputCoords();\n if (any(lessThan(outC, start)) || any(greaterThanEqual(outC, end))) {\n setOutput(float(" + constantValue + "));\n } else {\n " + type + " coords = outC - start;\n setOutput(getX(" + unpackedCoords + "));\n }\n }\n "; - } - return PadProgram; - }()); - - /** - * @license - * Copyright 2019 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var PadPackedProgram = /** @class */ (function () { - function PadPackedProgram(xShape, paddings, constantValue) { - this.variableNames = ['x']; - this.packedInputs = true; - this.packedOutput = true; - this.outputShape = paddings.map(function (p, i) { return p[0] /* beforePad */ + xShape[i] + p[1]; } /* afterPad */); - var rank = xShape.length; - var dtype = getCoordsDataType(rank); - var start = paddings.map(function (p) { return p[0]; }).join(','); - var end = paddings.map(function (p, i) { return p[0] + xShape[i]; }).join(','); - var coords = getChannels('rc', rank); - var source = getChannels('source', rank); - var cLimit = coords[rank - 1] + " < " + this.outputShape[rank - 1]; - var innerDims = rank === 1 ? 'source' : "vec2(" + source.slice(-2).join() + ")"; - var componentSetup = [ - dtype + " rc = outputLoc;", coords[rank - 1] + " += 1;\n if(" + cLimit + ") {\n ", - rank === 1 ? '' : "}\n rc = outputLoc;\n " + coords[rank - 2] + " += 1;\n if(" + coords[rank - 2] + " < " + this.outputShape[rank - 2] + ") {", - rank === 1 ? '' : " " + coords[rank - 1] + " += 1;\n if(" + cLimit + ") {" - ]; - var paddingArea = rank === 1 ? - 'rc < start || rc >= end' : - 'any(lessThan(rc, start)) || any(greaterThanEqual(rc, end))'; - var mainLoop = ''; - for (var i = 0, j = rank === 1 ? 2 : 4; i < j; i++) { - mainLoop += "\n " + componentSetup[i] + "\n if (" + paddingArea + ") {\n result[" + i + "] = float(" + constantValue + ");\n } else {\n " + dtype + " source = rc - start;\n result[" + i + "] = getChannel(getX(" + source.join() + "), " + innerDims + ");\n }\n "; - } - mainLoop += (rank === 1 ? "} " : "}}"); - this.userCode = "\n const " + dtype + " start = " + dtype + "(" + start + ");\n const " + dtype + " end = " + dtype + "(" + end + ");\n\n void main() {\n " + dtype + " outputLoc = getOutputCoords();\n vec4 result = vec4(0.);\n " + mainLoop + "\n setOutput(result);\n }\n "; - } - return PadPackedProgram; - }()); - - /** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var Pool2DProgram = /** @class */ (function () { - function Pool2DProgram(convInfo, poolType, computePositions, flattenPositions, includeBatchInIndex) { - if (flattenPositions === void 0) { flattenPositions = false; } - if (includeBatchInIndex === void 0) { includeBatchInIndex = false; } - this.variableNames = ['x']; - if (poolType === 'avg' && computePositions) { - throw new Error('Cannot compute positions for average pool.'); - } - var filterWidth = convInfo.filterWidth; - var strideHeight = convInfo.strideHeight; - var strideWidth = convInfo.strideWidth; - var dilationHeight = convInfo.dilationHeight; - var dilationWidth = convInfo.dilationWidth; - var effectiveFilterHeight = convInfo.effectiveFilterHeight; - var effectiveFilterWidth = convInfo.effectiveFilterWidth; - var padTop = convInfo.padInfo.top; - var padLeft = convInfo.padInfo.left; - this.outputShape = convInfo.outShape; - var isAvgPool = poolType === 'avg'; - var batchFlattenPositionStr = "((batch * " + convInfo.inHeight + " + xR) * " + convInfo.inWidth + " + xC) * " + convInfo.inChannels + " + d"; - var flattenPositionStr = "(xR * " + convInfo.inWidth + " + xC) * " + convInfo.inChannels + " + d"; - var initializationValue = '0.0'; - if (!isAvgPool) { - // WebGL on Firefox Linux can't compile 1/0 so we do 1/eps. - initializationValue = '-1.0 / 1e-20'; - } - if (computePositions) { - var compareOp_1 = '>='; - this.userCode = "\n const ivec2 strides = ivec2(" + strideHeight + ", " + strideWidth + ");\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d = coords[3];\n\n ivec2 xRCCorner = coords.yz * strides - pads;\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n // max/min x(?, ?, d) to get y(yR, yC, d).\n // ? = to be determined\n float minMaxValue = 0.0;\n float minMaxValueFound = 0.0;\n int minMaxPosition = 0;\n float avgValue = 0.0;\n\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n int xR = xRCorner + wR;\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + effectiveFilterWidth + ";\n wC += " + dilationWidth + ") {\n int xC = xCCorner + wC;\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n float value = getX(batch, xR, xC, d);\n\n // If a min / max value has already been found, use it. If not,\n // use the current value.\n float currMinMaxValue = mix(\n value, minMaxValue, minMaxValueFound);\n if (value " + compareOp_1 + " currMinMaxValue) {\n minMaxValue = value;\n minMaxValueFound = 1.0;\n minMaxPosition = " + (flattenPositions ? (includeBatchInIndex ? batchFlattenPositionStr : - flattenPositionStr) : - "wR * " + effectiveFilterWidth + " + wC") + ";\n }\n }\n }\n setOutput(float(minMaxPosition));\n }\n "; - return; - } - var compareOp = 'max'; - var returnValue = poolType + "(" + poolType + "(" + poolType + "(" + - 'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])'; - if (poolType === 'avg') { - returnValue = "avgValue / count"; - } - var filterWidthNearestVec4 = Math.floor(filterWidth / 4) * 4; - var filterWidthVec4Remainder = filterWidth % 4; - var updateSnippet = "\n if (" + isAvgPool + ") {\n avgValue += dot(values, ones);\n } else {\n minMaxValue = " + compareOp + "(values, minMaxValue);\n }\n "; - this.userCode = "\n const ivec2 strides = ivec2(" + strideHeight + ", " + strideWidth + ");\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n const float initializationValue = " + initializationValue + ";\n const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);\n\n float count = 0.0;\n\n float getValue(int batch, int xR, int xC, int d) {\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n return initializationValue;\n }\n count += 1.0;\n return getX(batch, xR, xC, d);\n }\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d = coords[3];\n\n ivec2 xRCCorner = coords.yz * strides - pads;\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n // max/min x(?, ?, d) to get y(yR, yC, d).\n // ? = to be determined\n vec4 minMaxValue = vec4(" + initializationValue + ");\n float avgValue = 0.0;\n count = 0.0;\n\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n int xR = xRCorner + wR;\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + filterWidthNearestVec4 + "; wC += 4) {\n int xC = xCCorner + wC * " + dilationWidth + ";\n\n vec4 values = vec4(\n getValue(batch, xR, xC, d),\n getValue(batch, xR, xC + " + dilationWidth + ", d),\n getValue(batch, xR, xC + 2 * " + dilationWidth + ", d),\n getValue(batch, xR, xC + 3 * " + dilationWidth + ", d)\n );\n\n " + updateSnippet + "\n }\n\n int xC = xCCorner + " + filterWidthNearestVec4 + ";\n if (" + (filterWidthVec4Remainder === 1) + ") {\n vec4 values = vec4(\n getValue(batch, xR, xC, d),\n initializationValue,\n initializationValue,\n initializationValue\n );\n\n " + updateSnippet + "\n } else if (" + (filterWidthVec4Remainder === 2) + ") {\n vec4 values = vec4(\n getValue(batch, xR, xC, d),\n getValue(batch, xR, xC + " + dilationWidth + ", d),\n initializationValue,\n initializationValue\n );\n\n " + updateSnippet + "\n } else if (" + (filterWidthVec4Remainder === 3) + ") {\n vec4 values = vec4(\n getValue(batch, xR, xC, d),\n getValue(batch, xR, xC + " + dilationWidth + ", d),\n getValue(batch, xR, xC + 2 * " + dilationWidth + ", d),\n initializationValue\n );\n\n " + updateSnippet + "\n }\n }\n setOutput(" + returnValue + ");\n }\n "; - } - return Pool2DProgram; - }()); - var Pool3DProgram = /** @class */ (function () { - function Pool3DProgram(convInfo, poolType, computePositions, flattenPositions, includeBatchInIndex) { - if (flattenPositions === void 0) { flattenPositions = false; } - if (includeBatchInIndex === void 0) { includeBatchInIndex = false; } - this.variableNames = ['x']; - if (poolType === 'avg' && computePositions) { - throw new Error('Cannot compute positions for average pool.'); - } - var filterWidth = convInfo.filterWidth; - var strideDepth = convInfo.strideDepth; - var strideHeight = convInfo.strideHeight; - var strideWidth = convInfo.strideWidth; - var dilationDepth = convInfo.dilationDepth; - var dilationHeight = convInfo.dilationHeight; - var dilationWidth = convInfo.dilationWidth; - var effectiveFilterDepth = convInfo.effectiveFilterDepth; - var effectiveFilterHeight = convInfo.effectiveFilterHeight; - var effectiveFilterWidth = convInfo.effectiveFilterWidth; - var padFront = convInfo.padInfo.front; - var padTop = convInfo.padInfo.top; - var padLeft = convInfo.padInfo.left; - this.outputShape = convInfo.outShape; - var isAvgPool = poolType === 'avg'; - var initializationValue = '0.0'; - if (!isAvgPool) { - // WebGL on Firefox Linux can't compile 1/0 so we do 1/eps. - initializationValue = '-1.0 / 1e-20'; - } - if (computePositions) { - var compareOp_2 = '>='; - this.userCode = "\n const ivec3 strides =\n ivec3(" + strideDepth + ", " + strideHeight + ", " + strideWidth + ");\n const ivec3 pads = ivec3(" + padFront + ", " + padTop + ", " + padLeft + ");\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int ch = coords.u;\n\n ivec3 xCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;\n int xDCorner = xCorner.x;\n int xRCorner = xCorner.y;\n int xCCorner = xCorner.z;\n\n // max/min x(?, ?, ?, ch) to get y(yD, yR, yC, ch).\n // ? = to be determined\n float minMaxValue = 0.0;\n float minMaxValueFound = 0.0;\n int minMaxPosition = 0;\n\n for (int wD = 0; wD < " + effectiveFilterDepth + ";\n wD += " + dilationDepth + ") {\n int xD = xDCorner + wD;\n\n if (xD < 0 || xD >= " + convInfo.inDepth + ") {\n continue;\n }\n\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n int xR = xRCorner + wR;\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + effectiveFilterWidth + ";\n wC += " + dilationWidth + ") {\n int xC = xCCorner + wC;\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n float value = getX(batch, xD, xR, xC, ch);\n\n // If a min / max value has already been found, use it. If not,\n // use the current value.\n float currMinMaxValue = mix(\n value, minMaxValue, minMaxValueFound);\n if (value " + compareOp_2 + " currMinMaxValue) {\n minMaxValue = value;\n minMaxValueFound = 1.0;\n minMaxPosition = " + (flattenPositions ? - (includeBatchInIndex ? - "(((batch * " + convInfo.inDepth + " + xD) * " + convInfo.inHeight + " + xR) * " + convInfo.inWidth + " + xC) * " + convInfo.inChannels + " + ch" : - "((xD * " + convInfo.inHeight + " + xR) * " + convInfo.inWidth + " + xC) * " + convInfo.inChannels + " + ch") : - "wD * " + effectiveFilterHeight + " * " + effectiveFilterWidth + " +\n wR * " + effectiveFilterWidth + " + wC") + ";\n }\n }\n }\n }\n setOutput(float(minMaxPosition));\n }\n "; - return; - } - var compareOp = 'max'; - var returnValue = poolType + "(" + poolType + "(" + poolType + "(" + - 'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])'; - if (poolType === 'avg') { - returnValue = "avgValue / count"; - } - var filterWidthNearestVec4 = Math.floor(filterWidth / 4) * 4; - var filterWidthVec4Remainder = filterWidth % 4; - var updateSnippet = "\n if (" + isAvgPool + ") {\n avgValue += dot(values, ones);\n } else {\n minMaxValue = " + compareOp + "(values, minMaxValue);\n }\n "; - this.userCode = "\n const ivec3 strides =\n ivec3(" + strideDepth + ", " + strideHeight + ", " + strideWidth + ");\n const ivec3 pads = ivec3(" + padFront + ", " + padTop + ", " + padLeft + ");\n const float initializationValue = " + initializationValue + ";\n const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);\n\n float count = 0.0;\n\n float getValue(int batch, int xD, int xR, int xC, int ch) {\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n return initializationValue;\n }\n count += 1.0;\n return getX(batch, xD, xR, xC, ch);\n }\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int ch = coords.u;\n\n ivec3 xCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;\n int xDCorner = xCorner.x;\n int xRCorner = xCorner.y;\n int xCCorner = xCorner.z;\n\n // max/min x(?, ?, ?, d) to get y(yD, yR, yC, ch).\n // ? = to be determined\n vec4 minMaxValue = vec4(" + initializationValue + ");\n float avgValue = 0.0;\n count = 0.0;\n\n for (int wD = 0; wD < " + effectiveFilterDepth + ";\n wD += " + dilationDepth + ") {\n int xD = xDCorner + wD;\n\n if (xD < 0 || xD >= " + convInfo.inDepth + ") {\n continue;\n }\n\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n int xR = xRCorner + wR;\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + filterWidthNearestVec4 + "; wC += 4) {\n int xC = xCCorner + wC * " + dilationWidth + ";\n\n vec4 values = vec4(\n getValue(batch, xD, xR, xC, ch),\n getValue(batch, xD, xR, xC + " + dilationWidth + ", ch),\n getValue(batch, xD, xR, xC + 2 * " + dilationWidth + ", ch),\n getValue(batch, xD, xR, xC + 3 * " + dilationWidth + ", ch)\n );\n\n " + updateSnippet + "\n }\n\n int xC = xCCorner + " + filterWidthNearestVec4 + ";\n if (" + (filterWidthVec4Remainder === 1) + ") {\n vec4 values = vec4(\n getValue(batch, xD, xR, xC, ch),\n initializationValue,\n initializationValue,\n initializationValue\n );\n\n " + updateSnippet + "\n } else if (" + (filterWidthVec4Remainder === 2) + ") {\n vec4 values = vec4(\n getValue(batch, xD, xR, xC, ch),\n getValue(batch, xD, xR, xC + " + dilationWidth + ", ch),\n initializationValue,\n initializationValue\n );\n\n " + updateSnippet + "\n } else if (" + (filterWidthVec4Remainder === 3) + ") {\n vec4 values = vec4(\n getValue(batch, xD, xR, xC, ch),\n getValue(batch, xD, xR, xC + " + dilationWidth + ", ch),\n getValue(batch, xD, xR, xC + 2 * " + dilationWidth + ", ch),\n initializationValue\n );\n\n " + updateSnippet + "\n }\n }\n setOutput(" + returnValue + ");\n }\n }\n "; - } - return Pool3DProgram; - }()); - - /** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var ReduceProgram = /** @class */ (function () { - function ReduceProgram(reduceInfo, reduceType) { - this.variableNames = ['x']; - var windowSize = reduceInfo.windowSize; - var batchSize = reduceInfo.batchSize; - var inSize = reduceInfo.inSize; - var outSize = Math.ceil(inSize / windowSize); - this.outputShape = [batchSize, outSize]; - var initializationValue = '0.0'; - var compareOp = ""; - if (reduceType === 'prod') { - initializationValue = '1.0'; - } - else if (reduceType === 'min') { - // WebGL on Firefox Linux can't compile 1/0 so we do 1/eps. - initializationValue = '1.0 / 1e-20'; - compareOp = "min"; - } - else if (reduceType === 'max') { - // WebGL on Firefox Linux can't compile 1/0 so we do 1/eps. - initializationValue = '-1.0 / 1e-20'; - compareOp = "max"; - } - var returnValue = reduceType + "(" + reduceType + "(" + reduceType + "(" + - 'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])'; - if (reduceType === 'sum') { - returnValue = "sumValue"; - } - else if (reduceType === 'prod') { - returnValue = "prodValue"; - } - else if (reduceType === 'all') { - returnValue = "allValue"; - } - else if (reduceType === 'any') { - returnValue = "anyValue"; - } - var windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4; - var windowSizeVec4Remainder = windowSize % 4; - var updateSnippet = "\n if (" + (reduceType === 'sum') + ") {\n sumValue += dot(values, ones);\n } else if (" + (reduceType === 'prod') + ") {\n vec2 tmp = vec2(values[0], values[1]) * vec2(values[2], values[3]);\n prodValue *= tmp[0] * tmp[1];\n } else {\n minMaxValue = " + compareOp + "(values, minMaxValue);\n }\n "; - var vecType = "vec4"; - if (reduceType === 'all') { - initializationValue = '1.0'; - updateSnippet = "\n bool reducedAllValue = all(values);\n float floatedReducedAllValue = float(reducedAllValue);\n allValue = float(allValue >= 1.0 && floatedReducedAllValue >= 1.0);\n "; - vecType = "bvec4"; - } - else if (reduceType === 'any') { - initializationValue = '0.0'; - updateSnippet = "\n bool reducedAnyValue = any(values);\n float floatedReducedAnyValue = float(reducedAnyValue);\n anyValue = float(anyValue >= 1.0 || floatedReducedAnyValue >= 1.0);\n "; - vecType = "bvec4"; - } - var checkOutOfBounds = ''; - if (inSize % windowSize > 0) { - checkOutOfBounds = "\n if (inIdx < 0 || inIdx >= " + inSize + ") {\n return initializationValue;\n }\n "; - } - this.userCode = "\n const float initializationValue = " + initializationValue + ";\n const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);\n\n float getValue(int batch, int inIdx) {\n " + checkOutOfBounds + "\n return getX(batch, inIdx);\n }\n\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n int outIdx = coords[1];\n int inOffset = outIdx * " + windowSize + ";\n\n vec4 minMaxValue = vec4(" + initializationValue + ");\n float prodValue = 1.0;\n float sumValue = 0.0;\n float allValue = 1.0;\n float anyValue = 0.0;\n\n for (int i = 0; i < " + windowSizeNearestVec4 + "; i += 4) {\n int inIdx = inOffset + i;\n " + vecType + " values = " + vecType + "(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n getValue(batch, inIdx + 3)\n );\n\n " + updateSnippet + "\n }\n\n int inIdx = inOffset + " + windowSizeNearestVec4 + ";\n if (" + (windowSizeVec4Remainder === 1) + ") {\n " + vecType + " values = " + vecType + "(\n getValue(batch, inIdx),\n initializationValue,\n initializationValue,\n initializationValue\n );\n\n " + updateSnippet + "\n } else if (" + (windowSizeVec4Remainder === 2) + ") {\n " + vecType + " values = " + vecType + "(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n initializationValue,\n initializationValue\n );\n\n " + updateSnippet + "\n } else if (" + (windowSizeVec4Remainder === 3) + ") {\n " + vecType + " values = " + vecType + "(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n initializationValue\n );\n\n " + updateSnippet + "\n }\n setOutput(" + returnValue + ");\n }\n "; - } - return ReduceProgram; - }()); - - /** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var ReshapePackedProgram = /** @class */ (function () { - function ReshapePackedProgram(outputShape, inputShape) { - this.variableNames = ['A']; - this.packedInputs = true; - this.packedOutput = true; - this.outputShape = outputShape; - var mainLoop = ""; - for (var i = 0; i < 4; i++) { - var thisRC = "thisRC = rc;"; - if (i % 2 === 1) { - thisRC += "thisRC.z += 1;"; - } - if (i > 1) { - thisRC += "thisRC.y += 1;"; - } - mainLoop += "\n " + thisRC + "\n " + (i > 0 ? "if(thisRC.y < rows && thisRC.z < cols){" : '') + "\n int flatIndex = getFlatIndex(thisRC);\n\n ivec3 inputRC = inputCoordsFromReshapedOutCoords(flatIndex);\n vec2 inputRCInnerDims = vec2(float(inputRC.y),float(inputRC.z));\n\n result[" + i + "] =\n getChannel(getA(inputRC.x, inputRC.y, inputRC.z), inputRCInnerDims);\n " + (i > 0 ? '}' : '') + "\n "; - } - this.userCode = "\n " + getReshapedInputCoords(inputShape) + "\n " + getFlatIndexFrom3D(outputShape) + "\n\n void main() {\n ivec3 rc = getOutputCoords();\n\n vec4 result = vec4(0.);\n\n ivec3 thisRC;\n int rows = " + outputShape[1] + ";\n int cols = " + outputShape[2] + ";\n\n " + mainLoop + "\n\n setOutput(result);\n }\n "; - } - return ReshapePackedProgram; - }()); - function getReshapedInputCoords(shape) { - var coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], shape); - return "\n ivec3 inputCoordsFromReshapedOutCoords(int index) {\n " + coordsFromIndexSnippet + "\n return ivec3(r, c, d);\n }\n "; - } - - /** - * @license - * Copyright 2018 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var ResizeBilinearBackpropProgram = /** @class */ (function () { - function ResizeBilinearBackpropProgram(dy, x, alignCorners) { - this.variableNames = ['dy']; - this.outputShape = []; - this.outputShape = x.shape; - var _a = x.shape, xHeight = _a[1], xWidth = _a[2]; - var _b = dy.shape, yHeight = _b[1], yWidth = _b[2]; - // In the backwards pass, we want to find the pixels that were generated for - // each pixel in the input image the forward pass and add the corresponding - // coefficient from dy to the gradient (with some interpolation). - var effectiveXSize = [ - (alignCorners && yHeight > 1) ? xHeight - 1 : xHeight, - (alignCorners && yWidth > 1) ? xWidth - 1 : xWidth - ]; - var effectiveYSize = [ - (alignCorners && yHeight > 1) ? yHeight - 1 : yHeight, - (alignCorners && yWidth > 1) ? yWidth - 1 : yWidth - ]; - var heightScale = effectiveXSize[0] / effectiveYSize[0]; - var widthScale = effectiveXSize[1] / effectiveYSize[1]; - var invHeightScale = 1 / heightScale; - var invWidthScale = 1 / widthScale; - // This defines the size of the window of values around a particular - // index in dy that we want to search for contributions to dx. - var winHeight = (Math.ceil(invHeightScale) * 2) + 2; - var winWidth = (Math.ceil(invWidthScale) * 2) + 2; - this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n int r = coords[1];\n int c = coords[2];\n\n float accumulator = 0.0;\n\n const float heightScale = float(" + heightScale + ");\n const float widthScale = float(" + widthScale + ");\n\n const float invHeightScale = float(" + invHeightScale + ");\n const float invWidthScale = float(" + invWidthScale + ");\n\n const int winHeight = int(" + winHeight + ");\n const int winWidth = int(" + winWidth + ");\n\n // Compute bounds for where in dy we will look\n float startRLerp = floor(float(r) * invHeightScale);\n int startDyR = int(startRLerp - float(winHeight / 2));\n\n float startCLerp = floor(float(c) * invWidthScale);\n int startDyC = int(startCLerp - float(winWidth / 2));\n\n // Loop over dy\n for (int dyROffset = 0; dyROffset < winHeight; dyROffset++) {\n int dyR = dyROffset + startDyR;\n\n // Guard against the window exceeding the bounds of dy\n if (dyR < 0 || dyR >= " + yHeight + ") {\n continue;\n }\n\n for (int dyCOffset = 0; dyCOffset < winWidth; dyCOffset++) {\n int dyC = dyCOffset + startDyC;\n\n // Guard against the window exceeding the bounds of dy\n if (dyC < 0 || dyC >= " + yWidth + ") {\n continue;\n }\n\n float dxR = float(dyR) * heightScale;\n int topDxRIndex = int(floor(dxR));\n int bottomDxRIndex = int(min(ceil(dxR), " + (xHeight - 1) + ".0));\n float dxRLerp = dxR - float(topDxRIndex);\n float inverseDxRLerp = 1.0 - dxRLerp;\n\n float dxC = float(dyC) * widthScale;\n int leftDxCIndex = int(floor(dxC));\n int rightDxCIndex = int(min(ceil(dxC), " + (xWidth - 1) + ".0));\n float dxCLerp = dxC - float(leftDxCIndex);\n float inverseDxCLerp = 1.0 - dxCLerp;\n\n if (r == topDxRIndex && c == leftDxCIndex) {\n // topLeft\n accumulator +=\n getDy(b, dyR, dyC, d) * inverseDxRLerp * inverseDxCLerp;\n }\n\n if (r == topDxRIndex && c == rightDxCIndex) {\n // topRight\n accumulator += getDy(b, dyR, dyC, d) * inverseDxRLerp * dxCLerp;\n }\n\n if (r == bottomDxRIndex && c == leftDxCIndex) {\n // bottomLeft\n accumulator += getDy(b, dyR, dyC, d) * dxRLerp * inverseDxCLerp;\n }\n\n if (r == bottomDxRIndex && c == rightDxCIndex) {\n // bottomRight\n accumulator += getDy(b, dyR, dyC, d) * dxRLerp * dxCLerp;\n }\n }\n }\n // End loop over dy\n\n setOutput(accumulator);\n }\n "; - } - return ResizeBilinearBackpropProgram; - }()); - - /** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var ResizeBilinearProgram = /** @class */ (function () { - function ResizeBilinearProgram(inputShape, newHeight, newWidth, alignCorners) { - this.variableNames = ['A']; - this.outputShape = []; - var batch = inputShape[0], oldHeight = inputShape[1], oldWidth = inputShape[2], depth = inputShape[3]; - this.outputShape = [batch, newHeight, newWidth, depth]; - var effectiveInSize = [ - (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight, - (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth - ]; - var effectiveOutSize = [ - (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight, - (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth - ]; - this.userCode = "\n const vec2 effectiveInputOverOutputRatioRC = vec2(\n " + effectiveInSize[0] / effectiveOutSize[0] + ",\n " + effectiveInSize[1] / effectiveOutSize[1] + ");\n const vec2 inputShapeRC = vec2(" + oldHeight + ".0, " + oldWidth + ".0);\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n ivec2 yRC = coords.yz;\n\n // Fractional source index.\n vec2 sourceFracIndexRC = vec2(yRC) * effectiveInputOverOutputRatioRC;\n\n // Compute the four integer indices.\n ivec2 sourceFloorRC = ivec2(sourceFracIndexRC);\n ivec2 sourceCeilRC = ivec2(\n min(inputShapeRC - 1.0, ceil(sourceFracIndexRC)));\n\n float topLeft = getA(b, sourceFloorRC.x, sourceFloorRC.y, d);\n float bottomLeft = getA(b, sourceCeilRC.x, sourceFloorRC.y, d);\n float topRight = getA(b, sourceFloorRC.x, sourceCeilRC.y, d);\n float bottomRight = getA(b, sourceCeilRC.x, sourceCeilRC.y, d);\n\n vec2 fracRC = sourceFracIndexRC - vec2(sourceFloorRC);\n\n float top = topLeft + (topRight - topLeft) * fracRC.y;\n float bottom = bottomLeft + (bottomRight - bottomLeft) * fracRC.y;\n float newValue = top + (bottom - top) * fracRC.x;\n\n setOutput(newValue);\n }\n "; - } - return ResizeBilinearProgram; - }()); - - /** - * @license - * Copyright 2019 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var ResizeBilinearPackedProgram = /** @class */ (function () { - function ResizeBilinearPackedProgram(inputShape, newHeight, newWidth, alignCorners) { - this.variableNames = ['A']; - this.packedInputs = true; - this.packedOutput = true; - this.outputShape = []; - var batch = inputShape[0], oldHeight = inputShape[1], oldWidth = inputShape[2], depth = inputShape[3]; - this.outputShape = [batch, newHeight, newWidth, depth]; - var effectiveInSize = [ - (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight, - (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth - ]; - var effectiveOutSize = [ - (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight, - (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth - ]; - this.userCode = "\n const vec3 effectiveInputOverOutputRatioRC = vec3(\n " + effectiveInSize[0] / effectiveOutSize[0] + ",\n " + effectiveInSize[1] / effectiveOutSize[1] + ",\n " + effectiveInSize[1] / effectiveOutSize[1] + ");\n const vec3 inputShapeRC = vec3(" + oldHeight + ".0, " + oldWidth + ".0,\n " + oldWidth + ".0);\n\n float getAValue(int b, int r, int c, int d) {\n return getChannel(getA(b, r, c, d), vec2(c, d));\n }\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n // Calculate values for next column in yRC.z.\n ivec3 yRC = coords.yzz + ivec3(0, 0, 1);\n\n // Fractional source index.\n vec3 sourceFracIndexRC = vec3(yRC) * effectiveInputOverOutputRatioRC;\n\n // Compute the four integer indices.\n ivec3 sourceFloorRC = ivec3(sourceFracIndexRC);\n ivec3 sourceCeilRC = ivec3(\n min(inputShapeRC - 1.0, ceil(sourceFracIndexRC)));\n\n // Should we calculate next column and row elements in 2x2 packed cell.\n bool hasNextCol = d < " + (depth - 1) + ";\n bool hasNextRow = coords.z < " + (newWidth - 1) + ";\n\n // In parallel, construct four corners for all four components in\n // packed 2x2 cell.\n vec4 topLeft = vec4(\n getAValue(b, sourceFloorRC.x, sourceFloorRC.y, d),\n hasNextCol ? getAValue(b, sourceFloorRC.x, sourceFloorRC.y, d + 1)\n : 0.0,\n hasNextRow ? getAValue(b, sourceFloorRC.x, sourceFloorRC.z, d)\n : 0.0,\n (hasNextRow && hasNextCol) ?\n getAValue(b, sourceFloorRC.x, sourceFloorRC.z, d + 1) : 0.0);\n\n vec4 bottomLeft = vec4(\n getAValue(b, sourceCeilRC.x, sourceFloorRC.y, d),\n hasNextCol ? getAValue(b, sourceCeilRC.x, sourceFloorRC.y, d + 1)\n : 0.0,\n hasNextRow ? getAValue(b, sourceCeilRC.x, sourceFloorRC.z, d)\n : 0.0,\n (hasNextRow && hasNextCol) ?\n getAValue(b, sourceCeilRC.x, sourceFloorRC.z, d + 1) : 0.0);\n\n vec4 topRight = vec4(\n getAValue(b, sourceFloorRC.x, sourceCeilRC.y, d),\n hasNextCol ? getAValue(b, sourceFloorRC.x, sourceCeilRC.y, d + 1)\n : 0.0,\n hasNextRow ? getAValue(b, sourceFloorRC.x, sourceCeilRC.z, d)\n : 0.0,\n (hasNextRow && hasNextCol) ?\n getAValue(b, sourceFloorRC.x, sourceCeilRC.z, d + 1) : 0.0);\n\n vec4 bottomRight = vec4(\n getAValue(b, sourceCeilRC.x, sourceCeilRC.y, d),\n hasNextCol ? getAValue(b, sourceCeilRC.x, sourceCeilRC.y, d + 1)\n : 0.0,\n hasNextRow ? getAValue(b, sourceCeilRC.x, sourceCeilRC.z, d)\n : 0.0,\n (hasNextRow && hasNextCol) ?\n getAValue(b, sourceCeilRC.x, sourceCeilRC.z, d + 1) : 0.0);\n\n vec3 fracRC = sourceFracIndexRC - vec3(sourceFloorRC);\n\n vec4 top = mix(topLeft, topRight, fracRC.yyzz);\n vec4 bottom = mix(bottomLeft, bottomRight, fracRC.yyzz);\n vec4 newValue = mix(top, bottom, fracRC.x);\n\n setOutput(newValue);\n }\n "; - } - return ResizeBilinearPackedProgram; - }()); - - /** - * @license - * Copyright 2018 Google LLC All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var ResizeNearestNeigborBackpropProgram = /** @class */ (function () { - function ResizeNearestNeigborBackpropProgram(dy, x, alignCorners) { - this.variableNames = ['dy']; - this.outputShape = []; - this.outputShape = x.shape; - var _a = x.shape, xHeight = _a[1], xWidth = _a[2]; - var _b = dy.shape, yHeight = _b[1], yWidth = _b[2]; - // In the backwards pass, we want to find the pixels that were generated for - // each pixel in the input image the forward pass and add the corresponding - // coefficient from dy to the gradient (with some interpolation). - var effectiveXSize = [ - (alignCorners && yHeight > 1) ? xHeight - 1 : xHeight, - (alignCorners && yWidth > 1) ? xWidth - 1 : xWidth - ]; - var effectiveYSize = [ - (alignCorners && yHeight > 1) ? yHeight - 1 : yHeight, - (alignCorners && yWidth > 1) ? yWidth - 1 : yWidth - ]; - var heightScale = effectiveXSize[0] / effectiveYSize[0]; - var widthScale = effectiveXSize[1] / effectiveYSize[1]; - var invHeightScale = 1 / heightScale; - var invWidthScale = 1 / widthScale; - // This defines the size of the window of values around a particular - // index in dy that we want to search for contributions to dx. - var winHeight = (Math.ceil(invHeightScale) * 2) + 2; - var winWidth = (Math.ceil(invWidthScale) * 2) + 2; - this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n int r = coords[1];\n int c = coords[2];\n\n float accumulator = 0.0;\n\n const float heightScale = float(" + heightScale + ");\n const float widthScale = float(" + widthScale + ");\n\n const float invHeightScale = float(" + invHeightScale + ");\n const float invWidthScale = float(" + invWidthScale + ");\n\n const int winHeight = int(" + winHeight + ");\n const int winWidth = int(" + winWidth + ");\n\n // Compute bounds for where in dy we will look\n float startRLerp = floor(float(r) * invHeightScale);\n int startDyR = int(floor(startRLerp - float(winHeight / 2)));\n\n float startCLerp = floor(float(c) * invWidthScale);\n int startDyC = int(floor(startCLerp - float(winWidth / 2)));\n\n // Loop over dy\n for (int dyROffset = 0; dyROffset < winHeight; dyROffset++) {\n int dyR = dyROffset + startDyR;\n\n // Guard against the window exceeding the bounds of dy\n if (dyR < 0 || dyR >= " + yHeight + ") {\n continue;\n }\n\n for (int dyCOffset = 0; dyCOffset < winWidth; dyCOffset++) {\n int dyC = dyCOffset + startDyC;\n\n // Guard against the window exceeding the bounds of dy\n if (dyC < 0 || dyC >= " + yWidth + ") {\n continue;\n }\n\n float sourceFracRow =\n float(" + effectiveXSize[0] + ") *\n (float(dyR) / float(" + effectiveYSize[0] + "));\n\n float sourceFracCol =\n float(" + effectiveXSize[1] + ") *\n (float(dyC) / float(" + effectiveYSize[1] + "));\n\n int sourceNearestRow = int(min(\n float(int(" + xHeight + ") - 1),\n " + alignCorners + " ? float(round(sourceFracRow)) :\n float(floor(sourceFracRow))));\n\n int sourceNearestCol = int(min(\n float(int(" + xWidth + ") - 1),\n " + alignCorners + " ? float(round(sourceFracCol)) :\n float(floor(sourceFracCol))));\n\n if (r == sourceNearestRow && c == sourceNearestCol) {\n accumulator += getDy(b, dyR, dyC, d);\n }\n }\n }\n // End loop over dy\n\n setOutput(accumulator);\n }\n "; - } - return ResizeNearestNeigborBackpropProgram; - }()); - - /** - * @license - * Copyright 2018 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var ResizeNearestNeighborProgram = /** @class */ (function () { - function ResizeNearestNeighborProgram(inputShape, newHeight, newWidth, alignCorners) { - this.variableNames = ['A']; - this.outputShape = []; - var batch = inputShape[0], oldHeight = inputShape[1], oldWidth = inputShape[2], depth = inputShape[3]; - this.outputShape = [batch, newHeight, newWidth, depth]; - var effectiveInSize = [ - (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight, - (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth - ]; - var effectiveOutSize = [ - (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight, - (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth - ]; - // When align corners is false, we rounds the value with floor. - var roundBase = alignCorners ? '0.5' : '0.0'; - this.userCode = "\n const vec2 effectiveInputOverOutputRatioRC = vec2(\n " + effectiveInSize[0] / effectiveOutSize[0] + ",\n " + effectiveInSize[1] / effectiveOutSize[1] + ");\n const vec2 inputShapeRC = vec2(" + oldHeight + ".0, " + oldWidth + ".0);\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n ivec2 yRC = coords.yz;\n\n // Fractional source index.\n vec2 sourceFracIndexRC = vec2(yRC) * effectiveInputOverOutputRatioRC;\n\n // Compute the coordinators of nearest neighbor point.\n ivec2 sourceNearestRC = ivec2(\n min(inputShapeRC - 1.0, floor(sourceFracIndexRC + " + roundBase + ")));\n\n float newValue = getA(b, sourceNearestRC.x, sourceNearestRC.y, d);\n\n setOutput(newValue);\n }\n "; - } - return ResizeNearestNeighborProgram; - }()); - - /** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var ReverseProgram = /** @class */ (function () { - function ReverseProgram(xShape, axis) { - this.variableNames = ['x']; - var rank = xShape.length; - if (rank > 4) { - throw new Error("WebGL backend: Reverse of rank-" + rank + " tensor is not yet supported"); - } - this.outputShape = xShape; - if (rank === 1) { - this.userCode = "\n void main() {\n int coord = getOutputCoords();\n setOutput(getX(" + xShape[0] + " - coord - 1));\n }\n "; - return; - } - var getInCoord = function (i) { - if (axis.indexOf(i) !== -1 && xShape[i] !== 1) { - return xShape[i] + " - coords[" + i + "] - 1"; - } - return "coords[" + i + "]"; - }; - var inCoords = xShape.map(function (_, i) { return getInCoord(i); }).join(','); - var type = getCoordsDataType(rank); - this.userCode = "\n void main() {\n " + type + " coords = getOutputCoords();\n setOutput(getX(" + inCoords + "));\n }\n "; - } - return ReverseProgram; - }()); - - /** - * @license - * Copyright 2019 Google LLC All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var ReversePackedProgram = /** @class */ (function () { - function ReversePackedProgram(xShape, axis) { - this.variableNames = ['x']; - this.packedInputs = true; - this.packedOutput = true; - var rank = xShape.length; - if (rank > 4) { - throw new Error("WebGL backend: Reverse of rank-" + rank + " tensor is not yet supported"); - } - this.outputShape = xShape; - var channels = getChannels('rc', rank); - var nextColumn = channels[rank - 1] + " + 1 < " + this.outputShape[rank - 1]; - var nextRow = channels[rank - 2] + " + 1 < " + this.outputShape[rank - 2]; - var type = getCoordsDataType(rank); - if (rank === 1) { - this.userCode = "\n void main(){\n int rc = getOutputCoords();\n vec4 result = vec4(0.);\n result.r = getChannel(getX(" + xShape[0] + " - rc - 1),\n " + xShape[0] + " - rc - 1);\n if(" + nextColumn + "){\n result.g = getChannel(getX(" + xShape[0] + " - (rc + 1) - 1),\n " + xShape[0] + " - (rc + 1) - 1);\n }\n setOutput(result);\n }\n "; - } - else { - this.userCode = "\n void main() {\n " + type + " rc = getOutputCoords();\n vec4 result = vec4(0.);\n result.r = " + getR(channels.slice()) + ";\n if(" + nextColumn + "){\n result.g = " + getG(channels.slice()) + ";\n }\n if(" + nextRow + ") {\n result.b = " + getB(channels.slice()) + ";\n if(" + nextColumn + ") {\n result.a = " + getA(channels.slice()) + ";\n }\n }\n setOutput(result);\n }\n "; - } - function getR(channels) { - return getChannel(channels); - } - function getG(channels) { - channels[rank - 1] = '(' + channels[rank - 1] + " + 1)"; - return getChannel(channels); - } - function getB(channels) { - channels[rank - 2] = '(' + channels[rank - 2] + " + 1)"; - return getChannel(channels); - } - function getA(channels) { - channels[rank - 1] = '(' + channels[rank - 1] + " + 1)"; - channels[rank - 2] = '(' + channels[rank - 2] + " + 1)"; - return getChannel(channels); - } - function getChannel(channels) { - var inCoordsArray = xShape.map(function (_, i) { return getInCoord(i, channels); }); - var inCoords = inCoordsArray.join(','); - var innerDims = inCoordsArray.slice(-2).join(','); - return "getChannel(getX(" + inCoords + "), vec2(" + innerDims + "))"; - } - function getInCoord(i, channels1) { - if (axis.indexOf(i) !== -1 && xShape[i] !== 1) { - return xShape[i] + " - " + channels1[i] + " - 1"; - } - else { - return "" + channels1[i]; - } - } - } - return ReversePackedProgram; - }()); - - /** - * @license - * Copyright 2018 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var ScatterProgram = /** @class */ (function () { - function ScatterProgram(updateSize, sliceDim, indicesRank, updatesRank, strides, shape, summingDupeIndex) { - if (summingDupeIndex === void 0) { summingDupeIndex = true; } - this.variableNames = ['updates', 'indices', 'defaultValue']; - this.outputShape = shape; - var stridesType = getCoordsDataType(strides.length); - var dtype = getCoordsDataType(shape.length); - var indicesString = ''; - if (indicesRank === 1) { - indicesString = 'i'; - } - else if (indicesRank === 2) { - indicesString = 'i, j'; - } - var indicesSnippet = "getIndices(" + indicesString + ")"; - var updatesString = ''; - if (updatesRank === 1) { - updatesString = 'i'; - } - else if (updatesRank === 2) { - updatesString = 'i, coords[1]'; - } - var updatesSnippet = "getUpdates(" + updatesString + ")"; - var strideString = sliceDim > 1 ? 'strides[j]' : 'strides'; - this.userCode = "\n " + stridesType + " strides = " + stridesType + "(" + strides + ");\n\n void main() {\n " + dtype + " coords = getOutputCoords();\n float sum = 0.0;\n bool found = false;\n for (int i = 0; i < " + updateSize + "; i++) {\n int flattenedIndex = 0;\n for (int j = 0; j < " + sliceDim + "; j++) {\n int index = round(" + indicesSnippet + ");\n flattenedIndex += index * " + strideString + ";\n }\n if (flattenedIndex == coords[0]) {\n sum += " + updatesSnippet + ";\n found = true;\n }\n }\n setOutput(mix(getDefaultValue(), sum, float(found)));\n }\n "; - } - return ScatterProgram; - }()); - - /** - * @license - * Copyright 2018 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var SegmentOpProgram = /** @class */ (function () { - function SegmentOpProgram(segOpInfo, segOpType) { - this.variableNames = ['x', 'segmentIds']; - var windowSize = segOpInfo.windowSize; - var batchSize = segOpInfo.batchSize; - var inSize = segOpInfo.inSize; - var numSegments = segOpInfo.numSegments; - var outSize = numSegments * Math.ceil(inSize / windowSize); - this.outputShape = [batchSize, outSize]; - var initializationValue = '0.0'; - var returnValue = "sumValue"; - var windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4; - var windowSizeVec4Remainder = windowSize % 4; - var updateSnippet = "\n sumValue += dot(values, segFilter);\n "; - var checkValueOutOfBounds = ''; - if (inSize % windowSize > 0) { - checkValueOutOfBounds = "\n if (inIdx < 0 || inIdx >= " + inSize + ") {\n return initializationValue;\n }\n "; - } - var checkSegmentIdOutOfBounds = ''; - if (inSize % windowSize > 0) { - checkSegmentIdOutOfBounds = "\n if (inIdx < 0 || inIdx >= " + inSize + ") {\n return -1.0;\n }\n "; - } - this.userCode = "\n const float initializationValue = " + initializationValue + ";\n\n float getValue(int batch, int inIdx) {\n " + checkValueOutOfBounds + "\n return getX(batch, inIdx);\n }\n\n float getSegmentIdAtIndex(int inIdx) {\n " + checkSegmentIdOutOfBounds + "\n return getSegmentIds(inIdx);\n }\n\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n int outIdx = coords[1];\n int inOffset = int(floor(float(outIdx) / float(\n " + numSegments + ")) * float(" + windowSize + "));\n int currentSeg = int(mod(float(outIdx), float(" + numSegments + ")));\n\n float sumValue = 0.0;\n\n for (int i = 0; i < " + windowSizeNearestVec4 + "; i += 4) {\n int inIdx = inOffset + i;\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n getValue(batch, inIdx + 3)\n );\n\n vec4 segFilter = vec4(\n int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 2)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 3)) == currentSeg ? 1 : 0\n );\n\n " + updateSnippet + "\n }\n\n int inIdx = inOffset + " + windowSizeNearestVec4 + ";\n if (" + (windowSizeVec4Remainder === 1) + ") {\n vec4 values = vec4(\n getValue(batch, inIdx),\n initializationValue,\n initializationValue,\n initializationValue\n );\n\n int inIdxSeg = int(getSegmentIdAtIndex(inIdx));\n\n vec4 segFilter = vec4(\n int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,\n 0,\n 0,\n 0\n );\n\n " + updateSnippet + "\n } else if (" + (windowSizeVec4Remainder === 2) + ") {\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n initializationValue,\n initializationValue\n );\n\n vec4 segFilter = vec4(\n int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,\n 0,\n 0\n );\n\n " + updateSnippet + "\n } else if (" + (windowSizeVec4Remainder === 3) + ") {\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n initializationValue\n );\n\n vec4 segFilter = vec4(\n int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 2)) == currentSeg ? 1 : 0,\n 0\n );\n\n " + updateSnippet + "\n }\n setOutput(" + returnValue + ");\n }\n "; - } - return SegmentOpProgram; - }()); - - /** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var SelectProgram = /** @class */ (function () { - function SelectProgram(cRank, shape, rank) { - this.variableNames = ['c', 'a', 'b']; - this.outputShape = shape; - var cCoords; - var abCoords; - if (rank > 4) { - throw Error("Where for rank " + rank + " is not yet supported"); - } - if (rank === 1) { - abCoords = "resRC"; - cCoords = "resRC"; - } - else { - var currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w']; - var cCoordVars = []; - var abCoordVars = []; - for (var i = 0; i < shape.length; i++) { - abCoordVars.push("" + currentCoords[i]); - if (i < cRank) { - cCoordVars.push("" + currentCoords[i]); - } - } - cCoords = cCoordVars.join(); - abCoords = abCoordVars.join(); - } - var dtype = getCoordsDataType(rank); - this.userCode = "\n void main() {\n " + dtype + " resRC = getOutputCoords();\n float cVal = getC(" + cCoords + ");\n if (cVal >= 1.0) {\n setOutput(getA(" + abCoords + "));\n } else {\n setOutput(getB(" + abCoords + "));\n }\n }\n "; - } - return SelectProgram; - }()); - - /** - * @license - * Copyright 2017 Google Inc. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - var SliceProgram = /** @class */ (function () { - function SliceProgram(destSize) { -