BitString Parity Challenge

This is an implementation of this Warmup problem. It states the following:

⭐ Train an LSTM to solve the XOR problem: that is, given a sequence of bits, determine its parity. The LSTM should consume the sequence, one bit at a time, and then output the correct answer at the sequence’s end. Test the two approaches below:

First, we obtain the data. Save the following code as data.jl.

using Flux: onehot, onehotbatch
using Random

const alphabet = [false, true]  # 0, 1

parity(x) = reduce(xor, x)

gendata(n::Int, k::Int) = gendata(n, k:k)

function gendata(n::Int, k::UnitRange{Int})
    X = bitrand.(rand(k, n))
    return [(onehotbatch(x, alphabet), onehot(y, alphabet)) for (x, y) in zip(X, parity.(X))]
end


Model for 2 bit strings.

include("data.jl")
using Flux, Statistics
using Flux: onehot, onehotbatch, throttle, logitcrossentropy, reset!, onecold
using Parameters: @with_kw

@with_kw mutable struct Args
    lr::Float64 = 1e-3    # Learning rate
    epochs::Int = 20      # Number of epochs for training
    train_len::Int = 100  # Length of training data to be generated
    val_len::Int = 10     # Length of Validation Data
    throttle::Int = 10    # Throttle timeout
end

function getdata(args)
    # Using gendata function defined in data.jl
    train = gendata(args.train_len, 2)	
    val = gendata(args.val_len, 2)
    return train, val
end

function build_model()
    scanner = LSTM(length(alphabet), 20)
    encoder = Dense(20, length(alphabet))
    return scanner, encoder
end

function model(x, scanner, encoder)
    state = scanner.(x.data)[end]
    reset!(scanner)
    encoder(state)
end

function train(; kws...)
    # Initialize the parameters
    args = Args(; kws...)
    
    # Load Data 
    train_data, val_data = getdata(args)

    @info("Constructing Model...")
    scanner,encoder = build_model()
   
    loss(x, y) = logitcrossentropy(model(x, scanner, encoder), y)
    batch_loss(data) = mean(loss(d...) for d in data)

    opt = ADAM(args.lr)
    ps = params(scanner, encoder)
    evalcb = () -> @show batch_loss(val_data)

    @info("Training...")
    for i=1:args.epochs
        Flux.train!(loss, ps, train_data, opt, cb=throttle(evalcb, args.throttle))
    end
    return scanner, encoder
end

function test(scanner, encoder)
    # sanity test
    tx = map(c -> onehotbatch(c, alphabet), [
        [false, true], # 01 -> 1
        [true, false], # 10 -> 1
        [false, false], # 00 -> 0
        [true, true]]) # 11 -> 0
    @info("Test...")
    out = [onecold(model(x, scanner, encoder)) - 1 for x in tx]
    input = [[0,1],[1,0],[0,0],[1,1]]
    for i in 1:length(tx)
        print(input[i]," => ",out[i],"\n")
    end	
end

cd(@__DIR__)
scanner, encoder = train()
test(scanner, encoder)


Save the above as xor1.jl and run it as:

julia xor1.jl

Model for 2000 1 to 10 length strings

include("data.jl")
using Flux, Statistics
using Flux: onehot, onehotbatch, throttle, logitcrossentropy, reset!, onecold
using Parameters: @with_kw

@with_kw mutable struct Args
    lr::Float64 = 1e-3    # Learning rate
    epochs::Int = 20      # Number of epochs for training
    train_len::Int = 2000  # Length of training data to be generated
    val_len::Int = 100     # Length of Validation Data
    throttle::Int = 10    # Throttle timeout
end

function getdata(args)
    # training data of bit strings from length 2 to 10
    train = gendata(args.train_len, 1:10)
    # validation data of bit strings of length 10
    val = gendata(args.val_len, 10)
    return train, val
end

function build_model()
    scanner = LSTM(length(alphabet), 20)
    encoder = Dense(20, length(alphabet))
    return scanner, encoder
end

function model(x, scanner, encoder)
    state = scanner.(x.data)[end]
    reset!(scanner)
    encoder(state)
end

function train(; kws...)
    # Initialize the parameters
    args = Args(; kws...)
    
    # Load Data 
    train_data, val_data = getdata(args)

    @info("Constructing Model...")
    scanner,encoder = build_model()
   
    loss(x, y) = logitcrossentropy(model(x, scanner, encoder), y)
    batch_loss(data) = mean(loss(d...) for d in data)

    opt = ADAM(args.lr)
    ps = params(scanner, encoder)
    evalcb = () -> @show batch_loss(val_data)

    @info("Training...")
    for i=1:args.epochs
        Flux.train!(loss, ps, train_data, opt, cb=throttle(evalcb, args.throttle))
    end

    # Try running the model on strings of length 50.
    #
    # Even though the model has only been trained with
    # much shorter strings, it has learned the
    # parity function and will accurate on longer strings.
    function t50()
        l = batch_loss(gendata(1000, 50))
        println("Batch_loss for length 50 string: ", l,"\n")
    end
    t50()
    return scanner, encoder
end

function test(scanner, encoder)
    # sanity test
    tx = map(c -> onehotbatch(c, alphabet), [
        [false, true], # 01 -> 1
        [true, false], # 10 -> 1
        [false, false], # 00 -> 0
        [true, true]]) # 11 -> 0
    @info("Test...")
    out = [onecold(model(x, scanner, encoder)) - 1 for x in tx]
    input = [[0,1],[1,0],[0,0],[1,1]]
    for i in 1:length(tx)
        print(input[i]," => ",out[i],"\n")
    end	
end

cd(@__DIR__)
scanner, encoder = train()
test(scanner, encoder)


Save the above as xor2.jl and run it as:

julia xor2.jl

Model for 100,000 1 to 50 length strings

include("data.jl")
using Flux, Statistics
using Flux: onehot, onehotbatch, throttle, logitcrossentropy, reset!, onecold
using Parameters: @with_kw

@with_kw mutable struct Args
    lr::Float64 = 1e-3    # Learning rate
    epochs::Int = 20      # Number of epochs for training
    train_len::Int = 100000  # Length of training data to be generated
    val_len::Int = 1000     # Length of Validation Data
    throttle::Int = 10    # Throttle timeout
end

function getdata(args)
    # training data of bit strings from length 2 to 50
    train = gendata(args.train_len, 1:50)
    # validation data of bit strings of length 50
    val = gendata(args.val_len, 50)
    return train, val
end

function build_model()
    scanner = LSTM(length(alphabet), 20)
    encoder = Dense(20, length(alphabet))
    return scanner, encoder
end

function model(x, scanner, encoder)
    state = scanner.(x.data)[end]
    reset!(scanner)
    encoder(state)
end

function train(; kws...)
    # Initialize the parameters
    args = Args(; kws...)
    
    # Load Data 
    train_data, val_data = getdata(args)

    @info("Constructing Model...")
    scanner,encoder = build_model()
   
    loss(x, y) = logitcrossentropy(model(x, scanner, encoder), y)
    batch_loss(data) = mean(loss(d...) for d in data)

    opt = ADAM(args.lr)
    ps = params(scanner, encoder)
    evalcb = () -> @show batch_loss(val_data)

    @info("Training...")
    for i=1:args.epochs
        Flux.train!(loss, ps, train_data, opt, cb=throttle(evalcb, args.throttle))
    end

    return scanner, encoder
end

function test(scanner, encoder)
    # sanity test
    tx = map(c -> onehotbatch(c, alphabet), [
        [false, true], # 01 -> 1
        [true, false], # 10 -> 1
        [false, false], # 00 -> 0
        [true, true]]) # 11 -> 0
    @info("Test...")
    out = [onecold(model(x, scanner, encoder)) - 1 for x in tx]
    input = [[0,1],[1,0],[0,0],[1,1]]
    for i in 1:length(tx)
        print(input[i]," => ",out[i],"\n")
    end	
end

cd(@__DIR__)
scanner, encoder = train()
test(scanner, encoder)

Save the above as xor3.jl and run it as:

julia xor3.jl

– Adarsh Kumar, Mike J Innes,