Skip to content

Commit

Permalink
feat(WasmTransformIO): support Composite transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
thewtex committed Jun 29, 2024
1 parent cdd70ed commit 7c608e5
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 131 deletions.
2 changes: 2 additions & 0 deletions include/itkTransformJSON.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ namespace itk
enum class JSONTransformParameterizationEnum
{
Identity,
Composite,
Translation,
Euler2D,
Euler3D,
Expand Down Expand Up @@ -101,6 +102,7 @@ template <>
struct glz::meta<itk::JSONTransformParameterizationEnum> {
using enum itk::JSONTransformParameterizationEnum;
static constexpr auto value = glz::enumerate(Identity,
Composite,
Translation,
Euler2D,
Euler3D,
Expand Down
1 change: 1 addition & 0 deletions model/itk-wasm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,7 @@ enums:
A detailed description of each transform type can be found in the ITK Software Guide: https://itk.org/ITKSoftwareGuide/html/Book1/ITKSoftwareGuide-Book1ch4.html
permissible_values:
Identity:
Composite:
Translation:
Euler2D:
Euler3D:
Expand Down
173 changes: 43 additions & 130 deletions src/itkWasmTransformIO.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,10 @@ WasmTransformIOTemplate<TParametersValueType>::ReadCBOR()
{
transformJSON.transformType.transformParameterization = JSONTransformParameterizationEnum::Identity;
}
else if (transformParameterization == "Composite")
{
transformJSON.transformType.transformParameterization = JSONTransformParameterizationEnum::Composite;
}
else if (transformParameterization == "Translation")
{
transformJSON.transformType.transformParameterization = JSONTransformParameterizationEnum::Translation;
Expand Down Expand Up @@ -359,6 +363,7 @@ WasmTransformIOTemplate<TParametersValueType>::ReadCBOR()
}
transformListJSON.push_back(transformJSON);
}

this->SetJSON(transformListJSON);

auto readTransformList = this->GetReadTransformList();
Expand All @@ -368,6 +373,11 @@ WasmTransformIOTemplate<TParametersValueType>::ReadCBOR()
++transformIt, ++jsonIt)
{
const auto transformJSON = *jsonIt;
if (transformJSON.transformType.transformParameterization == JSONTransformParameterizationEnum::Composite)
{
++count;
continue;
}
FixedParametersType fixedParams(transformJSON.numberOfFixedParameters);
const SizeValueType numberOfFixedBytes = transformJSON.numberOfFixedParameters * sizeof(FixedParametersValueType);
ParametersType params(transformJSON.numberOfParameters);
Expand Down Expand Up @@ -425,136 +435,6 @@ WasmTransformIOTemplate<TParametersValueType>::ReadCBOR()
}
++count;
}


// const struct cbor_pair *indexHandle = cbor_map_handle(index);
// for (size_t ii = 0; ii < indexCount; ++ii)
// {
// const std::string_view key(reinterpret_cast<char *>(cbor_string_handle(indexHandle[ii].key)),
// cbor_string_length(indexHandle[ii].key)); if (key == "transformType")
// {
// const cbor_item_t *transformTypeItem = indexHandle[ii].value;
// const size_t transformPropertyCount = cbor_map_size(transformTypeItem);
// const struct cbor_pair *transformTypeHandle = cbor_map_handle(transformTypeItem);
// for (size_t jj = 0; jj < transformPropertyCount; ++jj)
// {
// const std::string_view transformTypeKey(reinterpret_cast<char
// *>(cbor_string_handle(transformTypeHandle[jj].key)), cbor_string_length(transformTypeHandle[jj].key)); if
// (transformTypeKey == "transformType")
// {
// const std::string transformType(reinterpret_cast<char
// *>(cbor_string_handle(transformTypeHandle[jj].value)), cbor_string_length(transformTypeHandle[jj].value));
// const CommonEnums::IOComponent pointIOComponentType =
// IOComponentEnumFromWasmComponentType(pointComponentType); this->SetPointDimension(dimension);
// }
// else if (transformTypeKey == "inputDimension")
// {
// const auto inputDimension = cbor_get_uint32(transformTypeHandle[jj].value);
// this->SetPointDimension(inputDimension);
// }
// else if (transformTypeKey == "pointComponentType")
// {
// const std::string pointComponentType(reinterpret_cast<char
// *>(cbor_string_handle(transformTypeHandle[jj].value)), cbor_string_length(transformTypeHandle[jj].value));
// const CommonEnums::IOComponent pointIOComponentType =
// IOComponentEnumFromWasmComponentType(pointComponentType);
// this->SetPointComponentType(pointIOComponentType);
// }
// else if (transformTypeKey == "pointPixelType")
// {
// const std::string pointPixelType(reinterpret_cast<char
// *>(cbor_string_handle(transformTypeHandle[jj].value)), cbor_string_length(transformTypeHandle[jj].value));
// const CommonEnums::IOPixel pointIOPixelType = IOPixelEnumFromWasmPixelType(pointPixelType);
// this->SetPointPixelType(pointIOPixelType);
// }
// else if (transformTypeKey == "pointPixelComponentType")
// {
// const std::string pointPixelComponentType(reinterpret_cast<char
// *>(cbor_string_handle(transformTypeHandle[jj].value)), cbor_string_length(transformTypeHandle[jj].value));
// const CommonEnums::IOComponent pointPixelIOComponentType =
// IOComponentEnumFromWasmComponentType(pointPixelComponentType);
// this->SetPointPixelComponentType(pointPixelIOComponentType);
// }
// else if (transformTypeKey == "pointPixelComponents")
// {
// const auto components = cbor_get_uint32(transformTypeHandle[jj].value);
// this->SetNumberOfPointPixelComponents(components);
// }
// else if (transformTypeKey == "cellComponentType")
// {
// const std::string cellComponentType(reinterpret_cast<char
// *>(cbor_string_handle(transformTypeHandle[jj].value)), cbor_string_length(transformTypeHandle[jj].value));
// const CommonEnums::IOComponent cellIOComponentType =
// IOComponentEnumFromWasmComponentType(cellComponentType); this->SetCellComponentType(cellIOComponentType);
// }
// else if (transformTypeKey == "cellPixelType")
// {
// const std::string cellPixelType(reinterpret_cast<char
// *>(cbor_string_handle(transformTypeHandle[jj].value)), cbor_string_length(transformTypeHandle[jj].value));
// const CommonEnums::IOPixel cellIOPixelType = IOPixelEnumFromWasmPixelType(cellPixelType);
// this->SetCellPixelType(cellIOPixelType);
// }
// else if (transformTypeKey == "cellPixelComponentType")
// {
// const std::string cellPixelComponentType(reinterpret_cast<char
// *>(cbor_string_handle(transformTypeHandle[jj].value)), cbor_string_length(transformTypeHandle[jj].value));
// const CommonEnums::IOComponent cellPixelIOComponentType =
// IOComponentEnumFromWasmComponentType(cellPixelComponentType);
// this->SetCellPixelComponentType(cellPixelIOComponentType);
// }
// else if (transformTypeKey == "cellPixelComponents")
// {
// const auto components = cbor_get_uint32(transformTypeHandle[jj].value);
// this->SetNumberOfCellPixelComponents(components);
// }
// else
// {
// itkExceptionMacro("Unexpected transformType cbor map key: " << transformTypeKey);
// }
// }
// }
// else if (key == "numberOfPoints")
// {
// const auto components = cbor_get_uint64(indexHandle[ii].value);
// this->SetNumberOfPoints(components);
// if (components)
// {
// this->m_UpdatePoints = true;
// }
// }
// else if (key == "numberOfPointPixels")
// {
// const auto components = cbor_get_uint64(indexHandle[ii].value);
// this->SetNumberOfPointPixels(components);
// if (components)
// {
// this->m_UpdatePointData = true;
// }
// }
// else if (key == "numberOfCells")
// {
// const auto components = cbor_get_uint64(indexHandle[ii].value);
// this->SetNumberOfCells(components);
// if (components)
// {
// this->m_UpdateCells = true;
// }
// }
// else if (key == "numberOfCellPixels")
// {
// const auto components = cbor_get_uint64(indexHandle[ii].value);
// this->SetNumberOfCellPixels(components);
// if (components)
// {
// this->m_UpdateCellData = true;
// }
// }
// else if (key == "cellBufferSize")
// {
// const auto components = cbor_get_uint64(indexHandle[ii].value);
// this->SetCellBufferSize(components);
// }
// }
}

template <typename TParametersValueType>
Expand Down Expand Up @@ -597,6 +477,10 @@ WasmTransformIOTemplate<TParametersValueType>::GetJSON() -> TransformListJSON
{
transformJSON.transformType.transformParameterization = JSONTransformParameterizationEnum::Identity;
}
else if (pString == "CompositeTransform")
{
transformJSON.transformType.transformParameterization = JSONTransformParameterizationEnum::Composite;
}
else if (pString == "TranslationTransform")
{
transformJSON.transformType.transformParameterization = JSONTransformParameterizationEnum::Translation;
Expand Down Expand Up @@ -752,6 +636,11 @@ WasmTransformIOTemplate<TParametersValueType>::TransformParameterizationString(c
transformParameterization = "Identity";
break;
}
case JSONTransformParameterizationEnum::Composite:
{
transformParameterization = "Composite";
break;
}
case JSONTransformParameterizationEnum::Translation:
{
transformParameterization = "Translation";
Expand Down Expand Up @@ -1015,6 +904,12 @@ WasmTransformIOTemplate<TParametersValueType>::WriteCBOR()
cbor_pair{ cbor_move(cbor_build_string("outputSpaceName")),
cbor_move(cbor_build_string(transformJSON.outputSpaceName.c_str())) });

if (transformJSON.transformType.transformParameterization == JSONTransformParameterizationEnum::Composite)
{
cbor_array_push(index, cbor_move(transformItem));
++count;
continue;
}
const auto fixedNumberOfBytes = transformJSON.numberOfFixedParameters * sizeof(FixedParametersValueType);
const auto fixedParams = (*transformIt)->GetFixedParameters();
writeCBORBuffer(transformItem,
Expand Down Expand Up @@ -1092,6 +987,11 @@ WasmTransformIOTemplate<TParametersValueType>::ReadFixedParameters(const Transfo
++transformIt, ++jsonIt)
{
const auto transformJSON = *jsonIt;
if ((*jsonIt).transformType.transformParameterization == itk::JSONTransformParameterizationEnum::Composite)
{
++count;
continue;
}
FixedParametersType fixedParams(transformJSON.numberOfFixedParameters);
const SizeValueType numberOfBytes = transformJSON.numberOfFixedParameters * sizeof(FixedParametersValueType);

Expand Down Expand Up @@ -1121,6 +1021,11 @@ WasmTransformIOTemplate<TParametersValueType>::ReadParameters(const TransformLis
transformIt != readTransformList.end();
++transformIt, ++jsonIt)
{
if ((*jsonIt).transformType.transformParameterization == itk::JSONTransformParameterizationEnum::Composite)
{
++count;
continue;
}
const auto transformJSON = *jsonIt;
ParametersType params(transformJSON.numberOfParameters);
const auto valueBytes = sizeof(ParametersValueType);
Expand Down Expand Up @@ -1262,6 +1167,10 @@ WasmTransformIOTemplate<TParametersValueType>::WriteFixedParameters()
for (typename ConstTransformListType::const_iterator it = transformList.begin(); it != end; ++it, ++count)
{
const TransformType * currentTransform = it->GetPointer();
if (currentTransform->GetTransformTypeAsString().find("CompositeTransform") != std::string::npos)
{
continue;
}
auto fixedParams = currentTransform->GetFixedParameters();
// Fixed parameters are always double per itk::TransformBaseTemplate
const SizeValueType numberOfBytes = fixedParams.Size() * sizeof(FixedParametersValueType);
Expand Down Expand Up @@ -1307,6 +1216,10 @@ WasmTransformIOTemplate<TParametersValueType>::WriteParameters()
for (typename ConstTransformListType::const_iterator it = transformList.begin(); it != end; ++it, ++count)
{
const TransformType * currentTransform = it->GetPointer();
if (currentTransform->GetTransformTypeAsString().find("CompositeTransform") != std::string::npos)
{
continue;
}
auto params = currentTransform->GetParameters();
const SizeValueType numberOfBytes = params.Size() * sizeof(ParametersValueType);

Expand Down
20 changes: 20 additions & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,26 @@ itk_add_test(NAME itkWasmTransformIOTest
${ITK_TEST_OUTPUT_DIR}/itkWasmTransformIOTest.cbor.h5
)

itk_add_test(NAME itkWasmTransformIOSequenceTest
COMMAND WebAssemblyInterfaceTestDriver
itkWasmTransformIOTest
DATA{Input/TransformSequence.h5}
${ITK_TEST_OUTPUT_DIR}/itkWasmTransformIOSequenceTest.iwt
${ITK_TEST_OUTPUT_DIR}/itkWasmTransformIOSequenceTest.h5
${ITK_TEST_OUTPUT_DIR}/itkWasmTransformIOSequenceTest.iwt.cbor
${ITK_TEST_OUTPUT_DIR}/itkWasmTransformIOSequenceTest.cbor.h5
)

itk_add_test(NAME itkWasmTransformIOCompositeTest
COMMAND WebAssemblyInterfaceTestDriver
itkWasmTransformIOTest
DATA{Input/CompositeTransform.h5}
${ITK_TEST_OUTPUT_DIR}/itkWasmTransformIOCompositeTest.iwt
${ITK_TEST_OUTPUT_DIR}/itkWasmTransformIOCompositeTest.h5
${ITK_TEST_OUTPUT_DIR}/itkWasmTransformIOCompositeTest.iwt.cbor
${ITK_TEST_OUTPUT_DIR}/itkWasmTransformIOCompositeTest.cbor.h5
)

itk_add_test(NAME itkPipelineTest
COMMAND WebAssemblyInterfaceTestDriver
itkPipelineTest
Expand Down
1 change: 1 addition & 0 deletions test/Input/CompositeTransform.h5.cid
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
bafkreibq75hrj3xoemt2lof5fla3vjysmkiq4yy4s74k3q54wf264n6prm
1 change: 1 addition & 0 deletions test/Input/TransformSequence.h5.cid
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
bafkreidoegsttlbmdcgg5rhodbaq52urxkirgzbehzrs2jxhrynlwal64u
2 changes: 1 addition & 1 deletion test/itkWasmTransformIOTest.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ int itkWasmTransformIOTest(int argc, char *argv[])
wasmWriter->SetFileName(transformDirectory);
for (auto transform : *inputTransforms)
{
std::cout << "Input back transform:" << std::endl;
std::cout << "Input transform:" << std::endl;
transform->Print(std::cout);
wasmWriter->AddTransform(transform);
}
Expand Down

0 comments on commit 7c608e5

Please sign in to comment.