Fashion MNIST Non IID Dataset

Fashion MNIST Non IID DatasetΒΆ

import torch
import torchvision

import numpy as np
import math
mnist_trainset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.5,), (0.5,))
                             ]))
mnist_testset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.5,), (0.5,))
                             ]))
## Parameters:
n_epochs = 3
batch_size_train = 10000
batch_size_test = 500
log_interval = 500
train_loader = torch.utils.data.DataLoader(mnist_trainset,batch_size=batch_size_train, shuffle=True)
test_loader = torch.utils.data.DataLoader(mnist_testset,batch_size=batch_size_test, shuffle=False)
import papayaclient
class TheModel(torch.nn.Module):

    def __init__(self):
        super(TheModel, self).__init__()

        self.linear1 = torch.nn.Linear(784, 400)
        self.linear2 = torch.nn.Linear(400, 10)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x1 = x.flatten(start_dim = 1)
        return self.linear2(self.relu(self.linear1(x1)))
clients = []
list_of_data = []
list_of_labels = []
for batchno, (ex_data, ex_labels) in enumerate(train_loader):
    list_of_data.append(ex_data)
    list_of_labels.append(ex_labels)
data, labels = torch.cat(list_of_data), torch.cat(list_of_labels)
np.random.seed(42)

node_data = [[] for _ in range(6)]
node_labels = [[] for _ in range(6)]
for num in range(9):
    data_by_class = data[labels == num]
    label_by_class = labels[labels == num]

    idx = [0] + sorted(np.random.choice(len(data_by_class)-1, 5, replace=False)+1) + [len(data_by_class)]
    for i in range(6):
        ex_data = data[idx[i]:idx[i+1]]
        ex_labels = labels[idx[i]:idx[i+1]]
        node_data[i].append(ex_data)
        node_labels[i].append(ex_labels)
for i in range(6):
    ex_data = torch.cat(node_data[i])
    ex_labels = torch.cat(node_labels[i])
    rand_idx = torch.randperm(len(ex_data))
    ex_data = ex_data[rand_idx]
    ex_labels = ex_labels[rand_idx]
    clients.append(papayaclient.PapayaClient(dat = ex_data,
                                            labs = ex_labels,
                                            batch_sz = 500,
                                            num_partners = 5,
                                            model_class = TheModel,
                                            loss_fn = torch.nn.CrossEntropyLoss))
## Train the Nodes
num_epochs_total = 100
num_epochs_per_swap = 5
num_times = (num_epochs_total // num_epochs_per_swap)
for i in range(0, num_times):
    for n in clients:
        for j in range(0, num_epochs_per_swap):
            n.model_train_epoch()
    if i > 1 and i < num_times - 1 :
        for n in clients:
            n.select_partners(3)
        for n in clients:
            for i in range(0, 4) :
                n.update_partner_weights()
            n.average_partners()
for c in clients :
    print(c.logs['stringy'][99])
node3303epoch 99 loss 0.4667908251285553
node4585epoch 99 loss 0.533348560333252
node1239epoch 99 loss 0.4478834867477417
node4671epoch 99 loss 0.4271458685398102
node1943epoch 99 loss 0.4137968122959137
node4803epoch 99 loss 0.44756975769996643
accuracies = {}
with torch.no_grad():
    for i in clients :
        accuracies_node = []
        for batchno, (ex_data, ex_labels) in enumerate(test_loader) :
            accuracies_node.append(((i.model.forward(ex_data).argmax(dim = 1) == ex_labels).float().mean()).item())
        accuracies[i.node_id] = np.array(accuracies_node).mean()
accuracies
{3303: 0.7975999981164932,
 4585: 0.801100006699562,
 1239: 0.8006000012159348,
 4671: 0.801900002360344,
 1943: 0.796000000834465,
 4803: 0.799600002169609}