You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

172 lines
5.8 KiB

import argparse
from collections import OrderedDict
import numpy as np
import torch
import flwr as fl
import torch.nn as nn
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
EPOCH = 3
if torch.cuda.is_available():
print('ues GPU')
DEVICE = torch.device('cuda')
else:
print('uses CPU')
DEVICE = torch.device('cpu')
#Data_part
class DataRemake(Dataset):
def __init__(self, path):
self.data, self.label = self.transform(path)
self.len = len(self.label)
def __getitem__(self, index):
label = self.label[index]
data = self.data[index]
return data, label
def __len__(self):
print('数据集长度为:', self.len)
return self.len
def transform(self, path):
data_tensor_list = []
label_list = []
with open(path, mode='r', encoding='utf-8') as fp:
data_str_list = [line.strip() for line in fp]
for i in data_str_list:
data = list(i)
label = int(data[0])
# 转换标签为 one-hot 编码
if label == 2:
label = [1, 0, 0]
elif label == 3:
label = [0, 1, 0]
elif label == 4:
label = [0, 0, 1]
else:
raise ValueError(f"未知的标签值:{label}")
data = data[1:]
# 检查数据的长度并进行处理
if len(data) != 321:
# 如果数据长度不是321进行填充或截断操作
if len(data) < 322:
# 填充数据这里假设用0填充
data.extend([0] * (321 - len(data)))
else:
# 截断数据
data = data[:321]
data = np.array(list(map(float, data))).astype(np.float32)
label = np.array(label).astype(np.float32)
data = torch.from_numpy(data)
label = torch.from_numpy(label)
data_tensor_list.append(data)
label_list.append(label)
return data_tensor_list, label_list
#Model_part
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.net = nn.Sequential(
nn.Linear(321, 200),
nn.ReLU(),
nn.Linear(200, 100),
nn.ReLU(),
nn.Linear(100, 3),
nn.Softmax(dim=1)
)
def forward(self, input):
return self.net(input)
def train(net, trainloader, epochs, partition_id):
loss_fc = torch.nn.CrossEntropyLoss().to(DEVICE)
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
for i in range(epochs):
for img, label in tqdm(trainloader, 'Training'):
images = img.to(DEVICE)
labels = label.to(DEVICE)
optimizer.zero_grad()
output = net(images)
loss_fc(output, labels).backward()
optimizer.step()
torch.save(net, 'Modle_{}_GPU.pth'.format(partition_id))
print('模型已保存')
def test(net, testloader):
loss_fc = torch.nn.CrossEntropyLoss().to(DEVICE)
correct, loss = 0, 0.0
with torch.no_grad():
for img, label in tqdm(testloader, 'Testing'):
images = img.to(DEVICE)
labels = label.to(DEVICE)
output = net(images)
loss += loss_fc(output, labels).item()
correct += (torch.max(output.data, 1)[1] == labels).sum().item()
accuracy = correct/len(testloader.dataset)
with open('vision/text-loss-1', 'a') as fp1:
fp1.write(str(loss) + '\n')
with open('vision/text-accuracy-1', 'a') as fp2:
fp2.write(str(accuracy) + '\n')
print('TEST-ACCURACY={}, TEST-LOSS={}'.format(accuracy, loss))
return loss, accuracy
def load_data():
train_data = DataRemake('data/traindata/traindata_1.txt')
trainloader = DataLoader(dataset=train_data, batch_size=1)
test_data = DataRemake('data/testdata/testdata_1.txt')
testloader = DataLoader(dataset=test_data, batch_size=1)
return trainloader, testloader
#FL_part
#get id
parser = argparse.ArgumentParser(description='Flower')
parser.add_argument(
'--partition-id',
choices=[0, 1, 2],
required=True,
type=int,
help='Partition of the dataset divided into 3 iid partitions created artificially.'
)
partition_id = parser.parse_args().partition_id
#load model and data
net = Model().to(DEVICE)
trainloader, testloader = load_data()
#define client
class FlowerClient(fl.client.NumPyClient):
def get_parameters(self, config):
return [val.cpu().numpy() for _, val in net.state_dict().items()]
def set_parameters(self, parameters):
params_dict = zip(net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
net.load_state_dict(state_dict, strict=True)
def fit(self, parameters, config):
print('\n》 》 》 FIT 启动! 《 《 《')
self.set_parameters(parameters)
train(net, trainloader, epochs=EPOCH, partition_id=partition_id)
return self.get_parameters(config={}), len(trainloader.dataset), {}
def evaluate(self, parameters, config):
print('\n》 》 》 EVALUATE 启动! 《 《 《')
self.set_parameters(parameters)
loss, accuracy = test(net, testloader)
return loss, len(testloader.dataset), {'accuracy':accuracy}
#start client
if __name__ == '__main__':
fl.client.start_client(
server_address='127.0.0.1:50987',
client=FlowerClient().to_client()
)