import argparse from collections import OrderedDict import numpy as np import torch import flwr as fl import torch.nn as nn from tqdm import tqdm from torch.utils.data import DataLoader, Dataset EPOCH = 3 if torch.cuda.is_available(): print('ues GPU') DEVICE = torch.device('cuda') else: print('uses CPU') DEVICE = torch.device('cpu') #Data_part class DataRemake(Dataset): def __init__(self, path): self.data, self.label = self.transform(path) self.len = len(self.label) def __getitem__(self, index): label = self.label[index] data = self.data[index] return data, label def __len__(self): print('数据集长度为:', self.len) return self.len def transform(self, path): data_tensor_list = [] label_list = [] with open(path, mode='r', encoding='utf-8') as fp: data_str_list = [line.strip() for line in fp] for i in data_str_list: data = list(i) label = int(data[0]) # 转换标签为 one-hot 编码 if label == 2: label = [1, 0, 0] elif label == 3: label = [0, 1, 0] elif label == 4: label = [0, 0, 1] else: raise ValueError(f"未知的标签值:{label}") data = data[1:] # 检查数据的长度并进行处理 if len(data) != 321: # 如果数据长度不是321,进行填充或截断操作 if len(data) < 322: # 填充数据,这里假设用0填充 data.extend([0] * (321 - len(data))) else: # 截断数据 data = data[:321] data = np.array(list(map(float, data))).astype(np.float32) label = np.array(label).astype(np.float32) data = torch.from_numpy(data) label = torch.from_numpy(label) data_tensor_list.append(data) label_list.append(label) return data_tensor_list, label_list #Model_part class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.net = nn.Sequential( nn.Linear(321, 200), nn.ReLU(), nn.Linear(200, 100), nn.ReLU(), nn.Linear(100, 3), nn.Softmax(dim=1) ) def forward(self, input): return self.net(input) def train(net, trainloader, epochs, partition_id): loss_fc = torch.nn.CrossEntropyLoss().to(DEVICE) optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9) for i in range(epochs): for img, label in tqdm(trainloader, 'Training'): images = img.to(DEVICE) labels = label.to(DEVICE) optimizer.zero_grad() output = net(images) loss_fc(output, labels).backward() optimizer.step() torch.save(net, 'Modle_{}_GPU.pth'.format(partition_id)) print('模型已保存') def test(net, testloader): loss_fc = torch.nn.CrossEntropyLoss().to(DEVICE) correct, loss = 0, 0.0 with torch.no_grad(): for img, label in tqdm(testloader, 'Testing'): images = img.to(DEVICE) labels = label.to(DEVICE) output = net(images) loss += loss_fc(output, labels).item() correct += (torch.max(output.data, 1)[1] == labels).sum().item() accuracy = correct/len(testloader.dataset) with open('vision/text-loss-1', 'a') as fp1: fp1.write(str(loss) + '\n') with open('vision/text-accuracy-1', 'a') as fp2: fp2.write(str(accuracy) + '\n') print('TEST-ACCURACY={}, TEST-LOSS={}'.format(accuracy, loss)) return loss, accuracy def load_data(): train_data = DataRemake('data/traindata/traindata_1.txt') trainloader = DataLoader(dataset=train_data, batch_size=1) test_data = DataRemake('data/testdata/testdata_1.txt') testloader = DataLoader(dataset=test_data, batch_size=1) return trainloader, testloader #FL_part #get id parser = argparse.ArgumentParser(description='Flower') parser.add_argument( '--partition-id', choices=[0, 1, 2], required=True, type=int, help='Partition of the dataset divided into 3 iid partitions created artificially.' ) partition_id = parser.parse_args().partition_id #load model and data net = Model().to(DEVICE) trainloader, testloader = load_data() #define client class FlowerClient(fl.client.NumPyClient): def get_parameters(self, config): return [val.cpu().numpy() for _, val in net.state_dict().items()] def set_parameters(self, parameters): params_dict = zip(net.state_dict().keys(), parameters) state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) net.load_state_dict(state_dict, strict=True) def fit(self, parameters, config): print('\n》 》 》 FIT 启动! 《 《 《') self.set_parameters(parameters) train(net, trainloader, epochs=EPOCH, partition_id=partition_id) return self.get_parameters(config={}), len(trainloader.dataset), {} def evaluate(self, parameters, config): print('\n》 》 》 EVALUATE 启动! 《 《 《') self.set_parameters(parameters) loss, accuracy = test(net, testloader) return loss, len(testloader.dataset), {'accuracy':accuracy} #start client if __name__ == '__main__': fl.client.start_client( server_address='127.0.0.1:50987', client=FlowerClient().to_client() )