Skip to content

Commit

Permalink
another test
Browse files Browse the repository at this point in the history
  • Loading branch information
Laura Cabayol-Garcia committed Sep 30, 2024
1 parent d9e3ea8 commit dce7da5
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 23 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
timeout-minutes: 30
run: |
pytest tests/test_lace.py
pytest tests/plot_mpg.py --output_dir=tmp/validation_figures/
pytest tests/plot_mpg.py
# Archive the generated plots
- name: Archive generated plots
Expand Down
58 changes: 36 additions & 22 deletions tests/plot_mpg.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,53 @@
import pytest
import argparse
import os
# Import necessary modules
## General python modules
import numpy as np
import matplotlib
matplotlib.use('Agg')
matplotlib.use('Agg') # Use a non-interactive backend for rendering plots
from matplotlib import pyplot as plt
import matplotlib.cm as cm
import os
import argparse # Used for parsing command-line arguments

## LaCE specific modules
import lace
from lace.emulator.nn_emulator import NNEmulator
from lace.archive import gadget_archive
from lace.archive import nyx_archive, gadget_archive
from lace.utils import poly_p1d
from lace.utils.plotting_functions import plot_p1d_vs_emulator

@pytest.fixture
def output_dir(pytestconfig):
# Get the custom argument passed with pytest
return pytestconfig.getoption("output_dir")

def test(output_dir):
def test():
"""
Function to plot emulated P1D using specified archive and save plots to output directory.
Function to plot emulated P1D using specified archive (Nyx or Gadget).
Parameters:
output_dir (str): Directory to save the generated plots.
archive (str): Archive to use for data ('Nyx' or 'Gadget')
"""
archive_name = 'Gadget'
# Get the base directory of the lace module
repo = os.path.dirname(lace.__path__[0]) + "/"

# Define the parameters for the emulator specific to Gadget
emu_params = ['Delta2_p', 'n_p', 'mF', 'sigT_Mpc', 'gamma', 'kF_Mpc']
training_set = 'Cabayol23'
emulator_label = 'Cabayol23+'
training_set='Cabayol23'
emulator_label='Cabayol23+'
model_path = f'{repo}data/NNmodels/Cabayol23+/Cabayol23+_drop_sim'

# Initialize a GadgetArchive instance for postprocessing data
archive = gadget_archive.GadgetArchive(postproc="Cabayol23")

os.makedirs(output_dir, exist_ok=True)
# Directory for saving plots
#save_dir = f'{repo}data/validation_figures/{archive_name}/'
save_dir = '{repo}tmp/validation_figures'
# Create the directory if it does not exist
os.makedirs(save_dir, exist_ok=True)

for ii, sim in enumerate(['mpg_1', 'mpg_central']):
#for ii, sim in enumerate(archive.list_sim):
for ii, sim in enumerate(['mpg_1','mpg_central']):

if sim == 'mpg_central':
model_path_central = f'{repo}data/NNmodels/Cabayol23+/Cabayol23+.pt'

emulator = NNEmulator(
training_set=training_set,
emulator_label=emulator_label,
Expand All @@ -52,15 +65,16 @@ def test(output_dir):
train=False,
)

# Get testing data for the current simulation
testing_data = archive.get_testing_data(sim_label=f'{sim}')
if sim != 'nyx_central':
testing_data = [d for d in testing_data if d['val_scaling'] == 1]

save_path = os.path.join(output_dir, f'{sim}.png')

# Plot and save the emulated P1D
save_path = f'{save_dir}{sim}.png'
plot_p1d_vs_emulator(testing_data, emulator, save_path=save_path)

return
return

def pytest_addoption(parser):
# Add custom options to pytest command
parser.addoption("--output_dir", action="store", default="tmp/validation_figures", help="Directory to save plots")
# Call the function to execute the test
test()

0 comments on commit dce7da5

Please sign in to comment.