Image Classification

This use case shows you how we can use our MANN and aisquared libraries to build and configure a .air file which can be dragged & dropped into a widget created using air JS for use in the browser.

In this example, we will show how to create a .air file to perform image classification in the browser using a neural network. To do this, we will utilize the CIFAR-10 dataset to build the initial model,and then package the model using the aisquared Python SDK. As an added bonus, this tutorial also shows how the beyondml library can be used to prune a model, making it run more efficiently in the browser.

Dependencies

For this example, the following dependencies are required:

  • beyondml

  • aisquared

Both of these are available on pypi via pip. The following cell also runs the commands to install these dependencies as well as imports them into the notebook environment, along with TensorFlow (which is a dependency of the mann package).

! pip install beyondml
! pip install aisquared

import tensorflow as tf
import aisquared
import import beyondml.tflow as mann

Model Creation

Now that the required packages have been installed and imported, it is time to create the sentiment analysis model. To do this, we have to first download and preprocess the data, create the model, prune the model so that it can perform well in the browser, and then package the model in the .air format.

# Loading the data
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train = x_train/255
x_test = x_test/255

label_map = [
    'airplane',
    'automobile',
    'bird',
    'cat',
    'deer',
    'dog',
    'frog',
    'horse',
    'ship',
    'truck'
]

#Creating the model
input_layer = tf.keras.layers.Input(x_train.shape[1:])
x = mann.layers.MaskedConv2D(
    32,
    activation = 'relu'
)(input_layer)
x = mann.layers.MaskedConv2D(
    32,
    activation = 'relu'
)(x)
x = tf.keras.layers.MaxPool2D()(x)
x = tf.keras.layers.Dropout(0.2)(x)
x = mann.layers.MaskedConv2D(
    64,
    activation = 'relu'
)(x)
x = mann.layers.MaskedConv2D(
    64,
    activation = 'relu'
)(x)
x = tf.keras.layers.MaxPool2D()(x)
x = tf.keras.layers.Dropout(0.2)(x)
x = mann.layers.MaskedConv2D(
    128,
    activation = 'relu'
)(x)
x = mann.layers.MaskedConv2D(
    128,
    activation = 'relu'
)(x)
x = tf.keras.layers.MaxPool2D()(x)
x = tf.keras.layers.Dropout(0.2)(x)
x = tf.keras.layers.Flatten()(x)
x = mann.layers.MaskedDense(512, activation = 'relu')(x)
x = mann.layers.MaskedDense(512, activation = 'relu')(x)
x = mann.layers.MaskedDense(512, activation = 'relu')(x)
output_layer = mann.layers.MaskedDense(10, activation = 'softmax')(x)

model = tf.keras.models.Model(input_layer, output_layer)
model.compile(
    loss = 'sparse_categorical_crossentropy',
    optimizer = 'adam',
    metrics = ['accuracy']
)

Model Pruning, Training, and Evaluation

Now that you've created a model, you can get it ready for production. Since we are going to be using this model in an AI Squared - powered application in the browser, we'll want to ensure that it is as lightweight as possible. We accomplish this by using active sparsification - as the model trains, we reduce the parameter count without sacrificing accuracy, thereby yielding a model which requires less computational resources once its deployed. We then test the model's accuracy to ensure that it'll be up to the task of your production workloads.

# Perform initial sparsification
model = mann.utils.mask_model(
    model,
    40,
    x = x_train[:500],
    y = y_train[:500]
)

model.compile(
    loss = 'sparse_categorical_crossentropy',
    optimizer = 'adam',
    metrics = ['accuracy']
)

# Create a pruning callback that will increase pruning rate as performance improves
callback = mann.utils.ActiveSparsification(
    performance_cutoff = 0.65,
    starting_sparsification = 40,
    max_sparsification = 80,
    sparsification_rate = 5
)

# Train the model with the sparsification callback
model.fit(
    x_train,
    y_train,
    epochs = 1000,
    batch_size = 512,
    validation_split = 0.2,
    verbose = 2,
    callbacks = [callback]
)

# Now that the model has been trained, convert all model layers to base TensorFlow layers
model = mann.utils.remove_layer_masks(model)

Package the Model

Now that the model has been created, we can package the model into a single .air file that enables integration into the browser.

To perform this packaging, we will be utilizing the aisquared package ImagePredictor object. This object streamlines the creation of .air files for predefined use cases involving prediction on images.

# Configure the model for integration via the browser

# Harvester
harvester = aisquared.config.harvesting.ImageHarvester()

# Preprocessing steps
resize_step = aisquared.config.preprocessing.image.Resize([32, 32])
divide_step = aisquared.config.preprocessing.image.DivideValue(255)

preprocesser = aisquared.config.preprocessing.image.ImagePreprocessor(
    [
        resize_step,
        divide_step
    ]
)

# Analytic Step - point to the saved model
analytic = aisquared.config.analytic.LocalModel('cifar10.h5', 'cv')

# Postprocessing Step
postprocesser = aisquared.config.postprocessing.MulticlassClassification(label_map)

# Rendering
renderer = aisquared.config.rendering.ImageRendering(
    thickness = '5',
    font_size = '20',
    include_probability = True
)

# Feedback
feedback = aisquared.config.feedback.MulticlassFeedback(label_map)

# Put all of the steps together into a configuration object
config = aisquared.config.ModelConfiguration(
    name = 'CIFAR10Classifier',
    harvesting_steps = harvester,
    preprocessing_steps = preprocesser,
    analytic = analytic,
    postprocessing_steps = postprocesser,
    rendering_steps = renderer
)

# Compile the entirety of the configuration and the model into a .air file
config.compile(dtype = 'float16')

The output of the .air file compilation (config.compile(...))is a .air file that is saved to your working directory.

Congratulations! Now you are ready to drag & drop your model into the browser with the AI Squared Extension!

Last updated