Skip to content

Commit

Permalink
Better multifield implementation for interpolate and FEFunction, that…
Browse files Browse the repository at this point in the history
… avoids repeating work and works better with ZeroMeanFESpaces
  • Loading branch information
JordiManyer committed Aug 7, 2024
1 parent 46e40fa commit 18b0000
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 108 deletions.
57 changes: 41 additions & 16 deletions src/FESpaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,8 @@ function _add_distributed_constraint(F::DistributedFESpace,order::Integer,constr
V = F
elseif constraint == :zeromean
_trian = get_triangulation(F)
trian = remove_ghost_cells(_trian,get_free_dof_ids(F))
model = get_background_model(_trian)
trian = remove_ghost_cells(_trian,get_cell_gids(model))
= Measure(trian,order)
V = ZeroMeanFESpace(F,dΩ)
else
Expand Down Expand Up @@ -765,12 +766,14 @@ end

const DistributedZeroMeanFESpace{A,B,C,D,E,F} = DistributedSingleFieldFESpace{A,B,C,D,DistributedZeroMeanCache{E,F}}

function FESpaces.ZeroMeanFESpace(space::DistributedFESpace,dΩ::DistributedMeasure)
function FESpaces.FESpaceWithConstantFixed(
space::DistributedSingleFieldFESpace,
gid_to_fix::Int = num_free_dofs(space)
)
# Find the gid within the processors
gids = get_free_dof_ids(space)

# We always fix the first gid
lid_to_fix = map(partition(gids)) do gids
Int(global_to_local(gids)[1]) # returns 0 if not found in the processor
Int(global_to_local(gids)[gid_to_fix]) # returns 0 if not found in the processor
end

# Create local spaces
Expand All @@ -779,22 +782,29 @@ function FESpaces.ZeroMeanFESpace(space::DistributedFESpace,dΩ::DistributedMeas
FESpaceWithConstantFixed(lspace,fix_constant,lid_to_fix)
end

trian = get_triangulation(space)
model = get_background_model(trian)
gids = generate_gids(model,spaces)
vector_type = _find_vector_type(spaces,gids)
return DistributedSingleFieldFESpace(spaces,gids,trian,vector_type)
end

function FESpaces.ZeroMeanFESpace(space::DistributedSingleFieldFESpace,dΩ::DistributedMeasure)
# Create underlying space
_space = FESpaceWithConstantFixed(space,num_free_dofs(space))

# Setup volume integration
_vol, _dvol = map(local_views(space),local_views(dΩ)) do lspace, dΩ
_vol, dvol = map(local_views(space),local_views(dΩ)) do lspace, dΩ
dvol = assemble_vector(v -> (v)dΩ, lspace)
vol = sum(dvol)
return vol, dvol
end |> tuple_of_arrays
vol = reduce(+,_vol,init=zero(eltype(vol)))
dvol = PVector(_dvol,partition(gids))
metadata = DistributedZeroMeanCache(dvol,vol)

# Create the new global FESpace
trian = get_triangulation(space)
model = get_background_model(trian)
gids = generate_gids(model,spaces)
vector_type = _find_vector_type(spaces,gids)
return DistributedSingleFieldFESpace(spaces,gids,trian,vector_type,metadata)
return DistributedSingleFieldFESpace(
_space.spaces,_space.gids,_space.trian,_space.vector_type,metadata
)
end

function FESpaces.FEFunction(
Expand All @@ -817,30 +827,45 @@ function FESpaces.FEFunction(
c = _compute_new_distributed_fixedval(
f,free_values,dirichlet_values
)
fv = free_values .+ c
fv = free_values .+ c # TODO: Do we need to copy, or can we just modify?
dv = map(dirichlet_values) do dv
dv .+ c
end

fields = map(FEFunction,f.spaces,partition(fv),dv)
trian = get_triangulation(f)
metadata = DistributedFEFunctionData(free_values)
metadata = DistributedFEFunctionData(fv)
DistributedCellField(fields,trian,metadata)
end

# This is required, otherwise we end up calling `FEFunction` with a fixed value of zero,
# which does not properly interpolate the function provided.
# With this change, we are interpolating in the unconstrained space and then
# substracting the mean.
function FESpaces.interpolate!(u,free_values::AbstractVector,f::DistributedZeroMeanFESpace)
dirichlet_values = get_dirichlet_dof_values(f)
interpolate_everywhere!(u,free_values,dirichlet_values,f)
end
function FESpaces.interpolate!(u::DistributedCellField,free_values::AbstractVector,f::DistributedZeroMeanFESpace)
dirichlet_values = get_dirichlet_dof_values(f)
interpolate_everywhere!(u,free_values,dirichlet_values,f)
end

function _compute_new_distributed_fixedval(
f::DistributedZeroMeanFESpace,fv,dv
)
dvol = f.metadata.dvol
vol = f.metadata.vol

c_i = map(local_views(f),partition(fv),dv,partition(dvol)) do space,fv,dv,dvol
c_i = map(local_views(f),partition(fv),dv,dvol) do space,fv,dv,dvol
if isa(FESpaces.ConstantApproach(space),FESpaces.FixConstant)
lid_to_fix = space.dof_to_fix
c = FESpaces._compute_new_fixedval(fv,dv,dvol,vol,lid_to_fix)
else
c = - dot(fv,dvol)/vol
end
println("c = $c")
c
end
c = reduce(+,c_i,init=zero(eltype(c_i)))
return c
Expand Down
168 changes: 76 additions & 92 deletions src/MultiField.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,10 @@ end
function MultiField.restrict_to_field(
f::DistributedMultiFieldFESpace,free_values::AbstractVector,field::Integer
)
values = map(local_views(f),partition(free_values)) do u,x
restrict_to_field(u,x,field)
values = map(local_views(f),partition(free_values)) do u,fv
restrict_to_field(u,fv,field)
end
gids = f.field_fe_space[field].gids
gids = get_free_dof_ids(f[field])
PVector(values,partition(gids))
end

Expand All @@ -118,26 +118,22 @@ function FESpaces.get_dirichlet_dof_values(f::DistributedMultiFieldFESpace)
return map(get_dirichlet_dof_values,f.field_fe_space)
end

function FESpaces.FEFunction(f::DistributedMultiFieldFESpace,x::AbstractVector,isconsistent=false)
free_values = change_ghost(x,f.gids;is_consistent=isconsistent,make_consistent=true)
part_fe_fun = map(FEFunction,f.part_fe_space,partition(free_values))
field_fe_fun = DistributedSingleFieldFEFunction[]
for i in 1:num_fields(f)
free_values_i = restrict_to_field(f,free_values,i)
fe_space_i = f.field_fe_space[i]
fe_fun_i = FEFunction(fe_space_i,free_values_i,true)
push!(field_fe_fun,fe_fun_i)
end
DistributedMultiFieldFEFunction(field_fe_fun,part_fe_fun,free_values)
function FESpaces.FEFunction(
f::DistributedMultiFieldFESpace,free_values::AbstractVector,isconsistent=false
)
dirichlet_values = get_dirichlet_dof_values(f)
return FEFunction(f,free_values,dirichlet_values,isconsistent)
end

function FESpaces.FEFunction(
f::DistributedMultiFieldFESpace,x::AbstractVector,
dirichlet_values::AbstractArray{<:AbstractVector},isconsistent=false
f::DistributedMultiFieldFESpace,
_free_values::AbstractVector,
dirichlet_values::AbstractArray{<:AbstractVector},
isconsistent=false
)
free_values = GridapDistributed.change_ghost(x,f.gids;is_consistent=isconsistent,make_consistent=true)
part_dirvals = to_parray_of_arrays(dirichlet_values)
part_fe_fun = map(FEFunction,f.part_fe_space,partition(free_values),part_dirvals)
free_values = change_ghost(_free_values,f.gids;is_consistent=isconsistent,make_consistent=true)

# Create distributed single field functions
field_fe_fun = DistributedSingleFieldFEFunction[]
for i in 1:num_fields(f)
free_values_i = restrict_to_field(f,free_values,i)
Expand All @@ -146,112 +142,99 @@ function FESpaces.FEFunction(
fe_fun_i = FEFunction(fe_space_i,free_values_i,dirichlet_values_i,true)
push!(field_fe_fun,fe_fun_i)
end

# Retrieve the local multifield views
part_sf_fe_funs = map(local_views,field_fe_fun)
part_fe_fun = map(local_views(f),partition(free_values),part_sf_fe_funs...) do space,fv,part_sf_fe_funs...
MultiFieldFEFunction(fv,space,[part_sf_fe_funs...])
end

DistributedMultiFieldFEFunction(field_fe_fun,part_fe_fun,free_values)
end

function FESpaces.EvaluationFunction(f::DistributedMultiFieldFESpace,x::AbstractVector,isconsistent=false)
free_values = change_ghost(x,f.gids;is_consistent=isconsistent,make_consistent=true)
part_fe_fun = map(EvaluationFunction,f.part_fe_space,partition(free_values))
function FESpaces.EvaluationFunction(
f::DistributedMultiFieldFESpace,
_free_values::AbstractVector,
isconsistent=false
)
free_values = change_ghost(_free_values,f.gids;is_consistent=isconsistent,make_consistent=true)

# Create distributed single field functions
field_fe_fun = DistributedSingleFieldFEFunction[]
for i in 1:num_fields(f)
free_values_i = restrict_to_field(f,free_values,i)
fe_space_i = f.field_fe_space[i]
fe_fun_i = EvaluationFunction(fe_space_i,free_values_i)
push!(field_fe_fun,fe_fun_i)
end

# Retrieve the local multifield views
part_sf_fe_funs = map(local_views,field_fe_fun)
part_fe_fun = map(local_views(f),partition(free_values),part_sf_fe_funs...) do space,fv,part_sf_fe_funs...
MultiFieldFEFunction(fv,space,[part_sf_fe_funs...])
end

DistributedMultiFieldFEFunction(field_fe_fun,part_fe_fun,free_values)
end

function FESpaces.interpolate(objects,fe::DistributedMultiFieldFESpace)
free_values = zero_free_values(fe)
interpolate!(objects,free_values,fe)
function FESpaces.interpolate(objects,space::DistributedMultiFieldFESpace)
free_values = zero_free_values(space)
interpolate!(objects,free_values,space)
end

function FESpaces.interpolate!(objects,free_values::AbstractVector,fe::DistributedMultiFieldFESpace)
part_fe_fun = map(partition(free_values),local_views(fe)) do x,f
interpolate!(objects,x,f)
end
function FESpaces.interpolate!(objects,free_values::AbstractVector,space::DistributedMultiFieldFESpace)
msg = "free_values and FESpace have incompatible index partitions."
@check partition(axes(free_values,1)) === partition(space.gids) msg

# Interpolate each field
field_fe_fun = DistributedSingleFieldFEFunction[]
for i in 1:num_fields(fe)
free_values_i = restrict_to_field(fe,free_values,i)
fe_space_i = fe.field_fe_space[i]
fe_fun_i = FEFunction(fe_space_i,free_values_i)
for i in 1:num_fields(space)
free_values_i = restrict_to_field(space,free_values,i)
fe_space_i = space.field_fe_space[i]
fe_fun_i = interpolate!(objects[i], free_values_i, fe_space_i)
push!(field_fe_fun,fe_fun_i)
end
DistributedMultiFieldFEFunction(field_fe_fun,part_fe_fun,free_values)
end

function Gridap.FESpaces.interpolate!(
objects::Union{<:DistributedMultiFieldCellField,<:DistributedCellField},
free_values::AbstractVector,
fe::DistributedMultiFieldFESpace
)
part_fe_fun = map(local_views(objects),partition(free_values),local_views(fe)) do objects,x,f
interpolate!(objects,x,f)
# Retrieve the local multifield views
part_sf_fe_funs = map(local_views,field_fe_fun)
part_fe_fun = map(local_views(space),partition(free_values),part_sf_fe_funs...) do space,fv,part_sf_fe_funs...
MultiFieldFEFunction(fv,space,[part_sf_fe_funs...])
end
field_fe_fun = GridapDistributed.DistributedSingleFieldFEFunction[]
for i in 1:num_fields(fe)
free_values_i = Gridap.MultiField.restrict_to_field(fe,free_values,i)
fe_space_i = fe.field_fe_space[i]
fe_fun_i = FEFunction(fe_space_i,free_values_i)
push!(field_fe_fun,fe_fun_i)
end
GridapDistributed.DistributedMultiFieldFEFunction(field_fe_fun,part_fe_fun,free_values)

DistributedMultiFieldFEFunction(field_fe_fun,part_fe_fun,free_values)
end

function FESpaces.interpolate_everywhere(objects,fe::DistributedMultiFieldFESpace)
free_values = zero_free_values(fe)
part_fe_fun = map(partition(free_values),local_views(fe)) do x,f
interpolate!(objects,x,f)
end
field_fe_fun = DistributedSingleFieldFEFunction[]
for i in 1:num_fields(fe)
free_values_i = restrict_to_field(fe,free_values,i)
fe_space_i = fe.field_fe_space[i]
dirichlet_values_i = zero_dirichlet_values(fe_space_i)
fe_fun_i = interpolate_everywhere!(objects[i], free_values_i,dirichlet_values_i,fe_space_i)
push!(field_fe_fun,fe_fun_i)
end
DistributedMultiFieldFEFunction(field_fe_fun,part_fe_fun,free_values)
dirichlet_values = zero_dirichlet_values(fe)
return interpolate_everywhere!(objects,free_values,dirichlet_values,fe)
end

function FESpaces.interpolate_everywhere!(
objects,free_values::AbstractVector,
objects,
free_values::AbstractVector,
dirichlet_values::Vector{AbstractArray{<:AbstractVector}},
fe::DistributedMultiFieldFESpace)
msg = "free_values and fe have incompatible index partitions."
@check partition(axes(free_values,1)) === partition(fe.gids) msg
space::DistributedMultiFieldFESpace
)
msg = "free_values and FESpace have incompatible index partitions."
@check partition(axes(free_values,1)) === partition(space.gids) msg

part_fe_fun = map(partition(free_values),local_views(fe)) do x,f
interpolate!(objects,x,f)
end
# Interpolate each field
field_fe_fun = DistributedSingleFieldFEFunction[]
for i in 1:num_fields(fe)
free_values_i = restrict_to_field(fe,free_values,i)
for i in 1:num_fields(space)
free_values_i = restrict_to_field(space,free_values,i)
dirichlet_values_i = dirichlet_values[i]
fe_space_i = fe.field_fe_space[i]
fe_fun_i = interpolate_everywhere!(objects[i], free_values_i,dirichlet_values_i,fe_space_i)
fe_space_i = space.field_fe_space[i]
fe_fun_i = interpolate_everywhere!(objects[i], free_values_i, dirichlet_values_i,fe_space_i)
push!(field_fe_fun,fe_fun_i)
end
DistributedMultiFieldFEFunction(field_fe_fun,part_fe_fun,free_values)
end

function FESpaces.interpolate_everywhere(
objects::Vector{<:DistributedCellField},fe::DistributedMultiFieldFESpace)
local_objects = map(local_views,objects)
local_spaces = local_views(fe)
part_fe_fun = map(local_spaces,local_objects...) do f,o...
interpolate_everywhere(o,f)
end
free_values = zero_free_values(fe)
field_fe_fun = DistributedSingleFieldFEFunction[]
for i in 1:num_fields(fe)
free_values_i = restrict_to_field(fe,free_values,i)
fe_space_i = fe.field_fe_space[i]
dirichlet_values_i = get_dirichlet_dof_values(fe_space_i)
fe_fun_i = interpolate_everywhere!(objects[i], free_values_i,dirichlet_values_i,fe_space_i)
push!(field_fe_fun,fe_fun_i)
# Retrieve the local multifield views
part_sf_fe_funs = map(local_views,field_fe_fun)
part_fe_fun = map(local_views(space),partition(free_values),part_sf_fe_funs...) do space,fv,part_sf_fe_funs...
MultiFieldFEFunction(fv,space,[part_sf_fe_funs...])
end

DistributedMultiFieldFEFunction(field_fe_fun,part_fe_fun,free_values)
end

Expand All @@ -278,7 +261,7 @@ const DistributedMultiFieldFEBasis{A} = DistributedMultiFieldCellField{A,<:Abstr
function FESpaces.get_fe_basis(f::DistributedMultiFieldFESpace)
part_mbasis = map(get_fe_basis,f.part_fe_space)
field_fe_basis = map(1:num_fields(f)) do i
space_i = f.field_fe_space[i]
space_i = f.field_fe_space[i]
basis_i = map(b->b[i],part_mbasis)
DistributedCellField(basis_i,get_triangulation(space_i))
end
Expand All @@ -288,7 +271,7 @@ end
function FESpaces.get_trial_fe_basis(f::DistributedMultiFieldFESpace)
part_mbasis = map(get_trial_fe_basis,f.part_fe_space)
field_fe_basis = map(1:num_fields(f)) do i
space_i = f.field_fe_space[i]
space_i = f.field_fe_space[i]
basis_i = map(b->b[i],part_mbasis)
DistributedCellField(basis_i,get_triangulation(space_i))
end
Expand Down Expand Up @@ -328,6 +311,7 @@ function generate_multi_field_gids(
end
f_frange = map(get_free_dof_ids,f_dspace)
gids = generate_multi_field_gids(f_p_flid_lid,f_frange)
return gids
end

function generate_multi_field_gids(
Expand Down

0 comments on commit 18b0000

Please sign in to comment.