Skip to content

Commit

Permalink
add stochastic interpolants
Browse files Browse the repository at this point in the history
  • Loading branch information
kilianFatras committed Jan 16, 2024
1 parent cfafc28 commit 952bdec
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions examples/images/cifar10/train_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
ConditionalFlowMatcher,
ExactOptimalTransportConditionalFlowMatcher,
TargetConditionalFlowMatcher,
VariancePreservingConditionalFlowMatcher,
)
from torchcfm.models.unet.unet import UNetModelWrapper

Expand Down Expand Up @@ -128,9 +129,11 @@ def train(argv):
FM = ConditionalFlowMatcher(sigma=sigma)
elif FLAGS.model == "fm":
FM = TargetConditionalFlowMatcher(sigma=sigma)
elif FLAGS.model == "si":
FM = VariancePreservingConditionalFlowMatcher(sigma=sigma)
else:
raise NotImplementedError(
f"Unknown model {FLAGS.model}, must be one of ['otcfm', 'icfm', 'fm']"
f"Unknown model {FLAGS.model}, must be one of ['otcfm', 'icfm', 'fm', 'si']"
)

savedir = FLAGS.output_dir + FLAGS.model + "/"
Expand Down Expand Up @@ -162,7 +165,7 @@ def train(argv):
"optim": optim.state_dict(),
"step": step,
},
savedir + f"cifar10_weights_step_{step}.pt",
savedir + f"{FLAGS.model}_cifar10_weights_step_{step}.pt",
)


Expand Down

0 comments on commit 952bdec

Please sign in to comment.