Giter Site home page Giter Site logo

Comments (8)

clemsgrs avatar clemsgrs commented on May 27, 2024 2

For keeping records, I got 0.883 ± 0.06​ AUC for the breast subtyping task (ILC vs. IDC) using the same dataset (same 875 slides, same 10 folds).

This is on par with the results reported in the paper (0.874 ± 0.06, see Table 1).

The slight difference comes from me using different region-level pre-extracted features: I slightly adapted CLAM patching code to generate [4096,4096] regions per slide, then used the provided pre-trained weights to produce region-level features of shape [M, 192]. For each slide, I have slightly different regions.

from hipt.

clemsgrs avatar clemsgrs commented on May 27, 2024 1

Finally got it working!

working_curves

The issue was coming from me using img_size = [256] instead of the [224] value when instantiating VisionTransformer / VisionTransformer4K components in HIPT model. This caused the pre-trained weights for the positional embedding parameter to be skipped when loading state dict because of mismatching shape. As a results, when generating region-level features, positional embeddings were left as initialised, that is random tensors (normal dist)! This caused my region-level features to be garbage...

I've re-generated the region-level features with img_size = [224] and now got decent loss profiles & AUC number, great!

Before I clause this issue, I had two small follow-up questions:

  1. I get a warning when training with [M, 256, 384] features that comes from interpolating the positional embeddings (line 214 to 218)
    patch_pos_embed = nn.functional.interpolate(
    patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
    scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
    mode='bicubic',
    )

    Here's the associated warning: UserWarning: Default upsampling behavior when mode=bicubic is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.

To supresss it, I would add align_corners=False, but wanted to make sure this was the behavior you would also go for.

  1. Just want to make sure I'm not missing something: could you confirm the only reason one would want to train a model on [M, 256, 384] features (i.e. a HIPT model with a pre-trained self.local_vit component) instead of training the global aggregation transformer on [M, 192] features is: fine-tune the pre-trained self.local_vit (by allowing gradients to flow through this component).

In case the pre-trained self.local_vit component gets frozen, the former should yield the same results as the later. But given it will have a longer forward pass (the features additionally have to go through self.local_vit), one should favour the later.

Thank you for your previous answer, it really helped me find where the issue was coming from! Now that it is fixed, I'll try to reproduce the experiments you report in the paper & look at the ones you've recently run and linked in your answer above. Will be interesting!

from hipt.

Richarizardd avatar Richarizardd commented on May 27, 2024 1
  1. Yes - align_corners=False
  2. Though the "longer forward pass" that uses self.local_vit is more expensive to run, one can do more data augmentation via running self.local_vit with dropout. For larger datasets, finetuning self.local_vit may also be helpful. Lastly, another advantage is that with both features at 256- and 4096-level, one can also try exploring other variations such as concatenating: 1) slide feature from aggregating 256-level features via ABMIL, 2) slide feature from aggregating 4096-level features via ABMIL, and 3) slide feature from last Transformer. I have not tried other strategies, but seems intuitive for capturing the "different scales of features" across resolutions. Would be fun to also mix-and-match different aggregation functions.

Thank you for reporting these issues again. I will reflect these changes sometime this weekend.

from hipt.

Richarizardd avatar Richarizardd commented on May 27, 2024

Hi @clemsgrs - thank you for the detailed post, and will do my best to respond.

  1. Why is patch_size argument passed when instantiating VisionTransformer4K but not used?

Apologies, but it was a typo. I wanted to make the token size an argument for instantiating VisionTransformer4K, but ended up making the arguments for VisionTransformer4K the same as the regular VisionTransformer class as there is no change in ViT sequence length complexity. Whether it is 256-sized images with 16-size patching or 4096-sized images with 256-sized patching, the sequence length is always 16*16=256. To be more exact, technically the image size for VisionTransformer4K should be 3584 while the patch size is 256, as during pretraining, as the maximum global crop size is [14 x 14] in a [16 x 16 x 384] 2D grid of pre-extracted feature embeddings of 256-sized patches. However, since VisionTransformer4K doesn't actually take in 3584/4096-sized images but rather the 2D grid of pre-extracted feature embeddings, it was easier to keep VisionTransformer4K the same as VisionTransformer. I will fix some of these arguments so that it is less confusing.

  1. Why is the img_size argument passed when instantiating VisionTransformer4K is left as default (i.e. img_size = [224]) and not set to [256]...?

See above comment. In addition, I would note that as in the original VisionTransformer from DINO, we can't set img_size = [256] for instantiating VisionTransformer4K as the images are trained with img_size = [224] and thus, the sequence length in self.pos_embed is (224/16)**2+1 = 197. If you change img_size=[256], you would not be able to load in the pretrained weights. Despite some of the typos, everything ended up working in a roundabout way as the ViT complexities are consistent across image resolutions, but apologies for confusion!

  1. Isn't there a confusion between B (supposed to account for the batch size) and M (number of [4096, 4096] regions per slide)? Shouldn't the cls_token tensor be of shape [batch_size, 1, 192]?

I am not sure what the confusion is and may need clarification. However, I would say that in training the local aggregation (of 256-sized features to learn 4K-sized features) in HIPT, you can treat the number of [4096, 4096] regions essentially like a "minibatch" in processing all [M x 256 x 384] features at once. The actual "batch size" (# of WSIs) for weakly-supervised learning is 1.

  1. I've also tried training only the global aggregation layers by directly feeding the region-level pre-extracted features (of shape [M, 192]), without success (training loss not really decreasing either). Could you confirm that this should work just as well as training the intermediate transformer + the global aggregation layers on the [M, 256, 384] features?

I am sorry that you have not had success using the available region-level pre-extracted feature embeddings. What weakly-supervised scaffold code did you use? In this work, CLAM was used for weakly-supervised learning, which I slightly modified for HIPT. Here are the following areas in the repository that may help you in understanding the loss curves and reproducibility.

  • Training Logs For All Experiments: All training logs are made available in the following results directory, in which you can freely inspect via using the tensorboard.

  • Self-Supervised KNN Performance: This notebook provides a simple-to-perform sanity check where you can: 1) load all of the pre-extracted region embeddings, 2) take the mean, 3) plug-and-play into scikit-learn with StratifiedKFold(k=10) and KNeighborsClassifier. Using randomly-generated splits, the performance is roughly on-par with the splits used in the paper.

  • Plugging Pre-Extracted Region Embeddings into CLAM: With all of the 4096-level features pre-extracted, the problem is essentially reduced to a "bag-of-instances" problem (where instead of 256-level features, we have 4096-level features). Using both the official CLAM repository as well as the modified CLAM scaffold used for this work (training commands detailed in the README), I ran a quick experiment that checks how well these features perform on TCGA-BRCA/NSCLC/RCC subtyping using my 10-fold CV. The experiment was run on a machine with a different version of PyTorch+CUDA than the one used in the paper so results may not be exact (and did this somewhat quickly in < 1 hour, so may have also made some mistakes), but you can see here that: 1) results are on-par with reported results in the paper. Both vanilla MIL (pooling [M x 192] features) and a "global aggregation only" version of HIPT (performing vanilla self-attention on [M x 192] features) were trained with 25% / 100% of training data. 2) Training logs are found here. 3) Since all features were made available on GitHub, one can simply rerun these experiments following the README. 4) Both the official CLAM repository and my modified CLAM version gave similar results, and would be happy to provide the former as well.

What problems are you looking to apply HIPT too? I appreciated reading you detailed response in getting this method to work correctly, I would be happy to understand and work through any pain points you have in using this method on TCGA (and other downstream tasks).

from hipt.

clemsgrs avatar clemsgrs commented on May 27, 2024

Hi @Richarizardd, thank you for answering so quickly & with details!

  1. ok makes sense!
  2. indeed, when using img_size = [256], all pre-trained weights are nicely loaded, except the positional embedding (because of the mismatching shape). I just realised that, in that case, my positional embedding will be a random tensor during the whole training process (given it gets initialised as such & then gets frozen). I'll stick to img_size = [224] for now.
  3. true, the confusion was mine! I thought this was happening in the last transformer block (aggregating region-level features into a single slide-level representation), but as you pointed out it's happening in intermediate transformer block, where it makes total sense to have 1 cls_token per region.
  4. as a weakly-supervised scaffold code, I'm using the same modified CLAM scaffold as you used. Thanks for having added the training commands & the HIPT_GP_FC model (I had implemented the same a model). I was using a different set of pre-extracted features than the ones you provide via git LFS (I slightly adapted CLAM patch extraction pipeline).
    → I'll switch to using yours as a first step.
    Thank you for the list of resources I can use to debug what is happening on my side. I had forgotten about the Self-Supervised KNN, it seems a good place to start!

I'm looking to apply HIPT to various computational pathology problems and compare how other methods perform on the same tasks.

Once again, big thanks for the fast & detailed answer. Will reach back to you when I have something new!

from hipt.

bryanwong17 avatar bryanwong17 commented on May 27, 2024

Hi @clemsgrs, I have a small question. When training at the slide level, did you set freeze_4k = True?

from hipt.

clemsgrs avatar clemsgrs commented on May 27, 2024

hi @bryanwong17, when training on region-level features (i.e. sequence of embeddings shaped [M, 192]), I did set freeze_4k = True

from hipt.

bryanwong17 avatar bryanwong17 commented on May 27, 2024

Hi @clemsgrs, is the input of training HIPT_LGP_FC [M, 256, 384]? when we define 'pretrain_4k != None', it would load 'vit4k_xs_dino.pth' and change the dimension to [M. 192]? Then, we set 'freeze_4k=True'?

from hipt.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.