Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] Save dataset_train.reader.class_to_idx inside the model and load it by default during inference and validation #1817

Open
JustinMBrown opened this issue May 19, 2023 · 2 comments
Labels
enhancement New feature or request

Comments

@JustinMBrown
Copy link

Is your feature request related to a problem? Please describe.
I'm not using imagenet, but during inference it loads the imagenet class_map by default.

Describe the solution you'd like
Instead, the class_to_idx from the dataset_train.reader.class_to_idx should just be saved somewhere inside the model, and be loaded into class_map during inference by default. and of course, if someone still wants to override the class_map for whatever reason, they could still do so.

I'd make a PR myself, but y'all probably have other consideration for exactly where to save/load it, so here's a sample solution.

Sample solution:

  # load
  if args.class_map == ''
    ckpt = torch.load(args.checkpoint, map_location='cpu')
    args.class_map = ckpt.get('class_to_idx', '')

  # save
  output_dir = Path(output_dir)
  best_model = list(output_dir.rglob("*best*"))[0]
  ckpt = torch.load(best_model, map_location='cpu')
  ckpt["class_to_idx"] = class_to_idx
  torch.save(ckpt, best_model)

Describe alternatives you've considered
We could save the class_to_idx into a class map file and ship it along side the model, but that's cumbersome and tedious. The proposed solution just works by default.

Additional context
The same should probably be done with the args.yaml file. There are a ton of timm models on hugging face with pretrained weights, but no args.yaml file with them which makes it near impossible to reproduce their results.

@JustinMBrown JustinMBrown added the enhancement New feature or request label May 19, 2023
@rwightman
Copy link
Collaborator

rwightman commented May 19, 2023

@JustinMBrown it's a reasonable idea, only issue is that it ends up being a big change, ALL pretrained checkpoints right now are bare state_dict with no extra layer in the dict, every key is a param/buffer and every value a tensor. The train checkpoints (which do have extra keys ) are stripped of everything but the state dict before being published. I followed torchvision and other 'original' libs when this decision was made long ago.

The timm load functions would support stripping this automatically (and could be modified to extract other specific keys like class maps, but it would break anyone trying to just use torch.load() which works right now... https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/_helpers.py#L31

I think I should stash the args in the train state dict regardless though, I've though about this as I've had numerous instances where I toasted the original train folders in disk cleanup and have only the checkpoint left and lost the hparams :/

Although I will point out, I've had multiple occasions where people have been provided with exact hparams, and I get an angry 'it doesn't work' because they don't understand things change when you change the global batch size, use a different dataset, etc the highest value is seeing templates and building an intuition for how to adapt different recipe templates in different uses ...

I will ponder this some more. FYI if you publish the weights to the HF hub, the pretrained_cfg has fields for labels, https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/_pretrained.py#L41-L42 ... so for inference script and any side-loading it'd make sense to serialize/deserialize a pretraind_cfg instance w/ the weights.. the timm inference widget for Hub loads this as the cfg is built from the config.json file (https://huggingface.co/timm/nf_resnet50.ra2_in1k/blob/main/config.json) . Right now if you pass 'label_names' and 'label_descriptions' fields to the push_to_hub fn in timm vial model_cfg dict, the HF hub widget will do inference with the correct label names https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/_hub.py#L215-L227

See use of inferencing widget (inference.py should be updated to be closer to this by providing pretrained cfg json or state dicts with it embedded ) ... https://github.com/huggingface/api-inference-community/blob/main/docker_images/timm/app/pipelines/image_classification.py#L25-L42

@orena1
Copy link

orena1 commented Nov 14, 2023

Thanks @rwightman for the detailes, I am also intrested in the labels of the classes. I did try the code that you suggsetd:

model_id = 'timm/resnetrs350.tf_in1k'
    
model = timm.create_model(f"hf_hub:{model_id}", pretrained=True)

model.eval()

dataset_info = None
label_names = model.pretrained_cfg.get("label_names", None)
label_descriptions = model.pretrained_cfg.get("label_descriptions", None)
print(label_names,label_descriptions)

But I get None, None any idea how to get the label per index ?
Thank you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants