Pipeline for training Multiclass classifier in PyTorch with Timm model library and TensorRT/Onnx Export
The script expects images/labels info stored in CSV file and dataset in folder. Check Scene Classification dataset example. Link in reference
Command : python3 train.py --model_name resnet50 --epochs 100 --batch_size 64 --lr 0.0001 --img_size 256 --device 0 --optimizer adam --lr_scheduler CosineAnnealingLR --dataset /home/SceneData --split 0.2 --target_size 3 --early_stop 10 --loss_func CrossEntropyLoss --save_checkpoint_folder ./checkpoints --save_model_folder ./weights --exp_name testExp --labels night,day,noon --wandb --projec_name SceneClassifier --seed 22 --workers 4
If you want to train on dataset which is stored in folders then pass :
python3 train.py OTHER_ARGUEMENTS_AS_MENTIONED_ABOVE --image_folder flag along with --train_image_folder trainFolder --val_image_folder validFolder
In this case, dataset structure should be like ->
DatasetName
|-> train_data
|-> class_1
|-> class_2
|-> class_3
|-> val_data
|-> class_1
|-> class_2
|-> class_3
- Custom model file with Timm models
- Custom dataset file
- Creates Experiment folder allowing you to run continous training jobs.
- Support training for dataset stored in folder or in CSV
- Categorical Labels - [1], [2], [3]...
- Displays Training condiguration so you cross check the input
- Run validation on Train and test set and saves ConfusionMatrix as PNG.
- Use of garbage collector and torch's method to clear GPU cache
- Early stopping to save your time and resources
- Saves model checkpoint and weight file as accuracy improves
- Albumentations for image augmentation
- Added options for Loss functions
- Multiple LR Schedulers
- Added WandB support
- Added Number of workers parameter
- Added Seed to help reproduce experiment
- Added Torch inference code
- Added Torch to onnx model export
- Added Onnx inference code
- Added Onnx to TensorRT conversion code
- Added TensorRT inference code
- Heatmap of features
- Add LabelSmoothing
- Add Gradient Clipping
- Add Mixed Precision Training
- Install venv (recommended)
- Install all the requirements using requirements.txt
- To install TensorRT, refer Nvidia's Link
- Scene Classification Dataset
- Kaggle Notebook : Transfer Learning with Timm
- Kaggle Notebook : EfficientNet Mixup Leak free
- Kaggle Notebook : Scene classification
- Convert PyTorch model to TensorRT - Link
- Getting started with PyTorch model & Timm - Link