This ConvNets model classifies MNIST digits with a convolutional network. It writes out saved model to the file “mnist_conv.bson”. Also, it demonstrates basic model construction, training, saving, conditional early-exit, and learning rate scheduling.

Note that this model, while simple, should hit around 99% test accuracy after training for approximately 20 epochs.

Load the necessary packages.

using Flux, Flux.Data.MNIST, Statistics
using Flux: onehotbatch, onecold, logitcrossentropy
using Base.Iterators: partition
using Printf, BSON
using Parameters: @with_kw
using CUDA
if has_cuda()
    @info "CUDA is on"

Set learning rate, batch size, number of epochs and set as gpu (if available) parameters for the model.

@with_kw mutable struct Args
    lr::Float64 = 3e-3
    epochs::Int = 20
    batch_size = 128
    savepath::String = "./" 



Bundle images together with labels and group into minibatchess

function make_minibatch(X, Y, idxs)
    X_batch = Array{Float32}(undef, size(X[1])..., 1, length(idxs))
    for i in 1:length(idxs)
        X_batch[:, :, :, i] = Float32.(X[idxs[i]])
    Y_batch = onehotbatch(Y[idxs], 0:9)
    return (X_batch, Y_batch)


make_minibatch (generic function with 1 method)

Load the MNIST dataset from Flux.Data.MNIST.

function get_processed_data(args)
    # Load labels and images
    train_labels = MNIST.labels()
    train_imgs = MNIST.images()
    mb_idxs = partition(1:length(train_imgs), args.batch_size)
    train_set = [make_minibatch(train_imgs, train_labels, i) for i in mb_idxs] 
    # Prepare test set as one giant minibatch:
    test_imgs = MNIST.images(:test)
    test_labels = MNIST.labels(:test)
    test_set = make_minibatch(test_imgs, test_labels, 1:length(test_imgs))

    return train_set, test_set



get_processed_data (generic function with 1 method)

Build the ConvNets model.

function build_model(args; imgsize = (28,28,1), nclasses = 10)
    cnn_output_size = Int.(floor.([imgsize[1]/8,imgsize[2]/8,32]))	

    return Chain(
    # First convolution, operating upon a 28x28 image
    Conv((3, 3), imgsize[3]=>16, pad=(1,1), relu),

    # Second convolution, operating upon a 14x14 image
    Conv((3, 3), 16=>32, pad=(1,1), relu),

    # Third convolution, operating upon a 7x7 image
    Conv((3, 3), 32=>32, pad=(1,1), relu),

    # Reshape 3d tensor into a 2d one using `Flux.flatten`, at this point it should be (3, 3, 32, N)
    Dense(prod(cnn_output_size), 10))


build_model (generic function with 1 method)

We augment x a little bit here, adding in random noise.

augment(x) = x .+ gpu(0.1f0*randn(eltype(x), size(x)))


augment (generic function with 1 method)

Returns a vector of all parameters used in model.

paramvec(m) = vcat(map(p->reshape(p, :), params(m))...)


paramvec (generic function with 1 method)

Function to check if any element is NaN or not

anynan(x) = any(isnan.(x))

accuracy(x, y, model) = mean(onecold(cpu(model(x))) .== onecold(cpu(y)))

Train the model.

function train(; kws...)	
    args = Args(; kws...)

    @info("Loading data set")
    train_set, test_set = get_processed_data(args)

    # Define our model.  We will use a simple convolutional architecture with
    # three iterations of Conv -> ReLU -> MaxPool, followed by a final Dense layer.
    @info("Building model...")
    model = build_model(args) 

    # Load model and datasets onto GPU, if enabled
    train_set = gpu.(train_set)
    test_set = gpu.(test_set)
    model = gpu(model)
    # Make sure our model is nicely precompiled before starting our training loop

    # `loss()` calculates the crossentropy loss between our prediction `y_hat`
    # (calculated from `model(x)`) and the ground truth `y`.  We augment the data
    # a bit, adding gaussian random noise to our image to make it more robust.
    function loss(x, y)    
         = augment(x)
         = model()
        return logitcrossentropy(, y)
    # Train our model with the given training set using the ADAM optimizer and
    # printing out performance against the test set as we go.
    opt = ADAM(
    @info("Beginning training loop...")
    best_acc = 0.0
    last_improvement = 0
    for epoch_idx in 1:args.epochs
        # Train for a single epoch
        Flux.train!(loss, params(model), train_set, opt)
        # Terminate on NaN
        if anynan(paramvec(model))
            @error "NaN params"
        # Calculate accuracy:
        acc = accuracy(test_set..., model)
        @info(@sprintf("[%d]: Test accuracy: %.4f", epoch_idx, acc))
        # If our accuracy is good enough, quit out.
        if acc >= 0.999
            @info(" -> Early-exiting: We reached our target accuracy of 99.9%")
        # If this is the best accuracy we've seen so far, save the model out
        if acc >= best_acc
            @info(" -> New best accuracy! Saving model out to mnist_conv.bson")
            BSON.@save joinpath(args.savepath, "mnist_conv.bson") params=cpu.(params(model)) epoch_idx acc
            best_acc = acc
            last_improvement = epoch_idx
        # If we haven't seen improvement in 5 epochs, drop our learning rate:
        if epoch_idx - last_improvement >= 5 && opt.eta > 1e-6
            opt.eta /= 10.0
            @warn(" -> Haven't improved in a while, dropping learning rate to $(opt.eta)!")
            # After dropping learning rate, give it a few epochs to improve
            last_improvement = epoch_idx
        if epoch_idx - last_improvement >= 10
            @warn(" -> We're calling this converged.")

Testing the model, from saved model

function test(; kws...)
    args = Args(; kws...)
    # Loading the test data
    _,test_set = get_processed_data(args)
    # Re-constructing the model with random initial weights
    model = build_model(args)
    # Loading the saved parameters
    BSON.@load joinpath(args.savepath, "mnist_conv.bson") params
    # Loading parameters onto the model
    Flux.loadparams!(model, params)
    test_set = gpu.(test_set)
    model = gpu(model)
    @show accuracy(test_set...,model)

Finally, we train the model.



┌ Info: Loading data set
└ @ Main In[10]:4
┌ Info: Building model...
└ @ Main In[10]:9
┌ Info: Beginning training loop...
└ @ Main In[10]:33
┌ Info: [1]: Test accuracy: 0.9794
└ @ Main In[10]:49
┌ Info:  -> New best accuracy! Saving model out to mnist_conv.bson
└ @ Main In[10]:58
┌ Info: [2]: Test accuracy: 0.9867
└ @ Main In[10]:49
┌ Info:  -> New best accuracy! Saving model out to mnist_conv.bson
└ @ Main In[10]:58
┌ Info: [3]: Test accuracy: 0.9884
└ @ Main In[10]:49
┌ Info:  -> New best accuracy! Saving model out to mnist_conv.bson
└ @ Main In[10]:58
┌ Info: [4]: Test accuracy: 0.9868
└ @ Main In[10]:49
┌ Info: [5]: Test accuracy: 0.9870
└ @ Main In[10]:49
┌ Info: [6]: Test accuracy: 0.9900
└ @ Main In[10]:49
┌ Info:  -> New best accuracy! Saving model out to mnist_conv.bson
└ @ Main In[10]:58
┌ Info: [7]: Test accuracy: 0.9901
└ @ Main In[10]:49
┌ Info:  -> New best accuracy! Saving model out to mnist_conv.bson
└ @ Main In[10]:58
┌ Info: [8]: Test accuracy: 0.9904
└ @ Main In[10]:49
┌ Info:  -> New best accuracy! Saving model out to mnist_conv.bson
└ @ Main In[10]:58
┌ Info: [9]: Test accuracy: 0.9856
└ @ Main In[10]:49
┌ Info: [10]: Test accuracy: 0.9885
└ @ Main In[10]:49
┌ Info: [11]: Test accuracy: 0.9894
└ @ Main In[10]:49
┌ Info: [12]: Test accuracy: 0.9888
└ @ Main In[10]:49
┌ Info: [13]: Test accuracy: 0.9876
└ @ Main In[10]:49
┌ Warning:  -> Haven't improved in a while, dropping learning rate to 0.00030000000000000003!
└ @ Main In[10]:67
┌ Info: [14]: Test accuracy: 0.9924
└ @ Main In[10]:49
┌ Info:  -> New best accuracy! Saving model out to mnist_conv.bson
└ @ Main In[10]:58
┌ Info: [15]: Test accuracy: 0.9924
└ @ Main In[10]:49
┌ Info:  -> New best accuracy! Saving model out to mnist_conv.bson
└ @ Main In[10]:58
┌ Info: [16]: Test accuracy: 0.9921
└ @ Main In[10]:49
┌ Info: [17]: Test accuracy: 0.9924
└ @ Main In[10]:49
┌ Info:  -> New best accuracy! Saving model out to mnist_conv.bson
└ @ Main In[10]:58
┌ Info: [18]: Test accuracy: 0.9919
└ @ Main In[10]:49
┌ Info: [19]: Test accuracy: 0.9924
└ @ Main In[10]:49
┌ Info:  -> New best accuracy! Saving model out to mnist_conv.bson
└ @ Main In[10]:58
┌ Info: [20]: Test accuracy: 0.9919
└ @ Main In[10]:49
accuracy(test_set..., model) = 0.9924

– Elliot Saba, Adarsh Kumar, Mike J Innes, Dhairya Gandhi, Sudhanshu Agrawal, Sambit Kumar Dash,, Carlo Lucibello, Andrew Dinhobl