Skip to content

Commit

Permalink
Fix NormalOp weighting behaviour
Browse files Browse the repository at this point in the history
  • Loading branch information
nHackel committed Jul 4, 2024
1 parent 08c3ff7 commit 8b9451a
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 18 deletions.
22 changes: 21 additions & 1 deletion src/NormalOp.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,21 @@
export normalOperator

"""
NormalOp(T::Type; parent, weights)
Lazy normal operator of `parent` with an optional weighting operator `weights.`
Computes `adjoint(parent) * weights * parent`.
# Required Argument
* `T` - type of elements, .e.g. `Float64` for `ComplexF32`
# Required Keyword argument
* `parent` - Base operator
# Optional Keyword argument
* `weights` - Optional weights for normal operator. Must already be of form `weights = adjoint.(w) .* w`
"""
function LinearOperatorCollection.NormalOp(::Type{T}; parent, weights = opEye(eltype(parent), size(parent, 1), S = storage_type(parent))) where T <: Number
return NormalOp(T, parent, weights)
end
Expand Down Expand Up @@ -47,7 +63,6 @@ function NormalOpImpl(parent, weights, tmp)
function produ!(y, parent, weights, tmp, x)
mul!(tmp, parent, x)
mul!(tmp, weights, tmp) # This can be dangerous. We might need to create two tmp vectors
mul!(tmp, weights, tmp)
return mul!(y, adjoint(parent), tmp)
end

Expand All @@ -63,6 +78,11 @@ function Base.copy(S::NormalOpImpl)
return NormalOpImpl(copy(S.parent), S.weights, copy(S.tmp))
end

"""
normalOperator(parent (, weights); kwargs...)
Constructs a normal operator of the parent in an opinionated way, i.e. it tries to apply optimisations to the resulting operator.
"""
function normalOperator(parent, weights=opEye(eltype(parent), size(parent, 1), S= storage_type(parent)); kwargs...)
return NormalOp(eltype(storage_type((parent))); parent = parent, weights = weights)
end
7 changes: 6 additions & 1 deletion src/ProdOp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,12 @@ end
# In this case we are converting the left argument into a
# weighting matrix, that is passed to normalOperator
# TODO Port vom MRIOperators drops given weighting matrix, I just left it out for now
normalOperator(S::ProdOp{T, <:WeightingOp, matT}; kwargs...) where {T, matT} = normalOperator(S.B, S.A; kwargs...)
"""
normalOperator(prod::ProdOp{T, <:WeightingOp, matT}; kwargs...)
Fuses weights of `ẀeightingOp` by computing `adjoint.(weights) .* weights`
"""
normalOperator(S::ProdOp{T, <:WeightingOp, matT}; kwargs...) where {T, matT} = normalOperator(S.B, WeightingOp(adjoint.(S.A.weights) .* S.A.weights); kwargs...)
function normalOperator(S::ProdOp, W=opEye(eltype(S),size(S,1), S = storage_type(S)); kwargs...)
arrayType = storage_type(S)
tmp = arrayType(undef, size(S.A, 2))
Expand Down
38 changes: 22 additions & 16 deletions test/testNormalOp.jl
Original file line number Diff line number Diff line change
@@ -1,28 +1,34 @@

@testset "Normal Operator" begin
for arrayType in arrayTypes
@testset "$arrayType" begin
N = 512
for elType in [Float32, ComplexF32]
@testset "$arrayType" begin
N = 512

Random.seed!(1234)
x = arrayType(rand(N))
A = arrayType(rand(N,N))
A_adj = arrayType(collect(adjoint(A))) # LinearOperators can't resolve storage_type otherwise
W = WeightingOp(arrayType(rand(N)))
WA = W*A
Random.seed!(1234)
x = arrayType(rand(elType, N))
A = arrayType(rand(elType, N,N))
A_adj = arrayType(collect(adjoint(A))) # LinearOperators can't resolve storage_type otherwise
W = WeightingOp(arrayType(rand(elType, N)))
WA = W*A
WHW = adjoint.(W.weights) .* W.weights
prod = ProdOp(W, A)

y1 = Array(A_adj*W*W*A*x)
y2 = Array(adjoint(WA) * WA * x)
y = Array(normalOperator(A,W)*x)
y1 = Array(A_adj*adjoint(W)*W*A*x)
y2 = Array(adjoint(WA) * WA * x)
y3 = Array(normalOperator(prod) * x)
y4 = Array(normalOperator(A, WHW)*x)

@test norm(y1 - y) / norm(y) 0 atol=0.01
@test norm(y2 - y) / norm(y) 0 atol=0.01
@test norm(y1 - y4) / norm(y4) 0 atol=0.01
@test norm(y2 - y4) / norm(y4) 0 atol=0.01
@test norm(y3 - y4) / norm(y4) 0 atol=0.01


y1 = Array(adjoint(A)*A*x)
y = Array(normalOperator(A)*x)
y1 = Array(adjoint(A)*A*x)
y = Array(normalOperator(A)*x)

@test norm(y1 - y) / norm(y) 0 atol=0.01
@test norm(y1 - y) / norm(y) 0 atol=0.01
end
end
end
end

0 comments on commit 8b9451a

Please sign in to comment.