Giter Site home page Giter Site logo

aih-sgml / mixmil Goto Github PK

View Code? Open in Web Editor NEW
13.0 4.0 0.0 455 KB

Code for the paper: Mixed Models with Multiple Instance Learning

Home Page: https://arxiv.org/abs/2311.02455

License: Apache License 2.0

Jupyter Notebook 59.06% Python 40.94%
attention-mechanism empirical-bayes generalized-linear-mixed-models mixed-models multi-instance multi-instance-learning variational-inference

mixmil's Introduction

MixMIL

Code for the paper: Mixed Models with Multiple Instance Learning

Accepted at AISTATS 24 as an oral presentation & Outstanding Student Paper Highlight.

Please raise an issue for questions and bug-reports.

Installation

Install with:

pip install mixmil

alternatively, if you want to include the optional experiment and test dependencies use:

pip install "mixmil[experiments,test]"

or if you want to adapt the code:

git clone https://github.com/AIH-SGML/mixmil.git
cd mixmil
pip install -e ".[experiments,test]"

To enable computations on GPU please follow the installation instructions of PyTorch and PyTorch Scatter. MixMIL works e.g. with PyTorch 2.1.

Experiments

See the notebooks in the experiments folder for examples on how to run the simulation and histopathology experiments.

Make sure the experiments requirements are installed:

pip install "mixmil[experiments]"

Histopathology

The histopathology experiment was performed on the CAMELYON16 dataset.

Download Data

To download the embeddings provided by the DSMIL authors, either:

  • Full embeddings: python scripts/dsmil_data_download.py
  • PCA reduced embeddings: Google Drive

Microscopy

The full BBBC021 dataset can be downloaded here.

Download Data

  • We make the featurized cells available at BBBC021
  • The features are stored as an AnnData object. We recommend using the scanpy package to read and process them
  • The weights of the featurizer trained with the SimCLR algorithm can be downloaded from the original GitHub repository

Citation

@inproceedings{engelmann2024mixed,
  title={Mixed Models with Multiple Instance Learning},
  author={Engelmann, Jan P. and Palma, Alessandro and Tomczak, Jakub M. and Theis, Fabian and Casale, Francesco Paolo},
  booktitle={International Conference on Artificial Intelligence and Statistics},
  pages={3664--3672},
  year={2024},
  organization={PMLR}
}

mixmil's People

Contributors

allepalma avatar jan-engelmann avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

mixmil's Issues

Question about reparametrization and number of parameters

Hi, thank you for providing the materials for implementing your paper.

Reparametrization:
I've been going over the details about the implementation in the appendix, and I was wondering about the specifics of the reparametrization introduced in section A.1:

I was a bit puzzled about the exact details of the normalization introduced here. I understand why you would want the sample variance to equal the sum of squared $\beta$ so that it is linked to explained variance, but I don't quite understand the derivation.

First, what does it mean that $\eta$ in A.2 has unit mean square? I assumed that it meant $\mathbb{E}[\eta^2]$ = 1, but I don't quite see how that is achieved with the rescaling.

Secondly, why do we need this rescaling with $\eta$ in the normalization formula A.3 instead of normalizing using just the regular formula for $u$?. I.e., why can't we do this:

$$u = b \times \frac{u - mean(u)}{ std(u)}$$
instead of

$$u = b \times \frac{\widetilde{u} - mean(\widetilde{u})}{ std(\widetilde{u})}$$

is there a more lengthy explanation somewhere that I can follow?

Number of parameters
In the paper it was explained that there are a total of $2Q + Q(2Q+1)$ variational parameters, assuming full rank, and $2Q + K$ likelilhood parameters. I understand where the number of variational parameters are coming from, but what exactly comprises the likelihood parameters? Is it $2Q$ parameters coming from $\beta$ and $\gamma$, respectively, and $K$ parameters coming from $\alpha$? Are the trainable parameters from $\sigma_{\beta}$ and $\sigma_{\gamma}$ also included somewhere within the numbers of likelihood/variational parameters?

image

Bug in KL scaling

Hi, sorry for bothering again. However, I think I found a small bug within the KL divergence loss in the loss function.

Currently, the kl loss within the code is already weighted by the batch size divided by the total dataset (kld_w). So far so good.

However, dividing by y_shape[0] refers to the batch size, at least when I tried the Camelyon16 example, meaning that the KL loss would eventually be weighted by just the size of the dataset, which would make the KL loss very small.

Potential fix
I guess y_shape[0] was meant to divide by the number of outputs P? In that case, I think it should be a simple fix and change it to:
kld_term = kld_w * kld.sum() / y.shape[1]

or
kld_term = kld_w * kld.sum() /self.P

to avoid any confusion about the dimensions

kld_term = kld_w * kld.sum() / y.shape[0]

Sharing of single-cell pre-trained embeddings

Hi, thank you again for the previous explanation and this great work.

I have been through your paper, and I am interested in implementing the model on BBBC021 dataset.
Would it be possible to share the obtained single-cell representations?

That would be really appreciated :)!

Best regards,
Hsiuchi

Very long waiting time for the init setep

Hi, I found that the initial step of this method is superly long for me. I have been waiting for one hour. Any suggestions here? Are they related to data scales? I have 276400 instances for training here.

Understanding of the initialization strategies

Hi!

I was going through the code and stumbled on the initialization strategies defined in utils.py: https://github.com/AIH-SGML/mixmil/blob/main/mixmil/utils.py

My question is where these formulas come from, and if there is any literature that describes it. I understand conceptually what's going on: regressing a standard glmm to warm-start the training, and regressing out the individual components to get parameter estimates to start with. However, the formulas themselves are appearing out of the blue for me here. E.g. I'd like to know how you ended up with the formulas like below.

    # Compute bag prediction u and reparametrize
    u = X.dot(beta)
    um = u.mean(0)[None]
    us = u.std(0)[None]
    alpha = alpha + um
    mu_beta = us * beta / np.sqrt((beta**2).mean(0)[None])
    sd_beta = np.sqrt(0.1 * (mu_beta**2).mean()) * np.ones_like(mu_beta)

    alpha = Fiv.dot(np.ones((Fiv.shape[1], 1))).dot(alpha) - b.dot(mu_beta)

    # init prior
    var_z = (mu_beta**2 + sd_beta**2).mean().reshape(1, 1)

The function for the binomial case was almost clear, except for the part where var_z = (mu_beta2 + sd_beta2).mean().reshape(1, 1), since I don't understand how that seems to be an initialization for the prior in general.

I tried searching for it myself but couldn't immediately find any literature matching the formulas, so I thought I'd ask.

Cheers!

How to train MixMIL with categorical data?

Thanks for the great work (and congrats for being selected as an oral for this year's AISTATS)!

I'm trying to train a mixMIL model on categorical data.

As I had trouble getting it to work on my own data, I tried it on the "mock" test data. I adapted the code from the mock_data_categorical function in tests (which only initializes a mean model but doesn't train it):

def mock_data_categorical():
N, Q, K = 50, 10, 4
bag_sizes = torch.randint(5, 15, (N,))
Xs = [torch.randn(bag_sizes[n], Q) for n in range(N)] # List of tensors
F = torch.randn(N, K) # Fixed effects
Y = torch.randint(0, 5, (N, 1)) # Labels for categorical
return Xs, F, Y

and wrote my own mock data training script:

from mixmil import MixMIL
import torch 

device = "cuda:1" if torch.cuda.is_available() else "cpu"

N, Q, K = 50, 10, 4
bag_sizes = torch.randint(5, 15, (N,))
Xs = [torch.randn(bag_sizes[n], Q) for n in range(N)]  # List of tensors
F = torch.randn(N, K)  # Fixed effects
Y = torch.randint(0, 5, (N, 1))  # Labels for categorical
model = MixMIL.init_with_mean_model(Xs, F, Y, likelihood="categorical")
model.train(Xs, F, Y, n_epochs=10)

However, even with the mock data I run into a runtime error in model.train():
Screenshot 2024-07-08 at 1 21 24 PM

When I print the shapes of the offending scale_u and scale_z, I get scale_u to be of shape [5,10] and scale_z to be of shape [1,10]. Thus we can't torch.cat over dim=1 (since dim 0 is different).

Also, printing the model gives me the following:
Screenshot 2024-07-08 at 2 43 19 PM

I would greatly appreciate any tips to get mixMIL to train on categorical data (e.g. fixes for the above scenario or sample code).

Thanks in advance!
Patrick

Answer on BBBC021 Holdout procedure

In the BBBC021 dataset, compounds are measured across different plates to account batch effect (technical variability of the cellular response to treatment). Each plate contains multiple wells, where different compounds are tested. Only one compound gets applied per well. Each one of the considered compounds is measured over multiple plates. For example, a drug like Cytochalasin B is measured in wells located on three different plates.

To ensure the model picks up the biological effect caused by compounds on cells, we evaluated the model by training it on two plates and testing its performance on a held-out one. This holding-out procedure is carried out on all the plates subsequentially, and the performance is average across them. So, for all drugs, we leave out one plate first, train the model on the remaining ones, and test on the leave-out plate. This is repeated for the other plates where the compounds are measured and the performances are averaged across drugs and hold-out experiments.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.