import argparse
from collections import OrderedDict
import numpy as np
import torch
import flwr as fl
import torch.nn as nn
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset

EPOCH = 3


if torch.cuda.is_available():
    print('ues GPU')
    DEVICE = torch.device('cuda')
else:
    print('uses CPU')
    DEVICE = torch.device('cpu')

#Data_part
class DataRemake(Dataset):
    def __init__(self, path):
        self.data, self.label = self.transform(path)
        self.len = len(self.label)

    def __getitem__(self, index):
        label = self.label[index]
        data = self.data[index]
        return data, label

    def __len__(self):
        print('数据集长度为:', self.len)
        return self.len

    def transform(self, path):
        data_tensor_list = []
        label_list = []
        with open(path, mode='r', encoding='utf-8') as fp:
            data_str_list = [line.strip() for line in fp]
            for i in data_str_list:
                data = list(i)
                label = int(data[0])
                # 转换标签为 one-hot 编码
                if label == 2:
                    label = [1, 0, 0]
                elif label == 3:
                    label = [0, 1, 0]
                elif label == 4:
                    label = [0, 0, 1]
                else:
                    raise ValueError(f"未知的标签值:{label}")

                data = data[1:]
                # 检查数据的长度并进行处理
                if len(data) != 321:
                    # 如果数据长度不是321,进行填充或截断操作
                    if len(data) < 322:
                        # 填充数据,这里假设用0填充
                        data.extend([0] * (321 - len(data)))
                    else:
                        # 截断数据
                        data = data[:321]

                data = np.array(list(map(float, data))).astype(np.float32)
                label = np.array(label).astype(np.float32)
                data = torch.from_numpy(data)
                label = torch.from_numpy(label)
                data_tensor_list.append(data)
                label_list.append(label)
            return data_tensor_list, label_list

#Model_part
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(321, 200),
            nn.ReLU(),
            nn.Linear(200, 100),
            nn.ReLU(),
            nn.Linear(100, 3),
            nn.Softmax(dim=1)
        )

    def forward(self, input):
        return self.net(input)

def train(net, trainloader, epochs, partition_id):
    loss_fc = torch.nn.CrossEntropyLoss().to(DEVICE)
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    for i in range(epochs):
        for img, label in tqdm(trainloader, 'Training'):
            images = img.to(DEVICE)
            labels = label.to(DEVICE)
            optimizer.zero_grad()
            output = net(images)
            loss_fc(output, labels).backward()
            optimizer.step()
        torch.save(net, 'Modle_{}_GPU.pth'.format(partition_id))
        print('模型已保存')

def test(net, testloader):
    loss_fc = torch.nn.CrossEntropyLoss().to(DEVICE)
    correct, loss = 0, 0.0
    with torch.no_grad():
        for img, label in tqdm(testloader, 'Testing'):
            images = img.to(DEVICE)
            labels = label.to(DEVICE)
            output = net(images)
            loss += loss_fc(output, labels).item()
            correct += (torch.max(output.data, 1)[1] == labels).sum().item()
        accuracy = correct/len(testloader.dataset)
        with open('vision/text-loss-1', 'a') as fp1:
            fp1.write(str(loss) + '\n')
        with open('vision/text-accuracy-1', 'a') as fp2:
            fp2.write(str(accuracy) + '\n')
    print('TEST-ACCURACY={}, TEST-LOSS={}'.format(accuracy, loss))
    return loss, accuracy

def load_data():
    train_data = DataRemake('data/traindata/traindata_1.txt')
    trainloader = DataLoader(dataset=train_data, batch_size=1)
    test_data = DataRemake('data/testdata/testdata_1.txt')
    testloader = DataLoader(dataset=test_data, batch_size=1)
    return trainloader, testloader

#FL_part
#get id
parser = argparse.ArgumentParser(description='Flower')
parser.add_argument(
    '--partition-id',
    choices=[0, 1, 2],
    required=True,
    type=int,
    help='Partition of the dataset divided into 3 iid partitions created artificially.'
)
partition_id = parser.parse_args().partition_id

#load model and data
net = Model().to(DEVICE)
trainloader, testloader = load_data()

#define client
class FlowerClient(fl.client.NumPyClient):
    def get_parameters(self, config):
        return [val.cpu().numpy() for _, val in net.state_dict().items()]

    def set_parameters(self, parameters):
        params_dict = zip(net.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        net.load_state_dict(state_dict, strict=True)

    def fit(self, parameters, config):
        print('\n》 》 》 FIT 启动! 《 《 《')
        self.set_parameters(parameters)
        train(net, trainloader, epochs=EPOCH,  partition_id=partition_id)
        return self.get_parameters(config={}), len(trainloader.dataset), {}

    def evaluate(self, parameters, config):
        print('\n》 》 》 EVALUATE 启动! 《 《 《')
        self.set_parameters(parameters)
        loss, accuracy = test(net, testloader)
        return loss, len(testloader.dataset), {'accuracy':accuracy}

#start client
if __name__ == '__main__':
    fl.client.start_client(
        server_address='127.0.0.1:50987',
        client=FlowerClient().to_client()
    )