1. 框架介绍
Easy-Classification是一个应用于分类任务的深度学习框架,它集成了众多成熟的分类神经网络模型,可帮助使用者简单快速的构建分类训练任务。
1.1 框架功能
1.1.1 数据加载
- 文件夹形式
- 其它自定义形式,在项目应用中,参考案例编写DataSet自定义加载。如基于配置文件,csv,路径解析等。
1.1.2 扩展网络
本框架扩展支持如下网络模型,可在classification_model_enum.py枚举类中查看具体的model。
- Resnet系列,Densenet系列,VGGnet系列等所有[pretrained-models.pytorch]支持的网络
- ShuffleNetV2,[MicroNet]
1.1.3 优化器
- Adam
- SGD
- AdaBelief
- AdamW
1.1.4 学习率衰减
- ReduceLROnPlateau
- StepLR
- MultiStepLR
- SGDR
1.1.5 损失函数
- 直接调用PyTorch相关的损失函数
- 交叉熵
- Focalloss
1.1.6 其他
- Metric(acc, F1)
- 训练结果acc,loss过程图片保存
- 交叉验证
- 梯度裁剪
- Earlystop
- weightdecay
- 冻结/解冻 除最后的全连接层的特征层
1.2 框架设计
Easy-Classification是一个简单轻巧的分类框架,目前版本主要包括两大模块,框架通用模块和项目应用模块。为方便用户快速体验,框架中目前包括简单手写数字识别和验证码识别两个示例项目。
1.2.1 通用模块设计
Easy-Classification通用模块整体结构如下:

通用模块核心类/文件介绍说明:
|
目录
|
子项
|
功能说明
|
扩展说明
|
|
config
|
|
框架基础配置目录
|
|
|
|
weight
|
预训练权重模型存储目录
|
各种神经网络的模型文件,下载后存储在该目录下
|
|
|
classification_model_enum.py
|
列举出当前分类框架,目前支持的分类神经网络模型。
枚举中的神经网络名称,与配置文件中的名称一样,表示加载对应的网络模型。
|
后续新增网络时,需在该枚举类中注入
|
|
project
|
|
分类框架下的项目应用模块,详细使用参考后续项目应用模块。
|
分类项目目录名称如:验证码识别,简单手写数字识别
|
|
universe
|
|
框架通用模块主目录。
|
后续通用的功能,均可放在该目录下。
|
|
|
data_load
|
基础数据加载类
|
加载训练数据,验证数据,预测数据等
|
|
|
data_load_service.py
|
基于配置文件,加载配置路径下的基础数据,返回对应的张量信息。
不同的分类任务,用户构建DataSet模式不同,该模块提供函数,接收用户构建的DataSet对象。做统一数据加载处理。
|
目前支持目录模式加载。
|
|
|
normalize_adapter.py
|
归一化配置类
|
其他新增网络的归一化参数,可配置在此类中。
|
|
|
model
|
定义目前框架中,支持的所有分类网络模型。
|
新增网络放入到model_category目录下。
|
|
|
model_service.py
|
分类网络模型的对外暴露类,基于配置文件,可指定具体使用哪个分类网络,项目应用时,只需调用moel_service。
moel_service.py:代理者的角色。类似于java中的代理模式。
|
新增的分类网络,要注入到moel_service.py中,对所有分类网络的统一拦截,加日志等功能可在model_service中实现。
|
|
runner_config
|
|
训练配置的目录,定义训练过程中的一些配置信息。
|
定义如优化器,学习率调整,损失函数等。
深度学习运行前,配置相关的模块均可放在该目录下。
|
|
|
optimizer_adapter.py
|
优化器适配类,根据配置文件,可返回一个具体的优化器。
|
常用优化器如:Adam,AdamW,SGD,AdaBelief,Ranger
|
|
|
loss_function_adapter.py
|
自定义损失函数适配类,可基于配置文件,返回一个具体的损失函数。
|
损失函数也可使用 PyTorch中提供的。
|
|
|
scheduler_adapter.py
|
学习率调整适配类,可基于配置文件,返回具体的调整类。
|
扩展支持ReduceLROnPlateau,StepLR,MultiStepLR, SGDR
|
|
utils
|
utils.py
|
常用的工具函数,如加载文件,全连接处理等
|
一些项目通用的工具类函数,如保存acc,loss等记录。
|
配置文件是设置在具体应用项目的目录下,配置文件可根据项目需求自定义编写,但每个配置文件需包含如下关键key字段:
|
key字段
|
解释
|
参考值
|
|
model_name
|
分类网络模型名称,如mobilenetv3,efficientnet_advprop,具体值参考ClassificationModelEnum枚举类中定义的值
|
efficientnet_advprop
|
|
GPU_ID
|
多GPU时,设置的GPU编码,无GPU时,该值设置为空
|
0
|
|
class_number
|
目标输出分类数量,如简单数字识别,输出值10
|
10
|
|
random_seed
|
随机数种子
|
43
|
|
num_workers
|
DataLoad加载数据时,是否启用多个线程加载数据
|
4
|
|
train_path
|
训练图像对应的存储目录地址
|
"data/train"
|
|
val_path
|
验证图像对应的存储目录地址
|
"data/val"
|
|
test_path
|
预测图像对应的存储目录地址
|
"data/test"
|
|
pretrained
|
预加载模型权重的文件存储路径,无值时,设置为空‘’
|
'../../out/mobilenetv3.pth'
|
|
save_best_only
|
训练时,是否只保存最优的模型
|
true
|
|
target_img_size
|
图像转换为网络模型对应的目标图像尺寸,如mobilenet v3,接收图为:[224,224]
|
[224,224]
|
|
learning_rate
|
初始化学习率值
|
0.001
|
|
batch_size
|
训练时,DataLoad一次加载数据的批次数量
|
64
|
|
test_batch_size
|
预测时,DataLoad一次加载数据的批次数量
|
1
|
|
epochs
|
训练总次数
|
100
|
|
optimizer
|
优化器类型,枚举值:Ranger,AdaBelief,SGD,AdamW,Adam
|
SGD
|
|
scheduler
|
学习衰减率调整策略,枚举值:default,step,SGDR,multi
|
default
|
|
loss
|
损失函数,若使用pytorch提供的损失函数,可不管该值。使用框架提供的需配置。枚举值:CE,CE2,Focalloss
|
|
|
early_stop_patient
|
提前结束,当后续训练轮次出现N次,acc小于历史值时,就提前结束
|
7
|
|
model_path
|
模型预测时,训练生成的权重文件存储路径
|
'../../out/mobilenetv3_e22_0.97.pth'
|
|
dropout
|
为了防止过拟合,设置值,表示随机多少比例的神经元失效,取值服务[0,1]
|
0.5
|
|
class_weight
|
训练数据类别分配不均匀,防止过拟合等情况出现,设置的惩罚值。默认值设置为None。
|
调用:n.CroEntropyLoss(),设置不同类别的惩罚值,三个类别,如[0.8,0.1,0.1]。
|
|
weight_decay
|
在与梯度做运算时,当前权重先减去一定比例的大小。
|
0.01
|
1.2.2 项目应用模块设计
Easy-Classification项目应用模块整体结构如下:
