Giter Site home page Giter Site logo

ravinkohli / tabpfn Goto Github PK

View Code? Open in Web Editor NEW

This project forked from automl/tabpfn

0.0 1.0 1.0 261.13 MB

Official implementation of the TabPFN and the tabpfn package.

Home Page: https://arxiv.org/abs/2207.01848

License: Apache License 2.0

Python 69.05% Jupyter Notebook 30.95%

tabpfn's Introduction

TabPFN

The TabPFN is a neural network that learned to do tabular data prediction. This is the original CUDA-supporting pytorch impelementation.

We created a Colab, that lets you play with our scikit-learn interface.

We also created two demos. One to experiment with the TabPFNs predictions (https://huggingface.co/spaces/TabPFN/TabPFNPrediction) and one to check cross- validation ROC AUC scores on new datasets (https://huggingface.co/spaces/TabPFN/TabPFNEvaluation). Both of them run on a weak CPU, thus it can require a little bit of time. Both demos are based on a scikit-learn interface that makes using the TabPFN as easy as a scikit-learn SVM.

Installation

pip install tabpfn

If you want to train and evaluate our method like we did in the paper (including baselines) please install with

pip install tabpfn[full]

To run the autogluon and autosklearn baseline please create a separate environment and install autosklearn / autogluon==0.4.0, installation in the same environment as our other baselines is not possible.

Getting started

A simple usage of our sklearn interface is:

from sklearn.metrics import accuracy_score
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split

from tabpfn.scripts.transformer_prediction_interface import TabPFNClassifier

X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)

# N_ensemble_configurations controls the number of model predictions that are ensembled with feature and class rotations (See our work for details).
# When N_ensemble_configurations > #features * #classes, no further averaging is applied.

classifier = TabPFNClassifier(device='cpu', N_ensemble_configurations=32)

classifier.fit(X_train, y_train)
y_eval, p_eval = classifier.predict(X_test, return_winning_probability=True)

print('Accuracy', accuracy_score(y_test, y_eval))

Tips and Tricks

TabPFN is different from other methods you might know for tabular classification. Here, we list some tips and tricks that might help you understand how to use it best.

  • TabPFN expects scalar values only (you need to encode categoricals as integers e.g. with OrdinalEncoder). It works best on data that does not contain any categorical or NaN data (see Appendix B.1).
  • TabPFN ensembles multiple input encodings per default. It feeds different index rotations of the features and labels to the model per ensemble member. You can control the ensembling with TabPFNClassifier(...,N_ensemble_configurations=?)
  • TabPFN pre-processes its inputs. It applies a z-score normalization (x-train_x.mean()/train_x.std()) per feature (fitted on the training set) and log-scales outliers heuristically. Finally, TabPFN applies a PowerTransform to all features for every second ensemble member. Pre-processing is important for the TabPFN to make sure that the real-world dataset lies in the distribution of the synthetic datasets seen during training. So to get the best results, do not apply a PowerTransformation to the inputs.
  • TabPFN does not use any statistics from the test set. That means predicting each test example one-by-one will yield the same result as feeding the whole test set together.
  • TabPFN is differentiable in principle, only the pre-processing is not and relies on numpy.

Our Paper

Read our paper for more information about the setup (or contact us ☺️). If you use our method, please cite us using

@misc{tabpfn,
  doi = {10.48550/ARXIV.2207.01848},
  url = {https://arxiv.org/abs/2207.01848},
  author = {Hollmann, Noah and Müller, Samuel and Eggensperger, Katharina and Hutter, Frank},
  keywords = {Machine Learning (cs.LG), Machine Learning (stat.ML), FOS: Computer and information sciences, FOS: Computer and information sciences},
  title = {TabPFN: A Transformer That Solves Small Tabular Classification Problems in a Second},
  publisher = {arXiv},
  year = {2022},
  copyright = {arXiv.org perpetual, non-exclusive license}
}

License

Copyright 2022 Noah Hollmann, Samuel Müller, Katharina Eggensperger, Frank Hutter

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

tabpfn's People

Contributors

noahho avatar samuelgabriel avatar frank-hutter avatar eddiebergman avatar tabpfn-anonym avatar

Watchers

James Cloos avatar

Forkers

noahho

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.