FizzBuzz
This example was inspired by Fizz Buzz in Tensorflow post by Joel Grus.
Load the necessary packages.
using Flux: Chain, Dense, params, logitcrossentropy, onehotbatch, ADAM, train!, softmax
using Test
Data preparation.
function fizzbuzz(x::Int)
is_divisible_by_three = x % 3 == 0
is_divisible_by_five = x % 5 == 0
if is_divisible_by_three & is_divisible_by_five
return "fizzbuzz"
elseif is_divisible_by_three
return "fizz"
elseif is_divisible_by_five
return "buzz"
else
return "else"
end
end
const LABELS = ["fizz", "buzz", "fizzbuzz", "else"];
Feature engineering
features(x) = float.([x % 3, x % 5, x % 15])
features(x::AbstractArray) = hcat(features.(x)...)
Load the data.
function getdata()
@test fizzbuzz.([3, 5, 15, 98]) == LABELS
raw_x = 1:100;
raw_y = fizzbuzz.(raw_x);
X = features(raw_x);
y = onehotbatch(raw_y, LABELS);
return X, y
end
Define the model and the train function.
function train()
# Get Data
X, y = getdata()
# Model
m = Chain(Dense(3, 10), Dense(10, 4))
loss(x, y) = logitcrossentropy(m(x), y)
# Helpers
deepbuzz(x) = (a = argmax(m(features(x))); a == 4 ? x : LABELS[a])
function monitor(e)
print("epoch $(lpad(e, 4)): loss = $(round(loss(X,y); digits=4)) | ")
@show deepbuzz.([3, 5, 15, 98])
end
opt = ADAM()
# Training
for e in 0:500
train!(loss, params(m), [(X, y)], opt)
if e % 50 == 0
monitor(e)
end
end
end
Finally, train the model.
cd(@__DIR__)
train()
Output:
epoch 0: loss = 3.5546 | deepbuzz.([3, 5, 15, 98]) = ["buzz", "buzz", "fizz", "buzz"]
epoch 50: loss = 1.3087 | deepbuzz.([3, 5, 15, 98]) = Any["fizz", 5, 15, "fizz"]
epoch 100: loss = 1.0117 | deepbuzz.([3, 5, 15, 98]) = Any["fizz", 5, 15, 98]
epoch 150: loss = 0.8925 | deepbuzz.([3, 5, 15, 98]) = Any["fizz", 5, 15, 98]
epoch 200: loss = 0.7895 | deepbuzz.([3, 5, 15, 98]) = Any["fizz", 5, 15, 98]
epoch 250: loss = 0.7017 | deepbuzz.([3, 5, 15, 98]) = Any["fizz", 5, "fizzbuzz", 98]
epoch 300: loss = 0.6269 | deepbuzz.([3, 5, 15, 98]) = Any["fizz", 5, "fizzbuzz", 98]
epoch 350: loss = 0.563 | deepbuzz.([3, 5, 15, 98]) = Any["fizz", 5, "fizzbuzz", 98]
epoch 400: loss = 0.5075 | deepbuzz.([3, 5, 15, 98]) = Any["fizz", 5, "fizzbuzz", 98]
epoch 450: loss = 0.4589 | deepbuzz.([3, 5, 15, 98]) = Any["fizz", 5, "fizzbuzz", 98]
epoch 500: loss = 0.4158 | deepbuzz.([3, 5, 15, 98]) = Any["fizz", 5, "fizzbuzz", 98]