背景

最近在做工业场景下结合深度图和RGB图像的实例分割,需要做特征的融合,添加一些注意力模块。之前使用过 detectron2 ,该框架主要使用 Mask R-CNN 进行实例分割,若对网络做修改,添加自己需要的模块不是那么方便,而且其中文社区涉及的内容也比较浅。基于此的 mask2former 的官方实现的 star 很少,issues 中的很多问题都没有得到有效的解决。相比之下,mmlab 的 mmdetection 提供更多的网络模块,更便于修改添加,中文社区更加完善,而且更新的频率也很高,所以果断选择了 mmdetection。

查看CUDA版本

查看 cuda 版本,我的是 CUDA 11.2,一开始安装的确实也是 CUDA 11.2,但是显卡监控里显示是 11.4,不知道什么时候自动更新的,但我依然安装 11.2 进行后续的环境配置:

nvcc --version

虚拟环境创建

conda create -n mmlab python=3.8 -y
conda activate mmlab
# 11.2 版本的不好找,但是可以向下兼容的,所以选择11.1版本的安装
pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html
conda install --channel https://conda.anaconda.org/Zimmf cudatoolkit=11.1

通过下面这个链接查看官方提供的 cu111-torch1.9.0 下的预编译的 mmcv 安装包有哪些版本:
https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.0/index.html

官方文档里写到:

PyTorch 在 1.x.0 和 1.x.1 之间通常是兼容的,故 mmcv-full 只提供 1.x.0 的编译包。如果你的 PyTorch 版本是 1.x.1,你可以放心地安装在 1.x.0 版本编译的 mmcv-full。

这里我选择最高的版本 1.5.0:

pip install mmcv-full==1.5.0 -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.0/index.html

至此 mmcv-full 安装完毕。


接下来安装 mmdetection:

git clone https://github.com/open-mmlab/mmdetection.git
cd mmdetection
pip install -r requirements/build.txt
pip install -v -e .

验证环境

demo/demo.jpg 在 mmdetection 这个项目下,接着从http://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth 下载预训练权重,放置在新建的 checkpoints/ 文件夹下,python 文件中写入如下代码,然后在虚拟环境下执行即可。

from mmdet.apis import init_detector, inference_detector, show_result_pyplot

config_file = 'configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py'
# 从 model zoo 下载 checkpoint 并放在 `checkpoints/` 文件下
# 网址为: http://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth
checkpoint_file = 'checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
device = 'cuda:0'
# 初始化检测器
model = init_detector(config_file, checkpoint_file, device=device)
# 推理演示图像
result = inference_detector(model, 'demo/demo.jpg')
show_result_pyplot(model, 'demo/demo.jpg', result, score_thr=0.3) # 官方指南没有 show 结果图像这行代码

推理结果如下:
在这里插入图片描述

查看显存占用和显卡正在执行的程序也正常:
在这里插入图片描述

参考链接

[1] mmcv 官方安装指南
[2] mmdetection 官方安装指南