# MNIST training A Bigger Model

In [1]:
import torch
import torchvision

import numpy as np
import math

In [2]:
mnist_trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]))
mnist_testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]))

In [3]:
## Parameters:
n_epochs = 3
batch_size_train = 10000
batch_size_test = 500
log_interval = 500

In [4]:
train_loader = torch.utils.data.DataLoader(mnist_trainset,batch_size=batch_size_train, shuffle=True)
test_loader = torch.utils.data.DataLoader(mnist_testset,batch_size=batch_size_test, shuffle=False)

In [5]:
import papayaclient

In [6]:
class TheModel(torch.nn.Module):

    def __init__(self):
        super(TheModel, self).__init__()

        self.conv1 = torch.nn.Conv2d(1, 1, 5)
        self.linear1 = torch.nn.Linear((24 * 24), 10)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = x1.flatten(start_dim = 1)
        return self.linear1(x2)

In [7]:
clients = []
for batchno, (ex_data, ex_labels) in enumerate(train_loader):
    clients.append(papayaclient.PapayaClient(dat = ex_data,
                                            labs = ex_labels,
                                            batch_sz = 500,
                                            num_partners = 5,
                                            model_class = TheModel,
                                            loss_fn = torch.nn.CrossEntropyLoss))


In [8]:
## Train the Nodes
num_epochs_total = 100
num_epochs_per_swap = 5
num_times = (num_epochs_total // num_epochs_per_swap)
for i in range(0, num_times):
    for n in clients:
        for j in range(0, num_epochs_per_swap):
            n.model_train_epoch()
            # print(n.logs['stringy'][n.epochs_trained - 1])
    if i > 1 and i < num_times - 1 :
        for n in clients:
            n.select_partners(3)
        for n in clients:
            for i in range(0, 4) :
                n.update_partner_weights()
            n.average_partners()

node1888epoch 0 loss 1.5461899042129517
node1888epoch 1 loss 0.807489275932312
node1888epoch 2 loss 0.6377502679824829
node1888epoch 3 loss 0.5672084093093872
node1888epoch 4 loss 0.5278582572937012
node2786epoch 0 loss 1.9848732948303223
node2786epoch 1 loss 1.3048557043075562
node2786epoch 2 loss 0.8567324876785278
node2786epoch 3 loss 0.6288896203041077
node2786epoch 4 loss 0.5104590654373169
node3408epoch 0 loss 1.9299290180206299
node3408epoch 1 loss 0.9124659895896912
node3408epoch 2 loss 0.5893670916557312
node3408epoch 3 loss 0.49376410245895386
node3408epoch 4 loss 0.4466298222541809
node3302epoch 0 loss 1.4667850732803345
node3302epoch 1 loss 0.6843005418777466
node3302epoch 2 loss 0.5147010684013367
node3302epoch 3 loss 0.44649094343185425
node3302epoch 4 loss 0.4092641770839691
node4282epoch 0 loss 1.3072646856307983
node4282epoch 1 loss 0.7020670175552368
node4282epoch 2 loss 0.5624675750732422
node4282epoch 3 loss 0.5059300065040588
node4282epoch 4 loss 0.4741377234458923

In [9]:
for c in clients :
    print(c.logs['stringy'][99])

node1888epoch 99 loss 0.35409218072891235
node2786epoch 99 loss 0.2528506815433502
node3408epoch 99 loss 0.2721341550350189
node3302epoch 99 loss 0.27993401885032654
node4282epoch 99 loss 0.29084330797195435
node1202epoch 99 loss 0.22220300137996674


In [10]:
accuracies = {}
with torch.no_grad():
    for i in clients :
        accuracies_node = []
        for batchno, (ex_data, ex_labels) in enumerate(test_loader) :
            accuracies_node.append(((i.model.forward(ex_data).argmax(dim = 1) == ex_labels).float().mean()).item())
        accuracies[i.node_id] = np.array(accuracies_node).mean()

In [11]:
accuracies

{1888: 0.9180999994277954,
 2786: 0.9181000024080277,
 3408: 0.9206999987363815,
 3302: 0.9205000013113022,
 4282: 0.9197999954223632,
 1202: 0.918299999833107}