Giter Site home page Giter Site logo

3dffl's Introduction

3DFFL-Privacy-Preserving-Federated-Few-Shot-Learning-for-3D-Point-Clouds-in-Autonomous-Vehicles

This project presents a comprehensive framework for federated few-shot learning (3DFFL), focusing on 3D point cloud classification. The approach integrates Federated Learning (FL) with Few-Shot Learning (FSL) techniques, using PointNet++ for feature extraction and ProtoNet for classification. The framework ensures data privacy and leverages collaborative learning to handle data scarcity and heterogeneity.

Key Features

  • PointNet++ for feature extraction
  • Attention mechanism
  • Loss functions for embedding and learnable tasks
  • Differential privacy updates
  • Data augmentation with Mixup
  • Federated training process
  • Few-shot learning with Prototypical Networks

Table of Contents

  1. Requirements
  2. Installation
  3. Usage
  4. Key Components
  5. Example
  6. References

Requirements

  • Python 3.x
  • PyTorch
  • NumPy
  • scikit-learn

Installation

pip install torch numpy scikit-learn

Usage

Loading Data

The load_data function generates synthetic data for testing purposes:

def load_data():
    N = 100  # Number of samples
    P = 1024 # Number of points per sample
    D = 3    # Dimensionality of each point
    C = 10   # Number of classes
    return [(torch.rand(N, D, P), torch.randint(0, 2, (N, C)).float()) for _ in range(5)]

Federated Training

To perform federated training with the provided framework:

local_datasets = load_data()
global_model = GlobalModel(input_dim=3, feature_dim=128, num_classes=10).to('cuda' if torch.cuda.is_available() else 'cpu')
lambda1, lambda2, lambda3 = 1.0, 1.0, 0.1

trained_global_model = federated_training_process(num_rounds=50, local_datasets=local_datasets, global_model=global_model, lambda1=lambda1, lambda2=lambda2, lambda3=lambda3, apply_privacy=True)

Few-Shot Training

To perform few-shot training with the trained global model:

n_way, k_shot, q_query = 5, 5, 15
few_shot_data = [torch.rand((100, 3, 1024)) for _ in range(n_way)]

trained_few_shot_model = few_shot_training(trained_global_model, few_shot_data, num_rounds=10, n_way=n_way, k_shot=k_shot, q_query=q_query)

Saving and Loading the Model

To save the trained model:

torch.save(trained_few_shot_model.state_dict(), 'few_shot_trained_model.pth')

To load the trained model:

model = GlobalModel(input_dim=3, feature_dim=128, num_classes=10)
model.load_state_dict(torch.load('few_shot_trained_model.pth'))

Key Components

PointNet++

The PointNetPP class implements a simplified version of PointNet++ for feature extraction:

class PointNetPP(nn.Module):
    # Initialization and forward methods here

Attention Layer

The AttentionLayer class implements an attention mechanism:

class AttentionLayer(nn.Module):
    def __init__(self, feature_dim):
        super(AttentionLayer, self).__init__()
        self.attention = nn.Sequential(
            nn.Linear(feature_dim, feature_dim),
            nn.ReLU(),
            nn.Linear(feature_dim, 1),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        attention_weights = self.attention(x)
        x = x * attention_weights
        return x

Loss Functions

The following functions calculate the different loss components:

  • calculate_embedding_loss
  • calculate_learnable_loss
  • calculate_overall_loss

Differential Privacy

The differential_privacy_update function applies differential privacy updates to the model:

def differential_privacy_update(model, noise_multiplier=0.1):
    # Implementation here
   for param in model.parameters():
        noise = torch.normal(0, noise_multiplier, size=param.size()).to(param.device)
        param.data.add_(noise)

Data Augmentation

The following functions handle data augmentation using Mixup:

  • mixup_data
  • mixup_criterion

Federated Training Process

The federated_training_process function performs federated training:

def federated_training_process(num_rounds, local_datasets, global_model, lambda1, lambda2, lambda3, apply_privacy=False):
        num_nodes = len(local_datasets)
    momentum = 0.9
    global_momentum = {key: torch.zeros_like(value).float() for key, value in global_model.state_dict().items()}
    previous_loss = float('inf')

    for t in range(num_rounds):
        local_updates = []
        total_loss = 0

        for i in range(num_nodes):
            X_i, Y_i = local_datasets[i]
            local_model = copy.deepcopy(global_model)
            optimizer = torch.optim.Adam(local_model.parameters(), lr=0.001)

            local_model.train()
            for epoch in range(5):
                optimizer.zero_grad()
                features = local_model.feature(X_i)

                attention_layer = AttentionLayer(features.size(1)).to(features.device)
                attended_features = attention_layer(features)

                embedding_loss = calculate_embedding_loss(attended_features, Y_i)
                learnable_loss = calculate_learnable_loss(attended_features, features, Y_i)
                comp_loss = torch.tensor(0.1).to(features.device)
                overall_loss = calculate_overall_loss(embedding_loss, learnable_loss, comp_loss, lambda1, lambda2, lambda3)
                overall_loss.backward()
                optimizer.step()

                total_loss += overall_loss.item()

            if apply_privacy:
                differential_privacy_update(local_model)

            local_updates.append(local_model.state_dict())

        average_loss = total_loss / num_nodes
        if abs(previous_loss - average_loss) < 1e-3:  # Convergence criterion
            break
        previous_loss = average_loss

        global_model_dict = global_model.state_dict()
        for key in global_model_dict.keys():
            updates = torch.stack([local_updates[i][key].float() for i in range(num_nodes)], dim=0)
            average_update = updates.mean(dim=0)
            global_momentum[key] = momentum * global_momentum[key] + (1 - momentum) * average_update
            global_model_dict[key] = global_model_dict[key].float() + global_momentum[key]

        global_model.load_state_dict(global_model_dict)

        for i in range(num_nodes):
            local_datasets[i][0].model = copy.deepcopy(global_model)

    return global_model

Few-Shot Learning

The ProtoNet class and few_shot_training function implement few-shot learning using Prototypical Networks:

def __init__(self, model_func, n_way, n_support):
        super(ProtoNet, self).__init__(model_func, n_way, n_support)

    def set_forward(self, x, is_feature=False):
        z_support, z_query = self.parse_feature(x, is_feature)
        z_proto = z_support.contiguous().view(self.n_way, self.n_support, -1).mean(1)
        z_query = z_query.contiguous().view(self.n_way * self.n_query, -1)
        dists = euclidean_dist(z_query, z_proto)
        scores = -dists
        return scores

    def set_forward_loss(self, x):
        y_query = torch.from_numpy(np.repeat(range(self.n_way), self.n_query))
        y_query = Variable(y_query.cuda())
        scores = self.set_forward(x)
        return F.cross_entropy(scores, y_query)

def euclidean_dist(x, y):
    n = x.size(0)
    m = y.size(0)
    d = x.size(1)
    x = x.unsqueeze(1).expand(n, m, d)
    y = y.unsqueeze(0).expand(n, m, d)
    return torch.pow(x - y, 2).sum(2)
def few_shot_training(global_model, data, num_rounds, n_way, k_shot, q_query):
    optimizer = torch.optim.Adam(global_model.parameters(), lr=0.001)
    loss_fn = torch.nn.CrossEntropyLoss()

    for round in range(num_rounds):
        global_model.train()
        optimizer.zero_grad()

        support_set, query_set, labels = create_few_shot_batches(data, n_way, k_shot, q_query)
        support_features = global_model.feature(support_set)
        query_features = global_model.feature(query_set)

        prototypes = support_features.view(n_way, k_shot, -1).mean(dim=1)

        dists = torch.cdist(query_features, prototypes)
        logits = -dists

        # Reshape logits to match the label dimensions
        logits = logits.view(-1, n_way)
        labels = labels.repeat_interleave(q_query) # Adjust labels for each query sample

        loss = loss_fn(logits, labels)
        loss.backward()
        optimizer.step()

        print(f'Round {round + 1}/{num_rounds}, Loss: {loss.item()}')

    return global_model

Example

The following example demonstrates how to use the provided framework:

# Main execution
local_datasets = load_data()
global_model = GlobalModel(input_dim=3, feature_dim=128, num_classes=10).to('cuda' if torch.cuda.is_available() else 'cpu')
lambda1, lambda2, lambda3 = 1.0, 1.0, 0.1

# Federated training
trained_global_model = federated_training_process(num_rounds=50, local_datasets=local_datasets, global_model=global_model, lambda1=lambda1, lambda2=lambda2, lambda3=lambda3, apply_privacy=True)

# Few-shot learning
n_way, k_shot, q_query = 5, 5, 15
few_shot_data = [torch.rand((100, 3, 1024)) for _ in range(n_way)]
trained_few_shot_model = few_shot_training(trained_global_model, few_shot_data, num_rounds=10, n_way=n_way, k_shot=k_shot, q_query=q_query)

# Save the trained model state
torch.save(trained_few_shot_model.state_dict(), 'few_shot_trained_model.pth')

3dffl's People

Contributors

raja21068 avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.