Giter Site home page Giter Site logo

davyneven / spatialembeddings Goto Github PK

View Code? Open in Web Editor NEW
214.0 214.0 35.0 45 KB

Instance Segmentation by Jointly Optimizing Spatial Embeddings and Clustering Bandwidth

Home Page: https://arxiv.org/pdf/1906.11109.pdf

License: Other

Python 100.00%

spatialembeddings's People

Contributors

davyneven avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

spatialembeddings's Issues

Regarding variable `xym`

Hello,

Thank you for an excellent implementation and publication. I am learning quite a lot looking at your code and how you packaged this project. ๐Ÿ‘

One question, which I wanted to run by you is, regarding the line of code for creating the state dict xym. I suspect that the cat order should be reversed since later this variable is accessed as [channel, height, width]. So what I suggest is:

# coordinate map
xm = torch.linspace(0, 2, 2048).view(1, 1, -1).expand(1, 1024, 2048)
ym = torch.linspace(0, 1, 1024).view(1, -1, 1).expand(1, 1024, 2048)
xym = torch.cat((xm, ym), 0)

should become

# coordinate map
xm = torch.linspace(0, 2, 2048).view(1, 1, -1).expand(1, 1024, 2048)
ym = torch.linspace(0, 1, 1024).view(1, -1, 1).expand(1, 1024, 2048)
xym = torch.cat((ym, xm), 0)

and this would lead to equivalent changes in this line of code as well. I might be interpreting this completely wrongly, but just wanted to check with you. Thank you for your time!

Multi class settings

Hi, neven. Could you update this repo to multi class settings as decribed in your paper? I'm trying to reproduce the results on Cityscapes. Thanks.

Did you use multi-scale test?

Hi,

Could you clarify if you used single scale or multi scale test to get the 27.6 AP on Cityscapes dataset? I could not find details about it in the paper.

Thanks!

Datasets expect Cityscapes

Hi, as mentioned in the CVPR version of your paper, do you have some results on other datasets expect Cityscapes?

Training problem

Hi I have a problem while running train.py . Do you know how to solve it?


lr_scheduler.py:122: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.  Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
  "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
learning rate: 0.0005
  0%|                                                                  | 0/187 [00:03<?, ?it/s]
Traceback (most recent call last):
  File "train.py", line 187, in <module>
    train_loss = train(epoch)
  File "train.py", line 108, in train
    loss = criterion(output, instances, class_labels, **args['loss_w'])
  File "/home/server0/anaconda3/envs/deep36_andi/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/server0/anaconda3/envs/deep36_andi/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 152, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/home/server0/anaconda3/envs/deep36_andi/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 162, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/home/server0/anaconda3/envs/deep36_andi/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 85, in parallel_apply
    output.reraise()
  File "/home/server0/anaconda3/envs/deep36_andi/lib/python3.6/site-packages/torch/_utils.py", line 394, in reraise
    raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/home/server0/anaconda3/envs/deep36_andi/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 60, in _worker
    output = module(*input, **kwargs)
  File "/home/server0/anaconda3/envs/deep36_andi/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/data/4TB/Andi/semantic/SpatialEmbeddings/src/criterions/my_loss.py", line 101, in forward
    lovasz_hinge(dist*2-1, in_mask)
  File "/data/4TB/Andi/semantic/SpatialEmbeddings/src/criterions/lovasz_losses.py", line 90, in lovasz_hinge
    for log, lab in zip(logits, labels))
  File "/data/4TB/Andi/semantic/SpatialEmbeddings/src/criterions/lovasz_losses.py", line 231, in mean
    acc = next(l)
  File "/data/4TB/Andi/semantic/SpatialEmbeddings/src/criterions/lovasz_losses.py", line 90, in <genexpr>
    for log, lab in zip(logits, labels))
  File "/data/4TB/Andi/semantic/SpatialEmbeddings/src/criterions/lovasz_losses.py", line 112, in lovasz_hinge_flat
    grad = lovasz_grad(gt_sorted)
  File "/data/4TB/Andi/semantic/SpatialEmbeddings/src/criterions/lovasz_losses.py", line 26, in lovasz_grad
    union = gts.float() + (1 - gt_sorted).float().cumsum(0)
  File "/home/server0/anaconda3/envs/deep36_andi/lib/python3.6/site-packages/torch/tensor.py", line 394, in __rsub__
    return _C._VariableFunctions.rsub(self, other)
RuntimeError: Subtraction, the `-` operator, with a bool tensor is not supported. If you are trying to invert a mask, use the `~` or `logical_not()` operator instead.

Regarding the Lovasz Hinge loss

Hello,

I have more of a question here rather than an issue. I was curious about this line of code. As I understand, dist is a variable that lies between 0 (pixel embedding is very far from the instance center) and 1 (pixel embedding lies atop the instance center). Also by that logic,

0< dist*2 -1 <1

Could you share any intuition on why this dist*2-1 is preferred instead of just using dist, as the first argument to the lovasz hinge class? Does it have better convergence properties for the loss in your opinion, for example? Thank you!

How to choose the value of n_sigma

It looks like a one channel sigma map works well. And from the code we can see that the sigma map can be multi channels accordding to the n_sigma. Does this make differences and how to choose a proper value?

Computing var_loss

Hello Davy @davyneven

I was wondering if instead of saying here:

 var_loss = var_loss + torch.mean(torch.pow(sigma_in - s.detach(), 2))

One should rather say:

 var_loss = var_loss + torch.mean(torch.pow(sigma_in - s[..., 0].detach(), 2))

I suggest this because sigma_in is of shape 2 x N and s is of shape 2 x 1 x 1, and subtracting two tensors of different shapes could lead to strange consequences (maybe?).

For example:

import numpy as np
sigma_in = np.array([[1.0 ,2.0, 3.0], [2.0, 4.0, 6.0]]) 
sigma_in = torch.from_numpy(sigma_in) # shape is 2 x 3
>>> sigma_in
tensor([[1., 2., 3.],
        [2., 4., 6.]], dtype=torch.float64)
s = sigma_in.mean(1).view(2, 1, 1)  
>>> s
tensor([[[2.]],
        [[4.]]], dtype=torch.float64) # shape is 2 x 1 x 1
result = sigma_in - s.detach() # shape is 2 x 2 x 3
>>> result tensor([[[-1.,  0.,  1.],
         [ 0.,  2.,  4.]],
        [[-3., -2., -1.],
         [-2.,  0.,  2.]]], dtype=torch.float64)
result_edited = sigma_in - s[..., 0].detach() # shape is 2 x 3
>>> result_edited
tensor([[-1.,  0.,  1.],
        [-2.,  0.,  2.]], dtype=torch.float64)

Just wanted to ask if the current way is intended. Wouldn't we want the correct margin bandwidth dimension to be subtracted instead of all (2 x 2) subtractions? Thank you!

seed loss after downsampling

hi @davyneven Thanks for your work.
I have trouble when getting the seed map after resizing the original image into its 1/4 original size by downsampling in model. I read through the paper. Shall I also change the loss function for seed map?

Thanks!

why don't use GT in instance seed loss calculation

seed loss

            seed_loss += self.foreground_weight * torch.sum(
                torch.pow(seed_map[in_mask] - dist[in_mask].detach(), 2))

we usually use prediction and gt to calculate losses, but in your loss function, both seed_map and dist are prediction. So why don't use GT in instance seed loss calculation? It shouldn't be
seed_loss += self.foreground_weight * torch.sum(
torch.pow(gt[in_mask] - dist[in_mask].detach(), 2))
or
seed_loss += self.foreground_weight * torch.sum(
torch.pow(seed_map[in_mask] - gt[in_mask].detach(), 2)) ?

Best Loss Weight

Hi, in the config file, the loss weight is 'w_inst':1,'w_var':1,'w_seed':10. When training in this way, I can't reproduce your results on cars due to my bad seed map.

Did you modify the weight when training? ex, first set the weight to 1 1 10 to optimize seed map then set 1 10 1 to optimize sigma map.

Thank you very much!

Ablation Experiments

I'm trying to reproduce the Ablation Experiments but result is not good.

  1. This experiments is done with single-class model, right? If so, should I use cropped dataset obtained by generate_crops.py? For example, when I train person class, first train with (512,512) cropped dataset(OBJ_ID=26) and then train with (1024,1024) cropped dataset(OBJ_ID=26).

  2. How to use cluster_with_gt() function? Is it used for Ablation Experiments?

Any help would be great, thanks!!

Post-processing (Clustering) Slow

Thank you for publishing the code!

I ran the test using your pretrained model on the car class but found out that the avg post-processing time is about 391 ms per image on a Titan Xp GPU which is much slower than the number reported in the paper. May I ask if I missed anything? Thanks!

How to visualize offset vectors?

As shown in figure 2, I want to visualize predicted offset vectors. But, there are no descriptions about how the offset vectors were visualized. I'm studying these papers and codes carefully, and I would really appreciate your help.

Gaussian calculation

Hi, I'm reading the my_loss.py file and get a question regarding the calculation of gaussian.

               s = torch.exp(s*10)
               # calculate gaussian
               dist = torch.exp(-1*torch.sum(
               torch.pow(spatial_emb - center, 2)*s, 0, keepdim=True))

According to Equation 5 in the paper, shouldn't this be:
dist = torch.exp(-1*torch.sum( torch.pow(spatial_emb - center, 2)/(2*s**2), 0, keepdim=True))? Am I missing something? Thank you.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.