Comments (12)
@Optimox I think the problem here is how to train the weak learners .
As in boosted trees this is done by gini index (for training a weak tree) etc. And the cross entropy was used on whole algorithm that is to find out residuals and another tree is trained on those residuals.
But here each step requires some gradients to train not as in tree (gini index only).
A solution could be to train each step on cross-entropy (for 1 or 2 epochs or gradient steps (weak learners)) predicitng classes probab and using that probab calculating the residuals for the next step to train using cross entropy in a same way and so on ?
from tabnet.
@Optimox
You may want to have a look at the Friedman paper about Gradient Boosting :
https://statweb.stanford.edu/~jhf/ftp/trebst.pdf
You'll see what I meant by using regressors only.
@Jaskaran170599
In case of Gradient Boostin technique the output of each step will be multiplied by a Learning rate and will be sum to get the log odd on which we can apply sigmoid to get probability (0/1)?
Exactly. The idea is to fit the decision function before the sigmoid is applied. And compute the gradient with respect to this decision function values. So at each step, the weak learner is trained to fit the gradient (hence the regressor), the result is added (with a weight) to the the previous decision function. And class probability can be computed by applying the sigmoid function for binary problems or softmax for multi-class problems.
from tabnet.
Thanks @bibhabasumohapatra, looks promising. Is there a research paper related to the repo?
from tabnet.
Why not try a mere application of gradient boosting ? Each step fits the gradient of the loss function (as computed so far) and adds it (using line search) to the previous result. Only regression is needed internally (to fit the gradient) and it allows for regression and classification.
from tabnet.
Interesting ,
For classification , i think we can try the same way as gradient boosting algorithm and or adaboost as mentioned in ther paper using cross_entropy loss function ?
In case of Gradient Boostin technique the output of each step will be multiplied by a Learning rate and will be sum to get the log odd on which we can apply sigmoid to get probability (0/1)?
In case of adaboost we can maybe use the same weightage formula as mentioned in the paper.
interesting would be to some how use MASK weights to give "IMPORTANCE WEIGHT" to each step to contribute to the final prediction as MASK heatmap shows us that some MASK weights are not that activated as others , It may improve decision making.
Different would be the training as in case of boosted algorithms they trin one tree then use it in boosting but here all the weak learners would be learning simultaneously.
I would like to do some research and contribution to this.
#Abhishek-eBook
from tabnet.
@AlexisMignon approaching classification problems with regression could be a solution but I feel like it's not satisfying and especially for multi class classification...
@JaskaranSingh-Precily tabnet is using cross entropy already, but you need to have integers as targets to apply cross entropy, so I don't see how a boosted version could use cross entropy at every step. Could you explain and/or give some links to literature? I probably just need to dig a bit deeper on how XGBoost deals with multi class classification.
@Jaskaran170599 Not sure you'll double your chance of winning Abhishek's ebook that way to be honnest! :)
from tabnet.
@Optimox actually commented with the company account that was not my personal account
from tabnet.
@AlexisMignon Yeah and i think here in tabnet case that weak learner is one block of the architecture and the main task that is different than Boosting algos is to train that block .
from tabnet.
https://github.com/tusharsarkar3/XBNet
from tabnet.
Thanks @bibhabasumohapatra, looks promising. Is there a research paper related to the repo?
from tabnet.
Thanks @bibhabasumohapatra, looks promising. Is there a research paper related to the repo?
Yes.
from tabnet.
This is a good job, but rather a completely design from my point of view.
from tabnet.
Related Issues (20)
- Current version on conda-forge is 4.0 while 4.1 is already released HOT 8
- Minimal working example for TabNetRegressor/Classifier HOT 4
- Transfer learning, capability to change structure of model HOT 1
- Generate Embeddings for Tabular Data HOT 1
- TabNet overfits (help wanted, not a bug) HOT 9
- TabNetRegressor vs other networks HOT 1
- spike in memory when training ends HOT 8
- Severe overfitting HOT 18
- OOM problem when I search hyperparameters with Tabnet HOT 3
- Support for complex-valued datasets HOT 4
- Different classification variables in the test set and train set HOT 1
- Struggling to get model to fit - Help Wanted HOT 7
- Optimizing TabNet for Disease Classification with Continuous Audio Features HOT 1
- Interpreting Sparsity on Global Importance HOT 5
- ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() HOT 1
- Validation loss HOT 1
- Lightweight Fine-tunning or few-shot learning for limited labeled data HOT 1
- Maybe `drop_last` should be set as False in default? HOT 1
- Incompatiblity of current round() method with pytorch tensors when performing early stopping HOT 1
- Retraining a saved model on different dataset HOT 3
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from tabnet.