Giter Site home page Giter Site logo

nash-mtl's Introduction

Nash-MTL

Official implementation of "Multi-Task Learning as a Bargaining Game".

Setup environment

conda create -n nashmtl python=3.9.7
conda activate nashmtl
conda install pytorch==1.9.0 torchvision==0.10.0 cudatoolkit=10.2 -c pytorch
conda install pyg -c pyg -c conda-forge

Install the repo:

git clone https://github.com/AvivNavon/nash-mtl.git
cd nash-mtl
pip install -e .

Run experiment

To run experiments:

cd experiment/<expirimnet name>
python trainer.py --method=nashmtl

Follow instruction on the experiment README file for more information regarding, e.g., datasets.

Here <experiment name> is one of [toy, quantum_chemistry, nyuv2]. You can also replace nashmtl with on of the following MTL methods.

We also support experiment tracking with Weights & Biases with two additional parameters:

python trainer.py --method=nashmtl --wandb_project=<project-name> --wandb_entity=<entity-name>

MTL methods

We support the following MTL methods with a unified API. To run experiment with MTL method X simply run:

python trainer.py --method=X
Method (code name) Paper (notes)
Nash-MTL (nashmtl) Multi-Task Learning as a Bargaining Game
CAGrad (cagrad) Conflict-Averse Gradient Descent for Multi-task Learning
PCGrad (pcgrad) Gradient Surgery for Multi-Task Learning
IMTL-G (imtl) Towards Impartial Multi-task Learning
MGDA (mgda) Multi-Task Learning as Multi-Objective Optimization
DWA (dwa) End-to-End Multi-Task Learning with Attention
Uncertainty weighting (uw) Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics
Linear scalarization (ls) - (equal weighting)
Scale-invariant baseline (scaleinvls) - (see Nash-MTL paper for details)
Random Loss Weighting (rlw) A Closer Look at Loss Weighting in Multi-Task Learning

Citation

If you find Nash-MTL to be useful in your own research, please consider citing the following paper:

@article{navon2022multi,
  title={Multi-Task Learning as a Bargaining Game},
  author={Navon, Aviv and Shamsian, Aviv and Achituve, Idan and Maron, Haggai and Kawaguchi, Kenji and Chechik, Gal and Fetaya, Ethan},
  journal={arXiv preprint arXiv:2202.01017},
  year={2022}
}

nash-mtl's People

Contributors

avivnavon avatar avivsham avatar explcre avatar soumik12345 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

nash-mtl's Issues

Questions about the property of convergence on theorem 5.4 and 5.5

Hi, experts
Thanks sharing the excellent work about MTL.

  1. ||1/α(t)||≥σK((G(t))⊤G(t))||α(t)|| on the theorem 5.4, where σK((G(t))⊤G(t)) is the smallest singular value of (G(t))⊤G(t).
  • What the type of norm?
    • I think this is spectral norm, then ||(G(t))⊤G(t)|| is the maximum singular value of (G(t))⊤G(t) such that the Inequality holds. Is this right?
  1. As σK((G(t))TG(t))→0 we have from continuity that σK(G⊤ ∗G∗)=0, where G∗is the matrix of gradients at θ∗.
  • Is this implication for arriving Pareto optimal point θ∗ we should check the smallest singular value of (G(t))⊤G(t) every iteration to observe whether to converge or not?

hi,i want to apply nash-mtl to my multi-task learning net.

Hello, my network is a 2-task network based on ResNet-18. I have tried PCGrad, am currently trying nash-mtl. But I found that the multi-task training effect decreased with the addition of PCGrad, and the network with nash-mtl could not even converge. May I ask if there are any applicable conditions for nash-mtl? For example, what are the requirements for the loss of each task, what are the requirements for multi-task network, and other applicable conditions? In addition, how to adjust the hyperparameter of nash-mtl, 'update_weights_every', "optim_niter", 'max_norm'?

is there code for torch.distributed?

I would like to express my sincere gratitude for the excellent research you have conducted. Your work has significantly helped me in defining the direction of my own research and initiating it. Specifically, I am working on solving problems related to multi-task learning in 3D object detection and BEV segmentation for autonomous driving.

I have a question regarding the repository you provided and have left an issue for discussion. Does your repository include code for torch.distributed to facilitate multi-GPU learning? I am inquiring because using torch.distributed allows averaging gradients across GPUs, which alters the computations. I am facing some challenges in this area and would greatly appreciate your guidance or insights.

Thank you for your support and looking forward to your response.
Junghokim

Logging

Hi @AvivNavon, thank you for your great work. Im trying to experiment with sth new, so could you please give me the log files for NYUv2 experiments so that I can benchmark my running?

About GradNorm

Can you implement GradNorm to compare with other algorithms? If it's convenient

MTL for multiple modules?

Thanks for your owsome works, it's really cooooool and helpful!

My question is about deploying the MTL methods on multiple NN modules. For example,

import torch.nn.functional as F
import torch

class First_model(nn.Module):
    def __init__(self, 
                 input_dim, 
                 out_dim,
                 ):
        super(First_model, self).__init__()
        
        self.nn = nn.Sequential(
            torch.nn.Linear(self.input_dim, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, self.state_dim)
        )
    
    def forward(self, inputs):
        return self.nn(inputs)

class Second_model(nn.Module):
    def __init__(self, 
                 input_dim, 
                 out_dim,
                 ):
        super(Second_model, self).__init__()
        
        self.nn = nn.Sequential(
            torch.nn.Linear(self.input_dim, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, self.state_dim)
        )
    
    def forward(self, inputs):
        return self.nn(inputs)

class Third_model(nn.Module):
    def __init__(self, 
                 input_dim, 
                 out_dim,
                 ):
        super(Third_model, self).__init__()
        
        self.nn = nn.Sequential(
            torch.nn.Linear(self.input_dim, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, self.state_dim)
        )
    
    def forward(self, inputs):
        return self.nn(inputs)

first_model = First_model(25, 1)
second_model = Second_model(25, 1)
third_model = Third_model(25, 1)

x = torch.randn(10,25)
y = torch.randn(10,1)
m = torch.randn(10,1)
n = torch.randn(10,1)

losses.append(F.mse(first_model(x), y)) 
losses.append(F.mse(second_model(x), m)) 
losses.append(F.mse(third_model(x), n)) 
losses.append(F.mse(third_model(x), y)) 

Now I want to use MTL methods to train these models, but I don't know if I was using it correctly. Here is my template code:

weight_method = WeightMethods(
        method,
        n_tasks=4, 
        device=device,
        **weight_method_params[method],
    )
loss, extra_outputs = weight_method.backward(
                losses=losses,
                shared_parameters=,
                task_specific_parameters=,
                last_shared_parameters=,
                representation=features,
            )

I'm not quite certain about what kind of variable structure should I feed in to 'weight_method.backward' in such case?

Question about large weights

Thank you for the great work.
When I tried the method, I found that the calculated weights are relatively large than other mtl method, like [47.24413258, 732.26542343] vs [0.4, 0.6]

Is it normal? Should I rescale the weights? Because I think the large weights of losses may influence the regularization.
Thank you!

about two warnings when training the model

Hi, there. I have two warnings when training the nash-mtl model. The first warning is "the problem is not DPP", and the second is "Solution may be inaccurate" which I check out the problem in detail is "OPTIMAL_INACCURATE". Do you have these warnings? And have any suggestions for fix these warnings? Thanks a lot.

There is the detail terminal warning code.

\.conda\envs\xxx\lib\site-packages\cvxpy\reductions\solvers\solving_chain.py:213: UserWarning: You are solving a parameterized problem that is not DPP. Because the problem is not DPP, subsequent solves will not be faster than the first one. For more information, see the documentation on Discplined Parametrized Programming, at
	https://www.cvxpy.org/tutorial/advanced/index.html#disciplined-parametrized-programming
  warnings.warn(dpp_error_msg)
.conda\envs\xxx\lib\site-packages\cvxpy\problems\problem.py:1387: UserWarning: Solution may be inaccurate. Try another solver, adjusting the solver settings, or solve with verbose=True for more information.
    - warnings.warn(

About an errror during training process

Hi, there. When I train the model, sometimes the code will call an error which is "ValueError: Parameter value must be real". It occurs at the call path from [self.get_weighted_loss] to [self.solve_optimization(GTG.cpu().detach().numpy())] to [self.G_param.value = gtg] to [self._value = self._validate_value(val)].

Could you help me to deal with the error? Thanks a lot!

Traceback (most recent call last):
  File "xxx.py", line 435, in <module>
    main()
  File "xxx.py", line 427, in main
    trainer.training(epoch)
  File "xxx.py", line 158, in training
    loss1, extra_outputs = self.weight_method.backward(
  File "methods\weight_methods.py", line 810, in backward
    return self.method.backward(losses, **kwargs)
  File "methods\weight_methods.py", line 263, in backward
    loss, extra_outputs = self.get_weighted_loss(
  File "methods\weight_methods.py", line 237, in get_weighted_loss
    alpha = self.solve_optimization(GTG.cpu().detach().numpy())
  File "methods\weight_methods.py", line 134, in solve_optimization
    self.G_param.value = gtg
  File "C:\Users\xxx\.conda\envs\xxx\lib\site-packages\cvxpy\expressions\constants\parameter.py", line 87, in value
    self._value = self._validate_value(val)
  File "C:\Users\xxx\.conda\envs\xxx\lib\site-packages\cvxpy\expressions\leaf.py", line 442, in _validate_value
    raise ValueError(
ValueError: Parameter value must be real.

Process finished with exit code 1

MT10 Training Code

First of all, thank you for the great repository :)
Could you also upload the RL training code for meta-world MT10 tasks?

Thank you!

Mean Rank metric?

I have run your experment of nyuv2 on my own computer. But I can not find the result of Mean Rank metric, which you reported in the paper. How can I compute the metric by myself?

Thank you very much

Possible error in the implementation of CAGrad

Thank you so much for the great work and for putting the code for all MTL algorithms together in a unified manner.

I find a minor error in your implementation of CAGrad. From line 563-567 in methods/weight_methods.py, I suppose what you want to do is to retain the computation graph for all tasks except the last one so that some memory can be saved. However, it seems that in the current implementation, the code will always go inside the first if statement, meaning that the computation graph is retained for all tasks. I find this issue because I run into out-of-memory errors when using your code. Hope this helps!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.