Skip to content

Commit

Permalink
Apply black
Browse files Browse the repository at this point in the history
  • Loading branch information
Sbozzolo committed Mar 31, 2024
1 parent b60897b commit bbdb2e7
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 32 deletions.
12 changes: 6 additions & 6 deletions kuibit/cactus_grid_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,9 +807,9 @@ def current_data_to_UniformGridData(

# Write iterations_to_time
if current_iteration not in self._iterations_to_times:
self._iterations_to_times[
current_iteration
] = current_time
self._iterations_to_times[current_iteration] = (
current_time
)

# Reset everything
current_iteration = line_data[0]
Expand Down Expand Up @@ -1169,9 +1169,9 @@ def _read_component_as_uniform_grid_data(
)
data = np.transpose(dataset[()])

self.alldata[path][iteration][ref_level][
component
] = grid_data.UniformGridData(grid, data)
self.alldata[path][iteration][ref_level][component] = (
grid_data.UniformGridData(grid, data)
)

return self.alldata[path][iteration][ref_level][component]

Expand Down
12 changes: 3 additions & 9 deletions kuibit/cactus_waves.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,14 +928,10 @@ def d_lm(el, em):
)

if direction == 0:
return (
self.dist**2 / (8 * np.pi) * (Pp_int1 * Pp_int2).real()
)
return self.dist**2 / (8 * np.pi) * (Pp_int1 * Pp_int2).real()

if direction == 1:
return (
self.dist**2 / (8 * np.pi) * (Pp_int1 * Pp_int2).imag()
)
return self.dist**2 / (8 * np.pi) * (Pp_int1 * Pp_int2).imag()

# This is direction == 2

Expand Down Expand Up @@ -1288,9 +1284,7 @@ def get_power_lm(self, mult_l, mult_m):
time.
:rtype: :py:class:`~TimeSeries`
"""
return (
self.dist**2 / (4 * np.pi) * np.abs(self[(mult_l, mult_m)]) ** 2
)
return self.dist**2 / (4 * np.pi) * np.abs(self[(mult_l, mult_m)]) ** 2

def get_energy_lm(self, mult_l, mult_m):
"""Return the cumulative energy lost in the mode (l, m).
Expand Down
1 change: 1 addition & 0 deletions kuibit/visualize_matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,7 @@ def save(
"""
if os.path.splitext(outputpath)[-1] == ".tikz":
import tikzplotlib

# If clean_figure is True, we extract from kwargs those argument
# that tikzplotlib.clean_figure would take. For this, we need to
# know what argument that function takes.
Expand Down
8 changes: 2 additions & 6 deletions tests/test_cactus_waves.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,7 @@ def test_get_force_linear_momentum_em(self):

phi2lm_2 = c_lm * np.conj(phi2lm)

force_lm = (
self.phi2.dist**2 / (4 * np.pi) * (phi2lm * phi2lm_2).real()
)
force_lm = self.phi2.dist**2 / (4 * np.pi) * (phi2lm * phi2lm_2).real()

self.assertEqual(force_lm, self.phi2.get_force_z_lm(2, 2))

Expand Down Expand Up @@ -650,9 +648,7 @@ def test_get_force_linear_momentum(self):
)

Pp_lm = (
psi4_234.dist**2
/ (8 * np.pi)
* (psi4lm_xy_int1 * psi4lm_xy_int2)
psi4_234.dist**2 / (8 * np.pi) * (psi4lm_xy_int1 * psi4lm_xy_int2)
)

self.assertEqual(Pp_lm.real(), psi4_234.get_force_x_lm(3, 1, 0.1))
Expand Down
3 changes: 1 addition & 2 deletions tests/test_gw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,7 @@ def test_effective_amplitude_spectral_density(self):
freqs = strain_plus_fft.f
vals = (
strain_plus_fft.f
* (0.5 * (strain_plus_fft.amp**2 + strain_cross_fft.amp**2))
** 0.5
* (0.5 * (strain_plus_fft.amp**2 + strain_cross_fft.amp**2)) ** 0.5
)

expected_heff = FrequencySeries(freqs, vals)
Expand Down
8 changes: 2 additions & 6 deletions tests/test_sensitivity_curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,7 @@ def test_ETB(self):

self.assertEqual(
etb,
fs.FrequencySeries(
freqs, [4.8012536e-21**2, 6.5667816e-25**2]
),
fs.FrequencySeries(freqs, [4.8012536e-21**2, 6.5667816e-25**2]),
)

def test_ETD(self):
Expand All @@ -52,9 +50,7 @@ def test_ETD(self):

self.assertEqual(
etd,
fs.FrequencySeries(
freqs, [3.4959517e-17**2, 5.7819941e-25**2]
),
fs.FrequencySeries(freqs, [3.4959517e-17**2, 5.7819941e-25**2]),
)

def test_CE1(self):
Expand Down
4 changes: 1 addition & 3 deletions tests/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,9 +394,7 @@ def test_dot(self):
self.assertEqual(dot, expected)

def test_norm(self):
self.assertEqual(
self.vec.norm(), (sum(t**2 for t in self.ts)).sqrt()
)
self.assertEqual(self.vec.norm(), (sum(t**2 for t in self.ts)).sqrt())


class TestMatrix(unittest.TestCase):
Expand Down

0 comments on commit bbdb2e7

Please sign in to comment.