MNIST Experiment: IID and Balanced Dataset
MNIST Experiment: IID and Balanced DatasetΒΆ
import torch
import torchvision
import numpy as np
import math
Loading in the data
mnist_trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
]))
mnist_testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
]))
Setting the hyperparameters
## Parameters:
n_epochs = 3
batch_size_train = 10000
batch_size_test = 500
log_interval = 500
train_loader = torch.utils.data.DataLoader(mnist_trainset,batch_size=batch_size_train, shuffle=True)
test_loader = torch.utils.data.DataLoader(mnist_testset,batch_size=batch_size_test, shuffle=False)
import papayaclient
Writing the model class. Here we use a simple two layer FC Net.
class TheModel(torch.nn.Module):
def __init__(self):
super(TheModel, self).__init__()
self.linear1 = torch.nn.Linear(784, 10)
#self.linear2 = torch.nn.Linear(400, 10)
#self.relu = torch.nn.ReLU()
def forward(self, x):
x1 = x.flatten(start_dim = 1)
return self.linear1(x1)
#return self.linear2(self.relu(self.linear1(x1)))
Creating the clients using papaya
clients = []
for batchno, (ex_data, ex_labels) in enumerate(train_loader):
clients.append(papayaclient.PapayaClient(dat = ex_data,
labs = ex_labels,
batch_sz = 500,
num_partners = 5,
model_class = TheModel,
loss_fn = torch.nn.CrossEntropyLoss))
import random
random.shuffle(clients)
## Train the Nodes
num_epochs_total = 100
num_epochs_per_swap = 5
num_times = (num_epochs_total // num_epochs_per_swap)
for i in range(0, num_times):
for n in clients:
for j in range(0, num_epochs_per_swap):
n.model_train_epoch()
if i > 1 and i < num_times - 1 :
for n in clients:
n.select_partners(3)
for n in clients:
for i in range(0, 4) :
n.update_partner_weights()
n.average_partners()
for c in clients :
print(c.logs['stringy'][99])
node3010epoch 99 loss 0.2327485978603363
node3802epoch 99 loss 0.27158960700035095
node3790epoch 99 loss 0.3395092487335205
node432epoch 99 loss 0.2874259650707245
node2642epoch 99 loss 0.2938651144504547
node2787epoch 99 loss 0.3081212043762207
accuracies = {}
with torch.no_grad():
for i in clients :
accuracies_node = []
for batchno, (ex_data, ex_labels) in enumerate(test_loader) :
accuracies_node.append(((i.model.forward(ex_data).argmax(dim = 1) == ex_labels).float().mean()).item())
accuracies[i.node_id] = np.array(accuracies_node).mean()
accuracies
{3010: 0.9139999985694885,
3802: 0.9152999937534332,
3790: 0.9150999933481216,
432: 0.914999994635582,
2642: 0.9155000001192093,
2787: 0.9150000065565109}
Above we see the accuracy that the model at each node achieves on the held out test set.