Giter Site home page Giter Site logo

Comments (6)

johannbrehmer avatar johannbrehmer commented on August 14, 2024

Results of line profiling the RASCAL training:

Timer unit: 1e-06 s

Total time: 165.879 s
File: /Users/johannbrehmer/work/projects/madminer/madminer/madminer/utils/ml/ratio_trainer.py
Function: train_ratio_model at line 17

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    17                                           def train_ratio_model():
 [...]
    63                                               # Convert to Tensor
    64         1          4.0      4.0      0.0      if theta0s is not None:
    65         1    1051220.0 1051220.0      0.6          data.append(torch.stack([tensor(i, requires_grad=True) for i in theta0s]))
    66         1          5.0      5.0      0.0      if theta1s is not None:
    67                                                   data.append(torch.stack([tensor(i, requires_grad=True) for i in theta1s]))
    68         1          4.0      4.0      0.0      if xs is not None:
    69         1    1081236.0 1081236.0      0.7          data.append(torch.stack([tensor(i) for i in xs]))
    70         1          4.0      4.0      0.0      if ys is not None:
    71         1     828597.0 828597.0      0.5          data.append(torch.stack([tensor(i) for i in ys]))
    72         1          4.0      4.0      0.0      if r_xzs is not None:
    73         1     761371.0 761371.0      0.5          data.append(torch.stack([tensor(i) for i in r_xzs]))
    74         1          6.0      6.0      0.0      if t_xz0s is not None:
    75         1     809017.0 809017.0      0.5          data.append(torch.stack([tensor(i) for i in t_xz0s]))
    76         1          4.0      4.0      0.0      if t_xz1s is not None:
    77                                                   data.append(torch.stack([tensor(i) for i in t_xz1s]))
[...]
   171                                                   # Loop over batches
   172       201    3307821.0  16456.8      2.0          for i_batch, batch_data in enumerate(train_loader):
[...]
   210       200   44179952.0 220899.8     26.6                      s_hat, log_r_hat, t_hat0 = model(theta0, x)
[...]
   225                                                       # Evaluate loss
   226                                                       losses = [
   227       200       1112.0      5.6      0.0                  loss_function(s_hat, log_r_hat, t_hat0, t_hat1, y, r_xz, t_xz0, t_xz1)
   228       200     109029.0    545.1      0.1                  for loss_function in loss_functions
   229                                                       ]
   230       200        934.0      4.7      0.0              if grad_x_regularization is not None:
   231                                                           losses.append(torch.mean(x_gradient ** 2))
   232                                           
   233       200       7794.0     39.0      0.0              loss = loss_weights[0] * losses[0]
   234       400       2431.0      6.1      0.0              for _w, _l in zip(loss_weights[1:], losses[1:]):
   235       200       4106.0     20.5      0.0                  loss += _w * _l
   236                                           
   237       600     133223.0    222.0      0.1              for i, individual_loss in enumerate(losses):
   238       400       4621.0     11.6      0.0                  individual_train_loss[i] += individual_loss.item()
   239       200        934.0      4.7      0.0              total_train_loss += loss.item()
   240                                           
   241                                                       # Calculate gradient and update optimizer
   242       200  113418481.0 567092.4     68.4              loss.backward()
   243       200      84310.0    421.6      0.1              optimizer.step()
[...]

and for the forward pass:

Timer unit: 1e-06 s

Total time: 43.4621 s
File: /Users/johannbrehmer/work/projects/madminer/madminer/madminer/utils/ml/models/ratio.py
Function: forward at line 40

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    40                                               def forward(self, theta, x, track_score=True, return_grad_x=False):
    41                                           
    42                                                   """ Calculates estimated log likelihood ratio and the derived score. """
    43                                           
    44                                                   # Track gradient wrt theta
    45       200        912.0      4.6      0.0          if track_score and not theta.requires_grad:
    46                                                       theta.requires_grad = True
    47                                           
    48                                                   # Track gradient wrt x
    49       200        182.0      0.9      0.0          if return_grad_x and not x.requires_grad:
    50                                                       x.requires_grad = True
    51                                           
    52                                                   # log r estimator
    53       200       8901.0     44.5      0.0          log_r_hat = torch.cat((theta, x), 1)
    54                                           
    55       800       4914.0      6.1      0.0          for i, layer in enumerate(self.layers):
    56       600        455.0      0.8      0.0              if i > 0:
    57       400      85777.0    214.4      0.2                  log_r_hat = self.activation(log_r_hat)
    58       600      99077.0    165.1      0.2              log_r_hat = layer(log_r_hat)
    59                                           
    60                                                   # Bayes-optimal s
    61       200       8673.0     43.4      0.0          s_hat = 1. / (1. + torch.exp(log_r_hat))
    62                                           
    63                                                   # Score t
    64       200        168.0      0.8      0.0          if track_score:
    65       200        191.0      1.0      0.0              t_hat = grad(log_r_hat, theta,
    66       200       1952.0      9.8      0.0                           grad_outputs=torch.ones_like(log_r_hat.data),
    67       200   43250036.0 216250.2     99.5                           only_inputs=True, create_graph=True)[0]
    68                                                   else:
    69                                                       t_hat = None
    70                                           
    71                                                   # Calculate gradient wrt x
    72       200        695.0      3.5      0.0          if return_grad_x:
    73                                                       x_gradient = grad(log_r_hat, x,
    74                                                                     grad_outputs=torch.ones_like(log_r_hat.data),
    75                                                                     only_inputs=True, create_graph=True)[0]
    76                                           
    77                                                       return s_hat, log_r_hat, t_hat, x_gradient
    78                                           
    79       200        173.0      0.9      0.0          return s_hat, log_r_hat, t_hat

So the issue seems to be the gradient. I'm trying to figure out how to speed that part up. Any pointers are much appreciated!

from madminer.

johannbrehmer avatar johannbrehmer commented on August 14, 2024

5c1b414 reduces the number of times the gradients are calculated. This should significantly speed up the training of CARL, ROLR, and ALICE, and the evaluation of all ratio-based methods.

from madminer.

cranmer avatar cranmer commented on August 14, 2024

Is this issue fundamentally fixed, or you made some change that is a work-around?

I asked Gilles if he could investigate a bit.

from madminer.

johannbrehmer avatar johannbrehmer commented on August 14, 2024

The issue is fixed for the methods that don't need the gradient of the network output (and for the evaluation of all methods).

But those techniques that do need the gradient (RASCAL, ALICES, ...) still train slowly. If Gilles has time to investigate, that would be great! There's a new toy example in the repository that might be a helpful.

Re-opening the issue for now.

from madminer.

glouppe avatar glouppe commented on August 14, 2024

Hi,

In general the backward pass is expected to be 2-3x slower than the forward. If we require second-order gradients, that results in a 4-9x slowdown.

See e.g. pytorch/pytorch#7714 (comment)

I am not sure there is any easy fix to that :/

from madminer.

johannbrehmer avatar johannbrehmer commented on August 14, 2024

Thanks. That's unfortunate.

Is anything about this specific to reverse-mode auto-differentiation? I.e. would it be different with tensorflow?

from madminer.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.