Skip to content

Commit

Permalink
fix quick test
Browse files Browse the repository at this point in the history
  • Loading branch information
LeoGrin committed Jan 7, 2025
1 parent 16f7dba commit 120fed4
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 38 deletions.
1 change: 1 addition & 0 deletions tabpfn_client/server_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
protocol: "https"
host: "tabpfn-server-wjedmz7r5a-ez.a.run.app"
port: "443"
gui_url: "https://ux.priorlabs.ai"
endpoints:
root:
path: "/"
Expand Down
76 changes: 38 additions & 38 deletions tabpfn_client/tests/quick_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

import logging
from unittest.mock import patch

from sklearn.datasets import load_breast_cancer, load_diabetes
from sklearn.model_selection import train_test_split
Expand All @@ -18,41 +19,40 @@


if __name__ == "__main__":
# set logging level to debug
# logging.basicConfig(level=logging.DEBUG)

use_server = True
# use_server = False

X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.33, random_state=42
)
y_train = y_train
y_test = y_test

tabpfn = TabPFNClassifier(n_estimators=3)
# print("checking estimator", check_estimator(tabpfn))
tabpfn.fit(X_train[:99], y_train[:99])
print("predicting")
print(tabpfn.predict(X_test))
print("predicting_proba")
print(tabpfn.predict_proba(X_test))

print(UserDataClient.get_data_summary())

X, y = load_diabetes(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.33, random_state=42
)

tabpfn = TabPFNRegressor(n_estimators=3)
# print("checking estimator", check_estimator(tabpfn))
tabpfn.fit(X_train[:99], y_train[:99])
print("predicting reg")
print(tabpfn.predict(X_test, output_type="mean"))

print(UserDataClient.get_data_summary())
# test predict_full
print("predicting ")
print(tabpfn.predict(X_test[:30], output_type="full", quantiles=[0.1, 0.5, 0.9]))
# Patch webbrowser.open to prevent browser login
with patch("webbrowser.open", return_value=False):
use_server = True
# use_server = False

X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.33, random_state=42
)

tabpfn = TabPFNClassifier(n_estimators=3)
# print("checking estimator", check_estimator(tabpfn))
tabpfn.fit(X_train[:99], y_train[:99])
print("predicting")
print(tabpfn.predict(X_test))
print("predicting_proba")
print(tabpfn.predict_proba(X_test))

print(UserDataClient.get_data_summary())

X, y = load_diabetes(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.33, random_state=42
)

tabpfn = TabPFNRegressor(n_estimators=3)
# print("checking estimator", check_estimator(tabpfn))
tabpfn.fit(X_train[:99], y_train[:99])
print("predicting reg")
print(tabpfn.predict(X_test, output_type="mean"))

print(UserDataClient.get_data_summary())
# test predict_full
print("predicting ")
print(
tabpfn.predict(X_test[:30], output_type="full", quantiles=[0.1, 0.5, 0.9])
)

0 comments on commit 120fed4

Please sign in to comment.