Giter Site home page Giter Site logo

Comments (2)

ankurmahesh avatar ankurmahesh commented on August 20, 2024

If I change _update_mean to only return the sum of the elements in the current iteration and the number of elements shown in the current iteration, I see the behavior I expected. See below for a full reproducible script. AlternateMean is the exact same as the Mean: I just needed to include it so that it called this different _update_mean method.)

from modulus.metrics.general.ensemble_metrics import Mean, Variance     
import torch.distributed as dist                                        
from modulus.distributed import DistributedManager                      
import torch                                                            
from typing import Union, Tuple, List                                   
Tensor = torch.Tensor                                                   
                                                                        
                                                                        
def _update_mean(                                                       
    old_sum: Tensor,                                                    
    old_n: Union[int, Tensor],                                          
    input: Tensor,                                                      
    device,                                                             
    batch_dim: Union[int, None] = 0,                                    
) -> Tuple[Tensor, Union[int, Tensor]]:                                                                                          
    if batch_dim is None:                                               
        input = torch.unsqueeze(input, 0)                               
        batch_dim = 0         
    new_sum = torch.sum(input, dim=batch_dim)                           
    new_n = torch.Tensor([input.size()[batch_dim]]).to(device).int()    
                                                                        
    return new_sum, new_n                                               
                                                                        
                                                                        
class AlternateMean(Mean):                                              
    """Utility class that computes the mean over a batched or ensemble dimension
                                                                        
    This is particularly useful for distributed environments and sequential computation.
                                                                        
    Parameters                                                          
    ----------                                                          
    input_shape : Union[Tuple, List]                                    
        Shape of broadcasted dimensions                                 
    """                                                                 
                                                                        
    def __init__(self, input_shape: Union[Tuple, List], **kwargs):      
        super().__init__(input_shape, **kwargs)                         
                                                                        
    def update(self, input: Tensor) -> Tensor:                          
        """Update current mean and essential statistics with new data   
                                                                        
        Parameters                                                      
        ----------                                                      
        input : Tensor                                                  
            Input tensor      
        Returns                                                         
        -------                                                         
        Tensor                                                          
            Current mean value                                          
        """                                                             
        self._check_shape(input)                                        
        # TODO(Dallas) Move distributed calls into finalize.            
        if DistributedManager.is_initialized() and dist.is_initialized():
            sums, n = _update_mean(self.sum, self.n, input, device=dm.device,
                                   batch_dim=0)                         
            dist.all_reduce(sums, op=dist.ReduceOp.SUM)                 
            dist.all_reduce(n, op=dist.ReduceOp.SUM)                    
            self.sum += sums                                            
            self.n += n                                                 
        else:                                                           
            self.sum, self.n = _update_mean(self.sum, self.n, input, batch_dim=0)
        return self.sum / self.n                                        
                                                                        
                                                                        
if __name__ == '__main__':                                              
                                                                        
                                                                        
    DistributedManager.initialize()                                     
    dm = DistributedManager()                                           
    if dm.rank == 0:                                                    
        print(dm.world_size)                                            
                                                                        
    tensor = torch.Tensor([[1]]).to(dm.device)    

    m = AlternateMean(tensor.shape, device=dm.device)                   
    for a in range(5):                                                  
        _ = m.update(tensor)                                            
                                                                        
        if dm.rank == 0:                                                
            print("n after {} iterations".format(a+1))                  
            print(m.n)

This outputs

2
n after 1 iterations
tensor([2], device='cuda:0', dtype=torch.int32)
n after 2 iterations
tensor([4], device='cuda:0', dtype=torch.int32)
n after 3 iterations
tensor([6], device='cuda:0', dtype=torch.int32)
n after 4 iterations
tensor([8], device='cuda:0', dtype=torch.int32)
n after 5 iterations
tensor([10], device='cuda:0', dtype=torch.int32)

from modulus.

dallasfoster avatar dallasfoster commented on August 20, 2024

Thank you for the issue submission. It appears that the bug is that the update call of Mean(EnsembleMetrics) should construct local sums and n before reducing across devices (as is what occurs in the Variance(EnsembleMetrics) class. Note that the fix should not occur in _update_mean, as that function behaves as expected.

Fixed by #63

from modulus.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.