Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding FID statistics calculation as an option (can now do "train", "eval", or "fid_stats") #5

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

AlexiaJM
Copy link

@AlexiaJM AlexiaJM commented Mar 6, 2021

With these small changes, you can get the fid statistics by running with --mode "fid_stats". It loops through the dataset for 1 epoch and extract the FID statistics. That makes it easier to add new datasets.

The only issue is that I am not getting the same FID when evaluating a model on the google drive FID statistics as opposed to the one made by this new mode. Can you verify that my implementation is correct?

I could be misusing the scaler or inverse scaler.

@yang-song
Copy link
Owner

After quickly looking through the code, I think you should always disable uniform dequantization, and in addition:

  1. Samples fed to run_inception_distributed must be with type uint8 and range [0, 255].
  2. The default dataset loader discards the last mini-batch to make each minibatch the same size. You will miss some data points by calculating FID stats with the data loader.

Copy link
Author

@AlexiaJM AlexiaJM left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed lines that should not be there

@AlexiaJM
Copy link
Author

AlexiaJM commented Mar 6, 2021

Alright, will change this.

@AlexiaJM
Copy link
Author

AlexiaJM commented Mar 6, 2021

Will now test to see if I can get the same FID on a model.

@AlexiaJM
Copy link
Author

AlexiaJM commented Mar 6, 2021

Very close results, but not exactly the same, sadly!

FID: 534.9150390625
vs
FID: 535.936279296875

Let me know if you figure out anything to be changed. At least it seems very close now.

@yang-song
Copy link
Owner

Our stats files were computed on TPUs, where they replace float32 with bfloat16 automatically in matrix multiplications, so the computation will have slightly different numerical values from, say, running the same code on GPUs.

Why are the FID scores so large?

@AlexiaJM
Copy link
Author

AlexiaJM commented Mar 6, 2021

Then, that might be fine, floating points errors are acceptable.

It's trained for 8 iterations on a tiny batch 😂.

@AlexiaJM
Copy link
Author

AlexiaJM commented Mar 7, 2021

Good news, I tried with your pre-trained model (cifar10_continuous_ve) on chkpt 24 on a FID on only 2k samples (so it finishes quickly enough).

The results from two runs with the FID statistics from the new code:

ckpt-24 --- inception_score: 9.995276e+00, FID: 1.574258e+01, KID: 3.736457e-04
ckpt-24 --- inception_score: 9.995672e+00, FID: 1.574303e+01, KID: 3.736335e-04

The results from one run with the FID statistics from your google drive:

ckpt-24 --- inception_score: 9.995425e+00, FID: 1.574306e+01, KID: 3.733854e-04

So the new code works. Thanks for your help!

Alexia

@pbizimis
Copy link

pbizimis commented Apr 7, 2021

Hey,
thanks for this code. I am trying to run --mode fid_stats on my score_sde_pytorch model but it does not seem to work. Is there some way to calculate the fid stats for a custom dataset that is trained with the score_sde_pytorch version?

Thanks!

I am using @AlexiaJM's fork with my custom config. I get the following error:

  File "main.py", line 71, in <module>
    app.run(main)
  File "/usr/local/lib/python3.7/dist-packages/absl/app.py", line 300, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.7/dist-packages/absl/app.py", line 251, in _run_main
    sys.exit(main(argv))
  File "main.py", line 65, in main
    run_lib.fid_stats(FLAGS.config, FLAGS.fid_folder)
  File "/content/score_sde/run_lib.py", line 609, in fid_stats
    for batch_id in range(len(train_ds)):
  File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 454, in __len__
    raise TypeError("dataset length is unknown.")
TypeError: dataset length is unknown.

Update
It seems to be a dataset problem. I am using this script to convert my images to tfrecords. I fixed the problem by catching the StopIteration exception once the generator has no next items.

  batch_id = 0

  while True:
    try:
      batch = next(bpd_iter)
    except:
      break

    if jax.host_id() == 0:
      logging.info("Making FID stats -- step: %d" % (batch_id))

    batch_ = jax.tree_map(lambda x: x._numpy(), batch)
    batch_ = (batch_['image']*255).astype(np.uint8).reshape((-1, config.data.image_size, config.data.image_size, 3))

    # Force garbage collection before calling TensorFlow code for Inception network
    gc.collect()
    latents = evaluation.run_inception_distributed(batch_, inception_model,
                                                   inceptionv3=inceptionv3)
    all_pools.append(latents["pool_3"])
    # Force garbage collection again before returning to JAX code
    gc.collect()
    batch_id += 1

Feel free to correct me if I am wrong. I am not sure if this is the best/right solution.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants