You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

40 lines
1.2 KiB

import torch
from torch import nn
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
list1=[]
list_data = "2945807302157368193036426212997220033224538323640562241549909514547527720405608656907029313758584719540613589525481454212472019860395476200753292612652064279287757447621682752174888515904584744529078454748554565275582823574162998649840329792320732021527380675691933505646185089414885945266985722969732915061599825966637476"
# print(len(list_data))
for i in range(1,len(list_data)):
list1.append(int(list_data[i]))
# print(list1)
list1=torch.tensor(list1, dtype=torch.float32).to(DEVICE)
# print(list1)
# print(list1.shape)
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.net = nn.Sequential(
nn.Linear(321, 200),
nn.ReLU(),
nn.Linear(200,100),
nn.ReLU(),
nn.Linear(100, 3),
)
def forward(self, input):
return self.net(input)
model = torch.load("Modle_0_GPU.pth").to(DEVICE)
model.eval()
with torch.no_grad():
output = model(list1)
result = output.argmax().item()
print('这是等级{}'.format(result))