Comments (2)
hello @xywust2014,
There are two distinct metrics to consider:
- the loss function that your model is trained to minimize
- the early stopping metric, which is used to select the best epoch and stop the training when no further improvement has been seen.
Currently you can easily change the loss function by using loss_fn
parameter during the fit, you can pick any loss function from pytorch https://pytorch.org/docs/stable/nn.html or even create your own if you like. By default TabNetClassifier uses cross entropy while TabnetRegressor uses Mean Squarred Error.
About the early stopping metric, you can't change it easily with the current implementation, this will be improved in the future. For binary classification the early stop metric (which is the metric that the training displays) is AUC. For multiclass classfication it's the accuracy. For regression it's the mean squarred error.
You can easily access your training loss and early stop metrics by calling clf.history['train']['loss']
or clf.history['train']['metric']
or with the same command for the valid set clf.history['valid']['metric']
Hope this helps!
from tabnet.
hello @xywust2014,
There are two distinct metrics to consider:
- the loss function that your model is trained to minimize
- the early stopping metric, which is used to select the best epoch and stop the training when no further improvement has been seen.
Currently you can easily change the loss function by using
loss_fn
parameter during the fit, you can pick any loss function from pytorch https://pytorch.org/docs/stable/nn.html or even create your own if you like. By default TabNetClassifier uses cross entropy while TabnetRegressor uses Mean Squarred Error.About the early stopping metric, you can't change it easily with the current implementation, this will be improved in the future. For binary classification the early stop metric (which is the metric that the training displays) is AUC. For multiclass classfication it's the accuracy. For regression it's the mean squarred error.
You can easily access your training loss and early stop metrics by calling
clf.history['train']['loss']
orclf.history['train']['metric']
or with the same command for the valid setclf.history['valid']['metric']
Hope this helps!
Thank you! Optimox. This helps a lot.
from tabnet.
Related Issues (20)
- Minimal working example for TabNetRegressor/Classifier HOT 4
- Transfer learning, capability to change structure of model HOT 1
- Generate Embeddings for Tabular Data HOT 1
- TabNet overfits (help wanted, not a bug) HOT 9
- TabNetRegressor vs other networks HOT 1
- spike in memory when training ends HOT 8
- Severe overfitting HOT 18
- OOM problem when I search hyperparameters with Tabnet HOT 3
- Support for complex-valued datasets HOT 4
- Different classification variables in the test set and train set HOT 1
- Struggling to get model to fit - Help Wanted HOT 7
- Optimizing TabNet for Disease Classification with Continuous Audio Features HOT 1
- Interpreting Sparsity on Global Importance HOT 5
- ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() HOT 1
- Validation loss HOT 1
- Lightweight Fine-tunning or few-shot learning for limited labeled data HOT 1
- Maybe `drop_last` should be set as False in default? HOT 1
- Incompatiblity of current round() method with pytorch tensors when performing early stopping HOT 1
- Retraining a saved model on different dataset HOT 3
- change device seems not work HOT 8
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from tabnet.