Conditional deep convolutional generative adversarial networks

This example shows and implementation of a Conditional deep convolutional generative adversarial network.

Load the necessary packages.

using Base.Iterators: partition
using Flux
using Flux.Optimise: update!
using Flux: logitbinarycrossentropy
using Images
using MLDatasets
using Statistics
using Parameters: @with_kw
using Random
using Printf
using CUDAapi
using Zygote
if has_cuda()		# Check if CUDA is available
    @info "CUDA is on"
    import CuArrays		# If CUDA is available, import CuArrays
    CuArrays.allowscalar(false)
end


Set the hyperparameters.

@with_kw struct HyperParams
    batch_size::Int = 128
    latent_dim::Int = 100
    nclasses::Int = 10
    epochs::Int = 25
    verbose_freq::Int = 1000
    output_x::Int = 6        # No. of sample images to concatenate along x-axis 
    output_y::Int = 6        # No. of sample images to concatenate along y-axis
    lr_dscr::Float64 = 0.0002
    lr_gen::Float64 = 0.0002
end


Discriminator and generator functions.

struct Discriminator
    d_labels		# Submodel to take labels as input and convert them to the shape of image ie. (28, 28, 1, batch_size)
    d_common   
end

function discriminator(args)
    d_labels = Chain(Dense(args.nclasses,784), x-> reshape(x, 28, 28, 1, size(x, 2))) |> gpu
    d_common = Chain(Conv((3,3), 2=>128, pad=(1,1), stride=(2,2)),
                  x-> leakyrelu.(x, 0.2f0),
                  Dropout(0.4),
                  Conv((3,3), 128=>128, pad=(1,1), stride=(2,2), leakyrelu),
                  x-> leakyrelu.(x, 0.2f0),
                  x-> reshape(x, :, size(x, 4)),
                  Dropout(0.4),
                  Dense(6272, 1)) |> gpu
    Discriminator(d_labels, d_common)
end

function (m::Discriminator)(x, y)
    t = cat(m.d_labels(x), y, dims=3)
    return m.d_common(t)
end

struct Generator
    g_labels          # Submodel to take labels as input and convert it to the shape of (7, 7, 1, batch_size) 
    g_latent          # Submodel to take latent_dims as input and convert it to shape of (7, 7, 128, batch_size)
    g_common    
end

function generator(args)
    g_labels = Chain(Dense(args.nclasses, 49), x-> reshape(x, 7 , 7 , 1 , size(x, 2))) |> gpu
    g_latent = Chain(Dense(args.latent_dim, 6272), x-> leakyrelu.(x, 0.2f0), x-> reshape(x, 7, 7, 128, size(x, 2))) |> gpu
    g_common = Chain(ConvTranspose((4, 4), 129=>128; stride=2, pad=1),
            BatchNorm(128, leakyrelu),
            Dropout(0.25),
            ConvTranspose((4, 4), 128=>64; stride=2, pad=1),
            BatchNorm(64, leakyrelu),
            Conv((7, 7), 64=>1, tanh; stride=1, pad=3)) |> gpu
    Generator(g_labels, g_latent, g_common)
end

function (m::Generator)(x, y)
    t = cat(m.g_labels(x), m.g_latent(y), dims=3)
    return m.g_common(t)
end


Load the MNIST dataset.

function load_data(hparams)
    # Load MNIST dataset
    images, labels = MLDatasets.MNIST.traindata(Float32)
    # Normalize to [-1, 1]
    image_tensor = reshape(@.(2f0 * images - 1f0), 28, 28, 1, :)
    y = float.(Flux.onehotbatch(labels, 0:hparams.nclasses-1))
    # Partition into batches
    data = [(image_tensor[:, :, :, r], y[:, r]) |> gpu for r in partition(1:60000, hparams.batch_size)]
    return data
end


Loss functions.

function discr_loss(real_output, fake_output)
    real_loss = mean(logitbinarycrossentropy.(real_output, 1f0))
    fake_loss = mean(logitbinarycrossentropy.(fake_output, 0f0))
    return (real_loss + fake_loss)
end

generator_loss(fake_output) = mean(logitbinarycrossentropy.(fake_output, 1f0))


Train functions.

function train_discr(discr, fake_data, fake_labels, original_data, label, opt_discr)
    ps = params(discr.d_labels, discr.d_common)
    loss, back = Zygote.pullback(ps) do
            discr_loss(discr(label, original_data), discr(fake_labels, fake_data))
    end
    grads = back(1f0)
    update!(opt_discr, ps, grads)
    return loss
end

Zygote.@nograd train_discr

function train_gan(gen, discr, original_data, label, opt_gen, opt_discr, hparams)
    # Random Gaussian Noise and Labels as input for the generator
    noise = randn!(similar(original_data, (hparams.latent_dim, hparams.batch_size)))
    labels = rand(0:hparams.nclasses-1, hparams.batch_size)
    y = Flux.onehotbatch(labels, 0:hparams.nclasses-1)
    noise , y  = noise, float.(y) |> gpu

    ps = params(gen.g_labels, gen.g_latent, gen.g_common)
    loss = Dict()
    loss["gen"], back = Zygote.pullback(ps) do
            fake = gen(y, noise)
            loss["discr"] = train_discr(discr, fake, y, original_data, label, opt_discr)
            generator_loss(discr(y, fake))
    end
    grads = back(1f0)
    update!(opt_gen, ps, grads)
    return loss
end

function create_output_image(gen, fixed_noise, fixed_labels, hparams)
    @eval Flux.istraining() = false
    fake_images = @. cpu(gen(fixed_labels, fixed_noise))
    @eval Flux.istraining() = true
    image_array = permutedims(dropdims(reduce(vcat, reduce.(hcat, partition(fake_images, hparams.output_y))); dims=(3, 4)), (2, 1))
    image_array = @. Gray(image_array + 1f0) / 2f0
    return image_array
end

function train(; kws...)
    hparams = HyperParams(kws...)

    # Load the data
    data = load_data(hparams)

    fixed_noise = [randn(hparams.latent_dim, 1) |> gpu for _=1:hparams.output_x * hparams.output_y]

    fixed_labels = [float.(Flux.onehotbatch(rand(0:hparams.nclasses-1, 1), 0:hparams.nclasses-1)) |> gpu 
                             for _ =1:hparams.output_x * hparams.output_y]

    # Discriminator
    dscr = discriminator(hparams) 

    # Generator
    gen =  generator(hparams)

    # Optimizers
    opt_dscr = ADAM(hparams.lr_dscr, (0.5, 0.99))
    opt_gen = ADAM(hparams.lr_gen, (0.5, 0.99))

    # Check if the `output` directory exists or needed to be created
    isdir("output")||mkdir("output")

    # Training
    train_steps = 0
    for ep in 1:hparams.epochs
        @info "Epoch $ep"
        for (x, y) in data
            # Update discriminator and generator
            loss = train_gan(gen, dscr, x, y, opt_gen, opt_dscr, hparams)

            if train_steps % hparams.verbose_freq == 0
                @info("Train step $(train_steps), Discriminator loss = $(loss["discr"]), Generator loss = $(loss["gen"])")
                # Save generated fake image
                output_image = create_output_image(gen, fixed_noise, fixed_labels, hparams)
                save(@sprintf("output/cgan_steps_%06d.png", train_steps), output_image)
            end

            train_steps += 1
        end
    end

    output_image = create_output_image(gen, fixed_noise, fixed_labels, hparams)
    save(@sprintf("output/cgan_steps_%06d.png", train_steps), output_image)
    return Flux.onecold.(cpu(fixed_labels))
end    


Train the model.

cd(@__DIR__)
fixed_labels = train()

– Adarsh Kumar