|
|
import argparse
|
|
|
from collections import OrderedDict
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
import flwr as fl
|
|
|
import torch.nn as nn
|
|
|
from tqdm import tqdm
|
|
|
from torch.utils.data import DataLoader, Dataset
|
|
|
|
|
|
EPOCH = 3
|
|
|
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
print('ues GPU')
|
|
|
DEVICE = torch.device('cuda')
|
|
|
else:
|
|
|
print('uses CPU')
|
|
|
DEVICE = torch.device('cpu')
|
|
|
|
|
|
#Data_part
|
|
|
class DataRemake(Dataset):
|
|
|
def __init__(self, path):
|
|
|
self.data, self.label = self.transform(path)
|
|
|
self.len = len(self.label)
|
|
|
|
|
|
def __getitem__(self, index):
|
|
|
label = self.label[index]
|
|
|
data = self.data[index]
|
|
|
return data, label
|
|
|
|
|
|
def __len__(self):
|
|
|
print('数据集长度为:', self.len)
|
|
|
return self.len
|
|
|
|
|
|
def transform(self, path):
|
|
|
data_tensor_list = []
|
|
|
label_list = []
|
|
|
with open(path, mode='r', encoding='utf-8') as fp:
|
|
|
data_str_list = [line.strip() for line in fp]
|
|
|
for i in data_str_list:
|
|
|
data = list(i)
|
|
|
label = int(data[0])
|
|
|
# 转换标签为 one-hot 编码
|
|
|
if label == 2:
|
|
|
label = [1, 0, 0]
|
|
|
elif label == 3:
|
|
|
label = [0, 1, 0]
|
|
|
elif label == 4:
|
|
|
label = [0, 0, 1]
|
|
|
else:
|
|
|
raise ValueError(f"未知的标签值:{label}")
|
|
|
|
|
|
data = data[1:]
|
|
|
# 检查数据的长度并进行处理
|
|
|
if len(data) != 321:
|
|
|
# 如果数据长度不是321,进行填充或截断操作
|
|
|
if len(data) < 322:
|
|
|
# 填充数据,这里假设用0填充
|
|
|
data.extend([0] * (321 - len(data)))
|
|
|
else:
|
|
|
# 截断数据
|
|
|
data = data[:321]
|
|
|
|
|
|
data = np.array(list(map(float, data))).astype(np.float32)
|
|
|
label = np.array(label).astype(np.float32)
|
|
|
data = torch.from_numpy(data)
|
|
|
label = torch.from_numpy(label)
|
|
|
data_tensor_list.append(data)
|
|
|
label_list.append(label)
|
|
|
return data_tensor_list, label_list
|
|
|
|
|
|
#Model_part
|
|
|
class Model(nn.Module):
|
|
|
def __init__(self):
|
|
|
super(Model, self).__init__()
|
|
|
self.net = nn.Sequential(
|
|
|
nn.Linear(321, 200),
|
|
|
nn.ReLU(),
|
|
|
nn.Linear(200, 100),
|
|
|
nn.ReLU(),
|
|
|
nn.Linear(100, 3),
|
|
|
nn.Softmax(dim=1)
|
|
|
)
|
|
|
|
|
|
def forward(self, input):
|
|
|
return self.net(input)
|
|
|
|
|
|
def train(net, trainloader, epochs, partition_id):
|
|
|
loss_fc = torch.nn.CrossEntropyLoss().to(DEVICE)
|
|
|
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
|
|
|
for i in range(epochs):
|
|
|
for img, label in tqdm(trainloader, 'Training'):
|
|
|
images = img.to(DEVICE)
|
|
|
labels = label.to(DEVICE)
|
|
|
optimizer.zero_grad()
|
|
|
output = net(images)
|
|
|
loss_fc(output, labels).backward()
|
|
|
optimizer.step()
|
|
|
torch.save(net, 'Modle_{}_GPU.pth'.format(partition_id))
|
|
|
print('模型已保存')
|
|
|
|
|
|
def test(net, testloader):
|
|
|
loss_fc = torch.nn.CrossEntropyLoss().to(DEVICE)
|
|
|
correct, loss = 0, 0.0
|
|
|
with torch.no_grad():
|
|
|
for img, label in tqdm(testloader, 'Testing'):
|
|
|
images = img.to(DEVICE)
|
|
|
labels = label.to(DEVICE)
|
|
|
output = net(images)
|
|
|
loss += loss_fc(output, labels).item()
|
|
|
correct += (torch.max(output.data, 1)[1] == labels).sum().item()
|
|
|
accuracy = correct/len(testloader.dataset)
|
|
|
with open('vision/text-loss-1', 'a') as fp1:
|
|
|
fp1.write(str(loss) + '\n')
|
|
|
with open('vision/text-accuracy-1', 'a') as fp2:
|
|
|
fp2.write(str(accuracy) + '\n')
|
|
|
print('TEST-ACCURACY={}, TEST-LOSS={}'.format(accuracy, loss))
|
|
|
return loss, accuracy
|
|
|
|
|
|
def load_data():
|
|
|
train_data = DataRemake('data/traindata/traindata_1.txt')
|
|
|
trainloader = DataLoader(dataset=train_data, batch_size=1)
|
|
|
test_data = DataRemake('data/testdata/testdata_1.txt')
|
|
|
testloader = DataLoader(dataset=test_data, batch_size=1)
|
|
|
return trainloader, testloader
|
|
|
|
|
|
#FL_part
|
|
|
#get id
|
|
|
parser = argparse.ArgumentParser(description='Flower')
|
|
|
parser.add_argument(
|
|
|
'--partition-id',
|
|
|
choices=[0, 1, 2],
|
|
|
required=True,
|
|
|
type=int,
|
|
|
help='Partition of the dataset divided into 3 iid partitions created artificially.'
|
|
|
)
|
|
|
partition_id = parser.parse_args().partition_id
|
|
|
|
|
|
#load model and data
|
|
|
net = Model().to(DEVICE)
|
|
|
trainloader, testloader = load_data()
|
|
|
|
|
|
#define client
|
|
|
class FlowerClient(fl.client.NumPyClient):
|
|
|
def get_parameters(self, config):
|
|
|
return [val.cpu().numpy() for _, val in net.state_dict().items()]
|
|
|
|
|
|
def set_parameters(self, parameters):
|
|
|
params_dict = zip(net.state_dict().keys(), parameters)
|
|
|
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
|
|
|
net.load_state_dict(state_dict, strict=True)
|
|
|
|
|
|
def fit(self, parameters, config):
|
|
|
print('\n》 》 》 FIT 启动! 《 《 《')
|
|
|
self.set_parameters(parameters)
|
|
|
train(net, trainloader, epochs=EPOCH, partition_id=partition_id)
|
|
|
return self.get_parameters(config={}), len(trainloader.dataset), {}
|
|
|
|
|
|
def evaluate(self, parameters, config):
|
|
|
print('\n》 》 》 EVALUATE 启动! 《 《 《')
|
|
|
self.set_parameters(parameters)
|
|
|
loss, accuracy = test(net, testloader)
|
|
|
return loss, len(testloader.dataset), {'accuracy':accuracy}
|
|
|
|
|
|
#start client
|
|
|
if __name__ == '__main__':
|
|
|
fl.client.start_client(
|
|
|
server_address='127.0.0.1:50987',
|
|
|
client=FlowerClient().to_client()
|
|
|
)
|
|
|
|