Variational Autoencoder (VAE)
This example is an implementation of Auto-Encoding Variational Bayes by Diederik P Kingma, Max Welling.
Load the necessary packages.
using Base.Iterators: partition
using BSON
using CUDAapi: has_cuda_gpu
using DrWatson: struct2dict
using Flux
using Flux: logitbinarycrossentropy, chunk
using Flux.Data: DataLoader
using Images
using Logging: with_logger
using MLDatasets
using Parameters: @with_kw
using ProgressMeter: Progress, next!
using TensorBoardLogger: TBLogger, tb_overwrite
using Random
Load MNIST images and return loader.
function get_data(batch_size)
xtrain, ytrain = MLDatasets.MNIST.traindata(Float32)
xtrain = reshape(xtrain, 28^2, :)
DataLoader(xtrain, ytrain, batchsize=batch_size, shuffle=true)
end
Define functions for encoding and decoding.
struct Encoder
linear
μ
logσ
Encoder(input_dim, latent_dim, hidden_dim, device) = new(
Dense(input_dim, hidden_dim, tanh) |> device, # linear
Dense(hidden_dim, latent_dim) |> device, # μ
Dense(hidden_dim, latent_dim) |> device, # logσ
)
end
function (encoder::Encoder)(x)
h = encoder.linear(x)
encoder.μ(h), encoder.logσ(h)
end
Decoder(input_dim, latent_dim, hidden_dim, device) = Chain(
Dense(latent_dim, hidden_dim, tanh),
Dense(hidden_dim, input_dim)
) |> device
function reconstuct(encoder, decoder, x, device)
μ, logσ = encoder(x)
z = μ + device(randn(Float32, size(logσ))) .* exp.(logσ)
μ, logσ, decoder(z)
end
function model_loss(encoder, decoder, λ, x, device)
μ, logσ, decoder_z = reconstuct(encoder, decoder, x, device)
len = size(x)[end]
# KL-divergence
kl_q_p = 0.5f0 * sum(@. (exp(2f0 * logσ) + μ^2 -1f0 - 2f0 * logσ)) / len
logp_x_z = -sum(logitbinarycrossentropy.(decoder_z, x)) / len
# regularization
reg = λ * sum(x->sum(x.^2), Flux.params(decoder))
-logp_x_z + kl_q_p + reg
end
function convert_to_image(x, y_size)
Gray.(permutedims(vcat(reshape.(chunk(sigmoid.(x |> cpu), y_size), 28, :)...), (2, 1)))
end
Set arguments for the train
function.
@with_kw mutable struct Args
η = 1e-3 # learning rate
λ = 0.01f0 # regularization paramater
batch_size = 128 # batch size
sample_size = 10 # sampling size for output
epochs = 20 # number of epochs
seed = 0 # random seed
cuda = true # use GPU
input_dim = 28^2 # image size
latent_dim = 2 # latent dimension
hidden_dim = 500 # hidden dimension
verbose_freq = 10 # logging for every verbose_freq iterations
tblogger = false # log training with tensorboard
save_path = "output" # results path
end
Define train
function.
function train(; kws...)
# load hyperparamters
args = Args(; kws...)
args.seed > 0 && Random.seed!(args.seed)
# GPU config
if args.cuda && has_cuda_gpu()
device = gpu
@info "Training on GPU"
else
device = cpu
@info "Training on CPU"
end
# load MNIST images
loader = get_data(args.batch_size)
# initialize encoder and decoder
encoder = Encoder(args.input_dim, args.latent_dim, args.hidden_dim, device)
decoder = Decoder(args.input_dim, args.latent_dim, args.hidden_dim, device)
# ADAM optimizer
opt = ADAM(args.η)
# parameters
ps = Flux.params(encoder.linear, encoder.μ, encoder.logσ, decoder)
!ispath(args.save_path) && mkpath(args.save_path)
# logging by TensorBoard.jl
if args.tblogger
tblogger = TBLogger(args.save_path, tb_overwrite)
end
# fixed input
original, _ = first(get_data(args.sample_size^2))
original = original |> device
image = convert_to_image(original, args.sample_size)
image_path = joinpath(args.save_path, "original.png")
save(image_path, image)
# training
train_steps = 0
@info "Start Training, total $(args.epochs) epochs"
for epoch = 1:args.epochs
@info "Epoch $(epoch)"
progress = Progress(length(loader))
for (x, _) in loader
loss, back = Flux.pullback(ps) do
model_loss(encoder, decoder, args.λ, x |> device, device)
end
grad = back(1f0)
Flux.Optimise.update!(opt, ps, grad)
# progress meter
next!(progress; showvalues=[(:loss, loss)])
# logging with TensorBoard
if args.tblogger && train_steps % args.verbose_freq == 0
with_logger(tblogger) do
@info "train" loss=loss
end
end
train_steps += 1
end
# save image
_, _, rec_original = reconstuct(encoder, decoder, original, device)
image = convert_to_image(rec_original, args.sample_size)
image_path = joinpath(args.save_path, "epoch_$(epoch).png")
save(image_path, image)
@info "Image saved: $(image_path)"
end
# save model
model_path = joinpath(args.save_path, "model.bson")
let encoder = cpu(encoder), decoder = cpu(decoder), args=struct2dict(args)
BSON.@save model_path encoder decoder args
@info "Model saved: $(model_path)"
end
end
if abspath(PROGRAM_FILE) == @__FILE__
train()
end
To visualize Latent space clustering, save the following code as vae_plot.jl
:
include("vae_mnist.jl")
using Plots
function plot_result()
BSON.@load "output/model.bson" encoder decoder args
args = Args(; args...)
device = args.cuda && has_cuda_gpu() ? gpu : cpu
encoder, decoder = encoder |> device, decoder |> device
# load MNIST images
loader = get_data(args.batch_size)
# clustering in the latent space
# visualize first two dims
plt = scatter(palette=:rainbow)
for (i, (x, y)) in enumerate(loader)
i < 20 || break
μ, logσ = encoder(x |> device)
scatter!(μ[1, :], μ[2, :],
markerstrokewidth=0, markeralpha=0.8,
aspect_ratio=1,
markercolor=y, label="")
end
savefig(plt, "output/clustering.png")
z = range(-2.0, stop=2.0, length=11)
len = Base.length(z)
z1 = repeat(z, len)
z2 = sort(z1)
x = zeros(Float32, args.latent_dim, len^2) |> device
x[1, :] = z1
x[2, :] = z2
samples = decoder(x)
image = convert_to_image(samples, len)
save("output/manifold.png", image)
end
if abspath(PROGRAM_FILE) == @__FILE__
plot_result()
end
Run vae_plot.jl
.
julia vae_plot.jl