Skip to content

Commit

Permalink
Merge pull request #43 from awesome-spectral-indices/fm/types
Browse files Browse the repository at this point in the history
Extending indices functions to accept custom types
  • Loading branch information
MartinuzziFrancesco authored Mar 1, 2024
2 parents bc941e0 + 15c83f2 commit 4a71c14
Show file tree
Hide file tree
Showing 24 changed files with 2,455 additions and 653 deletions.
9 changes: 5 additions & 4 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ makedocs(;
doctest=true,
linkcheck=true,
warnonly=[:missing_docs],
format = Documenter.HTML(
size_threshold = nothing,
prettyurls = get(ENV, "CI", nothing) == "true",
assets = ["assets/docs.css"]),
format=Documenter.HTML(;
size_threshold=nothing,
prettyurls=get(ENV, "CI", nothing) == "true",
assets=["assets/docs.css"],
),
pages=pages,
)

Expand Down
16 changes: 14 additions & 2 deletions docs/src/tutorials/basic_types.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ NDVI
This outputs the NDVI struct, containing all necessary information. The struct can also be used as a function to compute NDVI:

```@example basic
NDVI(nir, red)
NDVI(Float64, nir, red)
```

This method is direct but not the recommended approach for computing indices. When using this method, ensure the parameter order matches the `bands` field of the `SpectralIndex`:
Expand Down Expand Up @@ -144,7 +144,19 @@ savi = compute_index("SAVI", params)
savi = compute_index("SAVI"; N=nir, R=red, L=0.5)
```

### Computing Multiple Indices
## Float32, Float16
The package can compute indices at custom precision

```@example basic
T = Float32
savi = compute_index(T, "SAVI"; N=T(nir), R=T(red), L=T(0.5))
```
```@example basic
T = Float16
savi = compute_index(T, "SAVI"; N=T(nir), R=T(red), L=T(0.5))
```

## Computing Multiple Indices

Now that we have added more indices we can explore how to compute multiple indices at the same time. All is needed is to pass a Vector of `String`s to the `compute_index` function with the chosen spectral indices inside, as well as the chosen parameters of course:

Expand Down
56 changes: 41 additions & 15 deletions ext/SpectralIndicesDataFramesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,52 +15,78 @@ function SpectralIndices._create_params(kw_args::Pair{Symbol,DataFrame}...)
end

function SpectralIndices.compute_index(
index::String, params::DataFrame; indices=SpectralIndices._create_indices()
)
::Type{T}, index::String, params::DataFrame; indices=SpectralIndices._create_indices()
) where {T<:Number}
# Convert DataFrame to a dictionary for each row and compute the index
results = [
compute_index(index, Dict(zip(names(params), row)); indices=indices) for
row in eachrow(params)
SpectralIndices.compute_index(
T, index, Dict(zip(names(params), row)); indices=indices
) for row in eachrow(params)
]

# Return the results as a DataFrame with the column named after the index
return DataFrame(Symbol(index) => results)
end

function SpectralIndices.compute_index(
index::Vector{String}, params::DataFrame; indices=SpectralIndices._create_indices()
index::String, params::DataFrame; indices=SpectralIndices._create_indices()
)
return SpectralIndices.compute_index(Float64, index, params; indices=indices)
end

function SpectralIndices.compute_index(
::Type{T},
index::Vector{String},
params::DataFrame;
indices=SpectralIndices._create_indices(),
) where {T<:Number}
# Similar conversion and computation for a vector of indices
result_dfs = DataFrame()
for idx in index
result_df = compute_index(idx, params; indices=indices)
result_df = SpectralIndices.compute_index(T, idx, params; indices=indices)
result_dfs[!, Symbol(idx)] = result_df[!, 1]
end
# Return the combined DataFrame with columns named after each index
return result_dfs
end

function SpectralIndices.linear(params::DataFrame)
result = linear(params[!, "a"], params[!, "b"])
function SpectralIndices.compute_index(
index::Vector{String}, params::DataFrame; indices=SpectralIndices._create_indices()
)
return SpectralIndices.compute_index(Float64, index, params; indices=indices)
end

function SpectralIndices.linear(::Type{T}, params::DataFrame) where {T<:Number}
result = linear(T, params[!, "a"], params[!, "b"])
result_df = DataFrame(; linear=result)
return result_df
end

function SpectralIndices.poly(params::DataFrame)
result = poly(params[!, "a"], params[!, "b"], params[!, "c"], params[!, "p"])
function SpectralIndices.linear(params::DataFrame)
return linear(Float64, params)
end

function SpectralIndices.poly(::Type{T}, params::DataFrame) where {T<:Number}
result = poly(T, params[!, "a"], params[!, "b"], params[!, "c"], params[!, "p"])
result_df = DataFrame(; poly=result)
return result_df
end

function SpectralIndices.RBF(params::DataFrame)
result = RBF(params[!, "a"], params[!, "b"], params[!, "sigma"])
function SpectralIndices.poly(params::DataFrame)
return poly(Float64, params)
end

function SpectralIndices.RBF(::Type{T}, params::DataFrame) where {T<:Number}
result = RBF(T, params[!, "a"], params[!, "b"], params[!, "sigma"])
result_df = DataFrame(; RBF=result)
return result_df
end

function SpectralIndices.load_dataset(
dataset::String, ::Type{T}
) where {T<:DataFrame}
function SpectralIndices.RBF(params::DataFrame)
return RBF(Float64, params)
end

function SpectralIndices.load_dataset(dataset::String, ::Type{T}) where {T<:DataFrame}
datasets = Dict("spectral" => "spectral.json")

if dataset in keys(datasets)
Expand Down
65 changes: 45 additions & 20 deletions ext/SpectralIndicesYAXArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,58 +40,83 @@ end
## TODO: simplify even further
# this is same function contente as dispatch on Dict
function SpectralIndices.compute_index(
index::String, params::YAXArray; indices=SpectralIndices._create_indices()
)
::Type{T}, index::String, params::YAXArray; indices=SpectralIndices._create_indices()
) where {T<:Number}
SpectralIndices._check_params(indices[index], params)
params = SpectralIndices._order_params(indices[index], params)
result = SpectralIndices._compute_index(indices[index], params...)
result = SpectralIndices._compute_index(T, indices[index], params...)
return result
end

function SpectralIndices.compute_index(
index::Vector{String}, params::YAXArray; indices=SpectralIndices._create_indices()
index::String, params::YAXArray; indices=SpectralIndices._create_indices()
)
return SpectralIndices.compute_index(Float64, index, params; indices=indices)
end

function SpectralIndices.compute_index(
::Type{T},
index::Vector{String},
params::YAXArray;
indices=SpectralIndices._create_indices(),
) where {T<:Number}
results = []
for (nidx, idx) in enumerate(index)
res_tmp = compute_index(idx, params; indices=indices)
res_tmp = compute_index(T, idx, params; indices=indices)
push!(results, res_tmp)
end
result = concatenatecubes(results, Dim{:Variables}(index))

return result
end

function SpectralIndices._compute_index(
idx::SpectralIndices.AbstractSpectralIndex, prms::YAXArray...
function SpectralIndices.compute_index(
index::Vector{String}, params::YAXArray; indices=SpectralIndices._create_indices()
)
return idx.(prms...)
return SpectralIndices.compute_index(Float64, index, params; indices=indices)
end

function SpectralIndices._compute_index(
::Type{T}, idx::SpectralIndices.AbstractSpectralIndex, prms::YAXArray...
) where {T<:Number}
return idx.(T, prms...)
end

function SpectralIndices.linear(::Type{T}, params::YAXArray) where {T<:Number}
return SpectralIndices.linear(T, params[Variable=At("a")], params[Variable=At("b")])
end

function SpectralIndices.linear(params::YAXArray)
result = linear(params[Variable=At("a")], params[Variable=At("b")])
return result
return SpectralIndices.linear(
Float64, params[Variable=At("a")], params[Variable=At("b")]
)
end

function SpectralIndices.poly(params::YAXArray)
result = poly(
function SpectralIndices.poly(::Type{T}, params::YAXArray) where {T<:Number}
return SpectralIndices.poly(
T,
params[Variable=At("a")],
params[Variable=At("b")],
params[Variable=At("c")],
params[Variable=At("p")],
)
return result
end

function SpectralIndices.RBF(params::YAXArray)
result = RBF(
params[Variable=At("a")], params[Variable=At("b")], params[Variable=At("sigma")]
function SpectralIndices.poly(params::YAXArray)
return SpectralIndices.poly(Float64, params)
end

function SpectralIndices.RBF(::Type{T}, params::YAXArray) where {T<:Number}
return SpectralIndices.RBF(
T, params[Variable=At("a")], params[Variable=At("b")], params[Variable=At("sigma")]
)
return result
end

function SpectralIndices.load_dataset(
dataset::String, ::Type{T}
) where {T<:YAXArray}
function SpectralIndices.RBF(params::YAXArray)
return SpectralIndices.RBF(Float64, params)
end

function SpectralIndices.load_dataset(dataset::String, ::Type{T}) where {T<:YAXArray}
datasets = Dict("sentinel" => "S2_10m.json")

if dataset in keys(datasets)
Expand Down
2 changes: 1 addition & 1 deletion src/bands.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,4 @@ function _create_bands()
end

return bands_class
end
end
Loading

0 comments on commit 4a71c14

Please sign in to comment.