27 lines
604 B
Python
27 lines
604 B
Python
import copy
|
|
import torch
|
|
|
|
|
|
def Aggregation(w, lens):
|
|
w_avg = None
|
|
if lens == None:
|
|
total_count = len(w)
|
|
lens = []
|
|
for i in range(len(w)):
|
|
lens.append(1.0)
|
|
else:
|
|
total_count = sum(lens)
|
|
|
|
for i in range(0, len(w)):
|
|
if i == 0:
|
|
w_avg = copy.deepcopy(w[0])
|
|
for k in w_avg.keys():
|
|
w_avg[k] = w[i][k] * lens[i]
|
|
else:
|
|
for k in w_avg.keys():
|
|
w_avg[k] += w[i][k] * lens[i]
|
|
|
|
for k in w_avg.keys():
|
|
w_avg[k] = torch.div(w_avg[k], total_count)
|
|
|
|
return w_avg |