Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flexibilize the definition of motion time spans #531

Merged
merged 11 commits into from
Jan 6, 2025
2 changes: 1 addition & 1 deletion KomaMRIBase/src/KomaMRIBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ export MotionList, NoMotion, Motion
export Translate, TranslateX, TranslateY, TranslateZ
export Rotate, RotateX, RotateY, RotateZ
export HeartBeat, Path, FlowPath
export TimeRange, Periodic
export TimeRange, Periodic, TimeCurve
export SpinRange, AllSpins
export get_spin_coords
# Secondary
Expand Down
66 changes: 66 additions & 0 deletions KomaMRIBase/src/motion/Interpolation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# We defined two types of Interpolation objects: Interpolator1D and Interpolator2D
# 1D is for interpolating for 1 spin
# 2D is for interpolating for 2 or more spins
# This dispatch based on the number of spins wouldn't be necessary if it weren't for this:
# https://github.com/JuliaMath/Interpolations.jl/issues/603
#
# Once this issue is solved, this file should be simpler.
# We should then be able to define a single method for functions:
# - interpolate
# - resample
# and delete the Interpolator1D and Interpolator2D definitions

const Interpolator1D = Interpolations.GriddedInterpolation{
TCoefs,1,V,Itp,K
} where {
TCoefs<:Real,
TNodes<:Real,
V<:AbstractArray{TCoefs},
Itp<:Interpolations.Gridded,
K<:Tuple{AbstractVector{TNodes}},
}

const Interpolator2D = Interpolations.GriddedInterpolation{
TCoefs,2,V,Itp,K
} where {
TCoefs<:Real,
TNodes<:Real,
V<:AbstractArray{TCoefs},
Itp<:Interpolations.Gridded,
K<:Tuple{AbstractVector{TNodes}, AbstractVector{TNodes}},
}
function GriddedInterpolation(nodes, A, ITP)
return Interpolations.GriddedInterpolation{eltype(A), length(nodes), typeof(A), typeof(ITP), typeof(nodes)}(nodes, A, ITP)
end

function interpolate(d, ITPType, Ns::Val{1}, t)
_, Nt = size(d)
t_knots = _similar(t, Nt); copyto!(t_knots, collect(range(zero(eltype(t)), oneunit(eltype(t)), Nt)))
return GriddedInterpolation((t_knots, ), d[:], ITPType)
end

function interpolate(d, ITPType, Ns::Val, t)
Ns, Nt = size(d)
id_knots = _similar(t, Ns); copyto!(id_knots, collect(range(oneunit(eltype(t)), eltype(t)(Ns), Ns)))
t_knots = _similar(t, Nt); copyto!(t_knots, collect(range(zero(eltype(t)), oneunit(eltype(t)), Nt)))
return GriddedInterpolation((id_knots, t_knots), d, ITPType)
end

function resample(itp::Interpolator1D, t)
return itp.(t)
end

function resample(itp::Interpolator2D, t)
Ns = size(itp.coefs, 1)
id = _similar(t, Ns)
copyto!(id, collect(range(oneunit(eltype(t)), eltype(t)(Ns), Ns)))
return itp.(id, t)
end

function interpolate_times(t, t_unit, periodic, tq)
itp = GriddedInterpolation((t, ), t_unit, Gridded(Linear()))
return extrapolate(itp, periodic ? Interpolations.Periodic() : Flat()).(tq)
end

_similar(a, N) = similar(a, N)
_similar(a::Real, N) = zeros(typeof(a), N)
38 changes: 19 additions & 19 deletions KomaMRIBase/src/motion/Motion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ that are affected by that motion.

# Arguments
- `action`: (`::AbstractAction{T<:Real}`) action, such as [`Translate`](@ref) or [`Rotate`](@ref)
- `time`: (`::AbstractTimeSpan{T<:Real}`, `=TimeRange(0.0)`) time information about the motion
- `time`: (`::TimeCurve{T<:Real}`, `=TimeRange(0.0)`) time information about the motion
- `spins`: (`::AbstractSpinSpan`, `=AllSpins()`) spin indexes affected by the motion

# Returns
Expand All @@ -28,22 +28,22 @@ julia> motion = Motion(
"""
@with_kw mutable struct Motion{T<:Real}
action::AbstractAction{T}
time ::AbstractTimeSpan{T} = TimeRange(zero(typeof(action).parameters[1]))
spins ::AbstractSpinSpan = AllSpins()
time ::TimeCurve{T} = TimeRange(t_start=-oneunit(typeof(action).parameters[1]), t_end=zero(typeof(action).parameters[1]))
spins ::AbstractSpinSpan = AllSpins()
end

# Main constructors
function Motion(action)
T = first(typeof(action).parameters)
return Motion(action, TimeRange(zero(T)), AllSpins())
return Motion(action, TimeRange(t_start=-oneunit(T), t_end=zero(T)), AllSpins())
end
function Motion(action, time::AbstractTimeSpan)
function Motion(action, time::TimeCurve)
T = first(typeof(action).parameters)
return Motion(action, time, AllSpins())
end
function Motion(action, spins::AbstractSpinSpan)
T = first(typeof(action).parameters)
return Motion(action, TimeRange(zero(T)), spins)
return Motion(action, TimeRange(t_start=-oneunit(T), t_end=zero(T)), spins)
end

# Custom constructors
Expand All @@ -54,7 +54,7 @@ end
- `dx`: (`::Real`, `[m]`) translation in x
- `dy`: (`::Real`, `[m]`) translation in y
- `dz`: (`::Real`, `[m]`) translation in z
- `time`: (`::AbstractTimeSpan{T<:Real}`) time information about the motion
- `time`: (`::TimeCurve{T<:Real}`) time information about the motion
- `spins`: (`::AbstractSpinSpan`) spin indexes affected by the motion

# Returns
Expand All @@ -65,7 +65,7 @@ end
julia> translate = Translate(0.01, 0.02, 0.03, TimeRange(0.0, 1.0), SpinRange(1:10))
```
"""
function Translate(dx, dy, dz, time=TimeRange(zero(eltype(dx))), spins=AllSpins())
function Translate(dx, dy, dz, time=TimeRange(t_start=-oneunit(eltype(dx)), t_end=zero(eltype(dx))), spins=AllSpins())
pvillacorta marked this conversation as resolved.
Show resolved Hide resolved
return Motion(Translate(dx, dy, dz), time, spins)
end

Expand All @@ -76,7 +76,7 @@ end
- `pitch`: (`::Real`, `[º]`) rotation in x
- `roll`: (`::Real`, `[º]`) rotation in y
- `yaw`: (`::Real`, `[º]`) rotation in z
- `time`: (`::AbstractTimeSpan{T<:Real}`) time information about the motion
- `time`: (`::TimeCurve{T<:Real}`) time information about the motion
- `spins`: (`::AbstractSpinSpan`) spin indexes affected by the motion

# Returns
Expand All @@ -87,7 +87,7 @@ end
julia> rotate = Rotate(15.0, 0.0, 20.0, TimeRange(0.0, 1.0), SpinRange(1:10))
```
"""
function Rotate(pitch, roll, yaw, time=TimeRange(zero(eltype(pitch))), spins=AllSpins())
function Rotate(pitch, roll, yaw, time=TimeRange(t_start=-oneunit(eltype(pitch)), t_end=zero(eltype(pitch))), spins=AllSpins())
return Motion(Rotate(pitch, roll, yaw), time, spins)
end

Expand All @@ -98,7 +98,7 @@ end
- `circumferential_strain`: (`::Real`) contraction parameter
- `radial_strain`: (`::Real`) contraction parameter
- `longitudinal_strain`: (`::Real`) contraction parameter
- `time`: (`::AbstractTimeSpan{T<:Real}`) time information about the motion
- `time`: (`::TimeCurve{T<:Real}`) time information about the motion
- `spins`: (`::AbstractSpinSpan`) spin indexes affected by the motion

# Returns
Expand All @@ -109,7 +109,7 @@ end
julia> heartbeat = HeartBeat(-0.3, -0.2, 0.0, TimeRange(0.0, 1.0), SpinRange(1:10))
```
"""
function HeartBeat(circumferential_strain, radial_strain, longitudinal_strain, time=TimeRange(zero(eltype(circumferential_strain))), spins=AllSpins())
function HeartBeat(circumferential_strain, radial_strain, longitudinal_strain, time=TimeRange(t_start=-oneunit(eltype(circumferential_strain)), t_end=zero(eltype(circumferential_strain))), spins=AllSpins())
return Motion(HeartBeat(circumferential_strain, radial_strain, longitudinal_strain), time, spins)
end

Expand All @@ -120,7 +120,7 @@ end
- `dx`: (`::AbstractArray{T<:Real}`, `[m]`) displacements in x
- `dy`: (`::AbstractArray{T<:Real}`, `[m]`) displacements in y
- `dz`: (`::AbstractArray{T<:Real}`, `[m]`) displacements in z
- `time`: (`::AbstractTimeSpan{T<:Real}`) time information about the motion
- `time`: (`::TimeCurve{T<:Real}`) time information about the motion
- `spins`: (`::AbstractSpinSpan`) spin indexes affected by the motion

# Returns
Expand All @@ -137,7 +137,7 @@ julia> path = Path(
)
```
"""
function Path(dx, dy, dz, time=TimeRange(zero(eltype(dx))), spins=AllSpins())
function Path(dx, dy, dz, time=TimeRange(t_start=-oneunit(eltype(dx)), t_end=zero(eltype(dx))), spins=AllSpins())
return Motion(Path(dx, dy, dz), time, spins)
end

Expand All @@ -149,7 +149,7 @@ end
- `dy`: (`::AbstractArray{T<:Real}`, `[m]`) displacements in y
- `dz`: (`::AbstractArray{T<:Real}`, `[m]`) displacements in z
- `spin_reset`: (`::AbstractArray{Bool}`) reset spin state flags
- `time`: (`::AbstractTimeSpan{T<:Real}`) time information about the motion
- `time`: (`::TimeCurve{T<:Real}`) time information about the motion
- `spins`: (`::AbstractSpinSpan`) spin indexes affected by the motion

# Returns
Expand All @@ -167,7 +167,7 @@ julia> flowpath = FlowPath(
)
```
"""
function FlowPath(dx, dy, dz, spin_reset, time=TimeRange(zero(eltype(dx))), spins=AllSpins())
function FlowPath(dx, dy, dz, spin_reset, time=TimeRange(t_start=-oneunit(eltype(dx)), t_end=zero(eltype(dx))), spins=AllSpins())
return Motion(FlowPath(dx, dy, dz, spin_reset), time, spins)
end

Expand All @@ -192,7 +192,7 @@ function get_spin_coords(
m::Motion{T}, x::AbstractVector{T}, y::AbstractVector{T}, z::AbstractVector{T}, t
) where {T<:Real}
ux, uy, uz = x .* (0*t), y .* (0*t), z .* (0*t) # Buffers for displacements
t_unit = unit_time(t, m.time)
t_unit = unit_time(t, m.time.t, m.time.t_unit, m.time.periodic, m.time.duration)
idx = get_indexing_range(m.spins)
displacement_x!(@view(ux[idx, :]), m.action, @view(x[idx]), @view(y[idx]), @view(z[idx]), t_unit)
displacement_y!(@view(uy[idx, :]), m.action, @view(x[idx]), @view(y[idx]), @view(z[idx]), t_unit)
Expand All @@ -201,7 +201,7 @@ function get_spin_coords(
end

# Auxiliary functions
times(m::Motion) = times(m.time)
times(m::Motion) = times(m.time.t, m.time.duration)
is_composable(m::Motion) = is_composable(m.action)
add_jump_times!(t, m::Motion) = add_jump_times!(t, m.action, m.time)
add_jump_times!(t, ::AbstractAction, ::AbstractTimeSpan) = nothing
add_jump_times!(t, ::AbstractAction, ::TimeCurve) = nothing
9 changes: 5 additions & 4 deletions KomaMRIBase/src/motion/MotionList.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
include("Interpolation.jl")
include("SpinSpan.jl")
include("TimeSpan.jl")
include("TimeCurve.jl")
include("Action.jl")
include("Motion.jl")

Expand Down Expand Up @@ -156,7 +157,7 @@ function get_spin_coords(
ux, uy, uz = xt .* zero(T), yt .* zero(T), zt .* zero(T)
# Composable motions: they need to be run sequentially. Note that they depend on xt, yt, and zt
for m in Iterators.filter(is_composable, ml.motions)
t_unit = unit_time(t, m.time)
t_unit = unit_time(t, m.time.t, m.time.t_unit, m.time.periodic, m.time.duration)
idx = get_indexing_range(m.spins)
displacement_x!(@view(ux[idx, :]), m.action, @view(xt[idx, :]), @view(yt[idx, :]), @view(zt[idx, :]), t_unit)
displacement_y!(@view(uy[idx, :]), m.action, @view(xt[idx, :]), @view(yt[idx, :]), @view(zt[idx, :]), t_unit)
Expand All @@ -166,7 +167,7 @@ function get_spin_coords(
end
# Additive motions: these motions can be run in parallel
for m in Iterators.filter(!is_composable, ml.motions)
t_unit = unit_time(t, m.time)
t_unit = unit_time(t, m.time.t, m.time.t_unit, m.time.periodic, m.time.duration)
idx = get_indexing_range(m.spins)
displacement_x!(@view(ux[idx, :]), m.action, @view(x[idx]), @view(y[idx]), @view(z[idx]), t_unit)
displacement_y!(@view(uy[idx, :]), m.action, @view(x[idx]), @view(y[idx]), @view(z[idx]), t_unit)
Expand Down Expand Up @@ -199,7 +200,7 @@ If `motionset::MotionList`, this function sorts its motions.
- `nothing`
"""
function sort_motions!(m::MotionList)
sort!(m.motions; by=m -> times(m)[1])
sort!(m.motions; by=m -> m.time.t_start)
return nothing
end

Expand Down
67 changes: 67 additions & 0 deletions KomaMRIBase/src/motion/TimeCurve.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""
timecurve = TimeCurve(t, t_unit, periodic, duration)

TimeCurve struct. It is a specialized type that defines a time curve.
(...)

# Arguments
- `t`: (`::AbstractVector{<:Real}`, `[s]`) time vector
- `t_unit`: (`::AbstractVector{<:Real}`) y vector, it needs to be scaled between 0 and 1
- `periodic`: (`::Bool`, `=false`) indicates whether the time curve should be periodically repeated
- `duration`: (`::Union{<:Real,AbstractVector{<:Real}}`, `=1.0`)

# Returns
- `timecurve`: (`::TimeCurve`) TimeCurve struct

# Examples
```julia-repl
julia> timecurve = TimeCurve(t=[0.0, 0.1, 0.3, 0.4], t_unit=[0.0, 0.6, 0.2, 0.0], periodic=true)
```
"""
@with_kw struct TimeCurve{T<:Real}
t::AbstractVector{T}
t_unit::AbstractVector{T}
periodic::Bool = false
duration::Union{T,AbstractVector{T}} = oneunit(eltype(t))
t_start::T = t[1]
t_end::T = t[end]
@assert check_unique(t) "Vector t=$(t) contains duplicate elements. Please ensure all elements in t are unique and try again"
end

check_unique(t) = true
check_unique(t::Vector) = length(t) == length(unique(t))

# Main Constructors
TimeCurve(t, t_unit, periodic, duration) = TimeCurve(t=t, t_unit=t_unit, periodic=periodic, duration=duration)
TimeCurve(t, t_unit) = TimeCurve(t=t, t_unit=t_unit)
# Custom constructors
# --- TimeRange
TimeRange(t_start::T, t_end::T) where T = TimeCurve(t=[t_start, t_end], t_unit=[zero(T), oneunit(T)])
TimeRange(; t_start=0.0, t_end=1.0) = TimeRange(t_start, t_end)
# --- Periodic
Periodic(period::T, asymmetry::T) where T = TimeCurve(t=[zero(T), period*asymmetry, period], t_unit=[zero(T), oneunit(T), zero(T)])
Periodic(; period=1.0, asymmetry=0.5) = Periodic(period, asymmetry)

""" Compare two TimeCurves """
Base.:(==)(t1::TimeCurve, t2::TimeCurve) = reduce(&, [getfield(t1, field) == getfield(t2, field) for field in fieldnames(typeof(t1))])
Base.:(≈)(t1::TimeCurve, t2::TimeCurve) = reduce(&, [getfield(t1, field) ≈ getfield(t2, field) for field in fieldnames(typeof(t1))])

""" times """
function times(t, dur::AbstractVector)
tr = repeat(t, length(dur))
scale = repeat(dur, inner=[length(t)])
offsets = repeat(vcat(0, cumsum(dur)[1:end-1]), inner=[length(t)])
tr .= (tr .* scale) .+ offsets
return tr
end
function times(t, dur::Real)
return dur .* t
end

""" unit_time """
function unit_time(tq, t, t_unit, periodic, dur::Real)
return interpolate_times(t .* dur, t_unit, periodic, tq)
end
function unit_time(tq, t, t_unit, periodic, dur)
return interpolate_times(times(t, dur), repeat(t_unit, length(dur)), periodic, tq)
end
Loading
Loading