diff --git a/.gitignore b/.gitignore deleted file mode 100644 index a109725..0000000 --- a/.gitignore +++ /dev/null @@ -1,164 +0,0 @@ -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[cod] -*$py.class - -# C extensions -*.so - -# Distribution / packaging -CIFAR10 -model -logs -.idea -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -share/python-wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -*.py,cover -.hypothesis/ -.pytest_cache/ -cover/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 -db.sqlite3-journal - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -.pybuilder/ -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# IPython -profile_default/ -ipython_config.py - -# pyenv -# For a library or package, you might want to ignore these files since the code is -# intended to run in multiple environments; otherwise, check them in: -# .python-version - -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not -# install all needed dependencies. -#Pipfile.lock - -# poetry -# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. -# This is especially recommended for binary packages to ensure reproducibility, and is more -# commonly ignored for libraries. -# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control -#poetry.lock - -# pdm -# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. -#pdm.lock -# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it -# in version control. -# https://pdm.fming.dev/#use-with-ide -.pdm.toml - -# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm -__pypackages__/ - -# Celery stuff -celerybeat-schedule -celerybeat.pid - -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ -.dmypy.json -dmypy.json - -# Pyre type checker -.pyre/ - -# pytype static type analyzer -.pytype/ - -# Cython debug symbols -cython_debug/ - -# PyCharm -# JetBrains specific template is maintained in a separate JetBrains.gitignore that can -# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore -# and can be added to the global gitignore or merged into this file. For a more nuclear -# option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ diff --git a/QCQ_ResNet18.py b/QCQ_ResNet18.py deleted file mode 100644 index fcae726..0000000 --- a/QCQ_ResNet18.py +++ /dev/null @@ -1,62 +0,0 @@ -import torch -from torch import nn -from torch.nn import functional as F - - -class ResBlock(nn.Module): - def __init__(self, ch_in, ch_out, stride=2): - super(ResBlock, self).__init__() - self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1) # ! (h-3+2)/2 + 1 = h/2 图像尺寸减半 - self.bn1 = nn.BatchNorm2d(ch_out) - self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1) # ! h-3+2*1+1=h 图像尺寸没变化 - self.bn2 = nn.BatchNorm2d(ch_out) - - self.extra = nn.Sequential( - nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride), # ! 这句话是针对原图像尺寸写的,要进行element wise add - # ! 因此图像尺寸也必须减半,(h-1)/2+1=h/2 图像尺寸减半 - nn.BatchNorm2d(ch_out) - ) - - def forward(self, x): - out = x - x = torch.relu(self.bn1(self.conv1(x))) - x = self.bn2(self.conv2(x)) - # shortcut - # ! element wise add [b,ch_in,h,w] [b,ch_out,h,w] 必须当ch_in = ch_out时才能进行相加 - out = x + self.extra(out) # todo self.extra强制把输出通道变成一致 - return out - - -class ResNet18(nn.Module): - def __init__(self): - super(ResNet18, self).__init__() - self.conv1 = nn.Sequential( - nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1), # ! 图像尺寸不变 - nn.BatchNorm2d(64) - ) - # 4个ResBlock - # [b,64,h,w] --> [b,128,h,w] - self.block1 = ResBlock(64, 128) - # [b,128,h,w] --> [b,256,h,w] - self.block2 = ResBlock(128, 256) - # [b,256,h,w] --> [b,512,h,w] - self.block3 = ResBlock(256, 512) - # [b,512,h,w] --> [b,512,h,w] - self.block4 = ResBlock(512, 512) - - self.outlayer = nn.Linear(512, 10) - - def forward(self, x): - x = torch.relu(self.conv1(x)) - # [b,64,h,w] --> [b,1024,h,w] - x = self.block1(x) - x = self.block2(x) - x = self.block3(x) - x = self.block4(x) - # print("after conv:",x.shape) - # [b,512,h,w] --> [b,512,1,1] - x = F.adaptive_avg_pool2d(x, [1, 1]) - # flatten - x = x.view(x.shape[0], -1) - x = self.outlayer(x) - return x diff --git a/QCQ_VGG16.py b/QCQ_VGG16.py deleted file mode 100644 index 53ed451..0000000 --- a/QCQ_VGG16.py +++ /dev/null @@ -1,44 +0,0 @@ - -from torch import nn - - -cfg = {'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']} -class QCQ(nn.Module): - - def __init__(self, vgg): - super(QCQ, self).__init__() - self.features = self._make_layers(vgg) - self.dense = nn.Sequential( - nn.Linear(512, 4096), - nn.ReLU(inplace=True), - nn.Dropout(0.4), - nn.Linear(4096, 4096), - nn.ReLU(inplace=True), - nn.Dropout(0.4), - ) - self.classifier = nn.Linear(4096, 10) - - def forward(self, x): - out = self.features(x) - out = out.view(out.size(0), -1) - out = self.dense(out) - out = self.classifier(out) - return out - - def _make_layers(self, vgg): - layers = [] - in_channels = 3 - for x in vgg: - if x == 'M': - layers += [nn.MaxPool2d(kernel_size=2, stride=2)] - else: - layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), - nn.BatchNorm2d(x), - nn.ReLU(inplace=True)] - in_channels = x - - layers += [nn.AvgPool2d(kernel_size=1, stride=1)] - return nn.Sequential(*layers) - - - diff --git a/README.md b/README.md index f87d3d0..a461cd0 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,2 @@ # CIFAR10 -CIFAR10学习 -我的环境: -python 3.6.13 -torch 1.10.2 -cuda 11.3 -torchvision 0.11.3 -pillow 8.4.0 -tensorboard 2.10.1 \ No newline at end of file diff --git a/demo.py b/demo.py deleted file mode 100644 index 5c28b41..0000000 --- a/demo.py +++ /dev/null @@ -1,97 +0,0 @@ -import os - -import PIL -import torchvision.transforms -from PIL import Image - - -from QCQ_VGG16 import * -from QCQ_ResNet18 import * - -CIFAR10_class = ['airplane','automobile','brid','cat','deer','dog','frog','horse','ship','truck'] -vgg = [96, 96, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'] - -def readImage(img_path='img/test.png'): - img = Image.open(img_path).convert('RGB') - transform= torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),torchvision.transforms.ToTensor()]) - img = transform(img) - print(img.shape) - img = torch.reshape(img, (1, 3, 32, 32)) - return img - - -def DrawImageTxt(imageFile, targetImageFile, txtnum): - # 设置字体大小 - font = PIL.ImageFont.truetype('img/abc.ttf', 50) - # 打开文件 - im = Image.open(imageFile) - # 字体坐标 - draw = PIL.ImageDraw.Draw(im) - draw.text((0, 0), txtnum, (255, 255, 0), font=font) - # 保存 - im.save(targetImageFile) - # 关闭 - im.close() - - -if __name__=='__main__': - print("通过训练好的模型识别十种事物:飞机,汽车,鸟,猫,鹿,狗,青蛙,马,船,卡车") - device = 'cuda' - models=[] - i=1 - for root, dirs, files in os.walk('model'): - for file in files: - if file.__contains__('.pth'): - file_path=root+'/'+file - models.append(file_path) - print(f'{i}.'+file) - i+=1 - if models.__len__()==0: - print('model文件夹中没有pth模型文件') - else: - select = int(input('选择一个模型\n')) - model_path=models[select-1] - if model_path.__contains__('VGG16'): - qcq_test=QCQ(vgg) - elif model_path.__contains__('RestNet18'): - qcq_test=ResNet18() - else: - print('选择的模型名称中既不包含"VGG16",也不包含"RestNet18"') - - qcq_test.load_state_dict(torch.load(model_path, map_location=torch.device(device))) - - imgs=[] - i=1 - for root, dirs, files in os.walk('img'): - for file in files: - if file.__contains__('.png') or file.__contains__('.jpg'): - if not file.__contains__('pre'): - file_path = root + '/' + file - imgs.append(file_path) - print(f'{i}.' + file) - i+=1 - if imgs.__len__() == 0: - print('img文件夹中没有图片') - else: - select = int(input('选择一个测试图片,\n')) - image_path = imgs[select-1] - image_name,image_type = image_path.split('.') - targetImageFile=image_name+'_pre.png' - img=readImage(image_path) - qcq_test.eval() - with torch.no_grad(): - output = qcq_test(img) - pre=output.argmax(1) - txtnum = CIFAR10_class[pre.item()] - DrawImageTxt(image_path, targetImageFile, txtnum) - print(output) - print(pre) - print(txtnum) - Image.open(targetImageFile).show() - - - - - - - diff --git a/img/abc.ttf b/img/abc.ttf deleted file mode 100644 index c5f12a8..0000000 Binary files a/img/abc.ttf and /dev/null differ diff --git a/img/test.png b/img/test.png deleted file mode 100644 index c34f02b..0000000 Binary files a/img/test.png and /dev/null differ diff --git a/img/test2.png b/img/test2.png deleted file mode 100644 index 5bfb099..0000000 Binary files a/img/test2.png and /dev/null differ diff --git a/img/test2_pre.png b/img/test2_pre.png deleted file mode 100644 index cd19fb5..0000000 Binary files a/img/test2_pre.png and /dev/null differ diff --git a/img/test_pre.png b/img/test_pre.png deleted file mode 100644 index c8116b2..0000000 Binary files a/img/test_pre.png and /dev/null differ diff --git a/train_CIFAR10_complete.py b/train_CIFAR10_complete.py deleted file mode 100644 index 6af27c0..0000000 --- a/train_CIFAR10_complete.py +++ /dev/null @@ -1,124 +0,0 @@ -import os.path - - -import torchvision.datasets -from torch.utils.tensorboard import SummaryWriter - -from QCQ_VGG16 import * -from QCQ_ResNet18 import * -from torch.utils.data import DataLoader - -# 训练数据集 -train_data = torchvision.datasets.CIFAR10("CIFAR10/CIFAR10_train", train=True, - transform=torchvision.transforms.ToTensor(), download=True) -# 测试数据集 -test_data = torchvision.datasets.CIFAR10("CIFAR10/CIFAR10_test", train=False, - transform=torchvision.transforms.ToTensor(), download=True) - -# 数据集长度 -train_data_size = len(train_data) -test_data_size = len(test_data) -print("训练数据集的长度为:{}".format(train_data_size)) -print("测试数据集的长度为:{}".format(test_data_size)) - -# dataloader加载数据集 -train_dataloader = DataLoader(train_data, batch_size=64) -test_dataloader = DataLoader(test_data, batch_size=64) - -# 定义运行设备 -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - -# 创建神经网络模型 -vgg = [96, 96, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'] - -select=input("选择模型:\n" - "1.VGG16\n" - "2.RestNet18\n") -if select=='1': - qcq = QCQ(vgg) - model_name='VGG16' - learning_rate = 1e-2 - optimizer = torch.optim.SGD(qcq.parameters(), lr=learning_rate,weight_decay=5e-3) - scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.4, last_epoch=-1) - # 训练轮数 - epoch = 10 -else: - qcq = ResNet18() - model_name = 'RestNet18' - # 优化器,学习率 - learning_rate = 1e-3 - optimizer = torch.optim.Adam(qcq.parameters(), lr=learning_rate) - # 训练轮数 - epoch = 50 -qcq = qcq.to(device) - -# 创建损失函数 -loss_fun = nn.CrossEntropyLoss().to(device) -loss_fun = loss_fun.to(device) - - - -# 设置训练网络的一些参数 -# 训练次数 -total_train_step = 0 -# 测试次数 -total_test_step = 0 - - -# 添加tensorboard -if not os.path.exists('logs'): - os.mkdir('logs') -writer = SummaryWriter('logs') - -for i in range(epoch): - print("--------第%5d 轮训练开始--------" % (i + 1)) - - # 训练步骤开始 - qcq.train() - for data in train_dataloader: - imgs, targets = data - - imgs = imgs.to(device) - targets = targets.to(device) - - outputs = qcq(imgs) - loss = loss_fun(outputs, targets) - # 模型优化 - optimizer.zero_grad() - loss.backward() - optimizer.step() - total_train_step += 1 - if total_train_step % 100 == 0: - print('训练次数:%-8d,Loss:%f' % (total_train_step, loss.item())) - writer.add_scalar('train_loss', loss.item(), total_train_step) - - # 测试步骤开始 - qcq.eval() - total_test_loss = 0 - total_accuracy = 0 - with torch.no_grad(): - for data in test_dataloader: - imgs, targets = data - - imgs = imgs.to(device) - targets = targets.to(device) - - outputs = qcq(imgs) - loss = loss_fun(outputs, targets) - total_test_loss += loss.item() - accuracy = (outputs.argmax(1) == targets).sum() - total_accuracy += accuracy - print('整体测试集上test_loss:%f' % total_test_loss) - print('整体测试集上正确率:%f' % (total_accuracy / test_data_size)) - writer.add_scalar('test_loss', total_test_loss, total_test_step) - writer.add_scalar('test_accuracy_rate', total_accuracy / test_data_size, total_test_step) - if select=='1': scheduler.step() - total_test_step += 1 - - # 模型每轮保存 - if not os.path.exists('model'): - os.mkdir('model') - torch.save(qcq.state_dict(), 'model/'+model_name+'_%d_epoch.pth' % (i+1)) - print("模型自动保存成功") - -writer.close()