diff --git a/src/plans/vectorial_plan.jl b/src/plans/vectorial_plan.jl index c35729cd24..839a83a570 100644 --- a/src/plans/vectorial_plan.jl +++ b/src/plans/vectorial_plan.jl @@ -798,12 +798,11 @@ function get_jacobian!( return JF end -get_jacobian_basis(vgf::AbstractVectorGradientFunction) = DefaultOrthonormalBasis() -function get_jacobian_basis( - vgf::AbstractVectorGradientFunction{F,G,<:CoordinateVectorialType} -) where {F,G} - return vgf.jacobian_type.basis +function get_jacobian_basis(vgf::AbstractVectorGradientFunction) + return _get_jacobian_basis(vgf.jacobian_type) end +_get_jacobian_basis(jt::AbstractVectorialType) = DefaultOrthonormalBasis() +_get_jacobian_basis(jt::CoordinateVectorialType) = jt.basis # # diff --git a/test/plans/test_vectorial_plan.jl b/test/plans/test_vectorial_plan.jl index 745ad25982..c136dd2f70 100644 --- a/test/plans/test_vectorial_plan.jl +++ b/test/plans/test_vectorial_plan.jl @@ -59,6 +59,16 @@ using Manopt: get_value, get_value_function, get_gradient_function ) @test Manopt.get_jacobian_basis(vgf_ji) == vgf_ji.jacobian_type.basis @test Manopt.get_jacobian_basis(vgf_vi) == DefaultOrthonormalBasis() + vgf_jib = VectorGradientFunction( + g!, + jac_g!, + 2; + jacobian_type=CoordinateVectorialType(DefaultBasis()), + evaluation=InplaceEvaluation(), + ) + @test Manopt.get_jacobian_basis(vgf_ji) == vgf_ji.jacobian_type.basis + @test Manopt.get_jacobian_basis(vgf_jib) == DefaultBasis() + @test Manopt.get_jacobian_basis(vgf_vi) == DefaultOrthonormalBasis() p = [1.0, 2.0, 3.0] c = [0.0, -3.0] gg = [[1.0, 0.0, 0.0], [0.0, -1.0, 0.0]]