HospEdgeComputing/Fed.py

27 lines
604 B
Python

import copy
import torch
def Aggregation(w, lens):
w_avg = None
if lens == None:
total_count = len(w)
lens = []
for i in range(len(w)):
lens.append(1.0)
else:
total_count = sum(lens)
for i in range(0, len(w)):
if i == 0:
w_avg = copy.deepcopy(w[0])
for k in w_avg.keys():
w_avg[k] = w[i][k] * lens[i]
else:
for k in w_avg.keys():
w_avg[k] += w[i][k] * lens[i]
for k in w_avg.keys():
w_avg[k] = torch.div(w_avg[k], total_count)
return w_avg