pretrainedmodels 库使用

prtrainedmodels 是一个基于pytorch框架构建的预训练模型使用的库,比官方torchvision中的预训练模型更加丰富。

github项目地址:https://github.com/Cadene/pretrained-models.pytorch

安装

pip install pretrainedmodels

获取所有预训练模型名称

import pretrainedmodels

print(pretrainedmodels.model_names)

# result
> ['fbresnet152', 'bninception', 'resnext101_32x4d', 'resnext101_64x4d', 'inceptionv4', 'inceptionresnetv2', 'alexnet', 'densenet121', 'densenet169', 'densenet201', 'densenet161', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'inceptionv3', 'squeezenet1_0', 'squeezenet1_1', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19_bn', 'vgg19', 'nasnetalarge', 'nasnetamobile', 'cafferesnet101', 'senet154',  'se_resnet50', 'se_resnet101', 'se_resnet152', 'se_resnext50_32x4d', 'se_resnext101_32x4d', 'cafferesnet101', 'polynet', 'pnasnet5large']

加载基于imagenet的预训练模型

model_name = 'nasnetalarge' # could be fbresnet152 or inceptionresnetv2

model = pretrainedmodels.__dict__[model_name](num_classes=1000, pretrained='imagenet')
model.eval()

或者

from pretrainedmodels.models.senet import senet154

model = senet154(num_classes=1000)
# 自己根据预训练模型路径进行加载
model.load_state_dict(torch.load('../pre-trained_model/senet154-c7b49a05.pth'))

预测

import torch
import pretrainedmodels.utils as utils

load_img = utils.LoadImage()

# transformations depending on the model
# rescale, center crop, normalize, and others (ex: ToBGR, ToRange255)
tf_img = utils.TransformImage(model) 

path_img = 'data/cat.jpg'

input_img = load_img(path_img)
input_tensor = tf_img(input_img)         # 3x400x225 -> 3x299x299 size may differ
input_tensor = input_tensor.unsqueeze(0) # 3x299x299 -> 1x3x299x299
input = torch.autograd.Variable(input_tensor,
    requires_grad=False)

output_logits = model(input) # 1x1000

微调模型

# fine tuning
dim_feats = model.last_linear.in_features
nb_classes = 4
model.last_linear = nn.Linear(dim_feats, nb_classes)

直接使用结构

不使用预训练模型的话,也可以直接使用这个模型的结构,根据分类数设置自己最后一层的神经元个数。

model = pretrainedmodels.__dict__['nasnetalarge'](num_classes=54, pretrained=None)