Skip to content

Commit

Permalink
add back classes_ attribute (#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
LeoGrin authored Jan 7, 2025
1 parent e97d1ee commit 78baa19
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 1 deletion.
14 changes: 13 additions & 1 deletion tabpfn_client/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

from typing import Optional, Literal, Dict, Union
import logging

import numpy as np
import pandas as pd
from tabpfn_client.config import init
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
from sklearn.utils.validation import check_is_fitted
Expand Down Expand Up @@ -134,6 +134,7 @@ def fit(self, X, y):
init()

validate_data_size(X, y)
self._validate_targets_and_classes(y)
_check_paper_version(self.paper_version, X)

estimator_param = self.get_params()
Expand Down Expand Up @@ -192,6 +193,17 @@ def _predict(self, X, output_type):
)
return res

def _validate_targets_and_classes(self, y) -> np.ndarray:
from sklearn.utils import column_or_1d
from sklearn.utils.multiclass import check_classification_targets

y_ = column_or_1d(y, warn=True)
check_classification_targets(y)
# Get classes and encode before type conversion to guarantee correct class labels.
not_nan_mask = ~pd.isnull(y)
# TODO: should pass this from the server
self.classes_ = np.unique(y_[not_nan_mask])


class TabPFNRegressor(BaseEstimator, RegressorMixin, TabPFNModelSelection):
_AVAILABLE_MODELS = [
Expand Down
40 changes: 40 additions & 0 deletions tabpfn_client/tests/unit/test_tabpfn_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,46 @@ def test_predict_params_output_type(self):
predict_params = mock_predict.call_args[1]["predict_params"]
self.assertEqual(predict_params, {"output_type": "probas"})

def test_string_label_predictions(self):
"""Test that string labels in y are preserved in predictions."""
X = np.random.rand(10, 5)
test_X = np.random.rand(5, 5)

# Test with pure string labels
y_str = np.array(["cat", "dog"] * 5)
clf_str = TabPFNClassifier()

# Mock fit and predict
with patch.object(InferenceClient, "fit") as mock_fit:
mock_fit.return_value = "dummy_uid"
clf_str.fit(X, y_str)

with patch.object(InferenceClient, "predict") as mock_predict:
mock_predict.return_value = np.array(["cat", "dog"] * 5)
y_pred = clf_str.predict(test_X)

self.assertTrue(np.all(np.isin(y_pred, ["cat", "dog"])))
self.assertTrue(
y_pred.dtype.kind in {"U", "O"}, "Predictions should be string type"
)

# Test with numeric-like strings
y_num_str = np.array(["0", "1"] * 5)
clf_num_str = TabPFNClassifier()

with patch.object(InferenceClient, "fit") as mock_fit:
mock_fit.return_value = "dummy_uid"
clf_num_str.fit(X, y_num_str)

with patch.object(InferenceClient, "predict") as mock_predict:
mock_predict.return_value = np.array(["0", "1"] * 5)
y_pred = clf_num_str.predict(test_X)

self.assertTrue(np.all(np.isin(y_pred, ["0", "1"])))
self.assertTrue(
y_pred.dtype.kind in {"U", "O"}, "Predictions should be string type"
)


class TestTabPFNModelSelection(unittest.TestCase):
def setUp(self):
Expand Down

0 comments on commit 78baa19

Please sign in to comment.