diff --git a/src/gradients.jl b/src/gradients.jl index 3ec0c94..54fec28 100644 --- a/src/gradients.jl +++ b/src/gradients.jl @@ -247,7 +247,7 @@ function finite_difference_gradient!( if ArrayInterface.fast_scalar_indexing(c2) epsilon = ArrayInterface.allowed_getindex(c2, i) * dir else - epsilon = compute_epsilon(fdtype, x, relstep, absstep, dir) * dir + epsilon = compute_epsilon(fdtype, one(eltype(x)), relstep, absstep, dir) * dir end c1_old = ArrayInterface.allowed_getindex(c1, i) ArrayInterface.allowed_setindex!(c1, c1_old + epsilon, i) @@ -277,7 +277,7 @@ function finite_difference_gradient!( if ArrayInterface.fast_scalar_indexing(c2) epsilon = ArrayInterface.allowed_getindex(c2, i) * dir else - epsilon = compute_epsilon(fdtype, x, relstep, absstep, dir) * dir + epsilon = compute_epsilon(fdtype, one(eltype(x)), relstep, absstep, dir) * dir end c1_old = ArrayInterface.allowed_getindex(c1, i) ArrayInterface.allowed_setindex!(c1, c1_old + epsilon, i)