diff --git a/client0.py b/client0.py new file mode 100644 index 0000000..78f4036 --- /dev/null +++ b/client0.py @@ -0,0 +1,170 @@ +import argparse +from collections import OrderedDict +import numpy as np +import torch +import flwr as fl +import torch.nn as nn +from tqdm import tqdm +from torch.utils.data import DataLoader, Dataset + +EPOCH = 2 + + +if torch.cuda.is_available(): + print('ues GPU') + DEVICE = torch.device('cuda') +else: + print('uses CPU') + DEVICE = torch.device('cpu') + +#Data_part +class DataRemake(Dataset): + def __init__(self, path): + self.data, self.label = self.transform(path) + self.len = len(self.label) + + def __getitem__(self, index): + label = self.label[index] + data = self.data[index] + return data, label + + def __len__(self): + print('数据集长度为:', self.len) + return self.len + + def transform(self, path): + data_tensor_list = [] + label_list = [] + with open(path, mode='r', encoding='utf-8') as fp: + data_str_list = [line.strip() for line in fp] + for i in data_str_list: + data = list(i) + label = int(data[0]) + # 转换标签为 one-hot 编码 + if label == 2: + label = [1, 0, 0] + elif label == 3: + label = [0, 1, 0] + elif label == 4: + label = [0, 0, 1] + else: + raise ValueError(f"未知的标签值:{label}") + + data = data[1:] + # 检查数据的长度并进行处理 + if len(data) != 321: + # 如果数据长度不是321,进行填充或截断操作 + if len(data) < 322: + # 填充数据,这里假设用0填充 + data.extend([0] * (321 - len(data))) + else: + # 截断数据 + data = data[:321] + + data = np.array(list(map(float, data))).astype(np.float32) + label = np.array(label).astype(np.float32) + data = torch.from_numpy(data) + label = torch.from_numpy(label) + data_tensor_list.append(data) + label_list.append(label) + return data_tensor_list, label_list + +#Model_part +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + self.net = nn.Sequential( + nn.Linear(321, 200), + nn.ReLU(), + nn.Linear(200,100), + nn.ReLU(), + nn.Linear(100, 3), + nn.Softmax(dim=1) + ) + + def forward(self, input): + return self.net(input) + +def train(net, trainloader, epochs, partition_id): + loss_fc = torch.nn.CrossEntropyLoss().to(DEVICE) + optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9) + for i in range(epochs): + for img, label in tqdm(trainloader, 'Training'): + images = img.to(DEVICE) + labels = label.to(DEVICE) + optimizer.zero_grad() + output = net(images) + loss_fc(output, labels).backward() + optimizer.step() + torch.save(net, 'Modle_{}_GPU.pth'.format(partition_id)) + print('模型已保存') + +def test(net, testloader): + loss_fc = torch.nn.CrossEntropyLoss().to(DEVICE) + correct, loss = 0, 0.0 + with torch.no_grad(): + for img, label in tqdm(testloader, 'Testing'): + images = img.to(DEVICE) + labels = label.to(DEVICE) + output = net(images) + loss += loss_fc(output, labels).item() + correct += (torch.max(output.data, 1)[1] == labels).sum().item() + accuracy = correct/len(testloader.dataset) + with open('vision/text-loss-0', 'a') as fp1: + fp1.write(str(loss) + '\n') + with open('vision/text-accuracy-0', 'a') as fp2: + fp2.write(str(accuracy) + '\n') + print('TEST-ACCURACY={}, TEST-LOSS={}'.format(accuracy, loss)) + return loss, accuracy + +def load_data(): + train_data = DataRemake('data/traindata/traindata_0.txt') + trainloader = DataLoader(dataset=train_data, batch_size=1) + test_data = DataRemake('data/testdata/testdata_0.txt') + testloader = DataLoader(dataset=test_data, batch_size=1) + return trainloader, testloader + +#FL_part +#get id +parser = argparse.ArgumentParser(description='Flower') +parser.add_argument( + '--partition-id', + choices=[0, 1, 2], + required=True, + type=int, + help='Partition of the dataset divided into 3 iid partitions created artificially.' +) +partition_id = parser.parse_args().partition_id + +#load model and data +net = Model().to(DEVICE) +trainloader, testloader = load_data() + +#define client +class FlowerClient(fl.client.NumPyClient): + def get_parameters(self, config): + return [val.cpu().numpy() for _, val in net.state_dict().items()] + + def set_parameters(self, parameters): + params_dict = zip(net.state_dict().keys(), parameters) + state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) + net.load_state_dict(state_dict, strict=True) + + def fit(self, parameters, config): + print('\n》 》 》 FIT 启动! 《 《 《') + self.set_parameters(parameters) + train(net, trainloader, epochs=EPOCH, partition_id=partition_id) + return self.get_parameters(config={}), len(trainloader.dataset), {} + + def evaluate(self, parameters, config): + print('\n》 》 》 EVALUATE 启动! 《 《 《') + self.set_parameters(parameters) + loss, accuracy = test(net, testloader) + return loss, len(testloader.dataset), {'accuracy':accuracy} + +#start client +if __name__ == '__main__': + fl.client.start_client( + server_address='127.0.0.1:50987', + client=FlowerClient().to_client() + )