【python深度学习】——torchvision.models
torchvision是PyTorch生态系统中的一个包,专门用于计算机视觉任务。它提供了一系列用于加载、处理和预处理图像和视频数据的工具,以及常用的计算机视觉模型和数据集。关于此模块的官网介绍在这里。这个模块包含许多常用的预训练计算机视觉模型,例如ResNetAlexNetVGG等分类、分割等模型。在官网示例中可以看到, 在0.14版本之后, 可以通过调用list_modelm1 = get_m
·
【python深度学习】——torchvision.models
1. torchvision简介
torchvision是PyTorch生态系统中的一个包,专门用于计算机视觉任务。它提供了一系列用于加载、处理和预处理图像和视频数据的工具,以及常用的计算机视觉模型和数据集。
2. models模块
2.1 简介
关于此模块的官网介绍在这里。
这个模块包含许多常用的预训练计算机视觉模型,例如ResNet、AlexNet、VGG等分类、分割等模型。在官网示例中可以看到, 在0.14版本之后, 可以通过调用list_model函数来查看torch vision中提供的模型:
# List available models
all_models = list_models()
classification_models = list_models(module=torchvision.models)
# Initialize models
m1 = get_model("mobilenet_v3_large", weights=None) #不带预训练权重
m2 = get_model("quantized_mobilenet_v3_large", weights="DEFAULT")
# Fetch weights
weights = get_weight("MobileNet_V3_Large_QuantizedWeights.DEFAULT")
assert weights == MobileNet_V3_Large_QuantizedWeights.DEFAULT
weights_enum = get_model_weights("quantized_mobilenet_v3_large")
assert weights_enum == MobileNet_V3_Large_QuantizedWeights
weights_enum2 = get_model_weights(torchvision.models.quantization.mobilenet_v3_large)
assert weights_enum == weights_enum2
这些模型可以直接用于特定的任务,也可以进行微调以适应用户的特定需求。
2.1.1 模型加载——不带预训练权重
from torchvision.models import resnet50
model = resnet50(weights=None)
2.1.2 模型加载——指定预训练权重版本
从官网的介绍中,我们可以看到各个模型都有许多可以加载的预训练权重版本,如下图所示(部分):
那么我们就可以通过全称或者简称来加载模型, 如下所示:
from torchvision.models import resnet50, ResNet50_Weights
# Using pretrained weights:
resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
resnet50(weights="IMAGENET1K_V1")
2.1.3 模型加载——加载量化版本的模型
此外, 还可以指定加载量化的模型, 只需要在加载模型时指定“quantize=True”
from torchvision.models.quantization import resnet50, ResNet50_QuantizedWeights
weights = ResNet50_QuantizedWeights.DEFAULT
model = resnet50(weights=weights, quantize=True)
开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!
更多推荐
已为社区贡献3条内容
所有评论(0)