Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is KAN (pykan) sufficient for classification tasks? #477

Open
SaranDS opened this issue Oct 5, 2024 · 3 comments
Open

Is KAN (pykan) sufficient for classification tasks? #477

SaranDS opened this issue Oct 5, 2024 · 3 comments

Comments

@SaranDS
Copy link

SaranDS commented Oct 5, 2024

The implemented following code snippet for binary classification on tabular data, using stratified K-fold cross-validation (K=10). The performance results seem exceptionally good. Can someone help review and suggest improvements to the implementation?

`model = KAN(width=[38,5,3, 2], grid=5, k=3)
for train_idx, test_idx in (kf.split(X_scaled, y)):

X_train, X_test = X_scaled[train_idx], X_scaled[test_idx]
y_train, y_test = y[train_idx], y[test_idx]

X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.1, random_state=5) # Splitting train into train and val set

train_input = torch.tensor(X_train, dtype=torch.float32)
train_label = torch.tensor(y_train, dtype=torch.long)
val_input = torch.tensor(X_val, dtype=torch.float32)
val_label = torch.tensor(y_val, dtype=torch.long)
test_input = torch.tensor(X_test, dtype=torch.float32)
test_label = torch.tensor(y_test, dtype=torch.long)

dataset = {
'train_input': train_input,
'train_label': train_label,
'val_input': val_input,
'val_label': val_label,
'test_input': test_input,
'test_label': test_label
 }
 results = model.fit({'train_input': train_input, 'train_label': train_label, 
                     'test_input': val_input, 'test_label': val_label},
                     opt="LBFGS", steps=10, 
                    loss_fn=torch.nn.CrossEntropyLoss(),update_grid = False)

# Predictions 
test_preds = torch.argmax(model.forward(test_input).detach(), dim=1)

# Evaluate metrics on test set
PD, PF, auc, balance, fir, accuracy, precision = get_clf_eval(test_label, test_preds)

`

image

Dataset Description :
Features - 39
data points - 16,900 (after SMOTE - 32,900)

@YuriyKabanenko
Copy link

@SaranDS
Could you please provide more details on data you trained model on?
Since i've been trying a lot to increase accuracy but maximum what i got is 78%.
Firstly i tried BCEWithLogitsLoss for my binary classification task.
After that i thought maybe there is magic accuracy accelerating because of CrossEntropy loss function.
I refactored my model to output 2 features and take argMax as a result. But accuracy decreased to 74%.

Here code with 1 output:

`from kan import KAN

model = KAN(width=[16, 3, 1], grid=3, k=3)

def train_acc():
return torch.mean((torch.round(torch.sigmoid(model(custom_dataset['train_input']))[:,0]) == custom_dataset['train_label'][:,0]).type(dtype))

def test_acc():
return torch.mean((torch.round(torch.sigmoid(model(custom_dataset['test_input']))[:,0]) == custom_dataset['test_label'][:,0]).type(dtype))

start_time = time.time()

print(custom_dataset['train_input'].dtype)
print(custom_dataset['test_input'].dtype)

results = model.fit(custom_dataset, opt="LBFGS", steps=100,
batch=x_test.shape[0], metrics=(train_acc, test_acc), loss_fn=torch.nn.BCEWithLogitsLoss())
end_time = time.time()

results['train_acc'][-1], results['test_acc'][-1]`

Code with 2 output:

`from kan import KAN

model = KAN(width=[16, 5, 3, 2], grid=5, k=3)

start_time = time.time()

results = model.fit(custom_dataset, opt="LBFGS", steps=100,
batch=x_test.shape[0], loss_fn=torch.nn.CrossEntropyLoss(), update_grid = False)
end_time = time.time()`

@SaranDS
Copy link
Author

SaranDS commented Oct 21, 2024

@YuriyKabanenko
The dataset utilized is from the software domain and pertains to a binary classification task. It comprises 38+1 features, with a total of 16,962 data points, all in numerical form. Prior to training, I applied pre-processing techniques, including MinMax scaling for normalization and SMOTE to address the class imbalance issue. To ensure generalizability, stratified K -fold [10-folds] cross-validation employed for splitting the dataset.

The results obtained for this dataset [not only this dataset, I used 3 other different dataset related to same domain, but each dataset comprise different features and datapoints] are exceptionally good, which raises concerns about the possibility of test set leakage during training. To verify this, I included print statements to check the sizes of the training, validation, and test sets. The output confirmed that the dataset was split correctly according to the specified sizes. However, I am unsure of any further methods to validate these results.

@SaranDS
Copy link
Author

SaranDS commented Oct 22, 2024

The implemented following code snippet for binary classification on tabular data, using stratified K-fold cross-validation (K=10). The performance results seem exceptionally good. Can someone help review and suggest improvements to the implementation?

`model = KAN(width=[38,5,3, 2], grid=5, k=3) for train_idx, test_idx in (kf.split(X_scaled, y)):

X_train, X_test = X_scaled[train_idx], X_scaled[test_idx]
y_train, y_test = y[train_idx], y[test_idx]

X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.1, random_state=5) # Splitting train into train and val set

train_input = torch.tensor(X_train, dtype=torch.float32)
train_label = torch.tensor(y_train, dtype=torch.long)
val_input = torch.tensor(X_val, dtype=torch.float32)
val_label = torch.tensor(y_val, dtype=torch.long)
test_input = torch.tensor(X_test, dtype=torch.float32)
test_label = torch.tensor(y_test, dtype=torch.long)

dataset = {
'train_input': train_input,
'train_label': train_label,
'val_input': val_input,
'val_label': val_label,
'test_input': test_input,
'test_label': test_label
 }
 results = model.fit({'train_input': train_input, 'train_label': train_label, 
                     'test_input': val_input, 'test_label': val_label},
                     opt="LBFGS", steps=10, 
                    loss_fn=torch.nn.CrossEntropyLoss(),update_grid = False)

# Predictions 
test_preds = torch.argmax(model.forward(test_input).detach(), dim=1)

# Evaluate metrics on test set
PD, PF, auc, balance, fir, accuracy, precision = get_clf_eval(test_label, test_preds)

`

image

Dataset Description : Features - 39 data points - 16,900 (after SMOTE - 32,900)

@KindXiaoming
Is the implemented code a valid method for predicting test samples on tabular data utilizing KAN?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants