CenterNet源码结构解析

本文深入解析了CenterNet目标检测算法的源代码,重点介绍了数据集加载、数据增强和损失函数计算的过程。通过CtdetDataset类的getitem()函数,详细阐述了如何从原始标注生成网络所需的输入数据,包括热力图、宽高、中心点偏移量等。同时,文章探讨了如何理解多继承和面向对象编程在源码中的应用,以及如何在不支持DCNv2的情况下修改代码。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

CenterNet :Objects as Points

网上已经有很多关于CenterNet理论方面的解读,我就不再搬运了,我只是发现几乎大神们都忽略了一个事实从公式到代码实现其实并不总是一件很简单的事情,所以我试着从源码的整体实现框架进行解析。第一次写这类型的文章,如有解释不妥或不清晰的地方还请指出,我来修改。由于官方源码其实是实现了几个不同的任务,本篇以目标检测为例进行解析,其余的代码结构几乎都一致,相当于官方源码给出了一个通用框架,我们平时做项目的时候也可以借鉴。

  1. CenterNet的源码中使用到了DCNv2,所以在运行前请先根据上篇文章编译。如果只想跑通代码,不想纠结DCNv2的编译,后面会提到在哪里修改成正常的卷积。
  2. 官方的代码写的很好,结构清晰,值得学习的地方很多,即便不作为深度学习的模型demo,仅仅当做一个面向对象的程序设计也是教科书般的例子。优点就会伴随着“缺点”,由于大量用到了类的继承,可能对初学者不算太友好,但是了解代码结构后会比较容易阅读,这份源码值得花时间仔细研究,本篇文章以官方代码为例进行解析。
  3. 还有一个第三方的实现,核心代码与官方的一样,但是去掉了大量的继承和重载,重写了数据处理类和训练类,对不熟悉OOP的童鞋比较友好,可以作为重要参考。

源码结构总览

  1. 所有的功能源码都在根目录下的’src‘目录下
  2. 阅读的时候时刻记住OOP思想,由于作者在一个项目中需要实现不同的网络、不同的任务和不同的数据集,所以都是以工厂模式提供的。先写一个基类,定义好接口和共同的功能后再由各个子类来完成各自的任务,最后经一个xxx_factory.py文件对外提供
  3. 项目所有的配置都在src/lib/opts.py文件中,在代码中遇到opt.xxx不知道什么意思的时候可以在这个文件中直接搜索
    在这里插入图片描述

核心模块解读

tools文件夹下的文件基本上都是各个大类的辅助函数,基本上根据文件名和函数名就能推断出作用,就不再挨个说明。主要解读datasets、detectors和trains三个文件夹,这也是整个项目的核心模块。

datasets

dataset_factory.py
dataset_factory = {
  'coco': COCO,
  'pascal': PascalVOC,
  'kitti': KITTI,
  'coco_hp': COCOHP
}

_sample_factory = {
  'exdet': EXDetDataset,
  'ctdet': CTDetDataset,
  'ddd': DddDataset,
  'multi_pose': MultiPoseDataset
}

def get_dataset(dataset, task):
  class Dataset(dataset_factory[dataset], _sample_factory[task]):
    pass
  return Dataset

dataset_factory : 定义了数据集字典,根据配置选择相应的数据集,后面以COCO数据集为例;
_sample_factory :任务字典,目标检测、肢体识别等,配置文件默认为目标检测,即取值为CTDetDataset
get_dataset: 相当于对数据集和任务类做了一个封装
这里的class Dataset(dataset_factory[dataset], _sample_factory[task])是一个python的多继承,即Dataset这个类继承了COCO和CTDetDataset,所以在main.py中可以看到

train_loader = torch.utils.data.DataLoader(
      Dataset(opt, 'train'), 
      batch_size=opt.batch_size, 
      shuffle=True,
      num_workers=opt.num_workers,
      pin_memory=True,
      drop_last=True
  )

这样的写法,其中“Dataset(opt, ‘train’)”其实就是使用了COCO类的构造函数,可以去看src/lib/datasets/dataset/coco.py的构造函数,就是需要opt和split两个参数,split用于区分’tarin’和’val’,顺带说一句,训练时每个epcho中需要的数据迭代器由src/lib/datasets/sample文件夹下的类实现。

dataset/coco.py

这个文件比较简单,主要是一些参数的定义,比如总共多少个类、类别名称、默认的图片大小、数据集的均值和方差等。这里唯一要注意的地方是,如果换成自己的数据集,除了num_classes要改之外,均值和方差也需要根据自己的数据集计算,而不是直接使用默认值!

sample/ctdet.py

核心类CTDetDataset,主要实现了训练时需要的数据迭代器。核心函数:getitem()

  1. 在函数内部出现了self.opt.xxx和self.coco.xxx这种类型的调用,但是仔细看CTDetDataset类,却没有__init__函数,更找不到这两个变量的定义。别忘了,在dataset_factory.py中Dataset类时继承了COCO和CTDetDataset两个类的,而且train_loader的定义中用的是Dataset类,所以在实际使用中,即epcho中这个__getitem__()是由Dataset调用的,所以这里的self.opt和self.coco在coco.py中定义!
  2. 函数内部可以分成三个部分加载数据、数据增强、生成gt
  • 加载数据
		img_id = self.images[index]
        file_name = self.coco.loadImgs(ids=[img_id])[0]['file_name']
        img_path = os.path.join(self.img_dir, file_name)
        ann_ids = self.coco.getAnnIds(imgIds=[img_id])
        anns = self.coco.loadAnns(ids=ann_ids)
        num_objs = min(len(anns), self.max_objs)

        img = cv2.imread(img_path)
  • 数据增强
height, width = img.shape[0], img.shape[1]
c = np.array([img.shape[1] / 2., img.shape[0] / 2.], dtype=np.float32)
if self.opt.keep_res:
    input_h = (height | self.opt.pad) + 1
    input_w = (width | self.opt.pad) + 1
    s = np.array([input_w, input_h], dtype=np.float32)
else:
    s = max(img.shape[0], img.shape[1]) * 1.0
    input_h, input_w = self.opt.input_h, self.opt.input_w

flipped = False
if self.split == 'train':
    if not self.opt.not_rand_crop:
        s = s * np.random.choice(np.arange(0.6, 1.4, 0.1))
        w_border = self._get_border(128, img.shape[1])
        h_border = self._get_border(128, img.shape[0])
        c[0] = np.random.randint(low=w_border, high=img.shape[1] - w_border)
        c[1] = np.random.randint(low=h_border, high=img.shape[0] - h_border)
    else:
        sf = self.opt.scale
        cf = self.opt.shift
        c[0] += s * np.clip(np.random.randn() * cf, -2 * cf, 2 * cf)
        c[1] += s * np.clip(np.random.randn() * cf, -2 * cf, 2 * cf)
        s = s * np.clip(np.random.randn() * sf + 1, 1 - sf, 1 + sf)

    if np.random.random() < self.opt.flip:
        flipped = True
        img = img[:, ::-1, :]
        c[0] = width - c[0] - 1

trans_input = get_affine_transform( c, s, 0, [input_w, input_h])
inp = cv2.warpAffine(img, trans_input,
                     (input_w, input_h),
                     flags=cv2.INTER_LINEAR)
inp = (inp.astype(np.float32) / 255.)
if self.split == 'train' and not self.opt.no_color_aug:
    color_aug(self._data_rng, inp, self._eig_val, self._eig_vec)
inp = (inp - self.mean) / self.std
inp = inp.transpose(2, 0, 1)
	解释一下,作者这里用了很多的缩写,其实如果写成全拼会更好理解。
	inp: input,就是网络的输入图像了,也是做过数据增加的图像
	c:center,图像的中心坐标
	s:scale,随机缩放比例
	如果对imgaug包熟悉的话,这部分可以用imgaug提供的功能做替换,当然后续关于bbox的变换也要做相应改变。
  • 生成gt
output_h = input_h // self.opt.down_ratio
output_w = input_w // self.opt.down_ratio
num_classes = self.num_classes
trans_output = get_affine_transform(c, s, 0, [output_w, output_h])

hm = np.zeros((num_classes, output_h, output_w), dtype=np.float32)
wh = np.zeros((self.max_objs, 2), dtype=np.float32)
dense_wh = np.zeros((2, output_h, output_w), dtype=np.float32)
reg = np.zeros((self.max_objs, 2), dtype=np.float32)
ind = np.zeros((self.max_objs), dtype=np.int64)
reg_mask = np.zeros((self.max_objs), dtype=np.uint8)
cat_spec_wh = np.zeros((self.max_objs, num_classes * 2), dtype=np.float32)
cat_spec_mask = np.zeros((self.max_objs, num_classes * 2), dtype=np.uint8)

draw_gaussian = draw_msra_gaussian if self.opt.mse_loss else \
    draw_umich_gaussian

gt_det = []
for k in range(num_objs):
    ann = anns[k]
    bbox = self._coco_box_to_bbox(ann['bbox'])
    cls_id = int(self.cat_ids[ann['category_id']])
    if flipped:
        bbox[[0, 2]] = width - bbox[[2, 0]] - 1
    bbox[:2] = affine_transform(bbox[:2], trans_output)
    bbox[2:] = affine_transform(bbox[2:], trans_output)
    bbox[[0, 2]] = np.clip(bbox[[0, 2]], 0, output_w - 1)
    bbox[[1, 3]] = np.clip(bbox[[1, 3]], 0, output_h - 1)
    h, w = bbox[3] - bbox[1], bbox[2] - bbox[0]
    if h > 0 and w > 0:
        radius = gaussian_radius((math.ceil(h), math.ceil(w)))
        radius = max(0, int(radius))
        radius = self.opt.hm_gauss if self.opt.mse_loss else radius
        ct = np.array(
            [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2], dtype=np.float32)
        ct_int = ct.astype(np.int32)
        draw_gaussian(hm[cls_id], ct_int, radius)
        wh[k] = 1. * w, 1. * h
        ind[k] = ct_int[1] * output_w + ct_int[0]
        reg[k] = ct - ct_int
        reg_mask[k] = 1
        cat_spec_wh[k, cls_id * 2: cls_id * 2 + 2] = wh[k]
        cat_spec_mask[k, cls_id * 2: cls_id * 2 + 2] = 1
        if self.opt.dense_wh:
            draw_dense_reg(dense_wh, hm.max(axis=0), ct_int, wh[k], radius)
        gt_det.append([ct[0] - w / 2, ct[1] - h / 2,
                       ct[0] + w / 2, ct[1] + h / 2, 1, cls_id])

ret = {'input': inp, 'hm': hm, 'reg_mask': reg_mask, 'ind': ind, 'wh': wh}
  • hm:理论中需要热力图,形状如下图
    在这里插入图片描述
    然后在后续的for循环中,对图片上的每个obj分别生成对应的热图:
for k in range(num_objs):
	....
	if h > 0 and w > 0:
		draw_gaussian(hm[cls_id], ct_int, radius)
  • wh:网络最后需要回归出的目标的宽和高,形状:[128, 2]。赋值的时候需要注意是在标注好的bbox做完相应的数据增加相关操作后进行的计算。
# 先对bbox做对应的数据增强变换
bbox[:2] = affine_transform(bbox[:2], trans_output)
bbox[2:] = affine_transform(bbox[2:], trans_output)
bbox[[0, 2]] = np.clip(bbox[[0, 2]], 0, output_w - 1)
bbox[[1, 3]] = np.clip(bbox[[1, 3]], 0, output_h - 1)
# 做完后再计算对应的h,w
h, w = bbox[3] - bbox[1], bbox[2] - bbox[0]
if h > 0 and w > 0:
    radius = gaussian_radius((math.ceil(h), math.ceil(w)))
    radius = max(0, int(radius))
    radius = self.opt.hm_gauss if self.opt.mse_loss else radius
    ct = np.array([(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2], dtype=np.float32)
    ct_int = ct.astype(np.int32)
    draw_gaussian(hm[cls_id], ct_int, radius)
    wh[k] = 1. * w, 1. * h
  • reg:网络需要回归出的中心点偏移量,即在理论部分看到的由于降采样和取整后造成的偏移量。
 if h > 0 and w > 0:
     ....
     ct = np.array([(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2], dtype=np.float32)
     ct_int = ct.astype(np.int32)
     ......
     reg[k] = ct - ct_int
  • ind:一个索引数组,根据它的赋值能看出实际用处
ind[k] = ct_int[1] * output_w + ct_int[0]

是不是就是初学编程时老师教的二维数组转一维数组时的下标计算。至于为什么要这样写,可以从trains/ctdet.py中一步步的找到答案。找到CtdetLoss,这个类后面解析trains时还会详细说,这里只打开这个文件,在__init__函数中有wh和reg的损失计算函数定义,在计算wh的损失时有个NormRegL1Loss()类,跟进去可以看到forward()函数中有个_transpose_and_gather_feat()函数,这个函数中用到了这里生成的inds

def _gather_feat(feat, ind, mask=None):
    dim = feat.size(2)
    ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
    feat = feat.gather(1, ind)
    if mask is not None:
        mask = mask.unsqueeze(2).expand_as(feat)
        feat = feat[mask]
        feat = feat.view(-1, dim)
    return feat


def _transpose_and_gather_feat(feat, ind):
    feat = feat.permute(0, 2, 3, 1).contiguous()
    feat = feat.view(feat.size(0), -1, feat.size(3))
    feat = _gather_feat(feat, ind)
    return feat

在_transpose_and_gather_feat()函数第一行将tensor的形状变成了(batch_size, w, h, c),然后reshape成(batch_size, wh, c)然后传给_gather_feat()函数。在这一步将二维的feature map变成了wh的行向量,然后再在_gather_feat()函数中使用tensor.gather()函数在dim=1的维度上进行聚合。即在w*h这个维度上选取ind指定的数据。如果这里理解有困难,建议看一下pytorch中gather函数解析

  • 到这里加载数据集部分的核心代码基本上就完了,sample文件夹下的文件内容基本上都和ctdet.py差不多,应该都可以看明白了。

trains

train_factory.py

同样,只是简单的提供一个可以根据配置文件选择的字典,依然使用CtdetTrainer来举例。

base_trainer.py

这个就是之前提到过的那个基类了。BaseTrainer()核心基类,主要功能如下:

  1. 实现了每个训练器都需要用到了run_epoch(),在这个函数内部实现了网络的迭代训练
  2. 定义了各个子类需要实现的接口,特别是self._get_losses()这个函数,这个是核心函数,而且可以看到具体的loss计算是由ModelWithLoss()这个类定义并完成的。但是,在基类中定位到_get_losses()时,发现仅仅返回了一个NotImplementedError,如果学过C++这里就是一个虚函数,让子类必须自己来实现功能
  3. ModelWithLoss类的forward函数中的loss计算也仅仅是个定义,具体计算过程在子类中实现
  4. ok,这个基类需要了解的就这么多。接着看目标检测训练器的实现
trains/ctdet.py

目标检测网络训练的具体实现。文件中最核心的是损失函数的计算类CtdetLoss(),另一个类CtdetTrainer()继承自BaseTrainer,主要是基类中几个函数模板的功能实现。顺带说一句,如果我们自己写训练代码,也可以完全照搬他这套框架,只需要在train.py中用一个类继承BaseTrainer,再写一个类实现自己的loss函数。loss的具体实现过程,基本和理论一一对应,就不在赘述了。如果有不明白的地方,欢迎一起学习。

detectors

如果您看到这里,并且已经理解上面的内容的话,这个文件夹下的文件基本上也都能自己看懂了。

  1. 老套路,detector_factory.py提供各种检测器的集合,根据配置文件来选择具体运行的时候是哪一个。
  2. base_detector.py各个检测器的基类,主要提供了一个公共函数的实现,比如pre_process()图像输入网络前的预处理函数,对图像做一系列变换比如缩放,减均值除方差操作,最后输出符合网络输入的数据。
  3. 在detectors/ctdet.py中实现具体的检测过程。CtdetDetector()继承自BaseDetector,在process中做预测预测输出结果。到这里应该基本上不存在代码逻辑上的理解困难了。

关于dcnv2

在src/lib/models下有个model.py文件,里面定义了_model_factory字典,根据名称可以猜到get_dlav0是没有用到dcnv的,而名字后面带_dcn的都是用到dcnv2的网络。假如自己没编译通过,又想用这些网络,其实也简单。比如get_dla_dcn这个函数中的网络DLASeg,直接在pose_dla_dcn.py这个文件中搜索“DCN",可以看到只在DeformConv这个类中用到了,把DCN改成正常的卷积函数nn.Conv2d()。

class DeformConv(nn.Module):
    def __init__(self, chi, cho):
        super(DeformConv, self).__init__()
        self.actf = nn.Sequential(
            nn.BatchNorm2d(cho, momentum=BN_MOMENTUM),
            nn.ReLU(inplace=True)
        )
        self.conv = DCN(chi, cho, kernel_size=(3, 3), stride=1, padding=1, dilation=1, deformable_groups=1)

    def forward(self, x):
        x = self.conv(x)
        x = self.actf(x)
        return x
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值