Skip to content

Commit

Permalink
Fixes for reading from numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
franzpoeschel committed May 31, 2024
1 parent 432a07b commit a07f9a2
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
6 changes: 4 additions & 2 deletions examples/basic/ex07_convert_numpy_openpmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@

for snapshot in range(2):
data_converter.add_snapshot(
descriptor_input_type="numpy",
descriptor_input_path="Be_snapshot{}.in.npy".format(snapshot),
target_input_type='numpy',
target_input_path="Be_snapshot{}.out.npy".format(snapshot),
additional_info_input_type=None,
additional_info_input_path=None,
Expand All @@ -45,14 +47,14 @@
descriptor_save_path="./",
target_save_path="./",
additional_info_save_path="./",
naming_scheme="Be_snapshot*.bp4",
naming_scheme="Be_snapshot_from_numpy*.bp4",
descriptor_calculation_kwargs={"working_directory": "./"},
)

data_converter.convert_snapshots(
descriptor_save_path="./",
target_save_path="./",
additional_info_save_path="./",
naming_scheme="Be_snapshot*.npy",
naming_scheme="Be_snapshot_from_numpy*.npy",
descriptor_calculation_kwargs={"working_directory": "./"},
)
10 changes: 5 additions & 5 deletions mala/datahandling/data_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from mala.targets.target import Target
from mala.version import __version__ as mala_version

descriptor_input_types = ["espresso-out", "openpmd"]
target_input_types = [".cube", ".xsf", "openpmd"]
descriptor_input_types = ["espresso-out", "openpmd", "numpy"]
target_input_types = [".cube", ".xsf", "openpmd", "numpy"]
additional_info_input_types = ["espresso-out"]


Expand Down Expand Up @@ -632,7 +632,7 @@ def __convert_single_snapshot(
)
elif description["output"] == "numpy":
tmp_output = (
self.target_calculator.read_dimensions_from_numpy_file(
self.target_calculator.read_from_numpy_file(
snapshot["output"], units=original_units["output"]
)
)
Expand Down Expand Up @@ -683,8 +683,8 @@ def __convert_single_snapshot(
)
elif description["output"] == "numpy":
tmp_output = (
self.target_calculator.read_dimensions_from_numpy_file(
snapshot["output"], units=original_units["output"]
self.target_calculator.read_from_numpy_file(
snapshot["output"]
)
)

Expand Down

0 comments on commit a07f9a2

Please sign in to comment.