🔌添加注意力机制

各种注意力模块优缺点

优点: 可以通过学习自适应的通道权重,使得模型更加关注有用的通道信息。 缺点: SE注意力机制只考虑了通道维度上的注意力,无法捕捉空间维度上的注意力,适用于通道数较多的场景,但对于通道数较少的情况可能不如其他注意力机制。

在模型中添加注意力模块

  1. 设置备选注意力模块,在nets/yolo.py里添加如下代码:

# -----------------------------------------------#
#   备选注意力模块列表
# -----------------------------------------------#
attention_bocks = [se_block, cbam_block, eca_block, CA_Block]
  1. 添加加载模型时所需的选取注意力的参数,在nets/yolo.py里YoloBody类添加phi_attention参数:(例如本例中,phi_attention为0代表不使用注意力机制,1-4代表上述备选列表中的4个模块)

def __init__(self, anchors_mask, num_classes, phi, pretrained=False, phi_attention=0, pruned=1):
  1. 在模型中添加注意力模块(本例所示为在P3层添加注意力模块,类型通过phi_attention参数指定,您可自行根据需要修改):

    • 在__init__函数中进行注意力模块的初始化

      # -----------------------------------------------#
      #     注意力初始化
      # -----------------------------------------------#
      self.phi_attention = phi_attention
      
      if phi_attention >= 1 and phi_attention <= 4:
          self.P3_attention = attention_bocks[phi_attention - 1](128) # 128为通道数
    • 在forward中选择需要插入注意力的位置,例如在特征金字塔P3层后添加注意力机制

      # 80, 80, 256 => 80, 80, 128
      P3 = self.conv3_for_upsample2(P3)
      
      if 1 <= self.phi <= 4:
          P3 = self.P3_attention(P3)
      
      # 80, 80, 128 => 40, 40, 256
      P3_downsample = self.down_sample1(P3)

使用注意力模块改进的模型

  1. 在train.py中模型创建部分修改phi_attention参数,按照上节所述步骤训练模型。

    #------------------------------------------------------#
    #   创建yolo模型
    #------------------------------------------------------#
    model = YoloBody(anchors_mask, num_classes, phi, pretrained=pretrained, phi_attention=0)

  2. 在yolo.py中修改phi_attention参数,加载训练后的权重即可正常使用改进后的模型。

    #---------------------------------------------------#
    #   生成模型
    #---------------------------------------------------#
    def generate(self, onnx=False):
        #---------------------------------------------------#
        #   建立yolo模型,载入yolo模型的权重
        #---------------------------------------------------#
        self.net    = YoloBody(self.anchors_mask, self.num_classes, self.phi, pretrained=False, phi_attention=0)Py

Last updated