CV+Deep Learning - Network Architecture Pytorch Reproduction Series - 分類 (1: LeNet5, VGG, AlexNet, ResNet) (2024)

はじめに このシリーズでは、コンピューター ビジョンにおけるディープ ラーニングのさまざまな古典的なネットワーク モデル(分類、ターゲット検出、セマンティック セグメンテーション)の再現に焦点を当てており、初心者がそれらを使用できるようにします (簡単なものから深いものまで)。

コードはすべてエラーなしで実行されます! !

まず、深層学習の古典的な分類ネットワーク モジュールを再現します。その中で、バックボーン (10.、11.) はターゲット検出に特化していますが、その主な目的は特徴を抽出することであるため、以下を含めてここにも配置されます。

1.LeNet5(√)

2.VGG(√)

3.アレックスネット(√)

4.レスネット(√)

5.レスネクスト

6.グーグルネット

7.モバイルネット

8.シャッフルネット

9.EfficientNet

10.VovNet

11.ダークネット

...

知らせ:

a) 完全なコードが私の github にアップロードされます

https://github.com/HanXiaoyiGitHub/Simple-CV-Pytorch-master CV+Deep Learning - Network Architecture Pytorch Reproduction Series - 分類 (1: LeNet5, VGG, AlexNet, ResNet) (1)https://github.com/HanXiaoyiGitHub/Simple-CV-Pytorch-master b) コンパイル環境が設定されている (実際には、このコンパイル環境を使用しない場合は、調整します バグも大丈夫です!)

python == 3.9.12torch == 1.11.0+cu113torchvision== 0.11.0+cu113torchaudio== 0.12.0+cu113pycocotools == 2.0.4numpyCythonmatplotlibopencv-pythonskimagetensorobardtqdmthop

c) 分類データ セットは、ImageNet または CIFAR10 とそのディレクトリを使用します (coco と voc はターゲット検出に使用され、セマンティック セグメンテーションは現在使用されていません)。

dataset path: /data/data||----coco----|----coco2017||----cifar||----ImageNet----|----ILSVRC2012||----VOCdevkitcoco2017 path: /data/coco/coco2017coco2017|||----annotations|----train2017|----test2017|----val2017voc path: /data/VOCdevkit|| |----Annotations| |----ImageSets|----VOC2007----|----JPEGImages| |----SegmentationClass| |----SegmentationObject||| |----Annotations| |----ImageSets|----VOC2012----|----JPEGImages| |----SegmentationClass| |----SegmentationObjectILSVRC2012 path : /data/ImageNet/ILSVRC2012||----train||----valcifar path: /data/cifar||----cifar-10-batches-py||----cifar-10-python.tar.gz

d) amp 混合精度を使用して GPU を高速化する. 使用方法がわからない場合は、次のリンクを参照してください。

Pytorch を使用してネットワーク モデルのトレーニングを高速化する方法は? (自動キャストと GradScaler) CV+Deep Learning - Network Architecture Pytorch Reproduction Series - 分類 (1: LeNet5, VGG, AlexNet, ResNet) (2)https://blog.csdn.net/XiaoyYidiaodiao/article/details/124854343?spm=1001.2014.3001.5502

そのため、ネットワークモデルの forward 関数の前に @autocast() を追加する必要があります。また、torch バージョン 1.4 以降を使用しているため、ReLu(inplace=False)、Dropout(inplace=False) などを変更して設定する必要があります。 False に置き換えます。

e) LeNet5、VGG16、および AlexNet は全結合層を使用し、画像サイズを変更できないため、これらのネットワーク アーキテクチャの画像サイズは、画像の前処理中に固定する必要があります。

f) プロジェクトのファイル構造

使用しているOS(Ubuntu 20.04)はもちろんwindowsでも動くので、動かしてみました。一部のフォルダは必要ありません。そのままにしておいてください。後で説明します。

project path: /data/PycharmProject/Simple-CV-master path: /data/PycharmProject/Simple-CV-Pytorch-master||----checkpoints ( resnet50-19c8e357.pth \COCO_ResNet50.pth[RetinaNet]\ VOC_ResNet50.pth[RetinaNet] )|| |----cifar.py ( null, I just use torchvision.datasets.ImageFolder )| |----CIAR_labels.txt| |----coco.py| |----coco_eval.py| |----coco_labels.txt|----data----|----__init__.py| |----config.py ( path )| |----imagenet.py ( null, I just use torchvision.datasets.ImageFolder )| |----ImageNet_labels.txt| |----voc0712.py| |----voc_eval.py| |----voc_labels.txt| |----crash_helmet.jpg|----images----|----classification----|----sunflower.jpg| | |----photocopier.jpg| | |----automobile.jpg| || |----detection----|----000001.jpg| |----000001.xml| |----000002.jpg| |----000002.xml| |----000003.jpg| |----000003.xml||----log(XXX[ detection or classification ]_XXX[ train or test or eval ].info.log)|| |----__init__.py| || | |----__init.py| |----anchor----|----RetinaNetAnchors.py| || | |----lenet5.py| | |----alexnet.py| |----basenet----|----vgg.py| | |----resnet.py| || | |----DarkNetBackbone.py| |----backbones----|----__init__.py ( Don't finish writing )| | |----ResNetBackbone.py| | |----VovNetBackbone.py| || || ||----models----|----heads----|----__init.py| | |----RetinaNetHeads.py| || | |----RetinaNetLoss.py| |----losses----|----__init.py| || | |----FPN.py| |----necks----|----__init__.py| | |-----FPN.txt| || |----RetinaNet.py||----results ( eg: detection ( VOC or COCO AP ) )||----tensorboard ( Loss visualization )||----tools |----eval.py| |----classification----|----train.py| | |----test.py| || || || | |----eval_coco.py| | |----eval_voc.py| |----detection----|----test.py| |----train.py||| |----AverageMeter.py| |----BBoxTransform.py| |----ClipBoxes.py| |----Sampler.py| |----iou.py|----utils----|----__init__.py| |----accuracy.py| |----augmentations.py| |----collate.py| |----get_logger.py| |----nms.py| |----path.py||----FolderOrganization.txt||----main.py||----README.md||----requirements.txt

1.LeNet5(サイズ:32×32×3)[1]

CV+Deep Learning - Network Architecture Pytorch Reproduction Series - 分類 (1: LeNet5, VGG, AlexNet, ResNet) (3)

図1。

図 1. 復元コード

nn.BatchNorm2d() を追加して精度を上げます. もちろん、完全に再現するには、nn.BatchNorm2d() を無視してコードから削除できます.

最後の接続層の出力は、データセットのカテゴリに従って調整できます

from torch import nnfrom torch.cuda.amp import autocastclass lenet5(nn.Module): # cifar: 10, ImageNet: 1000 def __init__(self, num_classes=1000, init_weights=False): super(lenet5, self).__init__() self.num_classes = num_classes self.layers = nn.Sequential( # input:32 * 32 * 3 -> 28 * 28 * 6 nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5, padding=0, stride=1, bias=False), nn.BatchNorm2d(6), nn.ReLU(), # 28 * 28 * 6 -> 14 * 14 * 6 nn.MaxPool2d(kernel_size=2, stride=2, padding=0), # 14 * 14 * 6 -> 10 * 10 * 16 nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, padding=0, stride=1, bias=False), nn.BatchNorm2d(16), nn.ReLU(), # 10 * 10 * 16 -> 5 * 5 * 16 nn.MaxPool2d(kernel_size=2, stride=2, padding=0), nn.Flatten(), nn.Linear(16 * 5 * 5, 120), nn.Linear(120, 84)) self.classifier = nn.Linear(84, self.num_classes) if init_weights: self._initialize_weights() @autocast() def forward(self, x): x = self.layers(x) x = self.classifier(x) return x def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) nn.init.constant_(m.bias, 0)

2.AlexNet (サイズ: 224 * 224 * 3)[2]

CV+Deep Learning - Network Architecture Pytorch Reproduction Series - 分類 (1: LeNet5, VGG, AlexNet, ResNet) (4)

図2。

図 2 のように、特に明確でない場合は、以下の図 3 を参照してください。

CV+Deep Learning - Network Architecture Pytorch Reproduction Series - 分類 (1: LeNet5, VGG, AlexNet, ResNet) (5)

画像3。

図 3 を図 4 に変えてください。これは、以前の AlexNet が 2 枚のグラフィックス カードで実行されていたため (今年の計算能力では十分ではありませんでした)、今では計算能力が維持され、GPU 実行に配置できるようになったためです。

CV+Deep Learning - Network Architecture Pytorch Reproduction Series - 分類 (1: LeNet5, VGG, AlexNet, ResNet) (6)

図 4。

最後の接続層の出力は、データセットのカテゴリに従って調整できます

nn.BatchNorm2d() を追加して精度を上げます. もちろん、完全に再現するには、nn.BatchNorm2d() を無視してコードから削除できます.

import torch.nn as nnfrom torch.cuda.amp import autocastclass alexnet(nn.Module): def __init__(self, num_classes=1000, init_weights=False): super(alexnet, self).__init__() self.layers = nn.Sequential( # input: 224 * 224 * 3 -> 55 * 55 * (48*2) nn.Conv2d(in_channels=3, out_channels=96, kernel_size=11, stride=4, padding=2, bias=False), nn.BatchNorm2d(96), nn.ReLU(), # 55 * 55 * (48*2) -> 27 * 27 * (48*2) nn.MaxPool2d(kernel_size=3, stride=2), # 27 * 27 * (48*2) -> 27 * 27 * (128*2) nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, padding=2, bias=False), nn.BatchNorm2d(256), nn.ReLU(), # 27 * 27 * (128*2) -> 13 * 13 * (128*2) nn.MaxPool2d(kernel_size=3, stride=2), # 13 * 13 * (128*2) -> 13 * 13 * (192*2) nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(384), nn.ReLU(), # 13 * 13 * (192*2) -> 13 * 13 * (192*2) nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(384), nn.ReLU(), # 13 * 13 * (192*2) -> 13 * 13 * (128*2) nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(256), nn.ReLU(), # 13 * 13 * (128*2) -> 6 * 6 * (128*2) nn.MaxPool2d(kernel_size=3, stride=2) ) self.fc = nn.Sequential( nn.Flatten(), nn.Dropout(0.5), nn.Linear(6 * 6 * 128 * 2, 2048), nn.ReLU(), nn.Dropout(0.5), nn.Linear(2048, 2048), nn.ReLU() ) self.classifier = nn.Linear(2048, num_classes) if init_weights: self._initialize_weights() @autocast() def forward(self, x): x = self.layers(x) x = self.fc(x) x = self.classifier(x) return x def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) nn.init.constant_(m.bias, 0)

3.VGG(サイズ:224×224 × 3)[3]

CV+Deep Learning - Network Architecture Pytorch Reproduction Series - 分類 (1: LeNet5, VGG, AlexNet, ResNet) (7)

図 5.

図 5. 緑の枠で囲まれたネットワーク アーキテクチャを再現し、コードを復元する

精度の悪さが堪らないので、nn.BatchNorm2d(i)を追加して移行学習を行いました。

最後の接続層の出力は、データセットのカテゴリに従って調整できます

import torchfrom torch import nnfrom utils.path import CheckPointsfrom torch.cuda.amp import autocast__all__ = [ 'vgg11', 'vgg13', 'vgg16', 'vgg19',]# if your network is limited, you can download them, and put them into CheckPoints(my Project:Simple-CV-Pytorch-master/checkpoints/).model_urls = { # 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 'vgg11': '{}/vgg11-bbd30ac9.pth'.format(CheckPoints), # 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 'vgg13': '{}/vgg13-c768596a.pth'.format(CheckPoints), # 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 'vgg16': '{}/vgg16-397923af.pth'.format(CheckPoints), # 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 'vgg19': '{}/vgg19-dcbb9e9d.pth'.format(CheckPoints)}def vgg_(arch, num_classes, pretrained, init_weights=False, **kwargs): cfg = cfgs["vgg" + arch] features = make_features(cfg) model = vgg(num_classes=num_classes, features=features, init_weights=init_weights, **kwargs) # if you're training for the first time, no pretrained is required! if pretrained: pretrained_models = torch.load(model_urls["vgg" + arch]) # transfer learning # if you want to train your own dataset if arch == '11': del pretrained_models['features.8.weight'] del pretrained_models['features.11.weight'] del pretrained_models['features.16.weight'] elif arch == '13': del pretrained_models['features.7.weight'] del pretrained_models['features.10.weight'] del pretrained_models['features.15.weight'] del pretrained_models['features.17.weight'] del pretrained_models['features.22.weight'] elif arch == '16': del pretrained_models['features.7.weight'] del pretrained_models['features.10.weight'] del pretrained_models['features.14.weight'] del pretrained_models['features.17.weight'] del pretrained_models['features.21.weight'] del pretrained_models['features.24.weight'] del pretrained_models['features.28.weight'] elif arch == '19': del pretrained_models['features.7.weight'] del pretrained_models['features.10.weight'] del pretrained_models['features.14.weight'] del pretrained_models['features.21.weight'] del pretrained_models['features.23.weight'] del pretrained_models['features.28.weight'] del pretrained_models['features.34.weight'] else: raise ValueError("Pretrained: unsupported VGG depth") model.load_state_dict(pretrained_models, strict=False) return modeldef vgg11(num_classes, pretrained=False, init_weights=False, **kwargs): return vgg_('11', num_classes, pretrained, init_weights, **kwargs)def vgg13(num_classes, pretrained=False, init_weights=False, **kwargs): return vgg_('13', num_classes, pretrained, init_weights, **kwargs)def vgg16(num_classes, pretrained=False, init_weights=False, **kwargs): return vgg_('16', num_classes, pretrained, init_weights, **kwargs)def vgg19(num_classes, pretrained=False, init_weights=False, **kwargs): return vgg_('19', num_classes, pretrained, init_weights, **kwargs)class vgg(nn.Module): # cifar: 10, ImageNet: 1000 def __init__(self, features, num_classes=1000, init_weights=False): super(vgg, self).__init__() self.num_classes = num_classes self.features = features self.fc = nn.Sequential( nn.Flatten(), nn.Linear(7 * 7 * 512, 4096), nn.ReLU(), nn.Dropout(0.5), nn.Linear(4096, 4096), nn.ReLU(), nn.Dropout(0.5), ) self.classifier = nn.Linear(4096, self.num_classes) if init_weights: self._initialize_weights() @autocast() def forward(self, x): x = self.features(x) x = x.view(x.size(0), -1) x = self.fc(x) x = self.classifier(x) return x def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.xavier_uniform_(m.weight) nn.init.constant_(m.bias, 0)def make_features(cfgs: list): layers = [] in_channels = 3 for i in cfgs: if i == "M": layers += [nn.MaxPool2d(kernel_size=2, stride=2)] else: conv2d = nn.Conv2d(in_channels, i, kernel_size=3, stride=1, padding=1, bias=False) layers += [conv2d, nn.BatchNorm2d(i), nn.ReLU()] in_channels = i return nn.Sequential(*layers)cfgs = { 'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],}

4.レスネット[4]

CV+Deep Learning - Network Architecture Pytorch Reproduction Series - 分類 (1: LeNet5, VGG, AlexNet, ResNet) (8)

図 6.

図 6. ネットワーク アーキテクチャ (ResNet18、ResNet34、ResNet50、ResNet101、ResNet152) の再現、コードの復元

まずは各ブロックの再現方法を見てみましょうか。図 18 層、34 層は下の図 7 の緑のボックスで示され、50 層、101 層、152 層は下の図 7 の青のボックスで示されます。

CV+Deep Learning - Network Architecture Pytorch Reproduction Series - 分類 (1: LeNet5, VGG, AlexNet, ResNet) (9)

ブロック:18層、34層

# 18-layer, 34-layerclass BasicBlock(nn.Module): expansion = 1 def __init__(self, in_channels, out_channels, stride=1, downsample=None): super(BasicBlock, self).__init__() self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU() self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(out_channels) self.downsample = downsample def forward(self, x): identity = x if self.downsample is not None: identity = self.downsample(x) out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out += identity out = self.relu(out) return out

ブロック:50層、101層、152層

# 50-layer, 101-layer, 152-layerclass Bottleneck(nn.Module): """ self.conv1(kernel_size=1,stride=2) self.conv2(kernel_size=3,stride=1) to self.conv1(kernel_size=1,stride=1) self.conv2(kernel_size=3,stride=2) """ expansion = 4 def __init__(self, in_channels, out_channels, stride=1, downsample=None): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, bias=False) self.bn1 = nn.BatchNorm2d(out_channels) self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=stride, bias=False) self.bn2 = nn.BatchNorm2d(out_channels) self.conv3 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels * self.expansion, kernel__size=1, stride=1, bias=False) self.bn3 = nn.BatchNorm2d(out_channels * self.expansion) self.relu = nn.ReLU() self.downsample = downsample def forward(self, x): identity = x if self.downsample is not None: identity = self.downsample(x) out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) out += identity out = self.relu(out) return out

ResNet モデル全体を復元するには、最初に畳み込みの最初のレイヤーと最大のプーリング レイヤーを復元します。

CV+Deep Learning - Network Architecture Pytorch Reproduction Series - 分類 (1: LeNet5, VGG, AlexNet, ResNet) (10)

class ResNet(nn.Module): def __init__(self, block, blocks_num, num_classes=1000, include_top=True): super(ResNet, self).__init__() self.include_top = include_top self.in_channels = 64 self.conv1 = nn.Conv2d(in_channels=3, out_channels=self.in_channels, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(self.in_channels) self.relu = nn.ReLU() self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

後続の層層を図 8 に示します。図のコードと表現モジュールの間の対等関係

conv2_x -> self.layer1, conv3_x -> self.layer2, conv4_x -> self.layer3, conv5_x -> self.layer4

CV+Deep Learning - Network Architecture Pytorch Reproduction Series - 分類 (1: LeNet5, VGG, AlexNet, ResNet) (11)

... self.layer1 = self._make_layer(block, 64, blocks_num[0]) self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2) self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2) self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)...

50層、101層、152層は点線を再現、18層、34層も同様なので表示されません。

CV+Deep Learning - Network Architecture Pytorch Reproduction Series - 分類 (1: LeNet5, VGG, AlexNet, ResNet) (12)

図 8.

 def _make_layer(self, block, channels, block_num, stride=1): downsample = None if stride != 1 or self.in_channels != channels * block.expansion: downsample = nn.Sequential( nn.Conv2d(in_channels=self.in_channels, out_channels=channels * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(channels * block.expansion) ) ...

次に、ResNet モデルを呼び出し、適切なレイヤー (18、34、50、101、152) を選択します。

CV+Deep Learning - Network Architecture Pytorch Reproduction Series - 分類 (1: LeNet5, VGG, AlexNet, ResNet) (13)

def resnet18(num_classes=1000, pretrained=False, include_top=True): return resnet_('18', BasicBlock, [2, 2, 2, 2], num_classes, pretrained, include_top)def resnet34(num_classes=1000, pretrained=False, include_top=True): return resnet_('34', BasicBlock, [3, 4, 6, 3], num_classes, pretrained, include_top)def resnet50(num_classes=1000, pretrained=False, include_top=True): return resnet_('50', Bottleneck, [3, 4, 6, 3], num_classes, pretrained, include_top)def resnet101(num_classes=1000, pretrained=False, include_top=True): return resnet_('101', Bottleneck, [3, 4, 23, 3], num_classes, pretrained, include_top)def resnet152(num_classes=1000, pretrained=False, include_top=True): return resnet_('152', Bottleneck, [3, 8, 36, 3], num_classes, pretrained, include_top)

完全なコード

import torchimport torch.nn as nnfrom utils.path import CheckPointsfrom torch.cuda.amp import autocast__all__ = [ 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']# if your network is limited, you can download them, and put them into CheckPoints(my Project:Simple-CV-Pytorch-master/checkpoints/).model_urls = { # 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 'resnet18': '{}/resnet18-5c106cde.pth'.format(CheckPoints), # 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 'resnet34': '{}/resnet34-333f7ec4.pth'.format(CheckPoints), # 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 'resnet50': '{}/resnet50-19c8e357.pth'.format(CheckPoints), # 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 'resnet101': '{}/resnet101-5d3b4d8f.pth'.format(CheckPoints), # 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 'resnet152': '{}/resnet152-b121ed2d.pth'.format(CheckPoints)}def resnet_(arch, block, block_num, num_classes, pretrained, include_top, **kwargs): model = resnet(block=block, blocks_num=block_num, num_classes=num_classes, include_top=include_top, **kwargs) # if you're training for the first time, no pretrained is required! if pretrained: # if you want to use cpu, you should modify map_loaction=torch.device("cpu") pretrained_models = torch.load(model_urls["resnet" + arch], map_location=torch.device("cuda:0")) # transfer learning # if you want to train your own dataset # del pretrained_models['module.classifier.bias'] model.load_state_dict(pretrained_models, strict=False) return model# 18-layer, 34-layerclass BasicBlock(nn.Module): expansion = 1 def __init__(self, in_channels, out_channels, stride=1, downsample=None): super(BasicBlock, self).__init__() self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU() self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(out_channels) self.downsample = downsample @autocast() def forward(self, x): identity = x if self.downsample is not None: identity = self.downsample(x) out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out += identity out = self.relu(out) return out# 50-layer, 101-layer, 152-layerclass Bottleneck(nn.Module): """ self.conv1(kernel_size=1,stride=2) self.conv2(kernel_size=3,stride=1) to self.conv1(kernel_size=1,stride=1) self.conv2(kernel_size=3,stride=2) acc: up 0.5% """ expansion = 4 def __init__(self, in_channels, out_channels, stride=1, downsample=None): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, bias=False) self.bn1 = nn.BatchNorm2d(out_channels) self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=stride, bias=False) self.bn2 = nn.BatchNorm2d(out_channels) self.conv3 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels * self.expansion, kernel__size=1, stride=1, bias=False) self.bn3 = nn.BatchNorm2d(out_channels * self.expansion) self.relu = nn.ReLU() self.downsample = downsample @autocast() def forward(self, x): identity = x if self.downsample is not None: identity = self.downsample(x) out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) out += identity out = self.relu(out) return outclass resnet(nn.Module): def __init__(self, block, blocks_num, num_classes=1000, include_top=True): super(resnet, self).__init__() self.include_top = include_top self.in_channels = 64 self.conv1 = nn.Conv2d(in_channels=3, out_channels=self.in_channels, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(self.in_channels) self.relu = nn.ReLU() self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, blocks_num[0]) self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2) self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2) self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2) if self.include_top: self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.flatten = nn.Flatten() self.fc = nn.Linear(512 * block.expansion, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') def _make_layer(self, block, channels, block_num, stride=1): downsample = None if stride != 1 or self.in_channels != channels * block.expansion: downsample = nn.Sequential( nn.Conv2d(in_channels=self.in_channels, out_channels=channels * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(channels * block.expansion) ) layers = [] layers.append(block(in_channels=self.in_channels, out_channels=channels, downsample=downsample, stride=stride)) self.in_channels = channels * block.expansion for _ in range(1, block_num): layers.append( block(in_channels=self.in_channels, out_channels=channels)) return nn.Sequential(*layers) @autocast() def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) if self.include_top: x = self.avgpool(x) x = self.flatten(x) x = self.fc(x) return xdef resnet18(num_classes=1000, pretrained=False, include_top=True): return resnet_('18', BasicBlock, [2, 2, 2, 2], num_classes, pretrained, include_top)def resnet34(num_classes=1000, pretrained=False, include_top=True): return resnet_('34', BasicBlock, [3, 4, 6, 3], num_classes, pretrained, include_top)def resnet50(num_classes=1000, pretrained=False, include_top=True): return resnet_('50', Bottleneck, [3, 4, 6, 3], num_classes, pretrained, include_top)def resnet101(num_classes=1000, pretrained=False, include_top=True): return resnet_('101', Bottleneck, [3, 4, 23, 3], num_classes, pretrained, include_top)def resnet152(num_classes=1000, pretrained=False, include_top=True): return resnet_('152', Bottleneck, [3, 8, 36, 3], num_classes, pretrained, include_top)

いくつかの設定ファイル

ユーティリティ/path.py

import os.pathimport sysBASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))sys.path.append(BASE_DIR)# Gets home dir cross platform# "/data/"MyName = "PycharmProject"Folder = "Simple-CV-Pytorch-master"# Path to store checkpoint modelCheckPoints = 'checkpoints'CheckPoints = os.path.join(BASE_DIR, MyName, Folder, CheckPoints)# Path to store tensorboard loadtensorboard_log = 'tensorboard'tensorboard_log = os.path.join(BASE_DIR, MyName, Folder, tensorboard_log)# Path to save loglog = 'log'log = os.path.join(BASE_DIR, MyName, Folder, log)# Path to save classification train logclassification_train_log = 'classification_train'# Path to save classification test logclassification_test_log = 'classification_test'# Path to save classification eval logclassification_eval_log = 'classification_eval'# Classification evaluate model pathclassification_evaluate = None# Images classification pathimage_cls = 'automobile.jpg'images_cls_path = 'images/classification'images_cls_path = os.path.join(BASE_DIR, MyName, Folder, images_cls_path, image_cls)# DataDATAPATH = BASE_DIR# ImageNet/ILSVRC2012ImageNet = "ImageNet/ILSVRC2012"ImageNet_Train_path = os.path.join(DATAPATH, ImageNet, 'train')ImageNet_Eval_path = os.path.join(DATAPATH, ImageNet, 'val')# CIFAR10CIFAR = 'cifar'CIFAR_path = os.path.join(DATAPATH, CIFAR)

データ/config.py

from utils import path# Path to save loglog = path.log# Path to save classification train logclassification_train_log = path.classification_train_log# Path to save classification test logclassification_test_log = path.classification_test_log# Path to save classification eval logclassification_eval_log = path.classification_eval_log# Path to store checkpoint modelcheckpoint_path = path.CheckPoints# Classification evaluate model pathclassification_evaluate = path.classification_evaluate# Classification test imagesimages_cls_root = path.images_cls_path# Path to save tensorboardtensorboard_log = path.tensorboard_log

トレーニングコード

ツール/分類/train.py

import osimport loggingimport argparseimport warningswarnings.filterwarnings('ignore')import sysBASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))sys.path.append(BASE_DIR)import timeimport torchfrom data import *import torchvisionimport torch.nn as nnimport torch.nn.parallelimport torch.optim as optimfrom torchvision import transformsfrom utils.accuracy import accuracyfrom torch.utils.data import DataLoaderfrom utils.get_logger import get_loggerfrom models.basenets.lenet5 import lenet5from models.basenets.alexnet import alexnetfrom utils.AverageMeter import AverageMeterfrom torch.cuda.amp import autocast, GradScalerfrom models.basenets.vgg import vgg11, vgg13, vgg16, vgg19from models.basenets.resnet import resnet18, resnet34, resnet50, resnet101, resnet152def parse_args(): parser = argparse.ArgumentParser(description='PyTorch Classification Training') parser.add_mutually_exclusive_group() parser.add_argument('--dataset', type=str, default='CIFAR', choices=['ImageNet', 'CIFAR'], help='ImageNet, CIFAR') parser.add_argument('--dataset_root', type=str, default=CIFAR_ROOT, choices=[ImageNet_Train_ROOT, CIFAR_ROOT], help='Dataset root directory path') parser.add_argument('--basenet', type=str, default='lenet', choices=['resnet', 'vgg', 'lenet', 'alexnet'], help='Pretrained base model') parser.add_argument('--depth', type=int, default=5, help='BaseNet depth, including: LeNet of 5, AlexNet of 0, VGG of 11, 13, 16, 19, ResNet of 18, 34, 50, 101, 152') parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training') parser.add_argument('--resume', type=str, default=None, help='Checkpoint state_dict file to resume training from') parser.add_argument('--num_workers', type=int, default=8, help='Number of workers user in dataloading') parser.add_argument('--cuda', type=str, default=True, help='Use CUDA to train model') parser.add_argument('--momentum', type=float, default=0.9, help='Momentum value for optim') parser.add_argument('--gamma', type=float, default=0.1, help='Gamma update for SGD') parser.add_argument('--accumulation_steps', type=int, default=1, help='Gradient acumulation steps') parser.add_argument('--save_folder', type=str, default=config.checkpoint_path, help='Directory for saving checkpoint models') parser.add_argument('--tensorboard', type=str, default=False, help='Use tensorboard for loss visualization') parser.add_argument('--log_folder', type=str, default=config.log, help='Log Folder') parser.add_argument('--log_name', type=str, default=config.classification_train_log, help='Log Name') parser.add_argument('--tensorboard_log', type=str, default=config.tensorboard_log, help='Use tensorboard for loss visualization') parser.add_argument('--lr', type=float, default=1e-2, help='learning rate') parser.add_argument('--epochs', type=int, default=30, help='Number of epochs') parser.add_argument('--weight_decay', type=float, default=1e-4, help='weight decay') parser.add_argument('--milestones', type=list, default=[15, 20, 30], help='Milestones') parser.add_argument('--num_classes', type=int, default=10, help='the number classes, like ImageNet:1000, cifar:10') parser.add_argument('--image_size', type=int, default=32, help='image size, like ImageNet:224, cifar:32') parser.add_argument('--pretrained', type=str, default=True, help='Models was pretrained') parser.add_argument('--init_weights', type=str, default=False, help='Init Weights') return parser.parse_args()args = parse_args()# 1. Logget_logger(args.log_folder, args.log_name)logger = logging.getLogger(args.log_name)# 2. Torch choose cuda or cpuif torch.cuda.is_available(): if args.cuda: torch.set_default_tensor_type('torch.cuda.FloatTensor') if not args.cuda: print("WARNING: It looks like you have a CUDA device, but you aren't using it" + "\n You can set the parameter of cuda to True.") torch.set_default_tensor_type('torch.FloatTensor')else: torch.set_default_tensor_type('torch.FloatTensor')if not os.path.exists(args.save_folder): os.mkdir(args.save_folder)def train(): # 3. Create SummaryWriter if args.tensorboard: from torch.utils.tensorboard import SummaryWriter # tensorboard loss writer = SummaryWriter(args.tensorboard_log) # vgg16, alexnet and lenet5 need to resize image_size, because of fc. if args.basenet == 'vgg' or args.basenet == 'alexnet': args.image_size = 224 elif args.basenet == 'lenet': args.image_size = 32 # 4. Ready dataset if args.dataset == 'ImageNet': if args.dataset_root == CIFAR_ROOT: raise ValueError('Must specify dataset_root if specifying dataset ImageNet2012.') elif os.path.exists(ImageNet_Train_ROOT) is None: raise ValueError("WARNING: Using default ImageNet2012 dataset_root because " + "--dataset_root was not specified.") dataset = torchvision.datasets.ImageFolder( root=args.dataset_root, transform=torchvision.transforms.Compose([ transforms.Resize((args.image_size, args.image_size)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ])) elif args.dataset == 'CIFAR': if args.dataset_root == ImageNet_Train_ROOT: raise ValueError('Must specify dataset_root if specifying dataset CIFAR10.') elif args.dataset_root is None: raise ValueError("Must provide --dataset_root when training on CIFAR10.") dataset = torchvision.datasets.CIFAR10(root=args.dataset_root, train=True, transform=torchvision.transforms.Compose([ transforms.Resize((args.image_size, args.image_size)), torchvision.transforms.ToTensor()])) else: raise ValueError('Dataset type not understood (must be ImageNet or CIFAR), exiting.') dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=False, generator=torch.Generator(device='cuda')) top1 = AverageMeter() top5 = AverageMeter() losses = AverageMeter() # 5. Define train model # Unfortunately, Lenet5 and Alexnet don't provide pretrianed Model. if args.basenet == 'lenet': if args.depth == 5: model = lenet5(num_classes=args.num_classes, init_weights=args.init_weights) else: raise ValueError('Unsupported LeNet depth!') elif args.basenet == 'alexnet': model = alexnet(num_classes=args.num_classes, init_weights=args.init_weights) elif args.basenet == 'vgg': if args.depth == 11: model = vgg11(pretrained=args.pretrained, num_classes=args.num_classes, init_weights=args.init_weights) elif args.depth == 13: model = vgg13(pretrained=args.pretrained, num_classes=args.num_classes, init_weights=args.init_weights) elif args.depth == 16: model = vgg16(pretrained=args.pretrained, num_classes=args.num_classes, init_weights=args.init_weights) elif args.depth == 19: model = vgg19(pretrained=args.pretrained, num_classes=args.num_classes, init_weights=args.init_weights) else: raise ValueError('Unsupported VGG depth!') # Unfortunately for my resnet, there is no set init_weight, because I'm going to set object detection algorithm elif args.basenet == 'resnet': if args.depth == 18: model = resnet18(pretrained=args.pretrained, num_classes=args.num_classes) elif args.depth == 34: model = resnet34(pretrained=args.pretrained, num_classes=args.num_classes) elif args.depth == 50: model = resnet50(pretrained=args.pretrained, num_classes=args.num_classes) # False means the models was not trained elif args.depth == 101: model = resnet101(pretrained=args.pretrained, num_classes=args.num_classes) elif args.depth == 152: model = resnet152(pretrained=args.pretrained, num_classes=args.num_classes) else: raise ValueError('Unsupported ResNet depth!') else: raise ValueError('Unsupported model type!') if args.cuda: if torch.cuda.is_available(): model = model.cuda() model = torch.nn.DataParallel(model).cuda() else: model = torch.nn.DataParallel(model) # 6. Loading weights if args.resume: other, ext = os.path.splitext(args.resume) if ext == '.pkl' or '.pth': print('Loading weights into state dict...') model_load = os.path.join(args.save_folder, args.resume) model.load_state_dict(torch.load(model_load)) else: print('Sorry only .pth and .pkl files supported.') if args.init_weights: # initialize newly added models' weights with xavier method if args.basenet == 'resnet': print("There is no set init_weight, because I'm going to set object detection algorithm.") else: print("Initializing weights...") else: print("Not Initializing weights...") if args.pretrained: if args.basenet == 'lenet' or args.basenet == 'alexnet': print("There is no available pretrained model on the website. ") else: print("Models was pretrained...") else: print("Pretrained models is False...") model.train() iteration = 0 # 7. Optimizer optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) criterion = nn.CrossEntropyLoss() scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=args.milestones, gamma=args.gamma) scaler = GradScaler() # 8. Length iter_size = len(dataset) // args.batch_size print("len(dataset): {}, iter_size: {}".format(len(dataset), iter_size)) logger.info(f"args - {args}") t0 = time.time() # 9. Create batch iterator for epoch in range(args.epochs): t1 = time.time() torch.cuda.empty_cache() # 10. Load train data for data in dataloader: iteration += 1 images, targets = data # 11. Backward optimizer.zero_grad() if args.cuda: images, targets = images.cuda(), targets.cuda() criterion = criterion.cuda() # 12. Forward with autocast(): outputs = model(images) loss = criterion(outputs, targets) loss = loss / args.accumulation_steps if args.tensorboard: writer.add_scalar("train_classification_loss", loss.item(), iteration) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() # 13. Measure accuracy and record loss acc1, acc5 = accuracy(outputs, targets, topk=(1, 5)) top1.update(acc1.item(), images.size(0)) top5.update(acc5.item(), images.size(0)) losses.update(loss.item(), images.size(0)) if iteration % 100 == 0: logger.info( f"- epoch: {epoch}, iteration: {iteration}, lr: {optimizer.param_groups[0]['lr']}, " f"top1 acc: {acc1.item():.2f}%, top5 acc: {acc5.item():.2f}%, " f"loss: {loss.item():.3f}, (losses.avg): {losses.avg:3f} " ) scheduler.step(losses.avg) t2 = time.time() h_time = (t2 - t1) // 3600 m_time = ((t2 - t1) % 3600) // 60 s_time = ((t2 - t1) % 3600) % 60 print("epoch {} is finished, and the time is {}h{}min{}s".format(epoch, int(h_time), int(m_time), int(s_time))) # 14. Save train model if epoch != 0 and epoch % 10 == 0: print('Saving state, iter:', epoch) torch.save(model.state_dict(), args.save_folder + '/' + args.dataset + '_' + args.basenet + str(args.depth) + '_' + repr(epoch) + '.pth') torch.save(model.state_dict(), args.save_folder + '/' + args.dataset + "_" + args.basenet + str(args.depth) + '.pth') if args.tensorboard: writer.close() t3 = time.time() h = (t3 - t0) // 3600 m = ((t3 - t0) % 3600) // 60 s = ((t3 - t0) % 3600) % 60 print("The Finished Time is {}h{}m{}s".format(int(h), int(m), int(s))) return top1.avg, top5.avg, losses.avgif __name__ == '__main__': torch.multiprocessing.set_start_method('spawn') logger.info("Program started") top1, top5, loss = train() print("top1 acc: {}, top5 acc: {}, loss:{}".format(top1, top5, loss)) logger.info("Done!")

テストコード

ツール/分類/test.py

import loggingimport osimport argparseimport warningswarnings.filterwarnings('ignore')import sysBASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))sys.path.append(BASE_DIR)import timefrom data import *from PIL import Imageimport torch.nn.parallelfrom torchvision import transformsfrom utils.get_logger import get_loggerfrom models.basenets.lenet5 import lenet5from models.basenets.alexnet import alexnetfrom models.basenets.vgg import vgg11, vgg13, vgg16, vgg19from models.basenets.resnet import resnet18, resnet34, resnet50, resnet101, resnet152def parse_args(): parser = argparse.ArgumentParser(description='PyTorch Classification Testing') parser.add_mutually_exclusive_group() parser.add_argument('--dataset', type=str, default='CIFAR', choices=['ImageNet', 'CIFAR'], help='ImageNet, CIFAR') parser.add_argument('--images_root', type=str, default=config.images_cls_root, help='Dataset root directory path') parser.add_argument('--basenet', type=str, default='alexnet', choices=['resnet', 'vgg', 'lenet', 'alexnet'], help='Pretrained base model') parser.add_argument('--depth', type=int, default=0, help='BaseNet depth, including: LeNet of 5, AlexNet of 0, VGG of 11, 13, 16, 19, ResNet of 18, 34, 50, 101, 152') parser.add_argument('--evaluate', type=str, default=config.classification_evaluate, help='Checkpoint state_dict file to evaluate training from') parser.add_argument('--save_folder', type=str, default=config.checkpoint_path, help='Directory for saving checkpoint models') parser.add_argument('--log_folder', type=str, default=config.log, help='Log Folder') parser.add_argument('--log_name', type=str, default=config.classification_test_log, help='Log Name') parser.add_argument('--cuda', type=str, default=True, help='Use CUDA to train model') parser.add_argument('--num_classes', type=int, default=10, help='the number classes, like ImageNet:1000, cifar:10') parser.add_argument('--image_size', type=int, default=32, help='image size, like ImageNet:224, cifar:32') parser.add_argument('--pretrained', type=str, default=False, help='Models was pretrained') return parser.parse_args()args = parse_args()# 1. Torch choose cuda or cpuif torch.cuda.is_available(): if args.cuda: torch.set_default_tensor_type('torch.cuda.FloatTensor') if not args.cuda: print("WARNING: It looks like you have a CUDA device, but you aren't using it" + "\n You can set the parameter of cuda to True.") torch.set_default_tensor_type('torch.FloatTensor')else: torch.set_default_tensor_type('torch.FloatTensor')if not os.path.exists(args.save_folder): os.mkdir(args.save_folder)# 2. Logget_logger(args.log_folder, args.log_name)logger = logging.getLogger(args.log_name)def get_label_file(filename): if not os.path.exists(filename): print("The dataset label.txt is empty, We need to create a new one.") os.mkdir(filename) return filenamedef dataset_labels_results(filename, output): filename = os.path.join(BASE_DIR, 'data', filename + '_labels.txt') get_label_file(filename=filename) with open(file=filename, mode='r') as f: dict = f.readlines() output = output.cpu().numpy() output = output[0] output = dict[output] f.close() return outputdef test(): # vgg16, alexnet and lenet5 need to resize image_size, because of fc. if args.basenet == 'vgg' or args.basenet == 'alexnet': args.image_size = 224 elif args.basenet == 'lenet': args.image_size = 32 # 3. Ready image if args.images_root is None: raise ValueError("The images is None, you should load image!") image = Image.open(args.images_root) transform = transforms.Compose([ transforms.Resize((args.image_size, args.image_size)), transforms.ToTensor()]) image = transform(image) image = image.reshape(1, 3, args.image_size, args.image_size) # 4. Define to train mode if args.basenet == 'lenet': if args.depth == 5: model = lenet5(num_classes=args.num_classes) else: raise ValueError('Unsupported LeNet depth!') elif args.basenet == 'alexnet': model = alexnet(num_classes=args.num_classes) elif args.basenet == 'vgg': if args.depth == 11: model = vgg11(pretrained=args.pretrained, num_classes=args.num_classes) elif args.depth == 13: model = vgg13(pretrained=args.pretrained, num_classes=args.num_classes) elif args.depth == 16: model = vgg16(pretrained=args.pretrained, num_classes=args.num_classes) elif args.depth == 19: model = vgg19(pretrained=args.pretrained, num_classes=args.num_classes) else: raise ValueError('Unsupported VGG depth!') elif args.basenet == 'resnet': if args.depth == 18: model = resnet18(pretrained=args.pretrained, num_classes=args.num_classes) elif args.depth == 34: model = resnet34(pretrained=args.pretrained, num_classes=args.num_classes) elif args.depth == 50: model = resnet50(pretrained=args.pretrained, num_classes=args.num_classes) # False means the models is not trained elif args.depth == 101: model = resnet101(pretrained=args.pretrained, num_classes=args.num_classes) elif args.depth == 152: model = resnet152(pretrained=args.pretrained, num_classes=args.num_classes) else: raise ValueError('Unsupported ResNet depth!') else: raise ValueError('Unsupported model type!') if args.cuda: model = model.cuda() model = torch.nn.DataParallel(model).cuda() else: model = torch.nn.DataParallel(model) # 5. Loading model if args.evaluate: other, ext = os.path.splitext(args.evaluate) if ext == '.pkl' or '.pth': print('Loading weights into state dict...') model_evaluate_load = os.path.join(args.save_folder, args.evaluate) model.load_state_dict(torch.load(model_evaluate_load)) else: print('Sorry only .pth and .pkl files supported.') elif args.evaluate is None: print("Sorry, you should load weights! ") model.eval() # 6. print logger.info(f"args - {args}") # 7. Test with torch.no_grad(): t0 = time.time() # 8. Forward if args.cuda: image = image.cuda() output = model(image) output = output.argmax(1) t1 = time.time() m = (t1 - t0) // 60 s = (t1 - t0) % 60 folder_name = args.dataset output = dataset_labels_results(filename=folder_name, output=output) logger.info(f"output: {output}") print("It took a total of {}m{}s to complete the testing.".format(int(m), int(s))) return outputif __name__ == '__main__': torch.multiprocessing.set_start_method('spawn') logger.info("Program started") output = test() logger.info("Done!")

ラベル

CIFAR_label.txt

{0: 'airplane', 1: 'automobile', 2: 'bird', 3: 'cat', 4: 'deer', 5: 'dog', 6: 'frog', 7: 'horse', 8: 'ship', 9: 'truck'}

ImageNet_label.txt

{0: 'tench, Tinca tinca', 1: 'goldfish, Carassius auratus', 2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias', 3: 'tiger shark, Galeocerdo cuvieri', 4: 'hammerhead, hammerhead shark', 5: 'electric ray, crampfish, numbfish, torpedo', 6: 'stingray', 7: 'co*ck', 8: 'hen', 9: 'ostrich, Struthio camelus', 10: 'brambling, Fringilla montifringilla', 11: 'goldfinch, Carduelis carduelis', 12: 'house finch, linnet, Carpodacus mexicanus', 13: 'junco, snowbird', 14: 'indigo bunting, indigo finch, indigo bird, Passerina cyanea', 15: 'robin, American robin, Turdus migratorius', 16: 'bulbul', 17: 'jay', 18: 'magpie', 19: 'chickadee', 20: 'water ouzel, dipper', 21: 'kite', 22: 'bald eagle, American eagle, Haliaeetus leucocephalus', 23: 'vulture', 24: 'great grey owl, great gray owl, Strix nebulosa', 25: 'European fire salamander, Salamandra salamandra', 26: 'common newt, Triturus vulgaris', 27: 'eft', 28: 'spotted salamander, Ambystoma maculatum', 29: 'axolotl, mud puppy, Ambystoma mexicanum', 30: 'bullfrog, Rana catesbeiana', 31: 'tree frog, tree-frog', 32: 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui', 33: 'loggerhead, loggerhead turtle, Caretta caretta', 34: 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea', 35: 'mud turtle', 36: 'terrapin', 37: 'box turtle, box tortoise', 38: 'banded gecko', 39: 'common iguana, iguana, Iguana iguana', 40: 'American chameleon, anole, Anolis carolinensis', 41: 'whiptail, whiptail lizard', 42: 'agama', 43: 'frilled lizard, Chlamydosaurus kingi', 44: 'alligator lizard', 45: 'Gila monster, Heloderma suspectum', 46: 'green lizard, Lacerta viridis', 47: 'African chameleon, Chamaeleo chamaeleon', 48: 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis', 49: 'African crocodile, Nile crocodile, Crocodylus niloticus', 50: 'American alligator, Alligator mississipiensis', 51: 'triceratops', 52: 'thunder snake, worm snake, Carphophis amoenus', 53: 'ringneck snake, ring-necked snake, ring snake', 54: 'hognose snake, puff adder, sand viper', 55: 'green snake, grass snake', 56: 'king snake, kingsnake', 57: 'garter snake, grass snake', 58: 'water snake', 59: 'vine snake', 60: 'night snake, Hypsiglena torquata', 61: 'boa constrictor, Constrictor constrictor', 62: 'rock python, rock snake, Python sebae', 63: 'Indian cobra, Naja naja', 64: 'green mamba', 65: 'sea snake', 66: 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus', 67: 'diamondback, diamondback rattlesnake, Crotalus adamanteus', 68: 'sidewinder, horned rattlesnake, Crotalus cerastes', 69: 'trilobite', 70: 'harvestman, daddy longlegs, Phalangium opilio', 71: 'scorpion', 72: 'black and gold garden spider, Argiope aurantia', 73: 'barn spider, Araneus cavaticus', 74: 'garden spider, Aranea diademata', 75: 'black widow, Latrodectus mactans', 76: 'tarantula', 77: 'wolf spider, hunting spider', 78: 'tick', 79: 'centipede', 80: 'black grouse', 81: 'ptarmigan', 82: 'ruffed grouse, partridge, Bonasa umbellus', 83: 'prairie chicken, prairie grouse, prairie fowl', 84: 'peaco*ck', 85: 'quail', 86: 'partridge', 87: 'African grey, African gray, Psittacus erithacus', 88: 'macaw', 89: 'sulphur-crested co*ckatoo, Kakatoe galerita, Cacatua galerita', 90: 'lorikeet', 91: 'coucal', 92: 'bee eater', 93: 'hornbill', 94: 'hummingbird', 95: 'jacamar', 96: 'toucan', 97: 'drake', 98: 'red-breasted merganser, Mergus serrator', 99: 'goose', 100: 'black swan, Cygnus atratus', 101: 'tusker', 102: 'echidna, spiny anteater, anteater', 103: 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus', 104: 'wallaby, brush kangaroo', 105: 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus', 106: 'wombat', 107: 'jellyfish', 108: 'sea anemone, anemone', 109: 'brain coral', 110: 'flatworm, platyhelminth', 111: 'nematode, nematode worm, roundworm', 112: 'conch', 113: 'snail', 114: 'slug', 115: 'sea slug, nudibranch', 116: 'chiton, coat-of-mail shell, sea cradle, polyplacophore', 117: 'chambered nautilus, pearly nautilus, nautilus', 118: 'Dungeness crab, Cancer magister', 119: 'rock crab, Cancer irroratus', 120: 'fiddler crab', 121: 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica', 122: 'American lobster, Northern lobster, Maine lobster, Homarus americanus', 123: 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish', 124: 'crayfish, crawfish, crawdad, crawdaddy', 125: 'hermit crab', 126: 'isopod', 127: 'white stork, Ciconia ciconia', 128: 'black stork, Ciconia nigra', 129: 'spoonbill', 130: 'flamingo', 131: 'little blue heron, Egretta caerulea', 132: 'American egret, great white heron, Egretta albus', 133: 'bittern', 134: 'crane', 135: 'limpkin, Aramus pictus', 136: 'European gallinule, Porphyrio porphyrio', 137: 'American coot, marsh hen, mud hen, water hen, Fulica americana', 138: 'bustard', 139: 'ruddy turnstone, Arenaria interpres', 140: 'red-backed sandpiper, dunlin, Erolia alpina', 141: 'redshank, Tringa totanus', 142: 'dowitcher', 143: 'oystercatcher, oyster catcher', 144: 'pelican', 145: 'king penguin, Aptenodytes patagonica', 146: 'albatross, mollymawk', 147: 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus', 148: 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca', 149: 'dugong, Dugong dugon', 150: 'sea lion', 151: 'Chihuahua', 152: 'Japanese spaniel', 153: 'Maltese dog, Maltese terrier, Maltese', 154: 'Pekinese, Pekingese, Peke', 155: 'Shih-Tzu', 156: 'Blenheim spaniel', 157: 'papillon', 158: 'toy terrier', 159: 'Rhodesian ridgeback', 160: 'Afghan hound, Afghan', 161: 'basset, basset hound', 162: 'beagle', 163: 'bloodhound, sleuthhound', 164: 'bluetick', 165: 'black-and-tan coonhound', 166: 'Walker hound, Walker foxhound', 167: 'English foxhound', 168: 'redbone', 169: 'borzoi, Russian wolfhound', 170: 'Irish wolfhound', 171: 'Italian greyhound', 172: 'whippet', 173: 'Ibizan hound, Ibizan Podenco', 174: 'Norwegian elkhound, elkhound', 175: 'otterhound, otter hound', 176: 'Saluki, gazelle hound', 177: 'Scottish deerhound, deerhound', 178: 'Weimaraner', 179: 'Staffordshire bullterrier, Staffordshire bull terrier', 180: 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier', 181: 'Bedlington terrier', 182: 'Border terrier', 183: 'Kerry blue terrier', 184: 'Irish terrier', 185: 'Norfolk terrier', 186: 'Norwich terrier', 187: 'Yorkshire terrier', 188: 'wire-haired fox terrier', 189: 'Lakeland terrier', 190: 'Sealyham terrier, Sealyham', 191: 'Airedale, Airedale terrier', 192: 'cairn, cairn terrier', 193: 'Australian terrier', 194: 'Dandie Dinmont, Dandie Dinmont terrier', 195: 'Boston bull, Boston terrier', 196: 'miniature schnauzer', 197: 'giant schnauzer', 198: 'standard schnauzer', 199: 'Scotch terrier, Scottish terrier, Scottie', 200: 'Tibetan terrier, chrysanthemum dog', 201: 'silky terrier, Sydney silky', 202: 'soft-coated wheaten terrier', 203: 'West Highland white terrier', 204: 'Lhasa, Lhasa apso', 205: 'flat-coated retriever', 206: 'curly-coated retriever', 207: 'golden retriever', 208: 'Labrador retriever', 209: 'Chesapeake Bay retriever', 210: 'German short-haired pointer', 211: 'vizsla, Hungarian pointer', 212: 'English setter', 213: 'Irish setter, red setter', 214: 'Gordon setter', 215: 'Brittany spaniel', 216: 'clumber, clumber spaniel', 217: 'English springer, English springer spaniel', 218: 'Welsh springer spaniel', 219: 'co*cker spaniel, English co*cker spaniel, co*cker', 220: 'Sussex spaniel', 221: 'Irish water spaniel', 222: 'kuvasz', 223: 'schipperke', 224: 'groenendael', 225: 'malinois', 226: 'briard', 227: 'kelpie', 228: 'komondor', 229: 'Old English sheepdog, bobtail', 230: 'Shetland sheepdog, Shetland sheep dog, Shetland', 231: 'collie', 232: 'Border collie', 233: 'Bouvier des Flandres, Bouviers des Flandres', 234: 'Rottweiler', 235: 'German shepherd, German shepherd dog, German police dog, alsatian', 236: 'Doberman, Doberman pinscher', 237: 'miniature pinscher', 238: 'Greater Swiss Mountain dog', 239: 'Bernese mountain dog', 240: 'Appenzeller', 241: 'EntleBucher', 242: 'boxer', 243: 'bull mastiff', 244: 'Tibetan mastiff', 245: 'French bulldog', 246: 'Great Dane', 247: 'Saint Bernard, St Bernard', 248: 'Eskimo dog, husky', 249: 'malamute, malemute, Alaskan malamute', 250: 'Siberian husky', 251: 'dalmatian, coach dog, carriage dog', 252: 'affenpinscher, monkey pinscher, monkey dog', 253: 'basenji', 254: 'pug, pug-dog', 255: 'Leonberg', 256: 'Newfoundland, Newfoundland dog', 257: 'Great Pyrenees', 258: 'Samoyed, Samoyede', 259: 'Pomeranian', 260: 'chow, chow chow', 261: 'keeshond', 262: 'Brabancon griffon', 263: 'Pembroke, Pembroke Welsh corgi', 264: 'Cardigan, Cardigan Welsh corgi', 265: 'toy poodle', 266: 'miniature poodle', 267: 'standard poodle', 268: 'Mexican hairless', 269: 'timber wolf, grey wolf, gray wolf, Canis lupus', 270: 'white wolf, Arctic wolf, Canis lupus tundrarum', 271: 'red wolf, maned wolf, Canis rufus, Canis niger', 272: 'coyote, prairie wolf, brush wolf, Canis latrans', 273: 'dingo, warrigal, warragal, Canis dingo', 274: 'dhole, Cuon alpinus', 275: 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus', 276: 'hyena, hyaena', 277: 'red fox, Vulpes vulpes', 278: 'kit fox, Vulpes macrotis', 279: 'Arctic fox, white fox, Alopex lagopus', 280: 'grey fox, gray fox, Urocyon cinereoargenteus', 281: 'tabby, tabby cat', 282: 'tiger cat', 283: 'Persian cat', 284: 'Siamese cat, Siamese', 285: 'Egyptian cat', 286: 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor', 287: 'lynx, catamount', 288: 'leopard, Panthera pardus', 289: 'snow leopard, ounce, Panthera uncia', 290: 'jaguar, panther, Panthera onca, Felis onca', 291: 'lion, king of beasts, Panthera leo', 292: 'tiger, Panthera tigris', 293: 'cheetah, chetah, Acinonyx jubatus', 294: 'brown bear, bruin, Ursus arctos', 295: 'American black bear, black bear, Ursus americanus, Euarctos americanus', 296: 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus', 297: 'sloth bear, Melursus ursinus, Ursus ursinus', 298: 'mongoose', 299: 'meerkat, mierkat', 300: 'tiger beetle', 301: 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle', 302: 'ground beetle, carabid beetle', 303: 'long-horned beetle, longicorn, longicorn beetle', 304: 'leaf beetle, chrysomelid', 305: 'dung beetle', 306: 'rhinoceros beetle', 307: 'weevil', 308: 'fly', 309: 'bee', 310: 'ant, emmet, pismire', 311: 'grasshopper, hopper', 312: 'cricket', 313: 'walking stick, walkingstick, stick insect', 314: 'co*ckroach, roach', 315: 'mantis, mantid', 316: 'cicada, cicala', 317: 'leafhopper', 318: 'lacewing, lacewing fly', 319: "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk", 320: 'damselfly', 321: 'admiral', 322: 'ringlet, ringlet butterfly', 323: 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus', 324: 'cabbage butterfly', 325: 'sulphur butterfly, sulfur butterfly', 326: 'lycaenid, lycaenid butterfly', 327: 'starfish, sea star', 328: 'sea urchin', 329: 'sea cucumber, holothurian', 330: 'wood rabbit, cottontail, cottontail rabbit', 331: 'hare', 332: 'Angora, Angora rabbit', 333: 'hamster', 334: 'porcupine, hedgehog', 335: 'fox squirrel, eastern fox squirrel, Sciurus niger', 336: 'marmot', 337: 'beaver', 338: 'guinea pig, Cavia cobaya', 339: 'sorrel', 340: 'zebra', 341: 'hog, pig, grunter, squealer, Sus scrofa', 342: 'wild boar, boar, Sus scrofa', 343: 'warthog', 344: 'hippopotamus, hippo, river horse, Hippopotamus amphibius', 345: 'ox', 346: 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis', 347: 'bison', 348: 'ram, tup', 349: 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis', 350: 'ibex, Capra ibex', 351: 'hartebeest', 352: 'impala, Aepyceros melampus', 353: 'gazelle', 354: 'Arabian camel, dromedary, Camelus dromedarius', 355: 'llama', 356: 'weasel', 357: 'mink', 358: 'polecat, fitch, foulmart, foumart, Mustela putorius', 359: 'black-footed ferret, ferret, Mustela nigripes', 360: 'otter', 361: 'skunk, polecat, wood puss*', 362: 'badger', 363: 'armadillo', 364: 'three-toed sloth, ai, Bradypus tridactylus', 365: 'orangutan, orang, orangutang, Pongo pygmaeus', 366: 'gorilla, Gorilla gorilla', 367: 'chimpanzee, chimp, Pan troglodytes', 368: 'gibbon, Hylobates lar', 369: 'siamang, Hylobates syndactylus, Symphalangus syndactylus', 370: 'guenon, guenon monkey', 371: 'patas, hussar monkey, Erythrocebus patas', 372: 'baboon', 373: 'macaque', 374: 'langur', 375: 'colobus, colobus monkey', 376: 'proboscis monkey, Nasalis larvatus', 377: 'marmoset', 378: 'capuchin, ringtail, Cebus capucinus', 379: 'howler monkey, howler', 380: 'titi, titi monkey', 381: 'spider monkey, Ateles geoffroyi', 382: 'squirrel monkey, Saimiri sciureus', 383: 'Madagascar cat, ring-tailed lemur, Lemur catta', 384: 'indri, indris, Indri indri, Indri brevicaudatus', 385: 'Indian elephant, Elephas maximus', 386: 'African elephant, Loxodonta africana', 387: 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens', 388: 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca', 389: 'barracouta, snoek', 390: 'eel', 391: 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch', 392: 'rock beauty, Holocanthus tricolor', 393: 'anemone fish', 394: 'sturgeon', 395: 'gar, garfish, garpike, billfish, Lepisosteus osseus', 396: 'lionfish', 397: 'puffer, pufferfish, blowfish, globefish', 398: 'abacus', 399: 'abaya', 400: "academic gown, academic robe, judge's robe", 401: 'accordion, piano accordion, squeeze box', 402: 'acoustic guitar', 403: 'aircraft carrier, carrier, flattop, attack aircraft carrier', 404: 'airliner', 405: 'airship, dirigible', 406: 'altar', 407: 'ambulance', 408: 'amphibian, amphibious vehicle', 409: 'analog clock', 410: 'apiary, bee house', 411: 'apron', 412: 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin', 413: 'assault rifle, assault gun', 414: 'backpack, back pack, knapsack, packsack, rucksack, haversack', 415: 'bakery, bakeshop, bakehouse', 416: 'balance beam, beam', 417: 'balloon', 418: 'ballpoint, ballpoint pen, ballpen, Biro', 419: 'Band Aid', 420: 'banjo', 421: 'bannister, banister, balustrade, balusters, handrail', 422: 'barbell', 423: 'barber chair', 424: 'barbershop', 425: 'barn', 426: 'barometer', 427: 'barrel, cask', 428: 'barrow, garden cart, lawn cart, wheelbarrow', 429: 'baseball', 430: 'basketball', 431: 'bassinet', 432: 'bassoon', 433: 'bathing cap, swimming cap', 434: 'bath towel', 435: 'bathtub, bathing tub, bath, tub', 436: 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon', 437: 'beacon, lighthouse, beacon light, pharos', 438: 'beaker', 439: 'bearskin, busby, shako', 440: 'beer bottle', 441: 'beer glass', 442: 'bell cote, bell cot', 443: 'bib', 444: 'bicycle-built-for-two, tandem bicycle, tandem', 445: 'bikini, two-piece', 446: 'binder, ring-binder', 447: 'binoculars, field glasses, opera glasses', 448: 'birdhouse', 449: 'boathouse', 450: 'bobsled, bobsleigh, bob', 451: 'bolo tie, bolo, bola tie, bola', 452: 'bonnet, poke bonnet', 453: 'bookcase', 454: 'bookshop, bookstore, bookstall', 455: 'bottlecap', 456: 'bow', 457: 'bow tie, bow-tie, bowtie', 458: 'brass, memorial tablet, plaque', 459: 'brassiere, bra, bandeau', 460: 'breakwater, groin, groyne, mole, bulwark, seawall, jetty', 461: 'breastplate, aegis, egis', 462: 'broom', 463: 'bucket, pail', 464: 'buckle', 465: 'bulletproof vest', 466: 'bullet train, bullet', 467: 'butcher shop, meat market', 468: 'cab, hack, taxi, taxicab', 469: 'caldron, cauldron', 470: 'candle, taper, wax light', 471: 'cannon', 472: 'canoe', 473: 'can opener, tin opener', 474: 'cardigan', 475: 'car mirror', 476: 'carousel, carrousel, merry-go-round, roundabout, whirligig', 477: "carpenter's kit, tool kit", 478: 'carton', 479: 'car wheel', 480: 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM', 481: 'cassette', 482: 'cassette player', 483: 'castle', 484: 'catamaran', 485: 'CD player', 486: 'cello, violoncello', 487: 'cellular telephone, cellular phone, cellphone, cell, mobile phone', 488: 'chain', 489: 'chainlink fence', 490: 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour', 491: 'chain saw, chainsaw', 492: 'chest', 493: 'chiffonier, commode', 494: 'chime, bell, gong', 495: 'china cabinet, china closet', 496: 'Christmas stocking', 497: 'church, church building', 498: 'cinema, movie theater, movie theatre, movie house, picture palace', 499: 'cleaver, meat cleaver, chopper', 500: 'cliff dwelling', 501: 'cloak', 502: 'clog, geta, patten, sabot', 503: 'co*cktail shaker', 504: 'coffee mug', 505: 'coffeepot', 506: 'coil, spiral, volute, whorl, helix', 507: 'combination lock', 508: 'computer keyboard, keypad', 509: 'confectionery, confectionary, candy store', 510: 'container ship, containership, container vessel', 511: 'convertible', 512: 'corkscrew, bottle screw', 513: 'cornet, horn, trumpet, trump', 514: 'cowboy boot', 515: 'cowboy hat, ten-gallon hat', 516: 'cradle', 517: 'crane', 518: 'crash helmet', 519: 'crate', 520: 'crib, cot', 521: 'Crock Pot', 522: 'croquet ball', 523: 'crutch', 524: 'cuirass', 525: 'dam, dike, dyke', 526: 'desk', 527: 'desktop computer', 528: 'dial telephone, dial phone', 529: 'diaper, nappy, napkin', 530: 'digital clock', 531: 'digital watch', 532: 'dining table, board', 533: 'dishrag, dishcloth', 534: 'dishwasher, dish washer, dishwashing machine', 535: 'disk brake, disc brake', 536: 'dock, dockage, docking facility', 537: 'dogsled, dog sled, dog sleigh', 538: 'dome', 539: 'doormat, welcome mat', 540: 'drilling platform, offshore rig', 541: 'drum, membranophone, tympan', 542: 'drumstick', 543: 'dumbbell', 544: 'Dutch oven', 545: 'electric fan, blower', 546: 'electric guitar', 547: 'electric locomotive', 548: 'entertainment center', 549: 'envelope', 550: 'espresso maker', 551: 'face powder', 552: 'feather boa, boa', 553: 'file, file cabinet, filing cabinet', 554: 'fireboat', 555: 'fire engine, fire truck', 556: 'fire screen, fireguard', 557: 'flagpole, flagstaff', 558: 'flute, transverse flute', 559: 'folding chair', 560: 'football helmet', 561: 'forklift', 562: 'fountain', 563: 'fountain pen', 564: 'four-poster', 565: 'freight car', 566: 'French horn, horn', 567: 'frying pan, frypan, skillet', 568: 'fur coat', 569: 'garbage truck, dustcart', 570: 'gasmask, respirator, gas helmet', 571: 'gas pump, gasoline pump, petrol pump, island dispenser', 572: 'goblet', 573: 'go-kart', 574: 'golf ball', 575: 'golfcart, golf cart', 576: 'gondola', 577: 'gong, tam-tam', 578: 'gown', 579: 'grand piano, grand', 580: 'greenhouse, nursery, glasshouse', 581: 'grille, radiator grille', 582: 'grocery store, grocery, food market, market', 583: 'guillotine', 584: 'hair slide', 585: 'hair spray', 586: 'half track', 587: 'hammer', 588: 'hamper', 589: 'hand blower, blow dryer, blow drier, hair dryer, hair drier', 590: 'hand-held computer, hand-held microcomputer', 591: 'handkerchief, hankie, hanky, hankey', 592: 'hard disc, hard disk, fixed disk', 593: 'harmonica, mouth organ, harp, mouth harp', 594: 'harp', 595: 'harvester, reaper', 596: 'hatchet', 597: 'holster', 598: 'home theater, home theatre', 599: 'honeycomb', 600: 'hook, claw', 601: 'hoopskirt, crinoline', 602: 'horizontal bar, high bar', 603: 'horse cart, horse-cart', 604: 'hourglass', 605: 'iPod', 606: 'iron, smoothing iron', 607: "jack-o'-lantern", 608: 'jean, blue jean, denim', 609: 'jeep, landrover', 610: 'jersey, T-shirt, tee shirt', 611: 'jigsaw puzzle', 612: 'jinrikisha, ricksha, rickshaw', 613: 'joystick', 614: 'kimono', 615: 'knee pad', 616: 'knot', 617: 'lab coat, laboratory coat', 618: 'ladle', 619: 'lampshade, lamp shade', 620: 'laptop, laptop computer', 621: 'lawn mower, mower', 622: 'lens cap, lens cover', 623: 'letter opener, paper knife, paperknife', 624: 'library', 625: 'lifeboat', 626: 'lighter, light, igniter, ignitor', 627: 'limousine, limo', 628: 'liner, ocean liner', 629: 'lipstick, lip rouge', 630: 'Loafer', 631: 'lotion', 632: 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system', 633: "loupe, jeweler's loupe", 634: 'lumbermill, sawmill', 635: 'magnetic compass', 636: 'mailbag, postbag', 637: 'mailbox, letter box', 638: 'maillot', 639: 'maillot, tank suit', 640: 'manhole cover', 641: 'maraca', 642: 'marimba, xylophone', 643: 'mask', 644: 'matchstick', 645: 'maypole', 646: 'maze, labyrinth', 647: 'measuring cup', 648: 'medicine chest, medicine cabinet', 649: 'megalith, megalithic structure', 650: 'microphone, mike', 651: 'microwave, microwave oven', 652: 'military uniform', 653: 'milk can', 654: 'minibus', 655: 'miniskirt, mini', 656: 'minivan', 657: 'missile', 658: 'mitten', 659: 'mixing bowl', 660: 'mobile home, manufactured home', 661: 'Model T', 662: 'modem', 663: 'monastery', 664: 'monitor', 665: 'moped', 666: 'mortar', 667: 'mortarboard', 668: 'mosque', 669: 'mosquito net', 670: 'motor scooter, scooter', 671: 'mountain bike, all-terrain bike, off-roader', 672: 'mountain tent', 673: 'mouse, computer mouse', 674: 'mousetrap', 675: 'moving van', 676: 'muzzle', 677: 'nail', 678: 'neck brace', 679: 'necklace', 680: 'nipple', 681: 'notebook, notebook computer', 682: 'obelisk', 683: 'oboe, hautboy, hautbois', 684: 'ocarina, sweet potato', 685: 'odometer, hodometer, mileometer, milometer', 686: 'oil filter', 687: 'organ, pipe organ', 688: 'oscilloscope, scope, cathode-ray oscilloscope, CRO', 689: 'overskirt', 690: 'oxcart', 691: 'oxygen mask', 692: 'packet', 693: 'paddle, boat paddle', 694: 'paddlewheel, paddle wheel', 695: 'padlock', 696: 'paintbrush', 697: "pajama, pyjama, pj's, jammies", 698: 'palace', 699: 'panpipe, pandean pipe, syrinx', 700: 'paper towel', 701: 'parachute, chute', 702: 'parallel bars, bars', 703: 'park bench', 704: 'parking meter', 705: 'passenger car, coach, carriage', 706: 'patio, terrace', 707: 'pay-phone, pay-station', 708: 'pedestal, plinth, footstall', 709: 'pencil box, pencil case', 710: 'pencil sharpener', 711: 'perfume, essence', 712: 'Petri dish', 713: 'photocopier', 714: 'pick, plectrum, plectron', 715: 'pickelhaube', 716: 'picket fence, paling', 717: 'pickup, pickup truck', 718: 'pier', 719: 'piggy bank, penny bank', 720: 'pill bottle', 721: 'pillow', 722: 'ping-pong ball', 723: 'pinwheel', 724: 'pirate, pirate ship', 725: 'pitcher, ewer', 726: "plane, carpenter's plane, woodworking plane", 727: 'planetarium', 728: 'plastic bag', 729: 'plate rack', 730: 'plow, plough', 731: "plunger, plumber's helper", 732: 'Polaroid camera, Polaroid Land camera', 733: 'pole', 734: 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria', 735: 'poncho', 736: 'pool table, billiard table, snooker table', 737: 'pop bottle, soda bottle', 738: 'pot, flowerpot', 739: "potter's wheel", 740: 'power drill', 741: 'prayer rug, prayer mat', 742: 'printer', 743: 'prison, prison house', 744: 'projectile, missile', 745: 'projector', 746: 'puck, hockey puck', 747: 'punching bag, punch bag, punching ball, punchball', 748: 'purse', 749: 'quill, quill pen', 750: 'quilt, comforter, comfort, puff', 751: 'racer, race car, racing car', 752: 'racket, racquet', 753: 'radiator', 754: 'radio, wireless', 755: 'radio telescope, radio reflector', 756: 'rain barrel', 757: 'recreational vehicle, RV, R.V.', 758: 'reel', 759: 'reflex camera', 760: 'refrigerator, icebox', 761: 'remote control, remote', 762: 'restaurant, eating house, eating place, eatery', 763: 'revolver, six-gun, six-shooter', 764: 'rifle', 765: 'rocking chair, rocker', 766: 'rotisserie', 767: 'rubber eraser, rubber, pencil eraser', 768: 'rugby ball', 769: 'rule, ruler', 770: 'running shoe', 771: 'safe', 772: 'safety pin', 773: 'saltshaker, salt shaker', 774: 'sandal', 775: 'sarong', 776: 'sax, saxophone', 777: 'scabbard', 778: 'scale, weighing machine', 779: 'school bus', 780: 'schooner', 781: 'scoreboard', 782: 'screen, CRT screen', 783: 'screw', 784: 'screwdriver', 785: 'seat belt, seatbelt', 786: 'sewing machine', 787: 'shield, buckler', 788: 'shoe shop, shoe-shop, shoe store', 789: 'shoji', 790: 'shopping basket', 791: 'shopping cart', 792: 'shovel', 793: 'shower cap', 794: 'shower curtain', 795: 'ski', 796: 'ski mask', 797: 'sleeping bag', 798: 'slide rule, slipstick', 799: 'sliding door', 800: 'slot, one-armed bandit', 801: 'snorkel', 802: 'snowmobile', 803: 'snowplow, snowplough', 804: 'soap dispenser', 805: 'soccer ball', 806: 'sock', 807: 'solar dish, solar collector, solar furnace', 808: 'sombrero', 809: 'soup bowl', 810: 'space bar', 811: 'space heater', 812: 'space shuttle', 813: 'spatula', 814: 'speedboat', 815: "spider web, spider's web", 816: 'spindle', 817: 'sports car, sport car', 818: 'spotlight, spot', 819: 'stage', 820: 'steam locomotive', 821: 'steel arch bridge', 822: 'steel drum', 823: 'stethoscope', 824: 'stole', 825: 'stone wall', 826: 'stopwatch, stop watch', 827: 'stove', 828: 'strainer', 829: 'streetcar, tram, tramcar, trolley, trolley car', 830: 'stretcher', 831: 'studio couch, day bed', 832: 'stupa, tope', 833: 'submarine, pigboat, sub, U-boat', 834: 'suit, suit of clothes', 835: 'sundial', 836: 'sunglass', 837: 'sunglasses, dark glasses, shades', 838: 'sunscreen, sunblock, sun blocker', 839: 'suspension bridge', 840: 'swab, swob, mop', 841: 'sweatshirt', 842: 'swimming trunks, bathing trunks', 843: 'swing', 844: 'switch, electric switch, electrical switch', 845: 'syringe', 846: 'table lamp', 847: 'tank, army tank, armored combat vehicle, armoured combat vehicle', 848: 'tape player', 849: 'teapot', 850: 'teddy, teddy bear', 851: 'television, television system', 852: 'tennis ball', 853: 'thatch, thatched roof', 854: 'theater curtain, theatre curtain', 855: 'thimble', 856: 'thresher, thrasher, threshing machine', 857: 'throne', 858: 'tile roof', 859: 'toaster', 860: 'tobacco shop, tobacconist shop, tobacconist', 861: 'toilet seat', 862: 'torch', 863: 'totem pole', 864: 'tow truck, tow car, wrecker', 865: 'toyshop', 866: 'tractor', 867: 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi', 868: 'tray', 869: 'trench coat', 870: 'tricycle, trike, velocipede', 871: 'trimaran', 872: 'tripod', 873: 'triumphal arch', 874: 'trolleybus, trolley coach, trackless trolley', 875: 'trombone', 876: 'tub, vat', 877: 'turnstile', 878: 'typewriter keyboard', 879: 'umbrella', 880: 'unicycle, monocycle', 881: 'upright, upright piano', 882: 'vacuum, vacuum cleaner', 883: 'vase', 884: 'vault', 885: 'velvet', 886: 'vending machine', 887: 'vestment', 888: 'viaduct', 889: 'violin, fiddle', 890: 'volleyball', 891: 'waffle iron', 892: 'wall clock', 893: 'wallet, billfold, notecase, pocketbook', 894: 'wardrobe, closet, press', 895: 'warplane, military plane', 896: 'washbasin, handbasin, washbowl, lavabo, wash-hand basin', 897: 'washer, automatic washer, washing machine', 898: 'water bottle', 899: 'water jug', 900: 'water tower', 901: 'whiskey jug', 902: 'whistle', 903: 'wig', 904: 'window screen', 905: 'window shade', 906: 'Windsor tie', 907: 'wine bottle', 908: 'wing', 909: 'wok', 910: 'wooden spoon', 911: 'wool, woolen, woollen', 912: 'worm fence, snake fence, snake-rail fence, Virginia fence', 913: 'wreck', 914: 'yawl', 915: 'yurt', 916: 'web site, website, internet site, site', 917: 'comic book', 918: 'crossword puzzle, crossword', 919: 'street sign', 920: 'traffic light, traffic signal, stoplight', 921: 'book jacket, dust cover, dust jacket, dust wrapper', 922: 'menu', 923: 'plate', 924: 'guacamole', 925: 'consomme', 926: 'hot pot, hotpot', 927: 'trifle', 928: 'ice cream, icecream', 929: 'ice lolly, lolly, lollipop, popsicle', 930: 'French loaf', 931: 'bagel, beigel', 932: 'pretzel', 933: 'cheeseburger', 934: 'hotdog, hot dog, red hot', 935: 'mashed potato', 936: 'head cabbage', 937: 'broccoli', 938: 'cauliflower', 939: 'zucchini, courgette', 940: 'spaghetti squash', 941: 'acorn squash', 942: 'butternut squash', 943: 'cucumber, cuke', 944: 'artichoke, globe artichoke', 945: 'bell pepper', 946: 'cardoon', 947: 'mushroom', 948: 'Granny Smith', 949: 'strawberry', 950: 'orange', 951: 'lemon', 952: 'fig', 953: 'pineapple, ananas', 954: 'banana', 955: 'jackfruit, jak, jack', 956: 'custard apple', 957: 'pomegranate', 958: 'hay', 959: 'carbonara', 960: 'chocolate sauce, chocolate syrup', 961: 'dough', 962: 'meat loaf, meatloaf', 963: 'pizza, pizza pie', 964: 'potpie', 965: 'burrito', 966: 'red wine', 967: 'espresso', 968: 'cup', 969: 'eggnog', 970: 'alp', 971: 'bubble', 972: 'cliff, drop, drop-off', 973: 'coral reef', 974: 'geyser', 975: 'lakeside, lakeshore', 976: 'promontory, headland, head, foreland', 977: 'sandbar, sand bar', 978: 'seashore, coast, seacoast, sea-coast', 979: 'valley, vale', 980: 'volcano', 981: 'ballplayer, baseball player', 982: 'groom, bridegroom', 983: 'scuba diver', 984: 'rapeseed', 985: 'daisy', 986: "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum", 987: 'corn', 988: 'acorn', 989: 'hip, rose hip, rosehip', 990: 'buckeye, horse chestnut, conker', 991: 'coral fungus', 992: 'agaric', 993: 'gyromitra', 994: 'stinkhorn, carrion fungus', 995: 'earthstar', 996: 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa', 997: 'bolete', 998: 'ear, spike, capitulum', 999: 'toilet tissue, toilet paper, bathroom tissue'}

運用実績

1.LeNet5

basenet: lenet5 (image size: 32 * 32 * 3)dataset: cifarlen(dataset): 50000, iter_size: 1562 batch_size: 32optim: SGDscheduler: MultiStepLRmilestones: [15, 20, 30]weight_decay: 1e-4gamma: 0.1momentum: 0.9lr: 0.01epoch: 30

合計

エポック 平均トップ 1 アクセス (%) 上位 5 の平均アクセス (%)
30 0h11m44s 62.21 95.97

2.アレックスネット

basenet: alexnet (image size: 224 * 224 * 3)dataset: cifarlen(dataset): 50000, iter_size: 1562 batch_size: 32optim: SGDscheduler: MultiStepLRmilestones: [15, 20, 30]weight_decay: 1e-4gamma: 0.1momentum: 0.9lr:0.01epoch: 30 

合計

エポック 平均トップ 1 アクセス (%) 上位 5 の平均アクセス (%)
30 0h22m44s 86.27 99.0

3.VGG

basenet: vgg16 (image size: 224 * 224 * 3)dataset: cifarlen(dataset): 50000, iter_size: 1562 batch_size: 32optim: SGDscheduler: MultiStepLRmilestones: [15, 20, 30]weight_decay: 1e-4gamma: 0.1momentum: 0.9lr:0.01epoch: 30 

合計

エポック 平均トップ 1 アクセス (%) 上位 5 の平均アクセス (%)
30 1時間23分43秒 76.56 96.44

4.レスネット

basenet: resnet18dataset: ImageNetimage size: 224 * 224 * 3 (可自定义)batch_size: 32optim: SGDscheduler: MultiStepLRmilestones: [15, 20, 30]weight_decay: 1e-4gamma: 0.1momentum: 0.9lr:0.001epoch: 30
エポック番号 上位 1 件 (%) トップ 5 の割合 (%)
5 3時間49分35秒 50.21 75.59

次の章

CV + 深層学習——ネットワーク アーキテクチャ Pytorch 再現シリーズ——分類 (2: ResNeXt、GoogLeNet、MobileNet) CV+Deep Learning - Network Architecture Pytorch Reproduction Series - 分類 (1: LeNet5, VGG, AlexNet, ResNet) (14)https://blog.csdn.net/XiaoyYidiaodiao/article/details/125692368?csdn_share_tail=%7B%22type%22 % 3A%22blog%22%2C%22rType%22%3A%22article%22%2C%22rId%22%3A%22125692368%22%2C%22source%22%3A%22XiaoyYidiaodiao%22%7D&ctrtid=yBcgN

[1] LeCun Y、Bottou L、Bengio Y、他。ドキュメント認識に適用される勾配ベースの学習[J]。IEEE の議事録、1998 年、86(11): 2278-2324。

[2] Krizhevsky A、Sutskever I、Hinton G E. 深い畳み込みニューラル ネットワークによる Imagenet 分類[J]。神経情報処理システムの進歩、2012、25。

[3] Simonyan K、Zisserman A. 大規模な画像認識のための非常に深い畳み込みネットワーク[J]。arXiv プレプリント arXiv:1409.1556, 2014.

[4] He K、Zhang X、Ren S、他。画像認識のための深層残差学習[C]//コンピューター ビジョンとパターン認識に関する IEEE 会議の議事録。2016: 770-778.

CV+Deep Learning - Network Architecture Pytorch Reproduction Series - 分類 (1: LeNet5, VGG, AlexNet, ResNet) (2024)
Top Articles
The Livvy Dunne Leaked Scandal: Unveiling the Impact and Lessons Learned - The Digital Weekly
Video Filtrado De Iamferv
Busted Newspaper Zapata Tx
Ffxiv Palm Chippings
Western Union Mexico Rate
Doublelist Paducah Ky
877-668-5260 | 18776685260 - Robocaller Warning!
Costco in Hawthorne (14501 Hindry Ave)
WK Kellogg Co (KLG) Dividends
Xm Tennis Channel
How to watch free movies online
Nebraska Furniture Tables
Elbasha Ganash Corporation · 2521 31st Ave, Apt B21, Astoria, NY 11106
Boston Gang Map
Chelactiv Max Cream
Hocus Pocus Showtimes Near Amstar Cinema 16 - Macon
Ups Access Point Lockers
Whitefish Bay Calendar
Weepinbell Gen 3 Learnset
Craigslist Appomattox Va
Dragger Games For The Brain
Xfinity Cup Race Today
Meridian Owners Forum
Kabob-House-Spokane Photos
Water Temperature Robert Moses
Babydepot Registry
Insidious 5 Showtimes Near Cinemark Southland Center And Xd
R/Mp5
Wheeling Matinee Results
Http://N14.Ultipro.com
Ixlggusd
Gideon Nicole Riddley Read Online Free
Breckie Hill Fapello
Iban's staff
Linabelfiore Of
Elgin Il Building Department
Myfxbook Historical Data
301 Priest Dr, KILLEEN, TX 76541 - HAR.com
Registrar Lls
Trivago Anaheim California
Ehome America Coupon Code
John M. Oakey & Son Funeral Home And Crematory Obituaries
9:00 A.m. Cdt
Caesars Rewards Loyalty Program Review [Previously Total Rewards]
Bridgeport Police Blotter Today
Das schönste Comeback des Jahres: Warum die Vengaboys nie wieder gehen dürfen
Union Supply Direct Wisconsin
Cars & Trucks near Old Forge, PA - craigslist
Ics 400 Test Answers 2022
Download Twitter Video (X), Photo, GIF - Twitter Downloader
Cognitive Function Test Potomac Falls
Ff14 Palebloom Kudzu Cloth
Latest Posts
Article information

Author: Terence Hammes MD

Last Updated:

Views: 5814

Rating: 4.9 / 5 (69 voted)

Reviews: 92% of readers found this page helpful

Author information

Name: Terence Hammes MD

Birthday: 1992-04-11

Address: Suite 408 9446 Mercy Mews, West Roxie, CT 04904

Phone: +50312511349175

Job: Product Consulting Liaison

Hobby: Jogging, Motor sports, Nordic skating, Jigsaw puzzles, Bird watching, Nordic skating, Sculpting

Introduction: My name is Terence Hammes MD, I am a inexpensive, energetic, jolly, faithful, cheerful, proud, rich person who loves writing and wants to share my knowledge and understanding with you.