This repository showcases the fine-tuning of Google's DistilBERT model for classifying news articles into one of the following categories such as tech, business, politics, entertainment and sports. Included are the preprocessing data(learning.py) to train the model, and prediction scripts(predict.py) for the fine-tuned model. Please note that due to size limitations, the fine-tuned model itself is not included.
News Classification with DistilBERT is a project that harnesses cutting-edge natural language processing (NLP) techniques to categorize news articles into predefined topics. This project relies on Hugging Face Transformers and TensorFlow for its implementation.
- Built a NLP Pipeline to train Google’s DistilBert Large Language Model using TensorFlow and Hugging Face transformers for multi class text classification.
- Fine-Tuned the model with a custom BBC text classification dataset.
- Used DistilBert to pertain 97% of language understanding of Bert Model while reducing the size by 40% and speeding up the training process by 60%.
To begin using this project, follow these steps:
-
Clone the repository to your local machine:
git clone https://github.com/akhmadmamirov/fineTuningBert.git
Install the required dependencies:
pip install transformers tensorflow pandas scikit-learn
Before using the model, you should preprocess your dataset. If you choose to use your own custom dataset, follow these steps:
-
Tokenize your dataset and assign labels to each news article.
-
Convert labels to binary format, ensuring they are compatible with the model's requirements.
-
For a quick start, you can use the provided example dataset (BBC Text Classification) and follow the preprocessing steps outlined in the code.
-
Fine-tune the DistilBERT model on your preprocessed dataset.
-
Save the fine-tuned model in the saved_models directory. (Make sure you have permissions to write in the current path of your OS)
You can use the fine-tuned DistilBERT model to classify news articles. Run the following script and enter the news article text when prompted:
bash Copy code python predict.py The script will classify the news article into one of the following categories: Business, Entertainment, Politics, Sport, Tech based on the given input.
If you want to fine-tune the DistilBERT model on your own dataset, follow these steps:
-
Prepare your dataset in a format similar to the example dataset (BBC Text Classification).
-
Modify the code to load and preprocess your dataset.
-
Fine-tune the model using the TFTrainer.
-
Save the fine-tuned model in the saved_models directory.
Contributions to this project are welcome! If you encounter issues or have suggestions for improvements, please open an issue or create a pull request.
- Please be aware that training time may vary depending on the dataset size and the number of epochs used for training.
- The average training time in my case was <7 hours
- Currently working on train.py file
- @ github.com/rohan-paul
- ChatGPT