Skip to content

Commit

Permalink
fix the solve_triu:
Browse files Browse the repository at this point in the history
 - add unipotent
 - remove the solve_triu_right from the tests
 - (add support to allow override of snf similar to hnf)
  • Loading branch information
fieker committed Dec 20, 2024
1 parent bf480c5 commit d1e186f
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 7 deletions.
29 changes: 23 additions & 6 deletions src/flint/fmpz_mat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1170,7 +1170,11 @@ end
Compute the Smith normal form of $x$.
"""
function snf(x::ZZMatrix)
@inline snf(x::ZZMatrix) = _snf(x)

@inline _snf(x) = __snf(x)

function __snf(x::ZZMatrix)
z = similar(x)
@ccall libflint.fmpz_mat_snf(z::Ref{ZZMatrix}, x::Ref{ZZMatrix})::Nothing
return z
Expand Down Expand Up @@ -1538,7 +1542,7 @@ function _solve_dixon(a::ZZMatrix, b::ZZMatrix)
end

#XU = B. only the upper triangular part of U is used
function AbstractAlgebra._solve_triu_left(U::ZZMatrix, b::ZZMatrix)
function AbstractAlgebra._solve_triu_left(U::ZZMatrix, b::ZZMatrix; unipotent::Bool = false)
n = ncols(U)
m = nrows(b)
R = base_ring(U)
Expand All @@ -1565,7 +1569,11 @@ function AbstractAlgebra._solve_triu_left(U::ZZMatrix, b::ZZMatrix)
tmp_p += sizeof(ZZRingElem)
end
sub!(s, mat_entry_ptr(b, i, j), s)
divexact!(mat_entry_ptr(tmp, 1, j), s, mat_entry_ptr(U, j, j))
if unipotent
set!(mat_entry_ptr(tmp, 1, j), s)
else
divexact!(mat_entry_ptr(tmp, 1, j), s, mat_entry_ptr(U, j, j))
end
end
tmp_p = mat_entry_ptr(tmp, 1, 1)
X_p = mat_entry_ptr(X, i, 1)
Expand All @@ -1582,9 +1590,9 @@ end
#UX = B, U has to be upper triangular
#I think due to the Strassen calling path, where Strasse.solve(side = :left)
#call directly AA.solve_left, this has to be in AA and cannot be independent.
function AbstractAlgebra._solve_triu(U::ZZMatrix, b::ZZMatrix; side::Symbol=:left)
function AbstractAlgebra._solve_triu(U::ZZMatrix, b::ZZMatrix; side::Symbol=:left, unipotent::Bool = false)
if side == :left
return AbstractAlgebra._solve_triu_left(U, b)
return AbstractAlgebra._solve_triu_left(U, b; unipotent)
end
@assert side == :right
n = nrows(U)
Expand Down Expand Up @@ -1614,7 +1622,11 @@ function AbstractAlgebra._solve_triu(U::ZZMatrix, b::ZZMatrix; side::Symbol=:lef
# s = b[j, i] - s
tmp_ptr = mat_entry_ptr(tmp, 1, j)
U_ptr = mat_entry_ptr(U, j, j)
divexact!(tmp_ptr, s, U_ptr)
if unipotent
set!(tmp_ptr, s)
else
divexact!(tmp_ptr, s, U_ptr)
end
# tmp[j] = divexact(s, U[j,j])
end
tmp_ptr = mat_entry_ptr(tmp, 1, 1)
Expand Down Expand Up @@ -1744,6 +1756,11 @@ function mul!(z::ZZMatrixOrPtr, x::ZZMatrixOrPtr, y::ZZMatrixOrPtr)
return z
end

function mul_classical!(z::ZZMatrixOrPtr, x::ZZMatrixOrPtr, y::ZZMatrixOrPtr)
@ccall libflint.fmpz_mat_mul_classical(z::Ref{ZZMatrix}, x::Ref{ZZMatrix}, y::Ref{ZZMatrix})::Nothing
return z
end

function mul!(z::ZZMatrixOrPtr, a::ZZMatrixOrPtr, b::Int)
@ccall libflint.fmpz_mat_scalar_mul_si(z::Ref{ZZMatrix}, a::Ref{ZZMatrix}, b::Int)::Nothing
return z
Expand Down
2 changes: 1 addition & 1 deletion test/generic/Matrix-test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ end
M = randmat_triu(S, -100:100)
b = rand(U, -100:100)

x = AbstractAlgebra._solve_triu_right(M, b; unipotent = false)
x = AbstractAlgebra._solve_triu(M, b; unipotent = false, side = :right)

@test M*x == b
end
Expand Down

0 comments on commit d1e186f

Please sign in to comment.