From 952bdec3cc5bc3e4b701d366b2587b9ea44bb924 Mon Sep 17 00:00:00 2001 From: Kilian Date: Tue, 16 Jan 2024 13:49:22 -0500 Subject: [PATCH] add stochastic interpolants --- examples/images/cifar10/train_cifar10.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/images/cifar10/train_cifar10.py b/examples/images/cifar10/train_cifar10.py index 9e29963..14b8b04 100644 --- a/examples/images/cifar10/train_cifar10.py +++ b/examples/images/cifar10/train_cifar10.py @@ -17,6 +17,7 @@ ConditionalFlowMatcher, ExactOptimalTransportConditionalFlowMatcher, TargetConditionalFlowMatcher, + VariancePreservingConditionalFlowMatcher, ) from torchcfm.models.unet.unet import UNetModelWrapper @@ -128,9 +129,11 @@ def train(argv): FM = ConditionalFlowMatcher(sigma=sigma) elif FLAGS.model == "fm": FM = TargetConditionalFlowMatcher(sigma=sigma) + elif FLAGS.model == "si": + FM = VariancePreservingConditionalFlowMatcher(sigma=sigma) else: raise NotImplementedError( - f"Unknown model {FLAGS.model}, must be one of ['otcfm', 'icfm', 'fm']" + f"Unknown model {FLAGS.model}, must be one of ['otcfm', 'icfm', 'fm', 'si']" ) savedir = FLAGS.output_dir + FLAGS.model + "/" @@ -162,7 +165,7 @@ def train(argv): "optim": optim.state_dict(), "step": step, }, - savedir + f"cifar10_weights_step_{step}.pt", + savedir + f"{FLAGS.model}_cifar10_weights_step_{step}.pt", )