This project presents a comprehensive framework for federated few-shot learning (3DFFL), focusing on 3D point cloud classification. The approach integrates Federated Learning (FL) with Few-Shot Learning (FSL) techniques, using PointNet++ for feature extraction and ProtoNet for classification. The framework ensures data privacy and leverages collaborative learning to handle data scarcity and heterogeneity.
- PointNet++ for feature extraction
- Attention mechanism
- Loss functions for embedding and learnable tasks
- Differential privacy updates
- Data augmentation with Mixup
- Federated training process
- Few-shot learning with Prototypical Networks
- Python 3.x
- PyTorch
- NumPy
- scikit-learn
pip install torch numpy scikit-learn
The load_data
function generates synthetic data for testing purposes:
def load_data():
N = 100 # Number of samples
P = 1024 # Number of points per sample
D = 3 # Dimensionality of each point
C = 10 # Number of classes
return [(torch.rand(N, D, P), torch.randint(0, 2, (N, C)).float()) for _ in range(5)]
To perform federated training with the provided framework:
local_datasets = load_data()
global_model = GlobalModel(input_dim=3, feature_dim=128, num_classes=10).to('cuda' if torch.cuda.is_available() else 'cpu')
lambda1, lambda2, lambda3 = 1.0, 1.0, 0.1
trained_global_model = federated_training_process(num_rounds=50, local_datasets=local_datasets, global_model=global_model, lambda1=lambda1, lambda2=lambda2, lambda3=lambda3, apply_privacy=True)
To perform few-shot training with the trained global model:
n_way, k_shot, q_query = 5, 5, 15
few_shot_data = [torch.rand((100, 3, 1024)) for _ in range(n_way)]
trained_few_shot_model = few_shot_training(trained_global_model, few_shot_data, num_rounds=10, n_way=n_way, k_shot=k_shot, q_query=q_query)
To save the trained model:
torch.save(trained_few_shot_model.state_dict(), 'few_shot_trained_model.pth')
To load the trained model:
model = GlobalModel(input_dim=3, feature_dim=128, num_classes=10)
model.load_state_dict(torch.load('few_shot_trained_model.pth'))
The PointNetPP
class implements a simplified version of PointNet++ for feature extraction:
class PointNetPP(nn.Module):
# Initialization and forward methods here
The AttentionLayer
class implements an attention mechanism:
class AttentionLayer(nn.Module):
def __init__(self, feature_dim):
super(AttentionLayer, self).__init__()
self.attention = nn.Sequential(
nn.Linear(feature_dim, feature_dim),
nn.ReLU(),
nn.Linear(feature_dim, 1),
nn.Softmax(dim=1)
)
def forward(self, x):
attention_weights = self.attention(x)
x = x * attention_weights
return x
The following functions calculate the different loss components:
calculate_embedding_loss
calculate_learnable_loss
calculate_overall_loss
The differential_privacy_update
function applies differential privacy updates to the model:
def differential_privacy_update(model, noise_multiplier=0.1):
# Implementation here
for param in model.parameters():
noise = torch.normal(0, noise_multiplier, size=param.size()).to(param.device)
param.data.add_(noise)
The following functions handle data augmentation using Mixup:
mixup_data
mixup_criterion
The federated_training_process
function performs federated training:
def federated_training_process(num_rounds, local_datasets, global_model, lambda1, lambda2, lambda3, apply_privacy=False):
num_nodes = len(local_datasets)
momentum = 0.9
global_momentum = {key: torch.zeros_like(value).float() for key, value in global_model.state_dict().items()}
previous_loss = float('inf')
for t in range(num_rounds):
local_updates = []
total_loss = 0
for i in range(num_nodes):
X_i, Y_i = local_datasets[i]
local_model = copy.deepcopy(global_model)
optimizer = torch.optim.Adam(local_model.parameters(), lr=0.001)
local_model.train()
for epoch in range(5):
optimizer.zero_grad()
features = local_model.feature(X_i)
attention_layer = AttentionLayer(features.size(1)).to(features.device)
attended_features = attention_layer(features)
embedding_loss = calculate_embedding_loss(attended_features, Y_i)
learnable_loss = calculate_learnable_loss(attended_features, features, Y_i)
comp_loss = torch.tensor(0.1).to(features.device)
overall_loss = calculate_overall_loss(embedding_loss, learnable_loss, comp_loss, lambda1, lambda2, lambda3)
overall_loss.backward()
optimizer.step()
total_loss += overall_loss.item()
if apply_privacy:
differential_privacy_update(local_model)
local_updates.append(local_model.state_dict())
average_loss = total_loss / num_nodes
if abs(previous_loss - average_loss) < 1e-3: # Convergence criterion
break
previous_loss = average_loss
global_model_dict = global_model.state_dict()
for key in global_model_dict.keys():
updates = torch.stack([local_updates[i][key].float() for i in range(num_nodes)], dim=0)
average_update = updates.mean(dim=0)
global_momentum[key] = momentum * global_momentum[key] + (1 - momentum) * average_update
global_model_dict[key] = global_model_dict[key].float() + global_momentum[key]
global_model.load_state_dict(global_model_dict)
for i in range(num_nodes):
local_datasets[i][0].model = copy.deepcopy(global_model)
return global_model
The ProtoNet
class and few_shot_training
function implement few-shot learning using Prototypical Networks:
def __init__(self, model_func, n_way, n_support):
super(ProtoNet, self).__init__(model_func, n_way, n_support)
def set_forward(self, x, is_feature=False):
z_support, z_query = self.parse_feature(x, is_feature)
z_proto = z_support.contiguous().view(self.n_way, self.n_support, -1).mean(1)
z_query = z_query.contiguous().view(self.n_way * self.n_query, -1)
dists = euclidean_dist(z_query, z_proto)
scores = -dists
return scores
def set_forward_loss(self, x):
y_query = torch.from_numpy(np.repeat(range(self.n_way), self.n_query))
y_query = Variable(y_query.cuda())
scores = self.set_forward(x)
return F.cross_entropy(scores, y_query)
def euclidean_dist(x, y):
n = x.size(0)
m = y.size(0)
d = x.size(1)
x = x.unsqueeze(1).expand(n, m, d)
y = y.unsqueeze(0).expand(n, m, d)
return torch.pow(x - y, 2).sum(2)
def few_shot_training(global_model, data, num_rounds, n_way, k_shot, q_query):
optimizer = torch.optim.Adam(global_model.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss()
for round in range(num_rounds):
global_model.train()
optimizer.zero_grad()
support_set, query_set, labels = create_few_shot_batches(data, n_way, k_shot, q_query)
support_features = global_model.feature(support_set)
query_features = global_model.feature(query_set)
prototypes = support_features.view(n_way, k_shot, -1).mean(dim=1)
dists = torch.cdist(query_features, prototypes)
logits = -dists
# Reshape logits to match the label dimensions
logits = logits.view(-1, n_way)
labels = labels.repeat_interleave(q_query) # Adjust labels for each query sample
loss = loss_fn(logits, labels)
loss.backward()
optimizer.step()
print(f'Round {round + 1}/{num_rounds}, Loss: {loss.item()}')
return global_model
The following example demonstrates how to use the provided framework:
# Main execution
local_datasets = load_data()
global_model = GlobalModel(input_dim=3, feature_dim=128, num_classes=10).to('cuda' if torch.cuda.is_available() else 'cpu')
lambda1, lambda2, lambda3 = 1.0, 1.0, 0.1
# Federated training
trained_global_model = federated_training_process(num_rounds=50, local_datasets=local_datasets, global_model=global_model, lambda1=lambda1, lambda2=lambda2, lambda3=lambda3, apply_privacy=True)
# Few-shot learning
n_way, k_shot, q_query = 5, 5, 15
few_shot_data = [torch.rand((100, 3, 1024)) for _ in range(n_way)]
trained_few_shot_model = few_shot_training(trained_global_model, few_shot_data, num_rounds=10, n_way=n_way, k_shot=k_shot, q_query=q_query)
# Save the trained model state
torch.save(trained_few_shot_model.state_dict(), 'few_shot_trained_model.pth')