MNIST Non IID
MNIST Non IIDΒΆ
import torch
import torchvision
import numpy as np
import math
mnist_trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
]))
mnist_testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
]))
## Parameters:
n_epochs = 3
batch_size_train = 10000
batch_size_test = 500
log_interval = 500
train_loader = torch.utils.data.DataLoader(mnist_trainset,batch_size=batch_size_train, shuffle=True)
test_loader = torch.utils.data.DataLoader(mnist_testset,batch_size=batch_size_test, shuffle=False)
import papayaclient
class TheModel(torch.nn.Module):
def __init__(self):
super(TheModel, self).__init__()
self.linear1 = torch.nn.Linear(784, 400)
self.linear2 = torch.nn.Linear(400, 10)
self.relu = torch.nn.ReLU()
def forward(self, x):
x1 = x.flatten(start_dim = 1)
return self.linear2(self.relu(self.linear1(x1)))
clients = []
list_of_data = []
list_of_labels = []
for batchno, (ex_data, ex_labels) in enumerate(train_loader):
list_of_data.append(ex_data)
list_of_labels.append(ex_labels)
data, labels = torch.cat(list_of_data), torch.cat(list_of_labels)
np.random.seed(42)
node_data = [[] for _ in range(6)]
node_labels = [[] for _ in range(6)]
for num in range(9):
data_by_class = data[labels == num]
label_by_class = labels[labels == num]
idx = [0] + sorted(np.random.choice(len(data_by_class)-1, 5, replace=False)+1) + [len(data_by_class)]
for i in range(6):
ex_data = data[idx[i]:idx[i+1]]
ex_labels = labels[idx[i]:idx[i+1]]
node_data[i].append(ex_data)
node_labels[i].append(ex_labels)
for i in range(6):
ex_data = torch.cat(node_data[i])
ex_labels = torch.cat(node_labels[i])
rand_idx = torch.randperm(len(ex_data))
ex_data = ex_data[rand_idx]
ex_labels = ex_labels[rand_idx]
clients.append(papayaclient.PapayaClient(dat = ex_data,
labs = ex_labels,
batch_sz = 500,
num_partners = 5,
model_class = TheModel,
loss_fn = torch.nn.CrossEntropyLoss))
## Train the Nodes
num_epochs_total = 100
num_epochs_per_swap = 5
num_times = (num_epochs_total // num_epochs_per_swap)
for i in range(0, num_times):
for n in clients:
for j in range(0, num_epochs_per_swap):
n.model_train_epoch()
print(n.epochs_trained)
if i > 1 and i < num_times - 1 :
for n in clients:
n.select_partners(3)
for n in clients:
for i in range(0, 4) :
n.update_partner_weights()
n.average_partners()
1
2
3
4
5
1
2
3
4
5
1
2
3
4
5
1
2
3
4
5
1
2
3
4
5
1
2
3
4
5
6
7
8
9
10
6
7
8
9
10
6
7
8
9
10
6
7
8
9
10
6
7
8
9
10
6
7
8
9
10
11
12
13
14
15
11
12
13
14
15
11
12
13
14
15
11
12
13
14
15
11
12
13
14
15
11
12
13
14
15
16
17
18
19
20
16
17
18
19
20
16
17
18
19
20
16
17
18
19
20
16
17
18
19
20
16
17
18
19
20
21
22
23
24
25
21
22
23
24
25
21
22
23
24
25
21
22
23
24
25
21
22
23
24
25
21
22
23
24
25
26
27
28
29
30
26
27
28
29
30
26
27
28
29
30
26
27
28
29
30
26
27
28
29
30
26
27
28
29
30
31
32
33
34
35
31
32
33
34
35
31
32
33
34
35
31
32
33
34
35
31
32
33
34
35
31
32
33
34
35
36
37
38
39
40
36
37
38
39
40
36
37
38
39
40
36
37
38
39
40
36
37
38
39
40
36
37
38
39
40
41
42
43
44
45
41
42
43
44
45
41
42
43
44
45
41
42
43
44
45
41
42
43
44
45
41
42
43
44
45
46
47
48
49
50
46
47
48
49
50
46
47
48
49
50
46
47
48
49
50
46
47
48
49
50
46
47
48
49
50
51
52
53
54
55
51
52
53
54
55
51
52
53
54
55
51
52
53
54
55
51
52
53
54
55
51
52
53
54
55
56
57
58
59
60
56
57
58
59
60
56
57
58
59
60
56
57
58
59
60
56
57
58
59
60
56
57
58
59
60
61
62
63
64
65
61
62
63
64
65
61
62
63
64
65
61
62
63
64
65
61
62
63
64
65
61
62
63
64
65
66
67
68
69
70
66
67
68
69
70
66
67
68
69
70
66
67
68
69
70
66
67
68
69
70
66
67
68
69
70
71
72
73
74
75
71
72
73
74
75
71
72
73
74
75
71
72
73
74
75
71
72
73
74
75
71
72
73
74
75
76
77
78
79
80
76
77
78
79
80
76
77
78
79
80
76
77
78
79
80
76
77
78
79
80
76
77
78
79
80
81
82
83
84
85
81
82
83
84
85
81
82
83
84
85
81
82
83
84
85
81
82
83
84
85
81
82
83
84
85
86
87
88
89
90
86
87
88
89
90
86
87
88
89
90
86
87
88
89
90
86
87
88
89
90
86
87
88
89
90
91
92
93
94
95
91
92
93
94
95
91
92
93
94
95
91
92
93
94
95
91
92
93
94
95
91
92
93
94
95
96
97
98
99
100
96
97
98
99
100
96
97
98
99
100
96
97
98
99
100
96
97
98
99
100
96
97
98
99
100
for c in clients :
print(c.logs['stringy'][99])
node3811epoch 99 loss 0.2988002896308899
node2244epoch 99 loss 0.30985724925994873
node1856epoch 99 loss 0.12980203330516815
node4116epoch 99 loss 0.19149348139762878
node4327epoch 99 loss 0.24858054518699646
node689epoch 99 loss 0.19274069368839264
accuracies = {}
with torch.no_grad():
for i in clients :
accuracies_node = []
for batchno, (ex_data, ex_labels) in enumerate(test_loader) :
accuracies_node.append(((i.model.forward(ex_data).argmax(dim = 1) == ex_labels).float().mean()).item())
accuracies[i.node_id] = np.array(accuracies_node).mean()
accuracies
{3811: 0.9123000055551529,
2244: 0.9122999995946884,
1856: 0.9106000006198883,
4116: 0.9099000036716461,
4327: 0.9087999999523163,
689: 0.9097000092267991}