Giter Site home page Giter Site logo

ugenteraan / deep_hierarchical_classification Goto Github PK

View Code? Open in Web Editor NEW
81.0 4.0 20.0 1.23 MB

PyTorch Implementation of Deep Hierarchical Classification for Category Prediction in E-commerce System

License: MIT License

Python 100.00%
deep-learning hierarchical-models hierarchical-classification pytorch-implementation image-classification python neural-networks pytorch

deep_hierarchical_classification's Introduction

Hey there! I'm a Senior AI Engineer and a Machine Learning enthusiast. Apart from working a full time job & freelancing for my living, I enjoy contributing to open-source projects to inspire the growth of technology. All my projects here are free to be used in any way you like. Visit https://ugenteraan.github.io/ for more details about my works.

If you'd like to hire me for freelance or consultation, do visit https://cleverx.com/@Ugenteraan-M

You can find my resume and more about me on LinkedIn or message me directly at Telegram.

deep_hierarchical_classification's People

Contributors

ugenteraan avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

deep_hierarchical_classification's Issues

cifar100

In the paper, the cifar100 coarse acc is 92.21% and fine acc 75.91%, but the coarse accuracy I ran out is 71%, and the fine accuracy is 58%.
Why?

Confusion over the dependence loss (dloss)

Hi there,

Based on this line of your dloss, why is the product of ploss subtracted by 1 (-1), shouldn't it be multiplied by -1?

In the paper, equation 6 of dloss is -1 * ploss_l-1 * ploss_l. Please correct me if I am wrong.

Thank you.

Questions regarding the meta file

I am applying this to a custom dataset. I do not have a metafile. What do I do in this case? Should I create a metafile and if so how do I create it?

Random KeyError occuring in check_hierarchy()

I am attempting to train with my custom dataset which has the following value for numeric_hierarchy. I double-checked I modified train.csv, test.csv, metafile, level_dict.py, and load_dataset.py for my needs.

{
0: [0, 1, 2, 3, 4, 5], 
1: [6, 7, 8, 9], 
2: [10, 11, 12, 13, 14], 
3: [15, 16, 17, 18]
}

(4 superclass, 19 subclass)

However, I am getting random KeyError in the bool_tensor portion of check_hierarchy() below.

def check_hierarchy(self, current_level, previous_level):
    '''Check if the predicted class at level l is a children of the class predicted at level l-1 for the entire batch.
    '''

    #check using the dictionary whether the current level's prediction belongs to the superclass (prediction from the prev layer)
    bool_tensor = [not current_level[i] in self.numeric_hierarchy[previous_level[i].item()] for i in range(previous_level.size()[0])]

    return torch.FloatTensor(bool_tensor).to(self.device)
Traceback (most recent call last):
  File "train.py", line 88, in <module>
    dloss = HLN.calculate_dloss(prediction, [batch_y1, batch_y2])
  File "/root/fashion-effnet/hier-resnet/model/hierarchical_loss.py", line 69, in calculate_dloss
    D_l = self.check_hierarchy(current_lvl_pred, prev_lvl_pred)
  File "/root/fashion-effnet/hier-resnet/model/hierarchical_loss.py", line 42, in check_hierarchy
    bool_tensor = [not current_level[i] in self.numeric_hierarchy[previous_level[i].item()] for i in range(previous_level.size()[0])]
  File "/root/fashion-effnet/hier-resnet/model/hierarchical_loss.py", line 42, in <listcomp>
    bool_tensor = [not current_level[i] in self.numeric_hierarchy[previous_level[i].item()] for i in range(previous_level.size()[0])]
KeyError: 6

The value of KeyError changes every time (4, 9, 14, etc) and I can't make sense of it on how to fix this issue. Would really appreciate if anyone can provide help or insights!

About the default paramater setting

Greetings!
May I know why the default num_classes is 6? I thought cifar100 contains 20 classes (superclasses) where each class has 5 subclasses. And I can't find related information in the original paper.
Sincerely thanks for any help.

requirement file

Hi Ugenteraan Manogaran,

This is great, would you also provide a requirements.txt file?
Thanks!

best
Cheng

Gradient unable to backprop if we use argmax or torch.where

Hi,
Correct me if I am wrong but in the code snippet to calculate D_l for the dependency loss
(

current_lvl_pred = torch.argmax(nn.Softmax(dim=1)(predictions[l]), dim=1)
), argmax is non-differentiable, thus the gradient wrt to dloss won't be propagated back to predictions variables, and subsequently to the parameters in the neural net, that means the model won't be able to learn from the dloss penalty. I have run this loss on my NLP project and the way the parameters updated are the same without any value of beta, which led me to this theory. Can you help me check this one out?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.