Skip to content

Commit

Permalink
Merge pull request #1235 from K-sohooli/cross-validation
Browse files Browse the repository at this point in the history
an update on article Detecting Overfitting with K-Fold Cross-Validation
  • Loading branch information
lachlandeer authored Aug 9, 2024
2 parents 4955986 + 4117052 commit 2a936ca
Showing 1 changed file with 23 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,42 +9,42 @@ authorlink: "https://tilburgsciencehub.com/contributors/kheirysohooli/"
---

## Overview
In the realm of Machine learning, model evaluation is an crucial for assessing their reliability. It helps in detecting overfitting and underfitting after training algorithms on datasets. Cross-validation stands out as the best approach to pinpoint overfitting.
In this short article, you'll become familiar with the primary causes of overfitting and briefly introduce solutions to address it. Additionally, we'll explore K-fold cross-validation, the main method for detecting overfitting.
In the realm of machine learning, model evaluation is crucial for assessing their reliability. It helps in detecting overfitting and underfitting after training algorithms on datasets. Cross-validation stands out as the best approach to pinpoint overfitting.
In this short article, you'll become familiar with the primary causes of overfitting and some solutions to address them. Additionally, we'll explore K-fold cross-validation, the main method for detecting overfitting.

## Overfitting
Overfitting occurs when a model performs well on the training data but poorly on the test data. For example, in the image below, the training accuracy is close to 1, while the test accuracy is much lower, less than 0.7. In this case, the model fails to generalize to new samples.
<div align="center">
<img src="../images/overfitting.png" width="400">
</div>

### Why overfitting happen
### Why overfitting happens
Overfitting generally happens when the model is too complex relative to the amount of available data. For example, consider a neural network model with a complex structure trained on a small dataset. Some of the reasons that cause overfitting are as follows:

1- Complex Model: When the model includes many features, it can become overly complex, detecting noise as well.
1. *Complex Model:* When the model includes many features, it can become overly complex, detecting noise as well.

2- Small test dataset: If the test dataset is small and doesn't adequately represent all possible input data, it may not provide a reliable measure of the model's performance.
2. *Small test dataset:* If the test dataset is small and doesn't adequately represent all possible input data, it may not provide a reliable measure of the model's performance.

3- Training Data Errors and Outliers: Errors and outliers in the training data introduce noise. If the algorithm interprets this noise as meaningful patterns, it can negatively impact the model's ability to generalize to new data.
3. *Training Data Errors and Outliers:* Errors and outliers in the training data introduce noise. If the algorithm interprets this noise as meaningful patterns, it can negatively impact the model's ability to generalize to new data.

### Solutions for overfitting
Several solutions can mitigate overfitting:

1- Reducing Model Complexity: Simplify the model by decreasing the number of parameters or attributes. For example, in linear regression, include only the most predictive features instead of all available features. As another example, simplifying the structure of neural network models by reducing the number of nodes and layers in a complex architecture.
1. *Reducing Model Complexity:* Simplify the model by decreasing the number of parameters or attributes. For example, in linear regression, include only the most predictive features instead of all available features. Another example is simplifying the structure of neural network models by reducing the number of nodes and layers in a complex architecture.

2- Increasing Sample Size: Gathering more real data can help, although it may be costly. Alternatively, augmentation methods can add samples.
2. *Increasing Sample Size:* Gathering more real data can help, although it may be costly. Alternatively, augmentation methods can add samples.

3- Noise Reduction: Clean the data by removing outliers and erroneous data to prevent the model from fitting to noisy data.
3. *Noise Reduction:* Clean the data by removing outliers and erroneous data to prevent the model from fitting to noisy data.

4- Regularization: Regularization in linear regression is a technique used to prevent overfitting. When a linear model has large coefficients for its features, it tends to overfit the training data. Regularization terms like L1 and L2 works by penalizing large coefficients, which can decrease them or even set them to zero. This process makes the model smoother and helps prevent overfitting.
4. *Regularization:* Regularization in linear regression is a technique used to prevent overfitting. When a linear model has large coefficients for its features, it tends to overfit the training data. Regularization terms like [L1 and L2](https://tilburgsciencehub.com/topics/analyze/machine-learning/supervised/ml_objective_functions/) works by penalizing large coefficients, which can decrease them or even set them to zero. This process makes the model smoother and helps prevent overfitting.

5- Hyperparameter tuning: By finding the best combination of hyperparameters, the model's complexity can be managed. Examples of hyperparameters include the learning rate in gradient descent and the depth of a decision tree. Overall, hyperparameter tuning helps a balance between model complexity and generalization performance, ultimately reducing the risk of overfitting.
5. *Hyperparameter tuning:* By finding the best combination of hyperparameters, the model's complexity can be managed. Examples of hyperparameters include the learning rate in gradient descent and the depth of a decision tree. Overall, hyperparameter tuning helps a balance between model complexity and generalization performance, ultimately reducing the risk of overfitting.

## K-Fold Cross-Validation
One method to check how well a model works on new data is using train-test-split. Here, we set aside some data to test the model after training it. But a better method is K-fold cross-validation. In this technique, we split the dataset into k subsets or folds. The model is trained k times, each time using a different fold for testing and the rest for training. This gives us k evaluation scores. Finally, we average these scores to get the overall performance. This approach provides a more reliable evaluation compared to just using train-test-split.
Using K-fold cross-validation doesn't directly prevent overfitting, but it aids in detecting and addressing it. Let's explore it further by means of an example.
One method to check how well a model works on new data is using *train-test-split*. Here, we set aside some data to test the model after training it. But a better method is *K-fold cross-validation*. In this technique, we split the dataset into `k` subsets or folds. The model is trained `k` times, each time using a different fold for testing and the rest for training. This gives us `k` evaluation scores. Finally, we average these scores to get the overall performance. This approach provides a more reliable evaluation compared to just using *train-test-split*.
Using K-fold cross-validation doesn't directly prevent overfitting, but it aids in detecting and addressing it. Let's explore it further by using an example.

We'll use the diabetes dataset from Sklearn and train a DecisionTreeRegressor to predict diabetes progression.
We'll use the `diabetes` dataset from `Sklearn` and train a `DecisionTreeRegressor` to predict diabetes progression.
{{% tip %}}
Tree-based algorithms are particularly vulnerable to overfitting.
{{% /tip %}}
Expand Down Expand Up @@ -82,7 +82,7 @@ print("MSE for Test Data:", mse_test)
```
{{% /codeblock %}}

The result shows that the training MSE is zero, whereas the test MSE is 4933, indicating a severe case of overfitting. We can see this because the MSE in test set is **much** larger than in the training set. When we implement K-fold cross-validation in next codeblock, it becomes even evident that the model is overfits the training data.
The result shows that the training MSE is zero, whereas the test MSE is `4933`, indicating a severe case of overfitting. We can see this because the MSE in the test set is **much** larger than in the training set. When we implement K-fold cross-validation in the next code block, it becomes even evident that the model is overfits the training data.

{{% codeblock %}}
```Python
Expand All @@ -95,9 +95,14 @@ print("MSE with 5-fold cross-validation:", mse_cross_val)
```
{{% /codeblock %}}

Using cross-validation during model training improves our chances of creating a model that not only fits the training data well but also accurately predicts outcomes for new, unseen data.
By executing the above code, you will get an MSE of `6740` through cross-validation, which is higher than the previous number obtained using a simple `train_test_split`. This example shows that cross-validation can reveal overfitting more clearly than a simple `train_test_split`.

Using cross-validation during model training improves our chances of creating a model that not only fits the training data well but also accurately predicts outcomes for new, unseen data. By using multiple subsets of the data, cross-validation helps prevent overfitting and ensures consistent performance across different data segments. It averages the model's performance across k sample sets, providing robust and reliable results by reducing the variability in training and testing datasets. This effect is especially significant with larger datasets. However, with small sample sizes, cross-validation might not be ideal, as the already limited data is further divided into smaller subsets, making the evaluation less robust and reliable.




{{% summary %}}
- Overfitting: Overfitting definition is provided and we discuss what are causes and how to address it.
- K-fold cross-validation: We explore the advantages of K-fold cross-validation compared to the traditional train-test-split method. We also provide an example of its implementation in Python.
- *Overfitting:* The overfitting definition was provided and we discussed its causes and solutions.
- *K-fold cross-validation:* We explored the advantages of K-fold cross-validation compared to the traditional train-test-split method. We also provided an example of its implementation in Python.
{{% /summary %}}

0 comments on commit 2a936ca

Please sign in to comment.