diff --git a/Project.toml b/Project.toml index 250d47a..0ed446b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "QXContexts" uuid = "04c26001-d4a1-49d2-b090-1d469cf06784" authors = ["QuantEx team"] -version = "0.1.8" +version = "0.1.9" [deps] ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63" diff --git a/examples/ghz/ghz_5.yml b/examples/ghz/ghz_5.yml index 3021152..5a0bf2b 100644 --- a/examples/ghz/ghz_5.yml +++ b/examples/ghz/ghz_5.yml @@ -1,10 +1,13 @@ -amplitudes: - - "01000" - - "01110" - - "10101" - - "10001" - - "10010" - - "11111" +output: + method: List + params: + bitstrings: + - "01000" + - "01110" + - "10101" + - "10001" + - "10010" + - "11111" partitions: parameters: v1: 2 diff --git a/examples/ghz/ghz_5_rejection.yml b/examples/ghz/ghz_5_rejection.yml new file mode 100644 index 0000000..971aa92 --- /dev/null +++ b/examples/ghz/ghz_5_rejection.yml @@ -0,0 +1,12 @@ +output: + method: Rejection + params: + num_qubits: 5 + num_samples: 10 + M: 0.001 + fix_M: false + seed: 42 +partitions: + parameters: + v1: 2 + v2: 2 diff --git a/examples/ghz/ghz_5_uniform.yml b/examples/ghz/ghz_5_uniform.yml new file mode 100644 index 0000000..b07563d --- /dev/null +++ b/examples/ghz/ghz_5_uniform.yml @@ -0,0 +1,10 @@ +output: + method: Uniform + params: + num_samples: 10 + seed: 42 + num_qubits: 5 +partitions: + parameters: + v1: 2 + v2: 2 \ No newline at end of file diff --git a/src/QXContexts.jl b/src/QXContexts.jl index 72a3879..f554cf5 100644 --- a/src/QXContexts.jl +++ b/src/QXContexts.jl @@ -7,10 +7,12 @@ include("parameters.jl") include("dsl.jl") include("execution.jl") include("sysimage/sysimage.jl") +include("sampling.jl") @reexport using QXContexts.Logger @reexport using QXContexts.Param @reexport using QXContexts.DSL @reexport using QXContexts.Execution +@reexport using QXContexts.Sampling end diff --git a/src/execution.jl b/src/execution.jl index b73f18e..7465afd 100644 --- a/src/execution.jl +++ b/src/execution.jl @@ -1,6 +1,7 @@ module Execution -export QXContext, execute!, timer_output, reduce_nodes +export QXContext, QXMPIContext, execute!, timer_output, reduce_nodes +export Samples export execute export set_open_bonds!, set_slice_vals! export compute_amplitude! @@ -13,9 +14,11 @@ import LinearAlgebra import TensorOperations import QXContexts.Logger: @debug using TimerOutputs +using Random using OMEinsum using QXContexts.DSL using QXContexts.Param +import QXContexts const timer_output = TimerOutput() if haskey(ENV, "QXRUN_TIMER") @@ -140,13 +143,23 @@ Base.iterate(a::SliceIterator, state) = length(a) <= (state + 1 - a.start) ? not Base.length(a::SliceIterator) = a.stop - a.start + 1 Base.eltype(::SliceIterator) = Vector{Int} +""" +Struct to hold the results of a simulation. +""" +struct Samples{T} + bitstrings_counts::DefaultDict{String, <:Integer} + amplitudes::Dict{String, T} +end + +Samples() = Samples(DefaultDict{String, Int}(0), Dict{String, ComplexF32}()) + """Function for reducing over amplitude contributions from each slice. For non-serial contexts, contributions are summed over""" reduce_slices(::QXContext, a) = a -"""Function for reducing over calculated amplitudes. For non-serial contexts, contributions -are gathered""" -reduce_amplitudes(::QXContext, a) = a +"""Function for reducing over calculated amplitudes and samples. For non-serial contexts, +contributions are gathered""" +reduce_results(::QXContext, results::Samples) = results """Function for reducing over calculated amplitudes. For non-serial contexts, contributions are gathered""" @@ -158,7 +171,9 @@ BitstringIterator(::QXContext, bitstrings) = bitstrings Save results from calculations for the given """ function write_results(::QXContext, results, output_file) - JLD2.@save output_file results + amplitudes = results.amplitudes + bitstrings_counts = results.bitstrings_counts + JLD2.@save output_file amplitudes bitstrings_counts end """Function to create a scalar with zero value of appropriate data-type for given contexct""" @@ -276,6 +291,19 @@ function compute_amplitude!(ctx, bitstring::String) reduce_slices(ctx, amplitude) end +""" + compute_amplitude!(results::Samples{T}, ctx, bitstring::String) where T<:Complex + +Calculate a single amplitude with the given context and bitstring. Update `results` to hold +the new amplitude. +""" +function compute_amplitude!(results::Samples, ctx, bitstring::String) + if !haskey(results.amplitudes, bitstring) + results.amplitudes[bitstring] = compute_amplitude!(ctx, bitstring) + end + results +end + include("mpi_execution.jl") """ @@ -299,26 +327,34 @@ function execute(dsl_file::String, sub_comm_size::Int=1, max_amplitudes::Union{Int, Nothing}=nothing, max_parameters::Union{Int, Nothing}=nothing) - + # Parse the dsl file to create a list of commands to execute in a context. Also parse + # the parameter file to get partition parameters and to create a sampler object to + # produce bitstrings to compute amplitudes for. commands = parse_dsl(dsl_file) + sampler_args, partition_params = parse_parameters(param_file, + max_parameters=max_parameters) - bitstrings, partition_params = parse_parameters(param_file, - max_amplitudes=max_amplitudes, max_parameters=max_parameters) - + # Create a context to execute the commands in. ctx = QXContext(commands, partition_params, input_file) if comm !== nothing ctx = QXMPIContext(ctx, comm, sub_comm_size) end - bitstring_iter = BitstringIterator(ctx, bitstrings) - results = Array{ComplexF32, 1}(undef, length(bitstring_iter)) - for (i, bitstring) in enumerate(bitstring_iter) - results[i] = compute_amplitude!(ctx, bitstring) + # Create a sampler to produce bitstrings to get amplitudes for and a variable to store + # the results. + sampler = QXContexts.Sampling.create_sampler(sampler_args, ctx, max_amplitudes) + results = Samples() + + # For each bitstring produced by the sampler, compute its amplitude and accept or reject + # it as sample in accordance with the sampler. + for bitstring in sampler + results = compute_amplitude!(results, ctx, bitstring) + QXContexts.Sampling.accept!(results, sampler, bitstring) end - results = reduce_amplitudes(ctx, results) + # Collect, save and return the results. + results = reduce_results(ctx, results) write_results(ctx, results, output_file) - return results end diff --git a/src/mpi_execution.jl b/src/mpi_execution.jl index b709fdc..edc25a6 100644 --- a/src/mpi_execution.jl +++ b/src/mpi_execution.jl @@ -4,6 +4,7 @@ General utility functions for working with MPI and partitions export get_rank_size export get_rank_start +export get_rank_range """ get_rank_size(n::Integer, size::Integer, rank::Integer) @@ -101,14 +102,29 @@ function reduce_slices(ctx::QXMPIContext, a) end """ - reduce_amplitudes(ctx::QXMPIContext, a) + reduce_results(ctx::QXMPIContext, results::Samples) -Function to gather amplitudes from sub-communicators +Function to gather amplitudes and samples from sub-communicators. """ -function reduce_amplitudes(ctx::QXMPIContext, a) +function reduce_results(ctx::QXMPIContext, results::Samples) if MPI.Comm_rank(ctx.sub_comm) == 0 - return MPI.Gather(a, 0, ctx.root_comm) + bitstrings = keys(results.amplitudes) + num_qubits = length(first(bitstrings)) + + bitstrings_as_ints = parse.(UInt64, bitstrings, base=2) + amplitudes = [results.amplitudes[bitstring] for bitstring in bitstrings] + samples = [results.bitstrings_counts[bitstring] for bitstring in bitstrings] + + bitstrings_as_ints = MPI.Gather(bitstrings_as_ints, 0, ctx.root_comm) + amplitudes = MPI.Gather(amplitudes, 0, ctx.root_comm) + samples = MPI.Gather(samples, 0, ctx.root_comm) + + bitstrings = reverse.(digits.(bitstrings_as_ints, base=2, pad=num_qubits)) + bitstrings = [prod(string.(bits)) for bits in bitstrings] + amplitudes = Dict{String, eltype(amplitudes)}(bitstrings .=> amplitudes) + bitstrings = DefaultDict(0, Dict{String, Int}(bitstrings .=> samples)) end + Samples(bitstrings, amplitudes) end """ @@ -133,7 +149,11 @@ end Function write results for QXMPIContext. Only writes from root process """ function write_results(ctx::QXMPIContext, results, output_file) - if MPI.Comm_rank(ctx.comm) == 0 JLD2.@save output_file results end + if MPI.Comm_rank(ctx.comm) == 0 + amplitudes = results.amplitudes + bitstrings_counts = results.bitstrings_counts + JLD2.@save output_file amplitudes bitstrings_counts + end end Base.zero(ctx::QXMPIContext) = zero(ctx.serial_ctx) \ No newline at end of file diff --git a/src/parameters.jl b/src/parameters.jl index 5f1e9a0..c1410ce 100644 --- a/src/parameters.jl +++ b/src/parameters.jl @@ -8,45 +8,39 @@ import Base.Iterators: take """ parse_parameters(filename::String; - max_parameters::Union{Int, Nothing}=nothing, - max_amplitudes::Union{Int, Nothing}=nothing) + max_parameters::Union{Int, Nothing}=nothing) -Parse the parameters yml file to read information about partition parameters and their -dimensions as well as how the sampling of amplitudes will work. +Parse the parameters yml file to read information about partition parameters and output +sampling method. Example Parameter file ====================== +output: + method: List + params: + bitstrings: + - "01000" + - "01110" + - "10101" partitions: - parameters: - v1: 2 - v2: 2 - v3: 4 -amplitudes: - - "0000" - - "0001" - - "1111" - + parameters: + v1: 2 + v2: 2 """ function parse_parameters(filename::String; - max_parameters::Union{Int, Nothing}=nothing, - max_amplitudes::Union{Int, Nothing}=nothing) + max_parameters::Union{Int, Nothing}=nothing) param_dict = YAML.load_file(filename, dicttype=OrderedDict{String, Any}) - #TODO: Revise the way amplitudes are described - amplitudes = unique([amplitude for amplitude in param_dict["amplitudes"]]) - if max_amplitudes !== nothing && length(amplitudes) > max_amplitudes - amplitudes = amplitudes[1:max_amplitudes] - end - - # parse the paramters section of the parameter file + # parse the partition paramters section of the parameter file partition_params = param_dict["partitions"]["parameters"] max_parameters = max_parameters === nothing ? length(partition_params) : max_parameters partition_params = OrderedDict{Symbol, Int}(Symbol(x[1]) => x[2] for x in take(partition_params, max_parameters)) - # variables_symbols = [Symbol("\$$v") for v in take(keys(bond_info), max_parameters)] - # variable_values = CartesianIndices(Tuple(take(values(bond_info), max_parameters))) - # return Parameters(amplitudes, variable_symbols, variable_values) - return amplitudes, partition_params + # parse the output method section of the parameter file + method_params = OrderedDict{Symbol, Any}(Symbol(x[1]) => x[2] for x in param_dict["output"]) + method_params[:params] = OrderedDict{Symbol, Any}(Symbol(x[1]) => x[2] for x in method_params[:params]) + + return method_params, partition_params end end \ No newline at end of file diff --git a/src/sampling.jl b/src/sampling.jl new file mode 100644 index 0000000..b6ee7cd --- /dev/null +++ b/src/sampling.jl @@ -0,0 +1,221 @@ +module Sampling + +export ListSampler, RejectionSampler, UniformSampler +export create_sampler, accept! + +import MPI + +using Random +using DataStructures +using QXContexts.Execution + +"""Abstract type for samplers""" +abstract type AbstractSampler end + +############################################################################### +# ListSampler +############################################################################### + +""" +A Sampler struct to compute the amplitudes for a list of bitstrings. +""" +struct ListSampler <: AbstractSampler + list::Vector{String} +end + +""" + ListSampler(;bitstrings::Vector{String}=String[], + rank::Integer=0, + comm_size::Integer=1, + kwargs...) + +Constructor for a ListSampler to produce a portion of the given `bitstrings` determined by +the given `rank` and `comm_size`. +""" +function ListSampler(;bitstrings::Vector{String}=String[], + rank::Integer=0, + comm_size::Integer=1, + kwargs...) + range = get_rank_range(length(bitstrings), comm_size, rank) + bitstrings = bitstrings[range] + if haskey(kwargs, :num_samples) + num_amplitudes = kwargs[:num_samples] + num_amplitudes = get_rank_size(num_amplitudes, comm_size, rank) + num_amplitudes = min(num_amplitudes, length(bitstrings)) + else + num_amplitudes = length(bitstrings) + end + ListSampler(bitstrings[1:num_amplitudes]) +end + +"""Iterator interface functions for ListSampler""" +Base.iterate(sampler::ListSampler) = Base.iterate(sampler.list) +Base.iterate(sampler::ListSampler, state) = Base.iterate(sampler.list, state) + +""" + accept!(results::Samples{T}, ::ListSampler, bitstring::String) where T<:Complex + +Does nothing as a ListSampler is not for collecting samples. +""" +function accept!(::Samples{T}, ::ListSampler, ::String) where T<:Complex + nothing +end + +############################################################################### +# RejectionSampler +############################################################################### + +""" +A Sampler struct to use rejection sampling to produce output. +""" +mutable struct RejectionSampler <: AbstractSampler + num_qubits::Integer + num_samples::Integer + accepted::Integer + M::Real + fix_M::Bool + rng::MersenneTwister +end + +""" + function RejectionSampler(;num_qubits::Integer, + num_samples::Integer, + M::Real=0.0001, + fix_M::Bool=false, + seed::Integer=42, + rank::Integer=0, + comm_size::Integer=1, + kwargs...) + +Constructor for a RejectionSampler to produce and accept a number of bitstrings. +""" +function RejectionSampler(;num_qubits::Integer, + num_samples::Integer, + M::Real=0.0001, + fix_M::Bool=false, + seed::Integer=42, + rank::Integer=0, + comm_size::Integer=1, + kwargs...) + # Evenly divide the number of bitstrings to be sampled amongst the subgroups of ranks. + num_amplitudes = get_rank_size(num_samples, comm_size, rank) + rng = MersenneTwister(seed + rank) + RejectionSampler(num_qubits, num_samples, 0, M, fix_M, rng) +end + +"""Iterator interface functions for RejectionSampler""" +Base.iterate(sampler::RejectionSampler, ::Nothing) = iterate(sampler) + +function Base.iterate(sampler::RejectionSampler) + if sampler.accepted >= sampler.num_samples + return nothing + else + return prod(rand(sampler.rng, ["0", "1"], sampler.num_qubits)), nothing + end +end + +""" + accept!(results::Samples{T}, sampler::RejectionSampler, bitstring::String) where T<:Complex + +Accept or reject the given bitstring as a sample using the rejection method and update +`results` accordingly. +""" +function accept!(results::Samples{T}, sampler::RejectionSampler, bitstring::String) where T<:Complex + # Get the amplitude for the given bitstring and the parameters for the rejection method. + amp = results.amplitudes[bitstring] + Np = 2^sampler.num_qubits * abs(amp)^2 + sampler.fix_M && (sampler.M = max(Np, sampler.M)) + + # Accept the given bitstring as a sample with probability Np/M. + u = rand(sampler.rng) + if u < Np / sampler.M + sampler.accepted += 1 + results.bitstrings_counts[bitstring] += 1 + end +end + +############################################################################### +# UniformSampler +############################################################################### + +""" +A Sampler struct to uniformly sample bitstrings and compute their amplitudes. +""" +mutable struct UniformSampler <: AbstractSampler + num_qubits::Integer + num_samples::Integer + rng::MersenneTwister +end + +""" + UniformSampler(;num_qubits::Integer, + num_samples::Integer, + seed::Integer=42, + rank::Integer=0, + comm_size::Integer=1, + kwargs...) + +Constructor for a UniformSampler to uniformly sample bitstrings. +""" +function UniformSampler(;num_qubits::Integer, + num_samples::Integer, + seed::Integer=42, + rank::Integer=0, + comm_size::Integer=1, + kwargs...) + # Evenly divide the number of bitstrings to be sampled amongst the subgroups of ranks. + num_samples = (num_samples ÷ comm_size) + (rank < num_samples % comm_size) + rng = MersenneTwister(seed + rank) + UniformSampler(num_qubits, num_samples, rng) +end + +"""Iterator interface functions for UniformSampler""" +Base.iterate(sampler::UniformSampler) = iterate(sampler, 0) + +function Base.iterate(sampler::UniformSampler, samples_produced::Integer) + if samples_produced < sampler.num_samples + new_bitstring = prod(rand(sampler.rng, ["0", "1"], sampler.num_qubits)) + return new_bitstring, samples_produced + 1 + else + return nothing + end +end + +""" + accept!(results::Samples{T}, ::UniformSampler, bitstring::String) where T<:Complex + +Accept the given bitstring as a sample and update `results` accordingly. +""" +function accept!(results::Samples{T}, ::UniformSampler, bitstring::String) where T<:Complex + results.bitstrings_counts[bitstring] += 1 +end + +############################################################################### +# Sampler Constructor +############################################################################### + +""" + create_sampler(params) + +Returns a sampler whose type and parameters are specified in the Dict `params`. + +Additional parameters that determine load balancing and totale amout of work to be done +are set by `max_amplitudes` and the Context `ctx`. +""" +function create_sampler(params, ctx, max_amplitudes=nothing) + max_amplitudes === nothing || (params[:params][:num_samples] = max_amplitudes) + create_sampler(params, ctx) +end + +function create_sampler(params, ctx::QXMPIContext) + params[:rank] = MPI.Comm_rank(ctx.comm) ÷ MPI.Comm_size(ctx.sub_comm) + params[:comm_size] = MPI.Comm_size(ctx.comm) ÷ MPI.Comm_size(ctx.sub_comm) + create_sampler(params) +end + +create_sampler(params, ctx::QXContext{T}) where T = create_sampler(params) +create_sampler(params) = get_constructor(params[:method])(;params[:params]...) + +get_constructor(func_name::String) = getfield(Main, Symbol(func_name*"Sampler")) + +end \ No newline at end of file diff --git a/test/bin_tests.jl b/test/bin_tests.jl index 39d5c97..533f9b0 100644 --- a/test/bin_tests.jl +++ b/test/bin_tests.jl @@ -24,9 +24,9 @@ include("../bin/qxrun.jl") "-o", output_fname] main(args) @test isfile(output_fname) - output = load(output_fname, "results") - expected = collect(values(ghz_results)) - @test output ≈ expected + + output = load(output_fname, "amplitudes") + @test all([output[x] ≈ ghz_results[x] for x in keys(output)]) end mktempdir() do path @@ -36,9 +36,10 @@ include("../bin/qxrun.jl") "--number-amplitudes", "1"] main(args) @test isfile(output_fname) - output = load(output_fname, "results") - expected = [ghz_results["01000"]] - @test output ≈ expected + + output = load(output_fname, "amplitudes") + @test length(output) == 1 + @test output["01000"] ≈ ghz_results["01000"] end mktempdir() do path @@ -49,9 +50,10 @@ include("../bin/qxrun.jl") "--number-slices", "1"] main(args) @test isfile(output_fname) - output = load(output_fname, "results") - expected = [ghz_results["01000"], ghz_results["01110"]] - @test output ≈ expected + + output = load(output_fname, "amplitudes") + @test length(output) == 2 + @test all([output[x] ≈ ghz_results[x] for x in ["01000", "01110"]]) end end diff --git a/test/ghzexample_tests.jl b/test/ghzexample_tests.jl index 1572ba3..f5c09e5 100644 --- a/test/ghzexample_tests.jl +++ b/test/ghzexample_tests.jl @@ -28,8 +28,22 @@ using DataStructures execute(dsl_file, param_file, input_data_file, output_data_file) # ensure all dictionary entries match - output = FileIO.load(output_data_file, "results") - @test output ≈ expected_vals + output = FileIO.load(output_data_file, "amplitudes") + output_vals = [output[bitstring] for bitstring in keys(expected)] + @test output_vals ≈ expected_vals + end + + + param_file = joinpath(test_path, "examples/ghz/ghz_5_rejection.yml") + + mktempdir() do path + output_data_file = joinpath(path, "out.jld2") + execute(dsl_file, param_file, input_data_file, output_data_file) + + # ensure all dictionary entries match + output = FileIO.load(output_data_file, "bitstrings_counts") + @test length(output) == 2 + @test output["11111"] + output["00000"] == 10 end end diff --git a/test/sampling_tests.jl b/test/sampling_tests.jl new file mode 100644 index 0000000..54d3a88 --- /dev/null +++ b/test/sampling_tests.jl @@ -0,0 +1,41 @@ +module SamplingTests + +using QXContexts +using JLD2 +using FileIO +using Test +using DataStructures + +@testset "Sampling tests" begin + + test_path = dirname(@__DIR__) + dsl_file = joinpath(test_path, "examples/ghz/ghz_5.qx") + input_data_file = joinpath(test_path, "examples/ghz/ghz_5.jld2") + + + param_file = joinpath(test_path, "examples/ghz/ghz_5_uniform.yml") + + mktempdir() do path + output_data_file = joinpath(path, "out.jld2") + execute(dsl_file, param_file, input_data_file, output_data_file) + + # ensure all dictionary entries match + output = FileIO.load(output_data_file, "bitstrings_counts") + @test sum(values(output)) ≈ 10 + end + + + param_file = joinpath(test_path, "examples/ghz/ghz_5_rejection.yml") + + mktempdir() do path + output_data_file = joinpath(path, "out.jld2") + execute(dsl_file, param_file, input_data_file, output_data_file) + + # ensure all dictionary entries match + output = FileIO.load(output_data_file, "bitstrings_counts") + @test length(output) == 2 + @test output["11111"] + output["00000"] == 10 + end +end + +end \ No newline at end of file