Giter Site home page Giter Site logo

Comments (5)

chrischoy avatar chrischoy commented on June 2, 2024

The proposed method is probabilistic like the original paper and does not break the tree formulation. W_{leaf} is represented using a soft-max which is always in a simplex and the decision nodes follow the original paper. Please refer to the slides for more details.

from fully-differentiable-deep-ndf-tf.

yuvval avatar yuvval commented on June 2, 2024

In the paper W_leaf has the interpretation of Pr(y|leaf, x). And then Pr(y|x) = Expectation_{leaf|x}[Pr(y|leaf, x)].
The modification you suggested discards this notion, and try to approximate Pr(y|x) by softmax(W_leaf.' * leaf|x). However, using softmax does not guarantees that it will learn to produce the true Pr(y|x) with a single layer of W_leaf. And instead, you might need a multi-layer net, just in order to estimate the true Pr(y|x) given (leaf, x).

I can suggest a simple alternative that you can use, if you don't like to use the update rule given by the authors.

You can remove the softmax, and replace the gradient update of W_leaf with:

   # Take a gradient step on W_leaf
   leaf_grad = optimizer(*optimizer_params).compute_gradients(loss, var_list= [W_leaf])[0][0]
    updated = -leaf_grad + W_leaf
    # Projecting to [0,1] (simplex) 
    updated = tf.maximum(updated, 0) # clip to >= 0
    # Normalize each row sum to 1
    epsilon = 1e-7
    updated = tf.transpose(tf.div(tf.transpose(updated), epsilon + tf.reduce_sum(updated, axis=1)))
    leaf_update_op = tf.assign(self.kernel, updated)

For your loss, you can instead use: keras.losses.categorical_crossentropy
(instead of the softmax cross entropy)

from fully-differentiable-deep-ndf-tf.

chrischoy avatar chrischoy commented on June 2, 2024

The implementation I proposed in this repository does not discard any of the interpretation in the original paper.

First, I'll define the notations as there is no such thing as W_l in the original paper.

In the paper Eq.1, $\pi$ is the leaf node and is defined as a probability distribution (any distribution that has the right dimension) and $W_l$ is a parameter I introduced in this repository, which parametrizes a leaf node probability distribution $\pi$. $W_l$ is any vector and if you take a softmax of the $W_l$, you always get a valid probability distribution. So instead of using a bare probability distribution, I just used a parametric probability distribution that has the same degree of freedom as the bare probability distribution. That is the only change I made in the original formulation and it does not break any interpretation.

This parametrization $\pi$ however allows us to train $\pi$ as well and it replaces the alternating optimization step involving the Eq.11 with a gradient descent step.

This is justifiable as the original formulation does not involve latent variables. To elaborate, an alternating optimization is necessary or required if you have a latent variable that depends on the other variables and gives a good closed form solution. Well-known examples of such alternating optimization with closed form solutions are EM steps in the graphical model inference and ADMM. However, there is some type of problems that can be solved using joint optimization: neural networks. If you think all the layer parameters in a neural network as latent variables, you need the layer-wise alternating optimization like many people did for RBM. However, we can jointly optimize all parameters simply because the gradient computation is fast and cheap. This basic idea applies to the Neural Decision Forest.

I think you misunderstood the implementation or my slides and if you can elaborate why you think I'm try[ing] to approximate Pr(y|x) by softmax(W_leaf.' * leaf|x), though I'm not approximating anything, I'd be happy to clarify your misconception.

Tldr; This is the exact formulation following the Eq.1 and the differences proposed in this repository are: 1. use parametric function for $\pi$ ($\pi = softmax(W)$ where $W$ is a vector) and 2. update the $\pi$ using stochastic gradient descent instead of Eq.11 and thus removing alternating optimization.

from fully-differentiable-deep-ndf-tf.

yuvval avatar yuvval commented on June 2, 2024

from fully-differentiable-deep-ndf-tf.

chrischoy avatar chrischoy commented on June 2, 2024

Hmm. I'm not sure how you can disagree with the notation defined by the authors. To me, the notation is very clear and I don't agree with that The authors notations in the paper are confusing.

  1. The mathematical notation does not necessarily have to follow the convenience of the implementation. I implemented the terminal node probability using a matrix out of convenience, but it is not a necessary condition. You coud implement using a list of vectors, which would be exactly what the notation in the paper proposed.

  2. Yes, that is the routing probability.

Yes, Therefore, you can take gradient steps directly on \pi, without the need to use the softmax . The only thing you need to maintain, is projecting \pi to the simplex ([0,1]) and making sure its rows sum to 1. That is the main idea of what I proposed in this repository.

Regarding In the modification you suggested, Pr(Y=y|x) = softmax( dotproduct (W_leaf, Pr_{\theta}(leaf=l | x)) ).: No I'm not doing that. See the line https://github.com/chrischoy/fully-differentiable-deep-ndf-tf/blob/github/demo_fully_diff_ndf.py#L251.

I finally understand what you are trying to say and the answer is No, I'm not breaking the original formulation. I am closing the issue for invalid question.

from fully-differentiable-deep-ndf-tf.

Related Issues (4)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.