Mobile net V系列详解 理论+实战(2)

news/2024/9/18 23:15:15 标签: 人工智能, pytorch, 机器学习, 神经网络, 算法

请添加图片描述

Mobilenet 系列

  • 实践部分
  • 一、数据集介绍
  • 二、模型整体框架
  • 三、模型代码详解
  • 四、总结

实践部分

本章针对实践通过使用pytorch一个实例对这部分内容进行吸收分析。本章节采用的源代码在这里感兴趣的读者可以自行下载操作。

一、数据集介绍

可以看到数据集本身被存放在了三个文件夹下,其主要是花的图片,被分割成了验证集和训练集,模型训练主要就是采用训练集中的数据进行训练,验证集则用来对模型的性能进行测试。
请添加图片描述
为了进一步增强数据集的结构化和规范化,每个图像通常会被放置在代表其类别的文件夹中。这意味着所有同类别的图像会被存放在相同的文件夹里。这样的存放方式不仅使数据集的管理变得简单化,更重要的是,为使用自动化工具提供了便利。例如,图像数据集的这种标准存放形式完美支持了 PyTorch 中的DatasetFolder工具直接进行处理。请添加图片描述
前几章节在实战部分讲述过,可以省却重复编码自定义Dataset类的复杂过程。DatasetFolder工具能够直观地从这种组织形式的数据集中加载图像及其对应标签,大幅简化了数据预处理和加载的步骤。

二、模型整体框架

在深度学习模型的训练和部署过程中,整个工程项目通常围绕着以下三个核心文件进行组织,进而构建起模型的完整架构。这些文件分别负责不同的任务,协同工作以实现模型的训练、评估和应用。

  1. 模型模块(Model Module) - 位于心脏位置的模型模块,负责存放模型的主体架构。它定义了模型的各个层、前向传播逻辑以及计算过程,是整个深度学习任务的基础和核心。

  2. 训练文件(Training ) - 这个脚本文件负责驱动模型的训练过程。它通过调用先前准备好的数据集及模型模块,以特定的训练策略(例如学习率调整、批处理大小选择等)对模型进行训练。该文件通常会包含模型训练、验证过程,并输出训练过程中的性能指标,如损失和准确率等。

  3. 预测模块(Prediction) - 一旦模型被训练并优化到满意的状态,预测模块则负责将这个训练好的模型导入并应用到后续的任务中。无论是用于进一步的分析、应对实时的预测请求,还是集成至更广阔的系统中,预测模块都为模型的实际使用提供了接口。

将围绕这三个文件对整个模型的框架进行展开讲解。
请添加图片描述

三、模型代码详解

首先看下模型所需要的函数部分:

import os # 文件和文件夹提供一系列操作的工具,当前文件中主要用来查找模块文件的路径地址
import sys
import json
import torch
import torch.nn as nn
import torch.optim as optim # 优化方法Adam之类的优化算法
from torchvision import transforms, datasets # 数据集操作
from tqdm import tqdm # 进度条
from model_v2 import MobileNetV2 # 编写的模型主题框架文件

接下来看train的主体文件:

def main(): # 主函数在当前文件下直接执行
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # 判断下GPU是否有效
    print("using {} device.".format(device)) # 输出下在什么设备上运行的

    batch_size = 16 # 批大小
    epochs = 5 # 全部周期

    data_transform = {
    # 即对打开的图片如何处理再送入模型,数据增强技术 .Compose将做种方式进行整合,可以按照字典的方式进行调取使用
        "train":  transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

transforms.Compose是PyTorch的torchvision.transforms模块中的一个功能,用于组合多个图像变换操作。以下是这一系列变换操作的具体作用解释:

  1. transforms.RandomResizedCrop(224):

    • 这个变换随机地对图像进行裁剪,并将裁剪后的图像缩放到给定的大小(在这个例子中是224x224像素)。这种变换能够在一定程度上减少模型对图像特定部分的依赖,提高模型对于图像位置变化的鲁棒性,常用于数据增强。
  2. transforms.RandomHorizontalFlip():

    • 随机地水平翻转图像。对于每个图像,它有50%的概率被翻转。这种变换能够增加数据的多样性,帮助模型学习到对于水平方向不变性的特征,减少过拟合。
  3. transforms.ToTensor():

    • 将PIL图像或NumPy的ndarray转换为PyTorch的Tensor。这个操作还会自动将图像的数据从0到255的整数映射到0到1的浮点数,标准化图像的数据范围。
  4. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]):

    • 对图像进行标准化,即减去均值(mean)后再除以标准差(std)进行归一化。这里的均值[0.485, 0.456, 0.406]和标准差[0.229, 0.224, 0.225]是针对每一个通道的(通常为RGB通道)。这样的归一化有助于加速训练过程,减少模型对原始图像灰度尺度的依赖。
    • 这组特定的均值和标准差来自ImageNet数据集的统计,是很多预训练模型使用的标准化参数。如果你使用这些预训练模型,采用相同的归一化参数可以保持数据的一致性。

训练集合中这一组变换操作首先对图像进行了数据增强(通过随机裁剪和随机水平翻转),然后转换为了模型训练需要的Tensor格式,并且对图像进行了标准化处理,以便用于模型的训练。这些步骤是进行模型训练时常见的图像预处理流程。

测试集合中操作集合:

  1. transforms.Resize(256):

    • 首先对图像进行缩放,使其最短边的长度为256像素。这步是为了保证图像的尺寸一致性,为后续的裁剪操作做准备。
  2. transforms.CenterCrop(224):

    • 接下来执行中心裁剪,从缩放后的图像中裁切出一个大小为224x224像素的中心区域。中心裁剪通常用在验证和测试集的图像预处理中,旨在减少模型对图像边缘部分的依赖,同时保留图像最关键的内容区域。
  3. transforms.ToTensor():

    • 然后将处理过的图像转换为PyTorch Tensor,并自动将数值范围从[0, 255]归一化到[0, 1]。这是为了使图像数据适配PyTorch模型的输入要求。
  4. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]):

    • 最后,对图像的每个通道执行标准化操作。具体来说,使用给定的均值([0.485, 0.456, 0.406])和标准差([0.229, 0.224, 0.225])对图像的RGB通道进行标准化。这一步骤是基于ImageNet数据集的图像统计特性,可以进一步提升模型的泛化能力。标准化有助于加速模型训练,提高模型性能。

os.getcwd() 是Python中的一个函数,隶属于os(操作系统)模块。getcwdget current working directory的缩写,这个函数的作用是返回当前工作目录的绝对路径。

在Python程序中,当前工作目录指的是执行当前代码时所在的文件系统目录。

以下是一个简单的使用例子:

import os

# 获取并打印当前工作目录
current_directory = os.getcwd()
print("当前工作目录是:", current_directory)

下述代码找目录,就是找数据集的位置,用来传数据集,由于其为通用代码所以作者为了减少用户修改代码的必要再次进行模型自动调用。
如果你在命令行中运行上述Python脚本,它会打印出从哪个目录运行了Python解释器。了解当前的工作目录对于执行与文件路径操作相关的任务非常有用,比如读取或写入到相对路径的文件。
通过和"…/…"拼接找上两级的菜单作为当前图片的路径信息,如果要运行就自行修改。

    data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
    image_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set path

使用断言,如果这个路径不存在就报错,确保有数据集

    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)

这个功能和pytorch中的另一个模块比较像:
在PyTorch中,ImageFolderDatasetFolder是两个用来加载数据的类,它们确实有相似之处,但也有一些关键区别。详细地解析一下:

相似之处

  • 目的相同:两者都用于加载数据集,特别是那些按文件夹组织的数据集,其中每个文件夹包含一个类别的数据。
  • 简化数据加载:它们提供了简洁的接口来加载数据,减少了编写自定义加载逻辑的需要,通过transforms参数,还可以很方便地对数据进行预处理和增强。

关键区别

  1. 使用场景

    • ImageFolder特别适用于图像数据,它假定数据集是以文件夹方式组织的,其中每个文件夹对应一个类别的图像。它自动将文件夹的名字作为类别的标签。
    • DatasetFolder则更为通用,可以用来加载任何类型的数据,只要数据是按类别组织在不同文件夹中。它允许通过loader参数自定义如何加载数据,这意味着您可以定义加载图像、文本文件或其他类型文件的函数。
  2. 灵活性:#实际上是DatasetFolder的一个图片领域的应用,即在DatasetFolder中要规定如何打开这个数据,则这应用特例则直接内部定义好了,极简化处理

    • ImageFolder内部实际上是DatasetFolder一个具体实现,特化于处理图像文件,并且预设了使用PIL库来加载图像。这使得ImageFolder使用起来更加简单直观,特别是对于图像数据。
    • DatasetFolder提供了更多的自定义选项,比如自定义加载函数(loader)和数据后缀(extensions),从而可以更灵活地加载不同类型的文件数据。

示例

使用ImageFolder加载图像数据:

from torchvision.datasets import ImageFolder
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

dataset = ImageFolder(root='path/to/data', transform=transform)

使用DatasetFolder加载非图像类型的数据集:

from torchvision.datasets import DatasetFolder
from torchvision import transforms
from my_custom_loader import custom_loader_function

dataset = DatasetFolder(root='path/to/data', loader=custom_loader_function, extensions=('txt',), transform=some_transforms)

总之,虽然ImageFolderDatasetFolder有相似之处,它们都提供了用于加载和处理以文件夹为单位组织的数据集的便捷方法,但DatasetFolder的设计更为通用,提供了更大的灵活性,而ImageFolder则专门用于处理图像数据,使用起来更加方便简洁。

    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train") 
                                           transform=data_transform["train"])
    train_num = len(train_dataset) # 判断下数据集的长度
#获取属性到类别的映射
    # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4) #将python对象编码成Json字符串 indent:参数根据数据格式缩进显示,读起来更加清晰。
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)
# 具体流程就是通过使用class_to_idx得到索引映射信息,使用for进行辩论获取。反转位置将文件写入一个json字符串中,并创建一个文件夹对这部分数据进行保存。
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers 线程数量 计算单个批次损失你多个size就可以一起运行,多个size在不同的核上使用相同的模型计算,得到损失更新参数。
    print('Using {} dataloader workers every process'.format(nw)) # 输出最终决定使用的线程数量

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)
                                               # 创建加载器。迭代数据集

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=nw)

    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))

    # create model
    net = MobileNetV2(num_classes=5) # 实例化模型仅有最终类别需要进行设置

    # load pretrain weights
    # download url: https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
    model_weight_path = "./mobilenet_v2.pth"
    assert os.path.exists(model_weight_path), "file {} dose not exist.".format(model_weight_path)
    pre_weights = torch.load(model_weight_path, map_location='cpu')

    # delete classifier weights
    pre_dict = {k: v for k, v in pre_weights.items() if net.state_dict()[k].numel() == v.numel()}
    missing_keys, unexpected_keys = net.load_state_dict(pre_dict, strict=False)

    # freeze features weights
    for param in net.features.parameters():
        param.requires_grad = False

    net.to(device)

    # define loss function
    loss_function = nn.CrossEntropyLoss()

    # construct an optimizer
    params = [p for p in net.parameters() if p.requires_grad]
    optimizer = optim.Adam(params, lr=0.0001)

    best_acc = 0.0
    save_path = './MobileNetV2.pth'
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            logits = net(images.to(device))
            loss = loss_function(logits, labels.to(device))
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        # validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                # loss = loss_function(outputs, test_labels)
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

                val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,
                                                           epochs)
        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

    print('Finished Training')


if __name__ == '__main__':
    main()

四、总结

论文部分介绍的是mobilenet V1代码部分则是V2下一章节将对这差异部分进行详细的分析,及其模型核心代码的改变进行详细的指出,加油加油,明天就发。


http://www.niftyadmin.cn/n/5664680.html

相关文章

Day02Day03

1. 为什么拦截器不会去拦截/admin/login上,是因为在SpringMvc中清除了这种可能。 2.使用自己定义注解,实现AOP(insert ,update) 3.使用update最好使用动态语句,可以使用多次 4.使用阿里云的OSS存储。用common类 5.在写…

动手学深度学习(pytorch)学习记录28-使用块的网络(VGG)[学习记录]

目录 VGG块VGG网络训练模型 VGG块 定义了一个名为vgg_block的函数来实现一个VGG块 import torch from torch import nn from d2l import torch as d2ldef vgg_block(num_convs, in_channels, out_channels):layers []for _ in range(num_convs):layers.append(nn.Conv2d(in_…

Go 1.19.4 路径和目录-Day 15

1. 路径介绍 存储设备保存着数据,但是得有一种方便的模式让用户可以定位资源位置,操作系统采用一种路径字符 串的表达方式,这是一棵倒置的层级目录树,从根开始。 相对路径:不是以根目录开始的路径,例如 a/b…

ant-design表格自动合并相同内容的单元格

表格自动合并相同内容的单元格 合并hooks import { TableColumnProps } from antdexport const useAutoMergeTableCell <T extends object>(dataSource: Array<T>,columns: Array<TableColumnProps> | Array<keyof T> ): Map<keyof T, Array<…

【运维方案】软件运维服务方案(word)

1.项目情况 2.服务简述 2.1服务内容 2.2服务方式 2.3服务要求 2.4服务流程 2.5工作流程 2.6业务关系 2.7培训 3.资源提供 3.1项目组成员 3.2服务保障 进主页学习更多获取更多资料&#xff5e;

字节飞书-测开日常实习-部分手撕代码题

之前的文章提到了一道高频题&#xff1a;最长不重复的字串&#xff0c;用到动态窗口。解法就在之前的文章。这篇文章从牛客上找了一些手撕题&#xff0c;在这里记录分享一下。 1.将给定的字符串中的每个单词的首字母转化为大小字母【简单】 首字母大写__牛客网 不难 就是考察…

虚拟DOM介绍

工作流程 虚拟 DOM 并不直接发生在用户界面构建之前&#xff0c;而是作为构建用户界面过程中的一个重要部分。具体来说&#xff0c;虚拟 DOM 的工作流程如下&#xff1a; 初始化阶段&#xff1a; 组件定义&#xff1a;在应用程序开发过程中&#xff0c;开发者首先定义组件和它…

CMakeLists.txt的学习了解

CMakeLists.txt 是 CMake 构建系统中的配置文件&#xff0c;用于定义项目的编译规则和依赖关系。CMake 是一种跨平台的构建系统&#xff0c;支持从源代码生成编译脚本&#xff08;如 Makefile 或 Visual Studio 工程文件&#xff09;。CMakeLists.txt 通过指定项目信息、源文件…