基于MMDetection工具箱使用Faster RCNN模型对PWISeg数据集进行训练与验证的流程记录

最近在进行机器学习的入门学习,被李沐的课程搞到头大,故想先找一些项目进行实操,经过了解和友人点拨,选定了PWISeg训练集作为入门。

0. 基本情况介绍

0.1 数据集、工具箱、模型介绍

PWISeg是Point-based Weakly-supervised Instance Segmentation for Surgical Instruments的简称,由Zhen Sun, Huan Xu, Jinlin Wu, Zhen Chen, Zhen Lei, Hongbin Liu发表。

论文地址:arXiv

数据集下载:Google Drive

MMDetection是一个目标检测工具箱,包含了丰富的目标检测、实例分割、全景分割算法以及相关的组件和模块。

项目网站:MMDetection

Faster R-CNN是Region-based Convolutional Neural Network的简称,是一种用于目标检测的深度学习模型。它在之前的 R-CNN 和 Fast R-CNN 方法的基础上进行了改进,将特征提取、候选区域生成、边界框回归和分类整合到一个网络中。

0.2 环境介绍

本次训练使用的操作系统为Ubuntu 22.04.3 LTS on Windows 10 x86_64,使用miniconda3进行虚拟环境管理。

配置Python运行环境对我来说属实是一件很折磨的事情,长话短说,下面是本次使用的各库版本(主要):

1
2
3
4
5
6
Python == 3.11.9
mmcv == 2.1.0
mmdet == 3.3.0
mmengine == 0.10.4
setuptools == 60.2.0
pytorch=2.4.0=py3.11_cuda12.4_cudnn9.1.0_0

如果你的CUDA版本>=12.5,推荐不要按照官方文档的内容安装Pytorch,由于版本差异,会导致不能使用GPU进行训练。

同时一定注意mmcv的版本问题,若安装版本号为2.2.0的版本,mmengine会提示版本号过高,而2.0.0版本又会与其他组件发生兼容性问题。经反复尝试(废了我好几个环境)后,2.1.0版本是可以正常运行的。

其余各类依赖和库版本使用pip自动处理即可解决。

1. 训练开始

1.1 环境检查

在配置环境完成后,我们首先需要确认Pytorch是否可以使用CUDA,在Python命令行中使用:

1
2
3
4
5
6
Python 3.11.9 (main, Apr 19 2024, 16:48:06) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.cuda.is_available()
True
>>>

出现True字样则为正常,否则请自行搜索解决Pytorch无法使用CUDA的解决方案。

1.2 配置文件编写

本配置文件参考:
在标准数据集上训练预定义的模型-准备配置文件

以及

学习配置文件

下面是配置文件内容:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
_base_ = '/home/zhacai/mmdetection/configs/faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py'

data_root = '/home/zhacai/pwiseg_dataset'

metainfo = {
'classes': ('bending_shear','rongeur_forceps_1','wire_grabbing_pliers','circular_spoon','tweezers','artery_forceps','scalpel','stripping','aspirator','rongeur_forceps_2','core_needle','fine_needle',),
'palette': [
(255, 0, 0), # 红色
(255, 165, 0), # 橙色
(255, 255, 0), # 黄色
(0, 255, 0), # 绿色
(0, 255, 255), # 青色
(0, 0, 255), # 蓝色
(139, 0, 255), # 紫色(偏蓝)
(255, 0, 255), # 洋红色
(255, 128, 0), # 深橙色
(128, 128, 0), # 深黄色(橄榄绿)
(0, 128, 0), # 深绿色
(0, 0, 128), # 深蓝色
]
}

train_dataloader = dict(
batch_size = 4,
dataset=dict(
data_root = data_root,
ann_file='train_dataset/train.json',
metainfo = metainfo,
data_prefix=dict(img='train_dataset/images/')
)
)

val_dataloader = dict(
dataset=dict(
data_root = data_root,
ann_file = 'val_dataset/val.json',
metainfo = metainfo,
data_prefix = dict(img='val_dataset/images/')
)
)

val_evaluator = dict(ann_file='/home/zhacai/pwiseg_dataset/val_dataset/val.json')
test_evaluator = dict(ann_file='/home/zhacai/pwiseg_dataset/test_dataset/test.json')
test_dataloader = dict(
dataset=dict(
data_root = data_root,
ann_file = 'test_dataset/test.json',
metainfo = metainfo,
data_prefix = dict(img='test_dataset/images')
)
)


其中,请将标记文件(annotations)和图片目录换成自己的PWISeg数据集路径,同时调整batch_size以适应实际情况,metainfo内palette是输出图片中各框选内容的颜色,数量对应classes中类别数量。

1.3 开始训练

一切准备就绪后,使用该命令运行训练:
python tools/train.py work_dirs/faster-rcnn_r50_fpn_2x_coco/config/pwiseg_config.py

请在mmdetection目录下运行该命令,同时将pwiseg_config.py文件路径替换为你的config文件。

由于使用了模版 faster-rcnn_r50_fpn_1x_coco.py,故默认只有12个Epoch,速度较快。

2. 训练后验证

训练后,会在mmdetection目录下的work_dir目录中的以config文件名命名的文件夹中找到pth后缀的模型文件。

使用 python tools/test.py work_dirs/faster-rcnn_r50_fpn_2x_coco/config/pwiseeg_config.py work_dirs/pwiseeg_config/epoch_12.pth --show命令以可视化的形式查看训练结果,也可以取消 --show参数直接进行验证,速度更快。其中模型文件路径和config文件路径请替换为自己的实际路径。

3. 结语

由于是机器学习初学,大部分原理尚在学习之中,说不出所以然,本文章仅供参考。

本文尚缺乏对于部分命令的精细解释,这部分均可以在MMDetection的官方文档内找到。