Skip to content

Commit

Permalink
updating how units are propagated.
Browse files Browse the repository at this point in the history
  • Loading branch information
kujaku11 committed Oct 21, 2024
1 parent eca1588 commit 91e6153
Show file tree
Hide file tree
Showing 4 changed files with 218 additions and 10 deletions.
25 changes: 15 additions & 10 deletions mtpy/core/mt.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def clone_empty(self):
new_mt_obj.model_elevation = self.model_elevation
new_mt_obj._rotation_angle = self._rotation_angle
new_mt_obj.profile_offset = self.profile_offset
new_mt_obj.impedance_units = self.impedance_units

return new_mt_obj

Expand Down Expand Up @@ -307,16 +308,17 @@ def rotate(self, theta_r, inplace=True):

@property
def Z(self):
r"""Mtpy.core.z.Z object to hold impedance tenso."""
r"""Mtpy.core.z.Z object to hold impedance tensor."""

if self.has_impedance():
return Z(
z=self.impedance.to_numpy(),
z_error=self.impedance_error.to_numpy(),
frequency=self.frequency,
z_model_error=self.impedance_model_error.to_numpy(),
units=self.impedance_units,
)
z_object = Z(
z=self.impedance.to_numpy(),
z_error=self.impedance_error.to_numpy(),
frequency=self.frequency,
z_model_error=self.impedance_model_error.to_numpy(),
)
z_object.units = self.impedance_units
return z_object
return Z()

@Z.setter
Expand All @@ -325,16 +327,20 @@ def Z(self, z_object):
recalculate phase tensor and invariants, which shouldn't change except
for strike angle.
Be sure to have appropriate units set
"""
# if a z object is given the underlying data is in mt units, even
# if the units are set to ohm.
z_object.units = "mt"
self.impedance_units = z_object.units
if not isinstance(z_object.frequency, type(None)):
if self.frequency.size != z_object.frequency.size:
self.frequency = z_object.frequency

elif not (self.frequency == z_object.frequency).all():
self.frequency = z_object.frequency
# set underlying data to units of mt
z_object.units = "mt"
self.impedance = z_object.z
self.impedance_error = z_object.z_error
self.impedance_model_error = z_object.z_model_error
Expand Down Expand Up @@ -704,7 +710,6 @@ def to_dataframe(self, utm_crs=None, cols=None, impedance_units="mt"):
z_object = self.Z
z_object.units = impedance_units
mt_df.from_z_object(z_object)
mt_df.from_z_object(z_object)
if self.has_tipper():
mt_df.from_t_object(self.Tipper)

Expand Down
3 changes: 3 additions & 0 deletions mtpy/core/transfer_function/z.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ def _validate_ss_input(factor):
z_error=self.z_error,
frequency=self.frequency,
z_model_error=self.z_model_error,
units=self.units,
)

def remove_distortion(
Expand Down Expand Up @@ -365,6 +366,7 @@ def remove_distortion(
z_error=z_corrected_error,
frequency=self.frequency,
z_model_error=self.z_model_error,
units=self.units,
)

@property
Expand Down Expand Up @@ -804,6 +806,7 @@ def estimate_distortion(
new_z_object = Z(
z=self.z[0:nf, :, :],
frequency=self.frequency[0:nf],
units=self.units,
)
if self._has_tf_error():
new_z_object.z_error = self.z_error[0:nf]
Expand Down
183 changes: 183 additions & 0 deletions tests/core/test_mt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import numpy as np
from mtpy import MT
from mtpy.core.mt_dataframe import MTDataFrame
from mtpy.core.transfer_function import MT_TO_OHM_FACTOR, Z

from mt_metadata import TF_EDI_CGG

Expand Down Expand Up @@ -71,6 +72,16 @@ def test_copy(self):

self.assertEqual(self.mt, mt_copy)

def test_impedance_units(self):

def set_units(unit):
self.mt.impedance_units = unit

with self.subTest("bad type"):
self.assertRaises(TypeError, set_units, 4)
with self.subTest("bad choice"):
self.assertRaises(ValueError, set_units, "ants")


class TestMTFromKWARGS(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -240,6 +251,178 @@ def test_remove_component(self):
self.assertTrue(np.all(np.isnan(new_mt.Z.z[:, 0, 0])))


class TestMTSetImpedanceOhm(unittest.TestCase):
@classmethod
def setUpClass(self):
self.z = np.array(
[[0.1 - 0.1j, 10 + 10j], [-10 - 10j, -0.1 + 0.1j]]
).reshape((1, 2, 2))
self.z_ohm = self.z / MT_TO_OHM_FACTOR
self.z_err = np.array([[0.1, 0.05], [0.05, 0.1]]).reshape((1, 2, 2))
self.z_err_ohm = self.z_err / MT_TO_OHM_FACTOR
self.res = np.array([[[4.0e-03, 4.0e01], [4.0e01, 4.0e-03]]])
self.res_err = np.array(
[[[0.00565685, 0.28284271], [0.28284271, 0.00565685]]]
)
self.phase = np.array([[[-45.0, 45.0], [-135.0, 135.0]]])
self.phase_err = np.array(
[[[35.26438968, 0.20257033], [0.20257033, 35.26438968]]]
)

self.pt = np.array([[[1.00020002, -0.020002], [-0.020002, 1.00020002]]])
self.pt_error = np.array(
[[[0.01040308, 0.02020604], [0.02020604, 0.01040308]]]
)
self.pt_azimuth = np.array([315.0])
self.pt_azimuth_error = np.array([3.30832308])
self.pt_skew = np.array([0])
self.pt_skew_error = np.array([0.40923428])

self.z_object = Z(
z=self.z_ohm,
z_error=self.z_err_ohm,
z_model_error=self.z_err_ohm,
units="ohm",
)
self.mt = MT()
self.mt.station = "mt001"
self.mt.Z = self.z_object
self.z_object.units = "ohm"

def test_impedance_units(self):
self.assertEqual(self.mt.impedance_units, "ohm")

def test_period(self):
self.assertTrue((np.array([1]) == self.mt.period).all())

def test_impedance(self):
self.assertTrue((self.mt.impedance == self.z).all())

def test_z_impedance_ohm(self):
self.assertTrue((self.mt.Z.z == self.z_ohm).all())

def test_impedance_error(self):
self.assertTrue(np.allclose(self.mt.impedance_error, self.z_err))

def test_z_impedance_error_ohm(self):
self.assertTrue(np.allclose(self.mt.Z.z_error, self.z_err_ohm))

def test_impedance_model_error(self):
self.assertTrue(np.allclose(self.mt.impedance_model_error, self.z_err))

def test_resistivity(self):
self.assertTrue(np.allclose(self.mt.Z.resistivity, self.res))

def test_resistivity_error(self):
self.assertTrue(np.allclose(self.mt.Z.resistivity_error, self.res_err))

def test_resistivity_model_error(self):
self.assertTrue(
np.allclose(self.mt.Z.resistivity_model_error, self.res_err)
)

def test_phase(self):
self.assertTrue(np.allclose(self.mt.Z.phase, self.phase))

def test_phase_error(self):
self.assertTrue(np.allclose(self.mt.Z.phase_error, self.phase_err))

def test_phase_model_error(self):
self.assertTrue(
np.allclose(self.mt.Z.phase_model_error, self.phase_err)
)

def test_phase_tensor(self):
self.assertTrue(np.allclose(self.pt, self.mt.pt.pt))

def test_phase_tensor_error(self):
self.assertTrue(np.allclose(self.pt_error, self.mt.pt.pt_error))

def test_phase_tensor_model_error(self):
self.assertTrue(np.allclose(self.pt_error, self.mt.pt.pt_model_error))

def test_phase_tensor_azimuth(self):
self.assertTrue(np.allclose(self.pt_azimuth, self.mt.pt.azimuth))

def test_phase_tensor_azimuth_error(self):
self.assertTrue(
np.allclose(self.pt_azimuth_error, self.mt.pt.azimuth_error)
)

def test_phase_tensor_azimuth_model_error(self):
self.assertTrue(
np.allclose(self.pt_azimuth_error, self.mt.pt.azimuth_model_error)
)

def test_phase_tensor_skew(self):
self.assertTrue(np.allclose(self.pt_skew, self.mt.pt.skew))

def test_phase_tensor_skew_error(self):
self.assertTrue(np.allclose(self.pt_skew_error, self.mt.pt.skew_error))

def test_phase_tensor_skew_model_error(self):
self.assertTrue(
np.allclose(self.pt_skew_error, self.mt.pt.skew_model_error)
)

def test_remove_static_shift(self):
new_mt = self.mt.remove_static_shift(ss_x=0.5, ss_y=1.5, inplace=False)

self.assertTrue(
np.allclose(
(self.mt.impedance.data / new_mt.impedance.data) ** 2,
np.array(
[[[0.5 + 0.0j, 0.5 + 0.0j], [1.5 - 0.0j, 1.5 - 0.0j]]]
),
)
)

def test_remove_distortion(self):
new_mt = self.mt.remove_distortion()

self.assertTrue(
np.all(
np.isclose(
new_mt.Z.z,
np.array(
[
[
[
0.099995 - 0.099995j,
9.99949999 + 9.99949999j,
],
[
-9.99949999 - 9.99949999j,
-0.099995 + 0.099995j,
],
]
]
),
)
)
)

def test_interpolate_fail_bad_f_type(self):
self.assertRaises(
ValueError, self.mt.interpolate, [0, 1], f_type="wrong"
)

def test_interpolate_fail_bad_periods(self):
self.assertRaises(ValueError, self.mt.interpolate, [0.1, 2])

def test_phase_flip(self):
new_mt = self.mt.flip_phase(zxy=True, inplace=False)

self.assertTrue(
np.all(np.isclose(new_mt.Z.phase_xy % 180, self.mt.Z.phase_xy))
)

def test_remove_component(self):
new_mt = self.mt.remove_component(zxx=True, inplace=False)

self.assertTrue(np.all(np.isnan(new_mt.Z.z[:, 0, 0])))


class TestMTComputeModelError(unittest.TestCase):
def setUp(self):
self.z = np.array(
Expand Down
17 changes: 17 additions & 0 deletions tests/core/transfer_function/test_z.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,23 @@ def set_units(unit):
with self.subTest("bad choice"):
self.assertRaises(ValueError, set_units, "ants")

def test_phase_tensor_equal(self):
z_ohm = Z(z=self.z_in_ohms, units="ohm")
z_mt = Z(z=self.z, units="mt")

self.assertTrue(
np.allclose(z_ohm.phase_tensor.pt, z_mt.phase_tensor.pt)
)

def test_resistivity_phase_equal(self):
z_ohm = Z(z=self.z_in_ohms, units="ohm")
z_mt = Z(z=self.z, units="mt")

with self.subTest("resistivity"):
self.assertTrue(np.allclose(z_ohm.resistivity, z_mt.resistivity))
with self.subTest("phase"):
self.assertTrue(np.allclose(z_ohm.phase, z_mt.phase))


# =============================================================================
# Run
Expand Down

0 comments on commit 91e6153

Please sign in to comment.