diff --git a/use.py b/use.py new file mode 100644 index 0000000..5d20a15 --- /dev/null +++ b/use.py @@ -0,0 +1,45 @@ +import time +import torch +import torchvision +from torch import nn + +list1=[] +list_data = "2945807302157368193036426212997220033224538323640562241549909514547527720405608656907029313758584719540613589525481454212472019860395476200753292612652064279287757447621682752174888515904584744529078454748554565275582823574162998649840329792320732021527380675691933505646185089414885945266985722969732915061599825966637476" +# print(len(list_data)) +for i in range(1,len(list_data)): + list1.append(int(list_data[i])) +# print(list1) +list1=torch.tensor(list1, dtype=torch.float32) +# print(list1) +# print(list1.shape) + + + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + self.net = nn.Sequential( + nn.Linear(321, 159), + nn.ReLU(), + nn.Linear(159,81), + nn.ReLU(), + nn.Linear(81, 3), + ) + + def forward(self, input): + return self.net(input) + + +model = Model() # 导入网络结构 +model.load_state_dict(torch.load('model.pth', map_location='cpu')) # 导入网络的参数 + +# print(model) + + +with torch.no_grad(): + output = model(list1) + +result = output.argmax().item() +print('这是等级.。。。。。。。。。。。。。。。{}'.format(result)) +