Skip to content

Commit

Permalink
Add LUGP (#5)
Browse files Browse the repository at this point in the history
* Add LUGP

* Add support for multiple parameters in LUGP
  • Loading branch information
eliascarv authored Oct 26, 2023
1 parent 514b659 commit 03a0a59
Show file tree
Hide file tree
Showing 4 changed files with 218 additions and 7 deletions.
3 changes: 2 additions & 1 deletion src/GeoStatsProcesses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ include("interface.jl")
include("spde.jl")
include("seq.jl")
include("fft.jl")
include("lu.jl")

export SPDEGP, SEQ, SGP, FFTGP
export SPDEGP, SEQ, SGP, FFTGP, LUGP

end
145 changes: 145 additions & 0 deletions src/lu.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# ------------------------------------------------------------------
# Licensed under the MIT License. See LICENSE in the project root.
# ------------------------------------------------------------------

@kwdef struct LUGP{V,T,F,C,I} <: GeoStatsProcess
variogram::V = GaussianVariogram()
mean::T = nothing
factorization::F = cholesky
correlation::C = 0.0
init::I = NearestInit()
end

function randprep(::AbstractRNG, process::LUGP, setup::RandSetup)
# retrieve setup paramaters
(; domain, geotable, varnames, vartypes) = setup

# check the number of variables
nvars = length(varnames)
@assert nvars (1, 2) "only 1 or 2 variables can be simulated simultaneously"

# check process paramaters
_checkparam(process.variogram, nvars)
_checkparam(process.mean, nvars)

# retrieve process parameters
fact = process.factorization
init = process.init

# initialize buffers for realizations and simulation mask
vars = Dict(zip(varnames, vartypes))
buff, mask = initbuff(domain, vars, init, data=geotable)

# preprocess parameters for individual variables
pairs = map(enumerate(varnames)) do (i, var)
# get variable specific parameters
γ = _getparam(process.variogram, i)
vmean = _getparam(process.mean, i)

# check stationarity
@assert isstationary(γ) "variogram model must be stationary"

# retrieve data locations and data values in domain
dlocs = findall(mask[var])
z₁ = view(buff[var], dlocs)

# retrieve simulation locations
slocs = setdiff(1:nelements(domain), dlocs)

# create views of the domain
𝒟d = [centroid(domain, i) for i in dlocs]
𝒟s = [centroid(domain, i) for i in slocs]

# covariance between simulation locations
C₂₂ = sill(γ) .- Variography.pairwise(γ, 𝒟s)

if isempty(dlocs)
d₂ = zero(eltype(z₁))
L₂₂ = fact(Symmetric(C₂₂)).L
else
# covariance beween data locations
C₁₁ = sill(γ) .- Variography.pairwise(γ, 𝒟d)
C₁₂ = sill(γ) .- Variography.pairwise(γ, 𝒟d, 𝒟s)

L₁₁ = fact(Symmetric(C₁₁)).L
B₁₂ = L₁₁ \ C₁₂
A₂₁ = B₁₂'

d₂ = A₂₁ * (L₁₁ \ z₁)
L₂₂ = fact(Symmetric(C₂₂ - A₂₁ * B₁₂)).L
end

if !isnothing(vmean) && !isempty(dlocs)
@warn "mean can only be specified in unconditional simulation"
end

# mean for unconditional simulation
μ = isnothing(vmean) ? zero(eltype(z₁)) : vmean

# save preprocessed parameters for variable
var => (z₁, d₂, L₂₂, μ, dlocs, slocs)
end

Dict(pairs)
end

function randsingle(rng::AbstractRNG, process::LUGP, setup::RandSetup, prep)
# list of variable names
vars = setup.varnames

# simulate first variable
v₁ = first(vars)
Y₁, w₁ = _lusim(rng, prep[v₁])
varreal = Dict(v₁ => Y₁)

# simulate second variable
if length(vars) == 2
ρ = process.correlation
v₂ = last(vars)
Y₂, _ = _lusim(rng, prep[v₂], ρ, w₁)
push!(varreal, v₂ => Y₂)
end

varreal
end

#-----------
# UTILITIES
#-----------

function _checkparam(param, nvars)
if param isa Tuple
@assert length(param) == nvars "the number of parameters must be equal to the number of variables"
end
end

_getparam(param, i) = param
_getparam(params::Tuple, i) = params[i]

function _lusim(rng, params, ρ=nothing, w₁=nothing)
# unpack parameters
z₁, d₂, L₂₂, μ, dlocs, slocs = params

# number of points in domain
npts = length(dlocs) + length(slocs)

# allocate memory for result
y = Vector{eltype(z₁)}(undef, npts)

# conditional simulation
w₂ = randn(rng, size(L₂₂, 2))
if isnothing(ρ)
y₂ = d₂ .+ L₂₂ * w₂
else
y₂ = d₂ .+ L₂₂ ** w₁ + (1 - ρ^2) * w₂)
end

# hard data and simulated values
y[dlocs] = z₁
y[slocs] = y₂

# adjust mean in case of unconditional simulation
isempty(dlocs) && (y .+= μ)

y, w₂
end
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
GeoTables = "e502b557-6362-48c1-8219-d30d308dcdb0"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Meshes = "eacbb407-ea5a-433e-ab97-5258b1ca43fa"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
76 changes: 70 additions & 6 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using GeoStatsProcesses
using Variography
using GeoTables
using Meshes
using LinearAlgebra
using Random
using Test

Expand All @@ -12,13 +13,13 @@ using Test
dom = CartesianGrid(100, 100)
process = FFTGP(variogram=GaussianVariogram(range=10.0))
sims = rand(process, dom, [:z => Float64], 3)

# anisotropic simulation
Random.seed!(2019)
dom = CartesianGrid(100, 100)
process = FFTGP(variogram=GaussianVariogram(MetricBall((20.0, 5.0))))
sims = rand(process, dom, [:z => Float64], 3)

# simulation on view of grid
Random.seed!(2022)
grid = CartesianGrid(100, 100)
Expand All @@ -27,7 +28,7 @@ using Test
sim = rand(process, vgrid, [:z => Float64])
@test domain(sim) == vgrid
@test length(sim.geometry) == 5000

# conditional simulation
Random.seed!(2022)
table = (; z=[1.0, -1.0, 1.0])
Expand All @@ -42,18 +43,81 @@ using Test
𝒮 = georef((; z=[1.0, 0.0, 1.0]), [25.0 50.0 75.0; 25.0 75.0 50.0])
𝒟 = CartesianGrid((100, 100), (0.5, 0.5), (1.0, 1.0))
N = 3

process = SGP(variogram=SphericalVariogram(range=35.0), neighborhood=MetricBall(30.0))

Random.seed!(2017)
sims₁ = rand(process, 𝒟, 𝒮, 3)
sims₂ = rand(process, 𝒟, [:z => Float64], 3)

# basic checks
reals = sims₁[:z]
inds = LinearIndices(size(𝒟))
@test all(reals[i][inds[25, 25]] == 1.0 for i in 1:N)
@test all(reals[i][inds[50, 75]] == 0.0 for i in 1:N)
@test all(reals[i][inds[75, 50]] == 1.0 for i in 1:N)
end

@testset "LUGP" begin
𝒮 = georef((; z=[0.0, 1.0, 0.0, 1.0, 0.0]), [0.0 25.0 50.0 75.0 100.0])
𝒟 = CartesianGrid(100)

# ----------------------
# conditional simulation
# ----------------------
rng = MersenneTwister(123)
process = LUGP(variogram=SphericalVariogram(range=10.0))
sims = rand(rng, process, 𝒟, 𝒮, 2)

# ------------------------
# unconditional simulation
# ------------------------
rng = MersenneTwister(123)
process = LUGP(variogram=SphericalVariogram(range=10.0))
sims = rand(rng, process, 𝒟, [:z => Float64], 2)

# -------------
# co-simulation
# -------------
𝒟 = CartesianGrid(500)
rng = MersenneTwister(123)
process = LUGP(variogram=(SphericalVariogram(range=10.0), GaussianVariogram(range=10.0)), correlation=0.95)
sim = rand(rng, process, 𝒟, [:a => Float64, :b => Float64])

# -----------
# 2D example
# -----------
𝒟 = CartesianGrid(100, 100)
rng = MersenneTwister(123)
process = LUGP(variogram=GaussianVariogram(range=10.0))
sims = rand(rng, process, 𝒟, [:z => Float64], 3)

# -------------------
# anisotropy example
# -------------------
𝒟 = CartesianGrid(100, 100)
rng = MersenneTwister(123)
ball = MetricBall((20.0, 5.0))
process = LUGP(variogram=GaussianVariogram(ball))
sims = rand(rng, process, 𝒟, [:z => Float64], 3)

# ---------------------
# custom factorization
# ---------------------
𝒟 = CartesianGrid(100)
rng = MersenneTwister(123)
process1 = LUGP(variogram=SphericalVariogram(range=10.0), factorization=lu)
process2 = LUGP(variogram=SphericalVariogram(range=10.0), factorization=cholesky)
sim1 = rand(rng, process1, 𝒟, 𝒮, 2)
sim2 = rand(rng, process2, 𝒟, 𝒮, 2)

# throws
𝒟 = CartesianGrid(100, 100)
process = LUGP(variogram=GaussianVariogram(range=10.0))
# only 1 or 2 variables can be simulated simultaneously
@test_throws AssertionError rand(process, 𝒟, [:a => Float64, :b => Float64, :c => Float64])
process = LUGP(variogram=(GaussianVariogram(range=10.0),))
# the number of parameters must be equal to the number of variables
@test_throws AssertionError rand(process, 𝒟, [:a => Float64, :b => Float64])
end
end

0 comments on commit 03a0a59

Please sign in to comment.