How to Run Machine-Learning Models in the Browser using ONNX

Written by mxkrn | Published 2021/08/25
Tech Story Tags: machinelearning | javascript | webassembly | webdev | python | pytorch | tensorflow | hackernoon-top-story

TLDR Onnx runtime-web is a JavaScript library for running ONNX models on the browser and Node.js. It's now easier than ever to deploy machine-learning models natively. We will be using a pre-trained PyTorch model to deploy an image classifier to the browser. At the end of this tutorial, we will have built a bundled web app that can be run as a static web page or integrated into your JavaScript framework of choice. The official package is hosted on npm under the name onnxruntime-web.via the TL;DR App

Deploying machine learning models outside of a Python environment used to be difficult. When the target platform is the browser, the defacto standard for serving predictions has been an API call to a server-side inference engine. For many reasons, server-side inference APIs are a non-optimal solution and machine learning models are more often being deployed natively. TensorFlow has done a good job at supporting this movement by providing cross-platform APIs, however many of us do not want to be married to a single ecosystem.

In comes the Open Neural Network Exchange (ONNX) project which, since being picked up by Microsoft, has been seeing massive development efforts and is approaching a stable state. It's now easier than ever to deploy machine-learning models; trained using your machine-learning framework of choice, on your platform of choice, with hardware acceleration out of the box.

In April this year, onnxruntime-web was introduced (see this Pull Request). onnxruntime-web uses WebAssembly to compile the onnxruntime inference engine to wasm format - it's about time WebAssembly started to flex its muscles. Especially when paired with WebGL, we suddenly have GPU-powered machine learning in the browser, pretty cool.

In this tutorial we will dive into onnxruntime-web by deploying a pre-trained PyTorch model to the browser. We will be using AlexNet as our deployment target. AlexNet has been trained as an image classifier on the ImageNet dataset, so we will be building an image classifier - nothing better than re-inventing the wheel. At the end of this tutorial, we will have built a bundled web app that can be run as a stand alone static web page, or integrated into your JavaScript framework of choice.

Jump to codeonnxruntime-web-tutorial

Prerequisite

You will need a trained machine-learning model exported as an ONNX binary protobuf file. There's many ways to achieve this using a number of different deep-learning frameworks. For the sake of this tutorial, I will be using the exported model from the AlexNet example in the PyTorch documentation, the python code snippet below will help you generate your own model. You can also follow the documentation to export your own PyTorch model. If you're coming from Tensorflow, this tutorial will help you with exporting your model to ONNX. Lastly, ONNX doesn't just pride itself on cross-platform deployment, but also in allowing exports from all major deep-learning frameworks. Those of you using another deep learning framework should be able to find support for exporting to ONNX in the docs of your framework.

import torch
import torchvision

dummy_input = torch.randn(1, 3, 224, 224)
model = torchvision.models.alexnet(pretrained=True)

input_names = ["input1"]
output_names = ["output1"]

torch.onnx.export(
  model, 
  dummy_input, 
  "alexnet.onnx", 
  verbose=True, 
  input_names=input_names,
  output_names=output_names
)

Running this script creates a file, alexnet.onnx, a binary protobuf file which contains both the network structure and parameters of the model you exported (in this case, AlexNet).

ONNX Runtime Web

ONNX Runtime Web is a JavaScript library for running ONNX models on the browser and on Node.js. ONNX Runtime Web has adopted WebAssembly and WebGL technologies for providing an optimized ONNX model inference runtime for both CPUs and GPUs.

The official package is hosted on npm under the name onnxruntime-web. When using a bundler or working server-side, this package can be installed using npm install. However, it's also possible to deliver the code via a CDN using a script tag. The bundling process is a bit involved so we will start with the script tag approach and come back to using the npm package later.

Inference Session

Let's start with the core application logic: model inference. onnxruntime exposes a runtime object called an InferenceSession with a method .run() which is used to initiate the forward pass with the desired inputs. Both the InferenceSessesion constructor and the accompanying .run() method return a Promise so we will run the entire process inside an async context. Before implementing any browser elements, we will check that our model runs with a dummy input tensor, remembering the input and output names and sizes that we defined earlier when exporting the model.

async function run() {
  try {
    // create a new session and load the AlexNet model.
    const session = await ort.InferenceSession.create('./alexnet.onnx');

    // prepare dummy input data
    const dims = [1, 3, 224, 224];
    const size = dims[0] * dims[1] * dims[2] * dims[3];
    const inputData = Float32Array.from({ length: size }, () => Math.random());

    // prepare feeds. use model input names as keys.
    const feeds = { input1: new ort.Tensor('float32', inputData, dims) };

    // feed inputs and run
    const results = await session.run(feeds);
    console.log(results.output1.data);
  } catch (e) {
    console.log(e);
  }
}
run();

We then implement a simple HTML template, index.html, which should load both the pre-compiled onnxruntime-web package and main.js, containing our code.

<!DOCTYPE html>
<html>
  <header>
    <title>ONNX Runtime Web - Tutorial</title>
  </header>
  <body>
    <script src="https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.min.js">
    </script>
    <script src="main.js"></script>
  </body>
</html>

To run this, we can use light-server. If you haven't started an npm project by now, please do so by running npm init in your current working directory. Once you've completed the setup, install live-server (npm install light-server) and serve the static HTML page using npx light-server -s . -p 8080.

You’re now running a machine learning model natively in the browser! To check that everything is running fine go to your web console and make sure that the output tensor is logged (AlexNet is bulky so it's normal that inference takes a few seconds).

Bundled deployment

Next we will use webpack to bundle our dependencies as would be the case if we want to deploy the model in a Javascript app powered by frameworks like React or Vue. Usually bundling is a relatively simple procedure, however onnxruntime-web requires a slightly more involved webpack configuration - this is because WebAssembly is used to provide the natively assembled runtime.

Browser support, the classic pitfall, especially when working with cutting-edge web technology. If your intended users are not using one of the four major browsers (Chrome, Edge, Firefox, Safari) you might want to hold off on integrating WebAssembly components. More information on the WebAssembly support and roadmap can be found here.

The following steps are based on the examples provided by the official ONNX documentation. We’re assuming you've already started an npm project.

1. Install the dependencies.

npm install onnxruntime-web && npm install -D webpack webpack-cli copy-webpack

2. Instead of loading the onnxruntime-web module via a CDN, we should update main.js to require the package at the top of the script.

const ort = require('onnxruntime-web');

3. Save the configuration file, defined by the ONNX Runtime team, below as webpack.config.js

// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

const path = require('path');
const CopyPlugin = require("copy-webpack-plugin");

module.exports = () => {
    return {
        target: ['web'],
        entry: path.resolve(__dirname, 'main.js'),
        output: {
            path: path.resolve(__dirname, 'dist'),
            filename: 'bundle.min.js',
            library: {
                type: 'umd'
            }
        },
        plugins: [new CopyPlugin({
            // Use copy plugin to copy *.wasm to output folder.
            patterns: [{ from: 'node_modules/onnxruntime-web/dist/*.wasm', to: '[name][ext]' }]
        })],
        mode: 'production'
    }
};

4. Run npx webpack to compile the bundle.

5. Finally, before reloading the server, we need to update index.html.

  • Remove the ort.min.js script tag to stop loading the compiled package from the CDN.
  • Load the code dependencies from bundle.min.js (which contains all our dependencies bundled and minified by webpack) instead of main.js

index.html should now look something like this.

<!DOCTYPE html>
<html>
  <header>
    <title>ONNX Runtime Web Tutorial</title>
  </header>
  <body>
    <script src="bundle.min.js.js"></script>
  </body>
</html>

To make building and launching the live server easier, you could define build and a serve scripts in package.json

"scripts": {
    "build": "npx webpack",
    "serve": "npm run build && npx light-server -s . -p 8080"
  }

Image Classifier

Let's put this model to work and implement the image classification pipeline.

We will need some utility functions to load, resize, and display the image - the canvas object is perfect for this. In addition, image classification systems typically have lots of magic built into the pre-processing pipeline, this is quite trivial to implement in Python using frameworks like numpy, unfortunately this is not the case with JavaScript. It follows that we will have to implement our pre-processing from scratch to transform the image data into the correct input format.

1. DOM Elements

We will need some HTML elements to interact with and display the data.

File input, for loading a file from disk.

<label for="fileIn"><h2>What am I?</h2></label>
<input type="file" id="file-in" name="file-in">

Image displays, we want to display both the input and re-scaled image.

<img id="input-image" class="input-image"></img>
<img id="scaled-image" class="scaled-image"></img>

Classification target, to display our inference results.

<h3 id="target"></h3>

2. Image load and display

We want to load an image from file and display it. Back in main.js, we will get the file input element from the DOM and use FileReader to read the data into memory. Following this, the image data will be passed to handleImage which will draw the image using the 2D canvas context.

const canvas = document.createElement("canvas"),
  ctx = canvas.getContext("2d");

document.getElementById("file-in").onchange = function (evt) {
  let target = evt.target || window.event.src,
    files = target.files;

  if (FileReader && files && files.length) {
      var fileReader = new FileReader();
      fileReader.onload = () => onLoadImage(fileReader);
      fileReader.readAsDataURL(files[0]);
  }
}

function onLoadImage(fileReader) {
    var img = document.getElementById("input-image");
    img.onload = () => handleImage(img);
    img.src = fileReader.result;
}

function handleImage(img) {
  ctx.drawImage(img, 0, 0)
}

2. Preprocess and convert image to tensor

Now that we can load and display an image, we want to move to extracting and processing the data. Remember that our model takes in a matrix of shape [1, 3, 224, 224], this means we will have to resize the image to support any input image and perhaps also transpose the dimensions depending on how we extract the image data.

To resize and extract image data, we will use the canvas context again. Let's define a function processImage that does this. processImage has the necessary elements in scope to immediately draw the scaled image so we will also do that here.

function processImage(img, width) {
  const canvas = document.createElement("canvas"),
    ctx = canvas.getContext("2d")

  // resize image
  canvas.width = width;
  canvas.height = canvas.width * (img.height / img.width);

  // draw scaled image
  ctx.drawImage(img, 0, 0, canvas.width, canvas.height);
  document.getElementById("scaled-image").src = canvas.toDataURL();

  // return data
  return ctx.getImageData(0, 0, width, width).data;
}

We can now add a line to the function handleImage which calls processImage.

const resizedImageData = processImage(img, targetWidth);

Finally, let's implement a function called imageDataToTensor which applies the transforms needed to get the image data ready to be used as input to the model. imageDataToTensor should apply three transforms:

  • Filter out the alpha channel, our input tensor should contain 3 channels corresponding to the RGB channels.

  • ctx.getImageData returns data in the shape [224, 224, 3] so we need to transpose the data to the shape [3, 224, 224]

  • ctx.getImageData returns a UInt8ClampedArray with int values ranging 0 to 255, we need to convert the values to float32 and store them in a Float32Array to construct our tensor input.

function imageDataToTensor(data, dims) {
  // 1a. Extract the R, G, and B channels from the data
  const [R, G, B] = [[], [], []]
  for (let i = 0; i < data.length; i += 4) {
    R.push(data[i]);
    G.push(data[i + 1]);
    B.push(data[i + 2]);
    // 2. skip data[i + 3] thus filtering out the alpha channel
  }
  // 1b. concatenate RGB ~= transpose [224, 224, 3] -> [3, 224, 224]
  const transposedData = R.concat(G).concat(B);

  // 3. convert to float32
  let i, l = transposedData.length; // length, we need this for the loop
  const float32Data = new Float32Array(3 * 224 * 224); // create the Float32Array for output
  for (i = 0; i < l; i++) {
    float32Data[i] = transposedData[i] / 255.0; // convert to float
  }

  const inputTensor = new ort.Tensor("float32", float32Data, dims);
  return inputTensor;
}

3. Display classification result

Almost there, let’s wrap up some loose ends to get the full inference pipeline up and running.

First, stitch together the image processing and inference pipeline in handleImageData.

function handleImage(img, targetWidth) {
  ctx.drawImage(img, 0, 0);
  const resizedImageData = processImage(img, targetWidth);
  const inputTensor = imageDataToTensor(resizedImageData, DIMS);
  run(inputTensor);
}

The output of the model is a list of activation values corresponding to the probability that a certain class is identified in the image. We need to get get the most likely classification result by getting the index of the maximum value in the output data, this is done using an argMax function.

function argMax(arr) {
  let max = arr[0];
  let maxIndex = 0;
  for (var i = 1; i < arr.length; i++) {
      if (arr[i] > max) {
          maxIndex = i;
          max = arr[i];
      }
  }
  return [max, maxIndex];
}

3. Finally, run() needs to be re-factored to accept a tensor input. We also need to use the max index to actually retrieve the results from a list of ImageNet classes. I've pre-converted this list to JSON and we will load it into our script using require - you can find the JSON file in the code repository linked at the start and end of the tutorial.

const classes = require("./imagenet_classes.json").data;

async function run(inputTensor) {
  try {
    const session = await ort.InferenceSession.create('./alexnet.onnx');

    const feeds = { input1: inputTensor };
    const results = await session.run(feeds);

    const [maxValue, maxIndex] = argMax(results.output1.data);
    target.innerHTML = `${classes[maxIndex]}`;
  } catch (e) {
    console.error(e);  // non-fatal error handling
  }
}

That’s it! All that’s left is to re-build our bundle, serve the app, and start classifying some images.

As you test the app, you will notice that prediction quality is not as good as it could be. This is primarily because the current image processing pipeline is still rather rudimentary and can be improved in a number of ways, for example we could implement improved resizing, center-cropping, and/or normalization. Maybe food for a next tutorial, or I’ll just leave it up to you to explore!

Conclusion

That’s it, we’ve built a web app with a machine-learning model running natively in the browser! You can find the full code (including styles and layout) in this code repository on GitHub. I appreciate any and all feedback so feel free to share any Issues or Stars.

Thank you for reading!

Also published on: https://rekoil.io/blog/onnxruntime-web-tutorial


Written by mxkrn | Data scientist and software engineer for audio and music.
Published by HackerNoon on 2021/08/25