ADD file via upload

main
plor4hs9e 7 months ago
parent 07cba00131
commit 3fa829c679

@ -0,0 +1,170 @@
import argparse
from collections import OrderedDict
import numpy as np
import torch
import flwr as fl
import torch.nn as nn
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
EPOCH = 2
if torch.cuda.is_available():
print('ues GPU')
DEVICE = torch.device('cuda')
else:
print('uses CPU')
DEVICE = torch.device('cpu')
#Data_part
class DataRemake(Dataset):
def __init__(self, path):
self.data, self.label = self.transform(path)
self.len = len(self.label)
def __getitem__(self, index):
label = self.label[index]
data = self.data[index]
return data, label
def __len__(self):
print('数据集长度为:', self.len)
return self.len
def transform(self, path):
data_tensor_list = []
label_list = []
with open(path, mode='r', encoding='utf-8') as fp:
data_str_list = [line.strip() for line in fp]
for i in data_str_list:
data = list(i)
label = int(data[0])
# 转换标签为 one-hot 编码
if label == 2:
label = [1, 0, 0]
elif label == 3:
label = [0, 1, 0]
elif label == 4:
label = [0, 0, 1]
else:
raise ValueError(f"未知的标签值:{label}")
data = data[1:]
# 检查数据的长度并进行处理
if len(data) != 321:
# 如果数据长度不是321进行填充或截断操作
if len(data) < 322:
# 填充数据这里假设用0填充
data.extend([0] * (321 - len(data)))
else:
# 截断数据
data = data[:321]
data = np.array(list(map(float, data))).astype(np.float32)
label = np.array(label).astype(np.float32)
data = torch.from_numpy(data)
label = torch.from_numpy(label)
data_tensor_list.append(data)
label_list.append(label)
return data_tensor_list, label_list
#Model_part
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.net = nn.Sequential(
nn.Linear(321, 200),
nn.ReLU(),
nn.Linear(200,100),
nn.ReLU(),
nn.Linear(100, 3),
nn.Softmax(dim=1)
)
def forward(self, input):
return self.net(input)
def train(net, trainloader, epochs, partition_id):
loss_fc = torch.nn.CrossEntropyLoss().to(DEVICE)
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
for i in range(epochs):
for img, label in tqdm(trainloader, 'Training'):
images = img.to(DEVICE)
labels = label.to(DEVICE)
optimizer.zero_grad()
output = net(images)
loss_fc(output, labels).backward()
optimizer.step()
torch.save(net, 'Modle_{}_GPU.pth'.format(partition_id))
print('模型已保存')
def test(net, testloader):
loss_fc = torch.nn.CrossEntropyLoss().to(DEVICE)
correct, loss = 0, 0.0
with torch.no_grad():
for img, label in tqdm(testloader, 'Testing'):
images = img.to(DEVICE)
labels = label.to(DEVICE)
output = net(images)
loss += loss_fc(output, labels).item()
correct += (torch.max(output.data, 1)[1] == labels).sum().item()
accuracy = correct/len(testloader.dataset)
with open('vision/text-loss-0', 'a') as fp1:
fp1.write(str(loss) + '\n')
with open('vision/text-accuracy-0', 'a') as fp2:
fp2.write(str(accuracy) + '\n')
print('TEST-ACCURACY={}, TEST-LOSS={}'.format(accuracy, loss))
return loss, accuracy
def load_data():
train_data = DataRemake('data/traindata/traindata_0.txt')
trainloader = DataLoader(dataset=train_data, batch_size=1)
test_data = DataRemake('data/testdata/testdata_0.txt')
testloader = DataLoader(dataset=test_data, batch_size=1)
return trainloader, testloader
#FL_part
#get id
parser = argparse.ArgumentParser(description='Flower')
parser.add_argument(
'--partition-id',
choices=[0, 1, 2],
required=True,
type=int,
help='Partition of the dataset divided into 3 iid partitions created artificially.'
)
partition_id = parser.parse_args().partition_id
#load model and data
net = Model().to(DEVICE)
trainloader, testloader = load_data()
#define client
class FlowerClient(fl.client.NumPyClient):
def get_parameters(self, config):
return [val.cpu().numpy() for _, val in net.state_dict().items()]
def set_parameters(self, parameters):
params_dict = zip(net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
net.load_state_dict(state_dict, strict=True)
def fit(self, parameters, config):
print('\n》 》 》 FIT 启动! 《 《 《')
self.set_parameters(parameters)
train(net, trainloader, epochs=EPOCH, partition_id=partition_id)
return self.get_parameters(config={}), len(trainloader.dataset), {}
def evaluate(self, parameters, config):
print('\n》 》 》 EVALUATE 启动! 《 《 《')
self.set_parameters(parameters)
loss, accuracy = test(net, testloader)
return loss, len(testloader.dataset), {'accuracy':accuracy}
#start client
if __name__ == '__main__':
fl.client.start_client(
server_address='127.0.0.1:50987',
client=FlowerClient().to_client()
)
Loading…
Cancel
Save