Skip to content

Commit

Permalink
Merge pull request #21 from JuliaQX/add-rejection-sampling-method
Browse files Browse the repository at this point in the history
Add rejection sampling method
  • Loading branch information
brenjohn authored May 12, 2021
2 parents 8e6089d + 3e7919f commit aac5bd6
Show file tree
Hide file tree
Showing 12 changed files with 420 additions and 65 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "QXContexts"
uuid = "04c26001-d4a1-49d2-b090-1d469cf06784"
authors = ["QuantEx team"]
version = "0.1.8"
version = "0.1.9"

[deps]
ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63"
Expand Down
17 changes: 10 additions & 7 deletions examples/ghz/ghz_5.yml
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
amplitudes:
- "01000"
- "01110"
- "10101"
- "10001"
- "10010"
- "11111"
output:
method: List
params:
bitstrings:
- "01000"
- "01110"
- "10101"
- "10001"
- "10010"
- "11111"
partitions:
parameters:
v1: 2
Expand Down
12 changes: 12 additions & 0 deletions examples/ghz/ghz_5_rejection.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
output:
method: Rejection
params:
num_qubits: 5
num_samples: 10
M: 0.001
fix_M: false
seed: 42
partitions:
parameters:
v1: 2
v2: 2
10 changes: 10 additions & 0 deletions examples/ghz/ghz_5_uniform.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
output:
method: Uniform
params:
num_samples: 10
seed: 42
num_qubits: 5
partitions:
parameters:
v1: 2
v2: 2
2 changes: 2 additions & 0 deletions src/QXContexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ include("parameters.jl")
include("dsl.jl")
include("execution.jl")
include("sysimage/sysimage.jl")
include("sampling.jl")

@reexport using QXContexts.Logger
@reexport using QXContexts.Param
@reexport using QXContexts.DSL
@reexport using QXContexts.Execution
@reexport using QXContexts.Sampling

end
66 changes: 51 additions & 15 deletions src/execution.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module Execution

export QXContext, execute!, timer_output, reduce_nodes
export QXContext, QXMPIContext, execute!, timer_output, reduce_nodes
export Samples
export execute
export set_open_bonds!, set_slice_vals!
export compute_amplitude!
Expand All @@ -13,9 +14,11 @@ import LinearAlgebra
import TensorOperations
import QXContexts.Logger: @debug
using TimerOutputs
using Random
using OMEinsum
using QXContexts.DSL
using QXContexts.Param
import QXContexts

const timer_output = TimerOutput()
if haskey(ENV, "QXRUN_TIMER")
Expand Down Expand Up @@ -140,13 +143,23 @@ Base.iterate(a::SliceIterator, state) = length(a) <= (state + 1 - a.start) ? not
Base.length(a::SliceIterator) = a.stop - a.start + 1
Base.eltype(::SliceIterator) = Vector{Int}

"""
Struct to hold the results of a simulation.
"""
struct Samples{T}
bitstrings_counts::DefaultDict{String, <:Integer}
amplitudes::Dict{String, T}
end

Samples() = Samples(DefaultDict{String, Int}(0), Dict{String, ComplexF32}())

"""Function for reducing over amplitude contributions from each slice. For non-serial
contexts, contributions are summed over"""
reduce_slices(::QXContext, a) = a

"""Function for reducing over calculated amplitudes. For non-serial contexts, contributions
are gathered"""
reduce_amplitudes(::QXContext, a) = a
"""Function for reducing over calculated amplitudes and samples. For non-serial contexts,
contributions are gathered"""
reduce_results(::QXContext, results::Samples) = results

"""Function for reducing over calculated amplitudes. For non-serial contexts, contributions
are gathered"""
Expand All @@ -158,7 +171,9 @@ BitstringIterator(::QXContext, bitstrings) = bitstrings
Save results from calculations for the given
"""
function write_results(::QXContext, results, output_file)
JLD2.@save output_file results
amplitudes = results.amplitudes
bitstrings_counts = results.bitstrings_counts
JLD2.@save output_file amplitudes bitstrings_counts
end

"""Function to create a scalar with zero value of appropriate data-type for given contexct"""
Expand Down Expand Up @@ -276,6 +291,19 @@ function compute_amplitude!(ctx, bitstring::String)
reduce_slices(ctx, amplitude)
end

"""
compute_amplitude!(results::Samples{T}, ctx, bitstring::String) where T<:Complex
Calculate a single amplitude with the given context and bitstring. Update `results` to hold
the new amplitude.
"""
function compute_amplitude!(results::Samples, ctx, bitstring::String)
if !haskey(results.amplitudes, bitstring)
results.amplitudes[bitstring] = compute_amplitude!(ctx, bitstring)
end
results
end

include("mpi_execution.jl")

"""
Expand All @@ -299,26 +327,34 @@ function execute(dsl_file::String,
sub_comm_size::Int=1,
max_amplitudes::Union{Int, Nothing}=nothing,
max_parameters::Union{Int, Nothing}=nothing)

# Parse the dsl file to create a list of commands to execute in a context. Also parse
# the parameter file to get partition parameters and to create a sampler object to
# produce bitstrings to compute amplitudes for.
commands = parse_dsl(dsl_file)
sampler_args, partition_params = parse_parameters(param_file,
max_parameters=max_parameters)

bitstrings, partition_params = parse_parameters(param_file,
max_amplitudes=max_amplitudes, max_parameters=max_parameters)

# Create a context to execute the commands in.
ctx = QXContext(commands, partition_params, input_file)
if comm !== nothing
ctx = QXMPIContext(ctx, comm, sub_comm_size)
end

bitstring_iter = BitstringIterator(ctx, bitstrings)
results = Array{ComplexF32, 1}(undef, length(bitstring_iter))
for (i, bitstring) in enumerate(bitstring_iter)
results[i] = compute_amplitude!(ctx, bitstring)
# Create a sampler to produce bitstrings to get amplitudes for and a variable to store
# the results.
sampler = QXContexts.Sampling.create_sampler(sampler_args, ctx, max_amplitudes)
results = Samples()

# For each bitstring produced by the sampler, compute its amplitude and accept or reject
# it as sample in accordance with the sampler.
for bitstring in sampler
results = compute_amplitude!(results, ctx, bitstring)
QXContexts.Sampling.accept!(results, sampler, bitstring)
end
results = reduce_amplitudes(ctx, results)

# Collect, save and return the results.
results = reduce_results(ctx, results)
write_results(ctx, results, output_file)

return results
end

Expand Down
30 changes: 25 additions & 5 deletions src/mpi_execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ General utility functions for working with MPI and partitions

export get_rank_size
export get_rank_start
export get_rank_range

"""
get_rank_size(n::Integer, size::Integer, rank::Integer)
Expand Down Expand Up @@ -101,14 +102,29 @@ function reduce_slices(ctx::QXMPIContext, a)
end

"""
reduce_amplitudes(ctx::QXMPIContext, a)
reduce_results(ctx::QXMPIContext, results::Samples)
Function to gather amplitudes from sub-communicators
Function to gather amplitudes and samples from sub-communicators.
"""
function reduce_amplitudes(ctx::QXMPIContext, a)
function reduce_results(ctx::QXMPIContext, results::Samples)
if MPI.Comm_rank(ctx.sub_comm) == 0
return MPI.Gather(a, 0, ctx.root_comm)
bitstrings = keys(results.amplitudes)
num_qubits = length(first(bitstrings))

bitstrings_as_ints = parse.(UInt64, bitstrings, base=2)
amplitudes = [results.amplitudes[bitstring] for bitstring in bitstrings]
samples = [results.bitstrings_counts[bitstring] for bitstring in bitstrings]

bitstrings_as_ints = MPI.Gather(bitstrings_as_ints, 0, ctx.root_comm)
amplitudes = MPI.Gather(amplitudes, 0, ctx.root_comm)
samples = MPI.Gather(samples, 0, ctx.root_comm)

bitstrings = reverse.(digits.(bitstrings_as_ints, base=2, pad=num_qubits))
bitstrings = [prod(string.(bits)) for bits in bitstrings]
amplitudes = Dict{String, eltype(amplitudes)}(bitstrings .=> amplitudes)
bitstrings = DefaultDict(0, Dict{String, Int}(bitstrings .=> samples))
end
Samples(bitstrings, amplitudes)
end

"""
Expand All @@ -133,7 +149,11 @@ end
Function write results for QXMPIContext. Only writes from root process
"""
function write_results(ctx::QXMPIContext, results, output_file)
if MPI.Comm_rank(ctx.comm) == 0 JLD2.@save output_file results end
if MPI.Comm_rank(ctx.comm) == 0
amplitudes = results.amplitudes
bitstrings_counts = results.bitstrings_counts
JLD2.@save output_file amplitudes bitstrings_counts
end
end

Base.zero(ctx::QXMPIContext) = zero(ctx.serial_ctx)
46 changes: 20 additions & 26 deletions src/parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,45 +8,39 @@ import Base.Iterators: take

"""
parse_parameters(filename::String;
max_parameters::Union{Int, Nothing}=nothing,
max_amplitudes::Union{Int, Nothing}=nothing)
max_parameters::Union{Int, Nothing}=nothing)
Parse the parameters yml file to read information about partition parameters and their
dimensions as well as how the sampling of amplitudes will work.
Parse the parameters yml file to read information about partition parameters and output
sampling method.
Example Parameter file
======================
output:
method: List
params:
bitstrings:
- "01000"
- "01110"
- "10101"
partitions:
parameters:
v1: 2
v2: 2
v3: 4
amplitudes:
- "0000"
- "0001"
- "1111"
parameters:
v1: 2
v2: 2
"""
function parse_parameters(filename::String;
max_parameters::Union{Int, Nothing}=nothing,
max_amplitudes::Union{Int, Nothing}=nothing)
max_parameters::Union{Int, Nothing}=nothing)
param_dict = YAML.load_file(filename, dicttype=OrderedDict{String, Any})

#TODO: Revise the way amplitudes are described
amplitudes = unique([amplitude for amplitude in param_dict["amplitudes"]])
if max_amplitudes !== nothing && length(amplitudes) > max_amplitudes
amplitudes = amplitudes[1:max_amplitudes]
end

# parse the paramters section of the parameter file
# parse the partition paramters section of the parameter file
partition_params = param_dict["partitions"]["parameters"]
max_parameters = max_parameters === nothing ? length(partition_params) : max_parameters
partition_params = OrderedDict{Symbol, Int}(Symbol(x[1]) => x[2] for x in take(partition_params, max_parameters))
# variables_symbols = [Symbol("\$$v") for v in take(keys(bond_info), max_parameters)]
# variable_values = CartesianIndices(Tuple(take(values(bond_info), max_parameters)))

# return Parameters(amplitudes, variable_symbols, variable_values)
return amplitudes, partition_params
# parse the output method section of the parameter file
method_params = OrderedDict{Symbol, Any}(Symbol(x[1]) => x[2] for x in param_dict["output"])
method_params[:params] = OrderedDict{Symbol, Any}(Symbol(x[1]) => x[2] for x in method_params[:params])

return method_params, partition_params
end

end
Loading

0 comments on commit aac5bd6

Please sign in to comment.