ADD file via upload
This commit is contained in:
parent
c736b59799
commit
b378fecc74
|
@ -0,0 +1,68 @@
|
|||
import numpy as np
|
||||
from models.Update import *
|
||||
from models.Fed import *
|
||||
from models.test import *
|
||||
from utils.utilis import *
|
||||
import copy
|
||||
|
||||
def test_with_loss(net_glob, dataset_test, args):
|
||||
|
||||
# testing
|
||||
acc_test, loss_test = test_img(net_glob, dataset_test, args)
|
||||
|
||||
print("Testing accuracy: {:.2f}".format(acc_test))
|
||||
|
||||
return acc_test.item(), loss_test
|
||||
|
||||
def Cloud(args, net_glob, dataset_train, dataset_test, dict_users, dict_public):
|
||||
net_glob.train()
|
||||
|
||||
print("Start Cloud")
|
||||
|
||||
isNext = True
|
||||
today = 1
|
||||
m = 4
|
||||
|
||||
# training
|
||||
acc = []
|
||||
GB = []
|
||||
transf = 161
|
||||
total_transf = 0
|
||||
|
||||
while isNext:
|
||||
if today >= args.physical_time:
|
||||
isNext = False
|
||||
client_train_list = random.sample(
|
||||
dict_public, m
|
||||
)
|
||||
# print("*" * 80)
|
||||
print("today: {:3d}".format(today))
|
||||
|
||||
for iter in range(args.epochs):
|
||||
|
||||
w_locals = []
|
||||
lens = []
|
||||
max_time = 0
|
||||
for idx_client in client_train_list:
|
||||
local = LocalUpdate(
|
||||
args=args, dataset=dataset_train, idxs=dict_users[idx_client]
|
||||
)
|
||||
w = local.train(net=copy.deepcopy(net_glob).to(args.device))
|
||||
|
||||
w_locals.append(copy.deepcopy(w))
|
||||
lens.append(len(dict_users[idx_client]))
|
||||
|
||||
w_glob = Aggregation(w_locals, lens)
|
||||
# copy weight to net_glob
|
||||
net_glob.load_state_dict(w_glob)
|
||||
|
||||
today += 1
|
||||
total_transf += transf
|
||||
|
||||
item_acc, item_loss = test_with_loss(net_glob, dataset_test, args)
|
||||
ta, tl = test_with_loss(net_glob, dataset_train, args)
|
||||
acc.append(item_acc)
|
||||
GB.append(total_transf)
|
||||
|
||||
save_result(acc, "test_acc", args)
|
||||
save_result(GB, "test_GB", args)
|
Loading…
Reference in New Issue