From 392d21c97f0577964da88f520eec7b3c838b84a2 Mon Sep 17 00:00:00 2001 From: Gabriel Gerlero Date: Wed, 20 Dec 2023 10:43:13 -0300 Subject: [PATCH] Use AbstractDifferentiation package --- Project.toml | 2 ++ src/_Diff.jl | 24 ++++++++++++++++-------- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/Project.toml b/Project.toml index 0c09b2d4..30f9a5aa 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Gabriel S. Gerlero "] version = "2.5.2" [deps] +AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -20,6 +21,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" ToeplitzMatrices = "c751599d-da0a-543b-9d20-d0a503d91d24" [compat] +AbstractDifferentiation = "0.6.2" ArgCheck = "2" ForwardDiff = "0.10" LinearAlgebra = "1" diff --git a/src/_Diff.jl b/src/_Diff.jl index 3b354c8b..526c2d6e 100644 --- a/src/_Diff.jl +++ b/src/_Diff.jl @@ -1,18 +1,26 @@ module _Diff -using ForwardDiff: derivative -using ForwardDiff: Dual, Tag, value, extract_derivative +import AbstractDifferentiation +import ForwardDiff + +@inline function derivative(f, x::Real) + return only(AbstractDifferentiation.derivative(AbstractDifferentiation.ForwardDiffBackend(), + f, + x)) +end @inline function value_and_derivative(f, x::Real) - T = typeof(Tag(f, typeof(x))) - ydual = f(Dual{T}(x, oneunit(x))) - return value(T, ydual), extract_derivative(T, ydual) + a, b = AbstractDifferentiation.value_and_derivative(AbstractDifferentiation.ForwardDiffBackend(), + f, + x) + return a, only(b) end @inline function value_and_derivatives(f, x::Real) - T = typeof(Tag(f, typeof(x))) - ydual, ddual = value_and_derivative(f, Dual{T}(x, oneunit(x))) - return value(T, ydual), value(T, ddual), extract_derivative(T, ddual) + a, b, c = AbstractDifferentiation.value_derivative_and_second_derivative(AbstractDifferentiation.ForwardDiffBackend(), + f, + x) + return a, only(b), only(c) end export derivative, value_and_derivative, value_and_derivatives