Multi-layer Recurrent Neural Network for character-level language model
This is an implementation of a multi-layer Recurrent Neural Network for training/sampling from character-level language models.
Load the necessary packages.
using Flux
using Flux: onehot, chunk, batchseq, throttle, logitcrossentropy
using StatsBase: wsample
using Base.Iterators: partition
using Parameters: @with_kw
Hyperparameter arguments
@with_kw mutable struct Args
lr::Float64 = 1e-2 # Learning rate
seqlen::Int = 50 # Length of batchseqences
nbatch::Int = 50 # number of batches text is divided into
throttle::Int = 30 # Throttle timeout
end
Load the data.
function getdata(args)
# Download the data if not downloaded as 'input.txt'
isfile("input.txt") ||
download("https://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt","input.txt")
text = collect(String(read("input.txt")))
# an array of all unique characters
alphabet = [unique(text)..., '_']
text = map(ch -> onehot(ch, alphabet), text)
stop = onehot('_', alphabet)
N = length(alphabet)
# Partitioning the data as sequence of batches, which are then collected as array of batches
Xs = collect(partition(batchseq(chunk(text, args.nbatch), stop), args.seqlen))
Ys = collect(partition(batchseq(chunk(text[2:end], args.nbatch), stop), args.seqlen))
return Xs, Ys, N, alphabet
end
Function to construct model.
function build_model(N)
return Chain(
LSTM(N, 128),
LSTM(128, 128),
Dense(128, N))
end
Function to train model.
function train(; kws...)
# Initialize the parameters
args = Args(; kws...)
# Get Data
Xs, Ys, N, alphabet = getdata(args)
# Constructing Model
m = build_model(N)
function loss(xs, ys)
l = sum(logitcrossentropy.(m.(xs), ys))
return l
end
## Training
opt = ADAM(args.lr)
tx, ty = (Xs[5], Ys[5])
evalcb = () -> @show loss(tx, ty)
Flux.train!(loss, params(m), zip(Xs, Ys), opt, cb = throttle(evalcb, args.throttle))
return m, alphabet
end
Sampling.
function sample(m, alphabet, len; seed="")
m = cpu(m)
Flux.reset!(m)
buf = IOBuffer()
if seed == ""
seed = string(rand(alphabet))
end
write(buf, seed)
c = wsample(alphabet, softmax(m.(map(c -> onehot(c, alphabet), collect(seed)))[end]))
for i = 1:len
write(buf, c)
c = wsample(alphabet, softmax(m(onehot(c, alphabet))))
end
return String(take!(buf))
end
Train the model.
cd(@__DIR__)
m, alphabet = train()
sample(m, alphabet, 1000) |> println
Output:
loss(tx, ty) = 188.493f0
loss(tx, ty) = 164.71947f0
loss(tx, ty) = 163.34908f0
loss(tx, ty) = 141.63608f0
loss(tx, ty) = 135.5761f0
loss(tx, ty) = 134.25943f0
loss(tx, ty) = 133.33817f0
loss(tx, ty) = 133.78276f0
loss(tx, ty) = 131.65105f0
loss(tx, ty) = 129.55223f0
loss(tx, ty) = 129.3795f0
loss(tx, ty) = 128.73051f0
loss(tx, ty) = 126.728775f0
loss(tx, ty) = 125.84408f0
loss(tx, ty) = 125.60294f0
loss(tx, ty) = 124.5169f0
loss(tx, ty) = 126.114624f0
loss(tx, ty) = 126.629616f0
loss(tx, ty) = 123.28575f0
loss(tx, ty) = 123.2743f0
loss(tx, ty) = 121.49012f0
loss(tx, ty) = 120.70581f0
loss(tx, ty) = 119.71306f0
loss(tx, ty) = 119.45561f0
loss(tx, ty) = 118.88265f0
loss(tx, ty) = 118.57439f0
loss(tx, ty) = 119.1112f0
loss(tx, ty) = 117.25467f0
loss(tx, ty) = 117.20366f0
loss(tx, ty) = 116.40781f0
loss(tx, ty) = 116.5148f0
loss(tx, ty) = 115.36537f0
loss(tx, ty) = 115.88582f0
loss(tx, ty) = 115.18013f0
loss(tx, ty) = 115.0249f0
loss(tx, ty) = 114.61435f0
loss(tx, ty) = 115.03795f0
loss(tx, ty) = 115.97866f0
loss(tx, ty) = 115.85941f0
loss(tx, ty) = 115.91748f0
loss(tx, ty) = 114.87762f0
loss(tx, ty) = 113.778244f0
loss(tx, ty) = 114.05181f0
loss(tx, ty) = 114.4699f0
loss(tx, ty) = 113.46146f0
loss(tx, ty) = 114.079956f0
loss(tx, ty) = 113.89629f0
loss(tx, ty) = 113.31437f0
loss(tx, ty) = 114.2162f0
loss(tx, ty) = 113.37099f0
loss(tx, ty) = 113.31416f0
loss(tx, ty) = 112.205414f0
loss(tx, ty) = 112.91714f0
$venght heed ofve oo me comoin? Gad! deainsaon: by heave Lo incle mate hled seet by diloile, sreunk of aralons leever to too ressirioa Leh ghorD P: the varn I him thizow yow bunsinceluolo su netfessnd obll kint teets; goudint trere,
Goe rver
A mevaeamenll dinth thth the s kn the gull, and rio seamoly; Hehy the ssint: ul yo, and an chat in ofr do doie you, That:
Hellns, wwereunteir a shil oleaty; I daprey luch you noillt: IORIA AHIOrighe
Thmanwsinirill I gan,
Aleadienuath heres lista!
With hitheaeed ty high abe yooak doveshalithl you ssii is hime radm, sale the rarareteke uo, did am,.
NoRATONULELELEYF:
As youariom, neuersamie for espring to he eorio?,
O memelltece ridgimilf death saows to he then uipahanf am the priuse in of aletemial our mve rown 't, st the kingthe be or them reart am, I hace it beint aevel bed and piso, ms sienint and tn.
Why ded reowaid;
Theare mu-s ther, sage, shr will:
Whichge 'ts hoor:
Roingiedienmine ho thee a
UGARIreasiths sto nnencen mend and wey bast hith soo