From 5a6f1006bb7c44659369cf6428143597bfee01e0 Mon Sep 17 00:00:00 2001 From: grantbuster Date: Tue, 3 Dec 2024 13:38:40 -0700 Subject: [PATCH 01/13] calculate solarcc model content loss on daily mean fields - 8hrs for actuals, 24hrs for synthetic, to encourage realism in nighttime hours for synthetic --- sup3r/models/solar_cc.py | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/sup3r/models/solar_cc.py b/sup3r/models/solar_cc.py index 6ba3d1de1..dd3152e26 100644 --- a/sup3r/models/solar_cc.py +++ b/sup3r/models/solar_cc.py @@ -16,8 +16,9 @@ class SolarCC(Sup3rGan): Note ---- *Modifications to standard Sup3rGan* - - Content loss is only on the n_days of the center 8 daylight hours of - the daily true+synthetic high res samples + - Content loss is only on the n_days of the temporal-mean of the center + 8 daylight hours of the daily true + 24hours of synthetic high res + samples - Discriminator only sees n_days of the center 8 daylight hours of the daily true high res sample. - Discriminator sees random n_days of 8-hour samples of the daily @@ -142,7 +143,9 @@ def calc_loss( t_len = hi_res_true.shape[3] n_days = int(t_len // 24) - day_slices = [ + + # slices for middle-daylight-hours + sub_day_slices = [ slice( self.STARTING_HOUR + x, self.STARTING_HOUR + x + self.DAYLIGHT_HOURS, @@ -150,24 +153,30 @@ def calc_loss( for x in range(0, 24 * n_days, 24) ] + # slices for 24-hour full days + day_24h_slices = [slice(x, x + 24) for x in range(0, 24 * n_days, 24)] + # sample only daylight hours for disc training and gen content loss disc_out_true = [] disc_out_gen = [] loss_gen_content = 0.0 - for tslice in day_slices: - disc_t = self._tf_discriminate(hi_res_true[:, :, :, tslice, :]) - gen_c = self.calc_loss_gen_content( - hi_res_true[:, :, :, tslice, :], hi_res_gen[:, :, :, tslice, :] - ) - disc_out_true.append(disc_t) + for tslice_sub, tslice_24h in zip(sub_day_slices, day_24h_slices): + hr_true_slice = hi_res_true[:, :, :, tslice_sub, :] + hr_gen_slice = hi_res_gen[:, :, :, tslice_24h, :] + hr_true_slice = tf.math.reduce_mean(hr_true_slice, axis=3) + hr_gen_slice = tf.math.reduce_mean(hr_gen_slice, axis=3) + gen_c = self.calc_loss_gen_content(hr_true_slice, hr_gen_slice) loss_gen_content += gen_c + disc_t = self._tf_discriminate(hi_res_true[:, :, :, tslice_sub, :]) + disc_out_true.append(disc_t) + # Randomly sample daylight windows from generated data. Better than # strided samples covering full day because the random samples will # provide an evenly balanced training set for the disc logits = [[1.0] * (t_len - self.DAYLIGHT_HOURS)] - time_samples = tf.random.categorical(logits, len(day_slices)) - for i in range(len(day_slices)): + time_samples = tf.random.categorical(logits, len(sub_day_slices)) + for i in range(len(sub_day_slices)): t0 = time_samples[0, i] t1 = t0 + self.DAYLIGHT_HOURS disc_g = self._tf_discriminate(hi_res_gen[:, :, :, t0:t1, :]) @@ -177,7 +186,7 @@ def calc_loss( disc_out_gen = tf.concat([disc_out_gen], axis=0) loss_disc = self.calc_loss_disc(disc_out_true, disc_out_gen) - loss_gen_content /= len(day_slices) + loss_gen_content /= len(sub_day_slices) loss_gen_advers = self.calc_loss_gen_advers(disc_out_gen) loss_gen = loss_gen_content + weight_gen_advers * loss_gen_advers From 31a8dbc0b439dc371f2cd8a49d19b8dfe06c8f80 Mon Sep 17 00:00:00 2001 From: grantbuster Date: Wed, 4 Dec 2024 14:41:49 -0700 Subject: [PATCH 02/13] solarcc content loss needs to be a mix of 24h average and 8 hour pointwise --- sup3r/models/solar_cc.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/sup3r/models/solar_cc.py b/sup3r/models/solar_cc.py index dd3152e26..de4f4c8fa 100644 --- a/sup3r/models/solar_cc.py +++ b/sup3r/models/solar_cc.py @@ -16,9 +16,9 @@ class SolarCC(Sup3rGan): Note ---- *Modifications to standard Sup3rGan* - - Content loss is only on the n_days of the temporal-mean of the center - 8 daylight hours of the daily true + 24hours of synthetic high res - samples + - Content loss is only on the n_days of the center 8 daylight hours of + the daily true + synthetic and the temporal mean of the 24hours of + synthetic - Discriminator only sees n_days of the center 8 daylight hours of the daily true high res sample. - Discriminator sees random n_days of 8-hour samples of the daily @@ -161,12 +161,16 @@ def calc_loss( disc_out_gen = [] loss_gen_content = 0.0 for tslice_sub, tslice_24h in zip(sub_day_slices, day_24h_slices): - hr_true_slice = hi_res_true[:, :, :, tslice_sub, :] - hr_gen_slice = hi_res_gen[:, :, :, tslice_24h, :] - hr_true_slice = tf.math.reduce_mean(hr_true_slice, axis=3) - hr_gen_slice = tf.math.reduce_mean(hr_gen_slice, axis=3) - gen_c = self.calc_loss_gen_content(hr_true_slice, hr_gen_slice) - loss_gen_content += gen_c + hr_true_sub = hi_res_true[:, :, :, tslice_sub, :] + hr_gen_sub = hi_res_gen[:, :, :, tslice_sub, :] + hr_gen_24h = hi_res_gen[:, :, :, tslice_24h, :] + + hr_true_mean = tf.math.reduce_mean(hr_true_sub, axis=3) + hr_gen_mean = tf.math.reduce_mean(hr_gen_24h, axis=3) + + gen_c_sub = self.calc_loss_gen_content(hr_true_sub, hr_gen_sub) + gen_c_24h = self.calc_loss_gen_content(hr_true_mean, hr_gen_mean) + loss_gen_content += gen_c_24h + gen_c_sub disc_t = self._tf_discriminate(hi_res_true[:, :, :, tslice_sub, :]) disc_out_true.append(disc_t) From 34fc8ca080129afef02feb7e9b58606a851d0d7d Mon Sep 17 00:00:00 2001 From: grantbuster Date: Wed, 4 Dec 2024 16:29:11 -0700 Subject: [PATCH 03/13] fix solar loss test - no longer zero --- tests/training/test_train_solar.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/training/test_train_solar.py b/tests/training/test_train_solar.py index 7888124f9..167c493fd 100644 --- a/tests/training/test_train_solar.py +++ b/tests/training/test_train_solar.py @@ -246,4 +246,3 @@ def test_solar_custom_loss(): ) assert loss1 > loss2 - assert loss2 == 0 From 139c966d999dea1e79a2621f99ec75611ba73f04 Mon Sep 17 00:00:00 2001 From: grantbuster Date: Wed, 4 Dec 2024 17:03:41 -0700 Subject: [PATCH 04/13] pointwise loss should only encourage the central few hours so that sup3rcc-solar dynamics are more dynamic --- sup3r/models/solar_cc.py | 63 ++++++++++++++++++++++++++-------------- 1 file changed, 42 insertions(+), 21 deletions(-) diff --git a/sup3r/models/solar_cc.py b/sup3r/models/solar_cc.py index de4f4c8fa..c44dea1d0 100644 --- a/sup3r/models/solar_cc.py +++ b/sup3r/models/solar_cc.py @@ -16,28 +16,38 @@ class SolarCC(Sup3rGan): Note ---- *Modifications to standard Sup3rGan* - - Content loss is only on the n_days of the center 8 daylight hours of - the daily true + synthetic and the temporal mean of the 24hours of - synthetic - - Discriminator only sees n_days of the center 8 daylight hours of the - daily true high res sample. - - Discriminator sees random n_days of 8-hour samples of the daily - synthetic high res sample. + - Pointwise content loss (MAE/MSE) is only on the center 2 daylight + hours (POINT_LOSS_HOURS) of the daily true + synthetic days and the + temporal mean of the 24hours of synthetic for n_days + (usually just 1 day) + - Discriminator only sees n_days of the center 8 daylight hours + (DAYLIGHT_HOURS and STARTING_HOUR) of the daily true high res sample. + - Discriminator sees random n_days of 8-hour samples (DAYLIGHT_HOURS) + of the daily synthetic high res sample. - Includes padding on high resolution output of :meth:`generate` so that forward pass always outputs a multiple of 24 hours. """ - # starting hour is the hour that daylight starts at, daylight hours is the - # number of daylight hours to sample, so for example if 8 and 8, the - # daylight slice will be slice(8, 16). The stride length is the step size - # for sampling the temporal axis of the generated data to send to the - # discriminator for the adversarial loss component of the generator. For - # example, if the generator produces 24 timesteps and stride is 4 and the - # daylight hours is 8, slices of (0, 8) (4, 12), (8, 16), (12, 20), and - # (16, 24) will be sent to the disc. STARTING_HOUR = 8 + """Starting hour is the hour that daylight starts at, typically + zero-indexed and rolled to local time""" + DAYLIGHT_HOURS = 8 + """Daylight hours is the number of daylight hours to sample, so for example + if 8 and 8, the daylight slice will be slice(8, 16). """ + STRIDE_LEN = 4 + """The stride length is the step size for sampling the temporal axis of the + generated data to send to the discriminator for the adversarial loss + component of the generator. For example, if the generator produces 24 + timesteps and stride is 4 and the daylight hours is 8, + slices of (0, 8) (4, 12), (8, 16), (12, 20), and (16, 24) will be sent to + the disc.""" + + POINT_LOSS_HOURS = 2 + """Number of hours from the center of the day to calculate pointwise loss + from, e.g., MAE/MSE based on data from the true 4km hourly high res + field.""" def __init__(self, *args, t_enhance=None, **kwargs): """Add optional t_enhance adjustment. @@ -144,6 +154,9 @@ def calc_loss( t_len = hi_res_true.shape[3] n_days = int(t_len // 24) + # slices for 24-hour full days + day_24h_slices = [slice(x, x + 24) for x in range(0, 24 * n_days, 24)] + # slices for middle-daylight-hours sub_day_slices = [ slice( @@ -153,26 +166,34 @@ def calc_loss( for x in range(0, 24 * n_days, 24) ] - # slices for 24-hour full days - day_24h_slices = [slice(x, x + 24) for x in range(0, 24 * n_days, 24)] + # slices for middle-pointwise-loss-hours + point_loss_slices = [ + slice( + (24 - self.POINT_LOSS_HOURS) // 2 + x, + (24 - self.POINT_LOSS_HOURS) // 2 + x + self.POINT_LOSS_HOURS, + ) + for x in range(0, 24 * n_days, 24) + ] # sample only daylight hours for disc training and gen content loss disc_out_true = [] disc_out_gen = [] loss_gen_content = 0.0 - for tslice_sub, tslice_24h in zip(sub_day_slices, day_24h_slices): + ziter = zip(sub_day_slices, point_loss_slices, day_24h_slices) + for tslice_sub, tslice_ploss, tslice_24h in ziter: hr_true_sub = hi_res_true[:, :, :, tslice_sub, :] - hr_gen_sub = hi_res_gen[:, :, :, tslice_sub, :] hr_gen_24h = hi_res_gen[:, :, :, tslice_24h, :] + hr_true_ploss = hi_res_true[:, :, :, tslice_ploss, :] + hr_gen_ploss = hi_res_gen[:, :, :, tslice_ploss, :] hr_true_mean = tf.math.reduce_mean(hr_true_sub, axis=3) hr_gen_mean = tf.math.reduce_mean(hr_gen_24h, axis=3) - gen_c_sub = self.calc_loss_gen_content(hr_true_sub, hr_gen_sub) + gen_c_sub = self.calc_loss_gen_content(hr_true_ploss, hr_gen_ploss) gen_c_24h = self.calc_loss_gen_content(hr_true_mean, hr_gen_mean) loss_gen_content += gen_c_24h + gen_c_sub - disc_t = self._tf_discriminate(hi_res_true[:, :, :, tslice_sub, :]) + disc_t = self._tf_discriminate(hr_true_sub) disc_out_true.append(disc_t) # Randomly sample daylight windows from generated data. Better than From 5e43b052ebef1f2017d60b3736e6d20f24839de2 Mon Sep 17 00:00:00 2001 From: grantbuster Date: Sat, 14 Dec 2024 13:35:48 -0700 Subject: [PATCH 05/13] bug fix - broken syntax with flatten method not called, but flatten actually shouldn't even be used here --- sup3r/preprocessing/rasterizers/extended.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sup3r/preprocessing/rasterizers/extended.py b/sup3r/preprocessing/rasterizers/extended.py index 8aca23b88..794847ae7 100644 --- a/sup3r/preprocessing/rasterizers/extended.py +++ b/sup3r/preprocessing/rasterizers/extended.py @@ -193,7 +193,8 @@ def get_lat_lon(self): return self._get_flat_data_lat_lon() def _get_flat_data_lat_lon(self): - """Get lat lon for flattened source data.""" + """Get lat lon for flattened source data. Output is shape (y, x, 2) + where 2 is (lat, lon)""" if hasattr(self.full_lat_lon, 'vindex'): return self.full_lat_lon.vindex[self.raster_index] - return self.full_lat_lon[self.raster_index.flatten] + return self.full_lat_lon[self.raster_index] From 7862646a6eb5f811e71d2b7a24d09015db2024e4 Mon Sep 17 00:00:00 2001 From: grantbuster Date: Sun, 15 Dec 2024 14:30:54 -0700 Subject: [PATCH 06/13] updated number of pointwise loss hours based on successful 100kmdaily->100kmhourly training experiments --- sup3r/models/solar_cc.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sup3r/models/solar_cc.py b/sup3r/models/solar_cc.py index c44dea1d0..0e45068fc 100644 --- a/sup3r/models/solar_cc.py +++ b/sup3r/models/solar_cc.py @@ -34,7 +34,8 @@ class SolarCC(Sup3rGan): DAYLIGHT_HOURS = 8 """Daylight hours is the number of daylight hours to sample, so for example - if 8 and 8, the daylight slice will be slice(8, 16). """ + if STARTING_HOUR is 8 and DAYLIGHT_HOURS is 8, the daylight slice will be + slice(8, 16). """ STRIDE_LEN = 4 """The stride length is the step size for sampling the temporal axis of the @@ -44,7 +45,7 @@ class SolarCC(Sup3rGan): slices of (0, 8) (4, 12), (8, 16), (12, 20), and (16, 24) will be sent to the disc.""" - POINT_LOSS_HOURS = 2 + POINT_LOSS_HOURS = 8 """Number of hours from the center of the day to calculate pointwise loss from, e.g., MAE/MSE based on data from the true 4km hourly high res field.""" From 505ab5fc769342b442719d68db1c368abc5ab6a0 Mon Sep 17 00:00:00 2001 From: grantbuster Date: Sun, 15 Dec 2024 14:31:21 -0700 Subject: [PATCH 07/13] default timezone roll in sup3rcc NSRDB training data is -7 for MST, not CST as was stated --- sup3r/solar/solar.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sup3r/solar/solar.py b/sup3r/solar/solar.py index 5f40381f3..3a701a8e4 100644 --- a/sup3r/solar/solar.py +++ b/sup3r/solar/solar.py @@ -34,7 +34,7 @@ def __init__( sup3r_fps, nsrdb_fp, t_slice=slice(None), - tz=-6, + tz=-7, agg_factor=1, nn_threshold=0.5, cloud_threshold=0.99, @@ -64,8 +64,8 @@ def __init__( tz : int The timezone offset for the data in sup3r_fps. It is assumed that the GAN is trained on data in local time and therefore the output - in sup3r_fps should be treated as local time. For example, -6 is - CST which is default for CONUS training data. + in sup3r_fps should be treated as local time. For example, -7 is + MST which is default for CONUS training data. agg_factor : int Spatial aggregation factor for nsrdb-to-GAN-meta e.g. the number of NSRDB spatial pixels to average for a single sup3r GAN output site. @@ -585,7 +585,7 @@ def run_temporal_chunks( fp_pattern, nsrdb_fp, fp_out_suffix='irradiance', - tz=-6, + tz=-7, agg_factor=1, nn_threshold=0.5, cloud_threshold=0.99, @@ -610,8 +610,8 @@ def run_temporal_chunks( tz : int The timezone offset for the data in sup3r_fps. It is assumed that the GAN is trained on data in local time and therefore the output - in sup3r_fps should be treated as local time. For example, -6 is - CST which is default for CONUS training data. + in sup3r_fps should be treated as local time. For example, -7 is + MST which is default for CONUS training data. agg_factor : int Spatial aggregation factor for nsrdb-to-GAN-meta e.g. the number of NSRDB spatial pixels to average for a single sup3r GAN output site. @@ -663,7 +663,7 @@ def _run_temporal_chunk( fp_pattern, nsrdb_fp, fp_out_suffix='irradiance', - tz=-6, + tz=-7, agg_factor=1, nn_threshold=0.5, cloud_threshold=0.99, From c197b9b378753c261dd4ca850f0bd62342dcd09d Mon Sep 17 00:00:00 2001 From: grantbuster Date: Thu, 19 Dec 2024 08:44:31 -0700 Subject: [PATCH 08/13] added latitude/longitude features to the base registry --- sup3r/preprocessing/derivers/methods.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/sup3r/preprocessing/derivers/methods.py b/sup3r/preprocessing/derivers/methods.py index 50471664a..e2774eebf 100644 --- a/sup3r/preprocessing/derivers/methods.py +++ b/sup3r/preprocessing/derivers/methods.py @@ -384,8 +384,6 @@ class TasMax(Tas): class Sza(DerivedFeature): """Solar zenith angle derived feature.""" - inputs = () - @classmethod def compute(cls, data): """Compute method for sza.""" @@ -393,6 +391,24 @@ def compute(cls, data): return sza.astype(np.float32) +class LatitudeFeature(DerivedFeature): + """Latitude treated as a training feature""" + + @classmethod + def compute(cls, data): + """Method to compute latitude feature array""" + return data.coords['latitude'] + + +class LongitudeFeature(DerivedFeature): + """Longitude treated as a training feature""" + + @classmethod + def compute(cls, data): + """Method to compute longitude feature array""" + return data.coords['longitude'] + + RegistryBase = { 'u_(.*)': UWind, 'v_(.*)': VWind, @@ -402,6 +418,8 @@ def compute(cls, data): 'cloud_mask': CloudMask, 'clearsky_ratio': ClearSkyRatio, 'sza': Sza, + 'latitude_feature': LatitudeFeature, + 'longitude_feature': LongitudeFeature, } RegistryH5WindCC = { From 4daf683c6d0aca50fbdfb896eaa69c00f6e31815 Mon Sep 17 00:00:00 2001 From: grantbuster Date: Thu, 19 Dec 2024 08:55:41 -0700 Subject: [PATCH 09/13] updated pointwise loss hours based on latest successful experiment, removed stride var no longer used --- sup3r/models/solar_cc.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/sup3r/models/solar_cc.py b/sup3r/models/solar_cc.py index 0e45068fc..748694018 100644 --- a/sup3r/models/solar_cc.py +++ b/sup3r/models/solar_cc.py @@ -37,15 +37,7 @@ class SolarCC(Sup3rGan): if STARTING_HOUR is 8 and DAYLIGHT_HOURS is 8, the daylight slice will be slice(8, 16). """ - STRIDE_LEN = 4 - """The stride length is the step size for sampling the temporal axis of the - generated data to send to the discriminator for the adversarial loss - component of the generator. For example, if the generator produces 24 - timesteps and stride is 4 and the daylight hours is 8, - slices of (0, 8) (4, 12), (8, 16), (12, 20), and (16, 24) will be sent to - the disc.""" - - POINT_LOSS_HOURS = 8 + POINT_LOSS_HOURS = 2 """Number of hours from the center of the day to calculate pointwise loss from, e.g., MAE/MSE based on data from the true 4km hourly high res field.""" From da48b1469ba3e9a2afe32b251894cb864e338245 Mon Sep 17 00:00:00 2001 From: grantbuster Date: Thu, 19 Dec 2024 10:20:36 -0700 Subject: [PATCH 10/13] added lat/lon_feature to data handler test for cc data --- tests/data_handlers/test_dh_nc_cc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/data_handlers/test_dh_nc_cc.py b/tests/data_handlers/test_dh_nc_cc.py index 19100ec0b..9dd868023 100644 --- a/tests/data_handlers/test_dh_nc_cc.py +++ b/tests/data_handlers/test_dh_nc_cc.py @@ -122,11 +122,11 @@ def test_data_handling_nc_cc(): handler = DataHandlerNCforCC( pytest.FPS_GCM, - features=['u_100m', 'v_100m'], + features=['u_100m', 'v_100m', 'latitude_feature', 'longitude_feature'], target=target, shape=(20, 20), ) - assert handler.data.shape == (20, 20, 20, 2) + assert handler.data.shape == (20, 20, 20, 4) # upper case features warning features = [f'U_{int(plevel)}pa', f'V_{int(plevel)}pa'] From 9f5dfcd800a120f0c27bdb0edb67d2821aa5b003 Mon Sep 17 00:00:00 2001 From: grantbuster Date: Thu, 19 Dec 2024 11:00:56 -0700 Subject: [PATCH 11/13] simplify lat/lon_feature registry --- sup3r/preprocessing/derivers/methods.py | 22 ++-------------------- 1 file changed, 2 insertions(+), 20 deletions(-) diff --git a/sup3r/preprocessing/derivers/methods.py b/sup3r/preprocessing/derivers/methods.py index e2774eebf..5c4f60f36 100644 --- a/sup3r/preprocessing/derivers/methods.py +++ b/sup3r/preprocessing/derivers/methods.py @@ -391,24 +391,6 @@ def compute(cls, data): return sza.astype(np.float32) -class LatitudeFeature(DerivedFeature): - """Latitude treated as a training feature""" - - @classmethod - def compute(cls, data): - """Method to compute latitude feature array""" - return data.coords['latitude'] - - -class LongitudeFeature(DerivedFeature): - """Longitude treated as a training feature""" - - @classmethod - def compute(cls, data): - """Method to compute longitude feature array""" - return data.coords['longitude'] - - RegistryBase = { 'u_(.*)': UWind, 'v_(.*)': VWind, @@ -418,8 +400,8 @@ def compute(cls, data): 'cloud_mask': CloudMask, 'clearsky_ratio': ClearSkyRatio, 'sza': Sza, - 'latitude_feature': LatitudeFeature, - 'longitude_feature': LongitudeFeature, + 'latitude_feature': 'latitude', + 'longitude_feature': 'longitude', } RegistryH5WindCC = { From 9c51beccd1245ae083eb9cbe78e09b742bfc7e45 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Thu, 19 Dec 2024 13:37:54 -0700 Subject: [PATCH 12/13] preload of coords added to compute and h5 preload test added. --- sup3r/preprocessing/accessor.py | 5 +++-- tests/rasterizers/test_rasterizer_general.py | 12 ++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/sup3r/preprocessing/accessor.py b/sup3r/preprocessing/accessor.py index 598bdeac1..8eaa28d90 100644 --- a/sup3r/preprocessing/accessor.py +++ b/sup3r/preprocessing/accessor.py @@ -230,8 +230,9 @@ def compute(self, **kwargs): logger.debug(f'Loading dataset into memory: {self._ds}') logger.debug(f'Pre-loading: {_mem_check()}') - for f in self._ds.data_vars: - self._ds[f] = self._ds[f].compute(**kwargs) + for f in list(self._ds.data_vars) + list(self._ds.coords): + if hasattr(self._ds[f], 'compute'): + self._ds[f] = self._ds[f].compute(**kwargs) logger.debug( f'Loaded {f} into memory with shape ' f'{self._ds[f].shape}. {_mem_check()}' diff --git a/tests/rasterizers/test_rasterizer_general.py b/tests/rasterizers/test_rasterizer_general.py index e2e1be93a..0b7230da1 100644 --- a/tests/rasterizers/test_rasterizer_general.py +++ b/tests/rasterizers/test_rasterizer_general.py @@ -80,3 +80,15 @@ def test_topography_h5(): topo = res.get_meta_arr('elevation')[ri.flatten(),] topo = topo.reshape((ri.shape[0], ri.shape[1])) assert np.allclose(topo, rasterizer['topography'][..., 0]) + + +def test_preloaded_h5(): + """Test preload of h5 file""" + rasterizer = Rasterizer( + file_paths=pytest.FP_WTK, + target=(39.01, -105.15), + shape=(20, 20), + chunks=None, + ) + for f in list(rasterizer.data.data_vars) + list(Dimension.coords_2d()): + assert isinstance(rasterizer[f].data, np.ndarray) From ad62263d1d927315f62254ca7bf9cfb4ffab1eab Mon Sep 17 00:00:00 2001 From: grantbuster Date: Thu, 19 Dec 2024 14:22:08 -0700 Subject: [PATCH 13/13] minor bug: disc never saw last timestep because of a +1 index error --- sup3r/models/solar_cc.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sup3r/models/solar_cc.py b/sup3r/models/solar_cc.py index 748694018..f38ccffe5 100644 --- a/sup3r/models/solar_cc.py +++ b/sup3r/models/solar_cc.py @@ -192,9 +192,9 @@ def calc_loss( # Randomly sample daylight windows from generated data. Better than # strided samples covering full day because the random samples will # provide an evenly balanced training set for the disc - logits = [[1.0] * (t_len - self.DAYLIGHT_HOURS)] - time_samples = tf.random.categorical(logits, len(sub_day_slices)) - for i in range(len(sub_day_slices)): + logits = [[1.0] * (t_len - self.DAYLIGHT_HOURS + 1)] + time_samples = tf.random.categorical(logits, n_days) + for i in range(n_days): t0 = time_samples[0, i] t1 = t0 + self.DAYLIGHT_HOURS disc_g = self._tf_discriminate(hi_res_gen[:, :, :, t0:t1, :])