diff --git a/rsciio/dm5/_api.py b/rsciio/dm5/_api.py index 64c8eed1..e5bab3b4 100644 --- a/rsciio/dm5/_api.py +++ b/rsciio/dm5/_api.py @@ -337,9 +337,6 @@ def __init__(self, image_group, tags=None, unique_id=None, file=None): self.unique_id = image_group["UniqueID"] self.file = file - def __str__(self): - return f"Image: {self.image_data['Data'].shape}" - @property def ndim(self): return len(self.image_data["Data"].shape) @@ -380,9 +377,6 @@ def signal_dimensions(self): "Meta Data.Format attribute in the ImageTags group." ) - def navigation_dimensions(self): - return self.ndim - self.signal_dimensions() - def get_axis_dict(self, axis): """ Get the calibration data for a given axis. @@ -463,9 +457,13 @@ def update_dimension(self, axis, length=None): """ Update the dimension of the image for a given axis. - This is two places in the DM5 file??? - - Under Calibrations and under Dimension. I think that only the Calibrations should be updated. + Parameters + ---------- + axis : int + The axis to update the dimension for (Starting from 0 in array order). This will be reversed to match DM's + axis order. + length : int, optional + The length of the axis. """ axis = self.ndim - axis - 1 @@ -485,6 +483,12 @@ def _get_dimension(self, axis): def get_data(self, lazy=False): """ Get the image data. + + Parameters + ---------- + lazy : bool, optional + Whether to return a dask array or a numpy + """ if lazy: return da.from_array(self.image_data["Data"]) @@ -494,6 +498,11 @@ def get_data(self, lazy=False): def update_data(self, data): """ Update the image data. + + Parameters + ---------- + data : np.ndarray or da.Array + The new image data. """ HyperspyWriter.overwrite_dataset(self.image_data, data, "Data") self.image_data.attrs.update( @@ -522,7 +531,7 @@ def get_metadata(self): if "Microscope Info" in original_metadata: metadata["Acquisition_instrument"] = {} metadata["Acquisition_instrument"]["TEM"] = {} - metadata["Acquisition_instrument"]["TEM"]["beam_energy "] = ( + metadata["Acquisition_instrument"]["TEM"]["beam_energy"] = ( original_metadata["Microscope Info"].get("Voltage", 0) / 1000 ) metadata["Acquisition_instrument"]["TEM"]["acquisition_mode"] = ( @@ -534,8 +543,6 @@ def get_metadata(self): metadata["Acquisition_instrument"]["TEM"]["camera_length"] = ( original_metadata["Microscope Info"].get("STEM Camera Length", 0) ) - metadata["Acquisition_instrument"] = original_metadata["Microscope Info"] - metadata["Acquisition_instrument"] = {} return metadata, original_metadata def update_metadata( @@ -543,9 +550,21 @@ def update_metadata( ): """ Update the metadata for the image. + + Parameters + ---------- + metadata : dict, optional + The metadata to update. + signal_dimensions : int, optional + The number of signal dimensions. + navigation_dimensions : int, optional + The number of navigation dimensions. """ if metadata is None: metadata = {} + if navigation_dimensions is None and signal_dimensions is None: + signal_dimensions = self.ndim + navigation_dimensions = 0 formatted_metadata = {} formatted_metadata["Acquisition"] = {} @@ -569,26 +588,35 @@ def update_metadata( if navigation_dimensions > 0: formatted_metadata["Meta Data"]["IsSequence"] = "true" dict2group(formatted_metadata, self.image_tags) + + # Update Microscope Info + if ( + "Acquisition_instrument" in metadata + and "TEM" in metadata["Acquisition_instrument"] + ): + self.image_tags.create_group("Microscope Info") + microscope_info_dict = { + "Voltage": metadata["Acquisition_instrument"]["TEM"].get( + "beam_energy", 0 + ) + * 1000, + "Illumination Mode": metadata["Acquisition_instrument"]["TEM"].get( + "acquisition_mode", "Unknown" + ), + "Indicated Magnification": metadata["Acquisition_instrument"][ + "TEM" + ].get("magnification", 0), + "STEM Camera Length": metadata["Acquisition_instrument"]["TEM"].get( + "camera_length", 0 + ), + } + + dict2group(microscope_info_dict, self.image_tags["Microscope Info"]) + self.image_tags.create_group("UserTags") dict2group(metadata, self.image_tags["UserTags"]) return - def to_signal_dict(self): - """ - Convert the image to a Hyperspy signal dictionary. - """ - data = self.get_data() - metadata, original_metadata = self.get_metadata() - axes = [] - for axis in range(len(data.shape)): - axes.append(self.get_axis_dict(axis)) - return { - "data": data, - "metadata": metadata, - "original_metadata": original_metadata, - "axes": axes, - } - def dict2group(dictionary, group): for key, value in dictionary.items(): diff --git a/rsciio/tests/test_dm5.py b/rsciio/tests/test_dm5.py index 7e3adc96..bddfcf35 100644 --- a/rsciio/tests/test_dm5.py +++ b/rsciio/tests/test_dm5.py @@ -24,6 +24,7 @@ from pathlib import Path +import dask.array as da import numpy as np import pytest @@ -83,3 +84,43 @@ def test_save_load_files( assert "nm" in s.axes_manager[i].units assert s.axes_manager[i].scale == 0.1 assert s.axes_manager[i].size == int(original[i][-2:]) + + def test_save_load_undefined_axes(self, tmp_path): + fname = tmp_path / "test_save_undefined.dm5" + + data_shape = [10, 11, 12, 13] + data = np.ones(data_shape, dtype=np.float32) + signal = hs.signals.Signal2D(data) + signal.save(fname, overwrite=True) + s = hs.load(fname) + for i in range(4): + assert s.axes_manager[i].name == "" + assert s.axes_manager[i].units == "" + + def test_save_load_metadata(self, tmp_path): + fname = tmp_path / "test_save_undefined.dm5" + + data_shape = [10, 11, 12, 13] + data = np.ones(data_shape, dtype=np.float32) + signal = hs.signals.Signal2D(data) + signal.metadata.General.title = "test" + signal.metadata.add_node("Acquisition_instrument.TEM") + signal.metadata.Acquisition_instrument.TEM.beam_energy = 200 + signal.metadata.Acquisition_instrument.TEM.magnification = 100 + signal.metadata.Acquisition_instrument.TEM.camera_length = 10 + + signal.save(fname, overwrite=True) + s = hs.load(fname) + assert s.metadata.Acquisition_instrument.TEM.beam_energy == 200 + assert s.metadata.Acquisition_instrument.TEM.camera_length == 10 + assert s.metadata.Acquisition_instrument.TEM.magnification == 100 + + def test_save_load_lazy(self, tmp_path): + fname = tmp_path / "test_save_lazy.dm5" + + data_shape = [10, 11, 12, 13] + data = np.ones(data_shape, dtype=np.float32) + signal = hs.signals.Signal2D(data) + signal.save(fname, overwrite=True) + s = hs.load(fname, lazy=True) + assert isinstance(s.data, da.Array)