Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Feature: Temporal interpolation #168

Draft
wants to merge 6 commits into
base: develop
Choose a base branch
from
Draft

Conversation

Magnus-SI
Copy link
Contributor

@Magnus-SI Magnus-SI commented Nov 27, 2024

Adds temporal interpolation functionality to anemoi. The idea is that a 6 or 12 hour forecaster might yield better predictions going days out than a 1 hour forecaster, as it has to make fewer auto-regressive steps. To produce the hourly predictions still, we can use the information available from the forecaster, e.g. hours 12 and 18 as input to predict hours 13-17. These predictions are made individually, assisted by some information about the target time as input.

This is a work in progress, parts of the implementation can be found on the corresponding branch of anemoi-models.

Implemented

  • A prototype that runs (with decreasing loss, but haven't done any full scale training to verify yet). The interpolator itself is implemented as GraphInterpolator in interpolator.py, with train.py generalized to call a config-based model. If nothing is specified, it defaults to GraphForecaster.
  • The config options multistep and rollout are not applicable to the interpolation case, so I added the option to explicitly state which time steps to use as input, and as a target for the model. This also enables non-regular input for the forecaster, e.g. using input at 0, -1, and -6 hours to make a 1 hour forecast.
  • A corresponding change in the valid dates function, which also includes the possibility for data to be extracted across missing dates if the requested indices can cross the gap. E.g. with a missing date at index 8, and requested dates at 0,3,6. Indices 3,4 and 6,7 are still usable, which would not have been represented correctly before.
  • Option to add temporal forcings at the target time as an input to the model, this is necessary for interpolation at multiple distinct hours with the same model weights.

To do

  • Add support for custom forcing parameters beyond those in the dataset. As of now, only the target time as a fractional difference between the input times is possible and used by default. I will instead move this to a config instantiated object.
  • Run a full scale training and compare with previous results (from the old aifs-mono where I first implemented this). DONE
  • Also do a short training with the forecaster and see that the results are as expected. DONE
  • Implement a way to avoid crossing model runs. When training on data with analysis less frequent than the target interpolation frequency, one has to switch between different model runs at certain points in the dataset. If we for instance use 18 continuous hours of a run before switching, the change from hour 29 of one run to hour 12 of the next run will not represent a physical change in the state of the atmosphere. To avoid training on this, we thus need to consider all sets of inputs and targets that will cross this gap as invalid. I will add this to the usable_indices function. DONE

Questions

  • In train.py I could not get instantiate to work as an instance of the forecaster/interpolator because "config" is a kwarg of instantiate itself, and thus cannot be included in the additional kwargs. Instead I used a combination of importlib and getattr, but if there is a way to use instantiate here, I can replace it.
  • Does anemoi datasets not support extracting dates with irregular indices? I could not get open_dataset to work with indices that do not follow a regular sliceable pattern. Either with an array or a tuple/list of indices. Is there a way around this?
    Although a simple interpolation setup like using hours 0 and 6 to predict hours 1-5 yields a regular range from 0 to 6, irregular ranges would enable more complex setups for both the forecaster and interpolator.

@Magnus-SI Magnus-SI self-assigned this Nov 27, 2024
@FussyDuck
Copy link

FussyDuck commented Nov 27, 2024

CLA assistant check
All committers have signed the CLA.

Comment on lines +147 to 154
train_module = importlib.import_module(getattr(self.config.training, "train_module", "anemoi.training.train.forecaster"))
train_func = getattr(train_module, getattr(self.config.training, "train_function", "GraphForecaster"))
#NOTE: instantiate would be preferable, but I run into issues with "config" being the first kwarg of instantiate itself.
if self.load_weights_only:
LOGGER.info("Restoring only model weights from %s", self.last_checkpoint)
return GraphForecaster.load_from_checkpoint(self.last_checkpoint, **kwargs)
return GraphForecaster(**kwargs)
return train_func.load_from_checkpoint(self.last_checkpoint, **kwargs)
return train_func(**kwargs)

Copy link
Member

@HCookie HCookie Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that the instantiate would be preferable. If we were to delay the instantiatation of the model within the Forecaster, it may be possible to mimic a hydra instantiate call.

The delay will be neccessary to support loading weights only

model = instantiate({'_target_':self.config.get('forecaster'), **kwargs)
if self.load_weights_only:
            LOGGER.info("Restoring only model weights from %s", self.last_checkpoint)
            return train_func.load_from_checkpoint(self.last_checkpoint, **kwargs)
return model

Copy link
Contributor Author

@Magnus-SI Magnus-SI Nov 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, when adding recursive = False as an argument as well, that works to instantiate the model. However, after an epoch is complete I get "TypeError: Object of type DictConfig is not JSON serializable" during saving of metadata for the checkpoint. That should be fixable though.
As for loading weights only, it seems https://github.com/ecmwf/anemoi-training/tree/feature/ckpo_loading_skip_mismatched moves this to train.py, so the model can be instantiated beforehand without problem. I will wait until this reaches develop and pull it to this branch, then add the instantiation.

Comment on lines +46 to +48
class GraphInterpolator(GraphForecaster):
"""Graph neural network interpolator for PyTorch Lightning."""

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this work on the Interpolator. It's a good example that the GraphForecaster class needs some work and to be broken into a proper class structure.
What are your thoughts on which components are reusable and then in counter, which parts are typical to override?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a mix of both, as well as some components that are needed only for the forecaster and some only for the interpolator.

Reusable

  • All of the init function, except for rollout and multistep.
  • All of the instantiable objects: loss, metrics, the model, etc.
  • The scheduler and optimizers, which should maybe become an instantiated object anyway.
  • The training/validation_step functions
  • calculate_val_metrics: by reusing the rollout_step label as interp_step instead.

Overwritten

  • _step and forward

Only for the forecaster/interpolator

  • advance_input and rollout_step
  • target forcings (although these could also be useful for the forecaster)

To avoid inheriting unused components with the Interpolator, we could consider using a framework class containing only the common components between the forecaster and interpolator, then have both inherit this class. However, that might be a bit too much when there are only two options thus far.
In fact, the forecaster can be seen as a special case of the interpolator, since the boundary can be specified as the multistep input, and the target can be any time, including the future. If I implement rollout functionality to the interpolator and make the target forcings optional, I think it should be able to do anything the forecaster can.

In my opinion, it would be the best approach to merge the two this way. It also enables the option to train a combined forecaster/interpolator, instead of having two separate models.
Do you agree with merging the two, or should I make a base framework class for both to inherit, or just keep them as is?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I would lean towards making a base framework class. There are other use cases coming down the pipeline that would need this.
Although I am intrigued by the idea of have a class that can do both together.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants