Seq2Seq phoneme detection on CMUDict

This example is based on Neural Machine Translation by Jointly Learning to Align and Translate.

First, we save the following code as 0-data.jl and run it.

using Flux, Flux.Data.CMUDict
using Flux: onehot, batchseq
using Base.Iterators: partition

tokenise(s, α) = [onehot(c, α) for c in s]

function getdata(args)
    dict = cmudict()
    alphabet = [:end, CMUDict.alphabet()...]
    args.Nin = length(alphabet)

    phones = [:start, :end, CMUDict.symbols()...]
    args.phones_len = length(phones)

    # Turn a word into a sequence of vectors
    tokenise("PHYLOGENY", alphabet)
    # Same for phoneme lists
    tokenise(dict["PHYLOGENY"], phones)

    words = sort(collect(keys(dict)), by = length)

    # Finally, create iterators for our inputs and outputs.
    batches(xs, p) = [batchseq(b, p) for b in partition(xs, 50)]

    Xs = batches([tokenise(word, alphabet) for word in words],
             onehot(:end, alphabet))

    Ys = batches([tokenise([dict[word]..., :end], phones) for word in words],
             onehot(:end, phones))

    Yo = batches([tokenise([:start, dict[word]...], phones) for word in words],
             onehot(:end, phones))

    data = collect(zip(Xs, Yo, Ys))
    return data, alphabet, phones
end


Now, we define the actual model and training code.

Load the necessary packages.

include("0-data.jl")
using Flux: flip, logitcrossentropy, reset!, throttle
using Parameters: @with_kw
using StatsBase: wsample


Define Hyperparameter arguments.

@with_kw mutable struct Args
    lr::Float64 = 1e-3      # learning rate
    Nin::Int = 0            # size of input layer, will be assigned as length(alphabet)
    Nh::Int = 30            # size of hidden layer
    phones_len::Int = 0     # length of phonemes
    throttle::Int = 30      # throttle timeout
end


Define the build model function.

function build_model(args)
    # A recurrent model which takes a token and returns a context-dependent
    # annotation.
    forward  = LSTM(args.Nin, args.Nh÷2)
    backward = LSTM(args.Nin, args.Nh÷2)
    encode(tokens) = vcat.(forward.(tokens), flip(backward, tokens))

    alignnet = Dense(2*args.Nh, 1)

    # A recurrent model which takes a sequence of annotations, attends, and returns
    # a predicted output token.
    recur   = LSTM(args.Nh+args.phones_len, args.Nh)
    toalpha = Dense(args.Nh, args.phones_len)
    return (forward, backward, alignnet, recur, toalpha), encode
end


align(s, t, alignnet) = alignnet(vcat(t, s .* Int.(ones(1, size(t, 2)))))

function asoftmax(xs)
  xs = [exp.(x) for x in xs]
  s = sum(xs)
  return [x ./ s for x in xs]
end

function decode1(tokens, phone, state)
    # Unpack models
    forward, backward, alignnet, recur, toalpha = state
    weights = asoftmax([align(recur.state[2], t, alignnet) for t in tokens])
    context = sum(map((a, b) -> a .* b, weights, tokens))
    y = recur(vcat(Float32.(phone), context))
    return toalpha(y)
end

decode(tokens, phones, state) = [decode1(tokens, phone, state) for phone in phones]


Define the model.

function model(x, y, state, encode)
    # Unpack models
    forward, backward, alignnet, recur, toalpha = state
     = decode(encode(x), y, state)
    reset!(state)
    return 
end


Define the predict function.

function predict(s, state, encode, alphabet, phones)
    ts = encode(tokenise(s, alphabet))
    ps = Any[:start]
    for i = 1:50
      dist = softmax(decode1(ts, onehot(ps[end], phones), state))
      next = wsample(phones, vec(dist))
      next == :end && break
      push!(ps, next)
    end
    reset!(state)
    return ps[2:end]
end


Define the train function.

function train(; kws...)
    # Initialize Hyperparameters
    args = Args(; kws...)
    @info("Loading Data...")
    data,alphabet,phones = getdata(args)

    # The full model
    # state = (forward, backward, alignnet, recur, toalpha)
    @info("Constructing Model...")
    state, encode = build_model(args)

    loss(x, yo, y) = sum(logitcrossentropy.(model(x, yo, state, encode), y))
    evalcb = () -> @show loss(data[500]...)
    opt = ADAM(args.lr)
    @info("Training...")
    Flux.train!(loss, params(state), data, opt, cb = throttle(evalcb, args.throttle))
    return state, encode, alphabet, phones
end


Train the model

cd(@__DIR__)
state, encode, alphabet, phones = train()
@info("Testing...")
predict("PHYLOGENY", state, encode, alphabet, phones)


Output:

[ Info: Loading Data...
[ Info: Constructing Model...
[ Info: Training...
loss(data[500]...) = 30.91739f0
loss(data[500]...) = 17.47304f0
loss(data[500]...) = 14.833056f0
loss(data[500]...) = 13.803842f0
loss(data[500]...) = 14.142663f0
loss(data[500]...) = 15.463973f0
loss(data[500]...) = 15.767911f0
loss(data[500]...) = 17.431313f0
[ Info: Testing...

– Adarsh Kumar, Mike J Innes