diff --git a/use.py b/use.py deleted file mode 100644 index 5d20a15..0000000 --- a/use.py +++ /dev/null @@ -1,45 +0,0 @@ -import time -import torch -import torchvision -from torch import nn - -list1=[] -list_data = "2945807302157368193036426212997220033224538323640562241549909514547527720405608656907029313758584719540613589525481454212472019860395476200753292612652064279287757447621682752174888515904584744529078454748554565275582823574162998649840329792320732021527380675691933505646185089414885945266985722969732915061599825966637476" -# print(len(list_data)) -for i in range(1,len(list_data)): - list1.append(int(list_data[i])) -# print(list1) -list1=torch.tensor(list1, dtype=torch.float32) -# print(list1) -# print(list1.shape) - - - - -class Model(nn.Module): - def __init__(self): - super(Model, self).__init__() - self.net = nn.Sequential( - nn.Linear(321, 159), - nn.ReLU(), - nn.Linear(159,81), - nn.ReLU(), - nn.Linear(81, 3), - ) - - def forward(self, input): - return self.net(input) - - -model = Model() # 导入网络结构 -model.load_state_dict(torch.load('model.pth', map_location='cpu')) # 导入网络的参数 - -# print(model) - - -with torch.no_grad(): - output = model(list1) - -result = output.argmax().item() -print('这是等级.。。。。。。。。。。。。。。。{}'.format(result)) -