From 977f08507da71964364c104bef0a4343205e2cb7 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Tue, 21 Jan 2025 19:50:25 +0100 Subject: [PATCH] test: add tests for zygote rules with matrix inputs --- test/zygote_tests.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/zygote_tests.jl b/test/zygote_tests.jl index da4bf381..08151bf4 100644 --- a/test/zygote_tests.jl +++ b/test/zygote_tests.jl @@ -51,12 +51,18 @@ end t = collect(1.0:10.0) test_zygote( LinearInterpolation, u, t; name = "Linear Interpolation") + u2 = Matrix(hcat(u, u)') + test_zygote( + LinearInterpolation, u2, t; name = "Linear Interpolation with matrix input") end @testset "Quadratic Interpolation" begin u = [1.0, 4.0, 9.0, 16.0] t = [1.0, 2.0, 3.0, 4.0] test_zygote(QuadraticInterpolation, u, t; name = "Quadratic Interpolation") + u2 = Matrix(hcat(u, u)') + test_zygote( + QuadraticInterpolation, u2, t; name = "Quadratic Interpolation with matrix input") end @testset "Constant Interpolation" begin