상세 컨텐츠

본문 제목

[7주차 / 임종우 / 논문리뷰] ViTPose: Simple Vision Transformer Baselines for Human Pose Estimation

2023 Summer Session/CV Team 2

by imngooh 2023. 8. 19. 15:20

본문

  • 지난 주차 ViTPose 논문 리뷰에 이어 official github의 코드 리뷰를 진행하였음.
  • 공식 github

GitHub - ViTAE-Transformer/ViTPose: The official repo for [NeurIPS'22] "ViTPose: Simple Vision Transformer Baselines for Human Pose Estimation" and [Arxiv'22] "ViTPose+: Vision Transformer Foundation Model for Generic Body Pose Estimation"

1. Backbone - ViT

  • Patch Embedding

      class PatchEmbed(nn.Module):
          """ Image to Patch Embedding
          """
          def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1):
              super().__init__()
              img_size = to_2tuple(img_size)
              patch_size = to_2tuple(patch_size)
              num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (ratio ** 2)
              self.patch_shape = (int(img_size[0] // patch_size[0] * ratio), int(img_size[1] // patch_size[1] * ratio))
              self.origin_patch_shape = (int(img_size[0] // patch_size[0]), int(img_size[1] // patch_size[1]))
              self.img_size = img_size
              self.patch_size = patch_size
              self.num_patches = num_patches
    
              self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=(patch_size[0] // ratio), padding=4 + 2 * (ratio//2-1))
    
          def forward(self, x, **kwargs):
              B, C, H, W = x.shape
              x = self.proj(x)
              Hp, Wp = x.shape[2], x.shape[3]
    
              x = x.flatten(2).transpose(1, 2)
              return x, (Hp, Wp)
  • MLP

      class Mlp(nn.Module):
          def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
              super().__init__()
              out_features = out_features or in_features
              hidden_features = hidden_features or in_features
              self.fc1 = nn.Linear(in_features, hidden_features)
              self.act = act_layer()
              self.fc2 = nn.Linear(hidden_features, out_features)
              self.drop = nn.Dropout(drop)
    
          def forward(self, x):
              x = self.fc1(x)
              x = self.act(x)
              x = self.fc2(x)
              x = self.drop(x)
              return x
  • Attention(multi head)

      class Attention(nn.Module):
          def __init__(
                  self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
                  proj_drop=0., attn_head_dim=None,):
              super().__init__()
              self.num_heads = num_heads
              head_dim = dim // num_heads
              self.dim = dim
    
              if attn_head_dim is not None:
                  head_dim = attn_head_dim
              all_head_dim = head_dim * self.num_heads
    
              self.scale = qk_scale or head_dim ** -0.5
    
              self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias)
    
              self.attn_drop = nn.Dropout(attn_drop)
              self.proj = nn.Linear(all_head_dim, dim)
              self.proj_drop = nn.Dropout(proj_drop)
    
          def forward(self, x):
              B, N, C = x.shape
              qkv = self.qkv(x)
              qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
              q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)
    
              q = q * self.scale
              attn = (q @ k.transpose(-2, -1))
    
              attn = attn.softmax(dim=-1)
              attn = self.attn_drop(attn)
    
              x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
              x = self.proj(x)
              x = self.proj_drop(x)
    
              return x
  • Block (Transformer block)

      class Block(nn.Module):
    
          def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, 
                       drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, 
                       norm_layer=nn.LayerNorm, attn_head_dim=None
                       ):
              super().__init__()
    
              self.norm1 = norm_layer(dim)
              self.attn = Attention(
                  dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
                  attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim
                  )
    
              # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
              self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
              self.norm2 = norm_layer(dim)
              mlp_hidden_dim = int(dim * mlp_ratio)
              self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
    
          def forward(self, x):
              x = x + self.drop_path(self.attn(self.norm1(x)))
              x = x + self.drop_path(self.mlp(self.norm2(x)))
              return x
  • ViT

      class ViT(BaseBackbone):
    
          def __init__(self,
                       img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12,
                       num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
                       drop_path_rate=0., hybrid_backbone=None, norm_layer=None, use_checkpoint=False, 
                       frozen_stages=-1, ratio=1, last_norm=True,
                       patch_padding='pad', freeze_attn=False, freeze_ffn=False,
                       ):
              # Protect mutable default arguments
              super(ViT, self).__init__()
              norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
              self.num_classes = num_classes
              self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
              self.frozen_stages = frozen_stages
              self.use_checkpoint = use_checkpoint
              self.patch_padding = patch_padding
              self.freeze_attn = freeze_attn
              self.freeze_ffn = freeze_ffn
              self.depth = depth
    
              if hybrid_backbone is not None:
                  self.patch_embed = HybridEmbed(
                      hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
              else:
                  self.patch_embed = PatchEmbed(
                      img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ratio=ratio)
              num_patches = self.patch_embed.num_patches
    
              # since the pretraining model has class token
              self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
    
              dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
    
              self.blocks = nn.ModuleList([
                  Block(
                      dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                      drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
                      )
                  for i in range(depth)])
    
              self.last_norm = norm_layer(embed_dim) if last_norm else nn.Identity()
    
              if self.pos_embed is not None:
                  trunc_normal_(self.pos_embed, std=.02)
    
              self._freeze_stages()
    
          def _freeze_stages(self):
              """Freeze parameters."""
              if self.frozen_stages >= 0:
                  self.patch_embed.eval()
                  for param in self.patch_embed.parameters():
                      param.requires_grad = False
    
              for i in range(1, self.frozen_stages + 1):
                  m = self.blocks[i]
                  m.eval()
                  for param in m.parameters():
                      param.requires_grad = False
    
              if self.freeze_attn:
                  for i in range(0, self.depth):
                      m = self.blocks[i]
                      m.attn.eval()
                      m.norm1.eval()
                      for param in m.attn.parameters():
                          param.requires_grad = False
                      for param in m.norm1.parameters():
                          param.requires_grad = False
    
              if self.freeze_ffn:
                  self.pos_embed.requires_grad = False
                  self.patch_embed.eval()
                  for param in self.patch_embed.parameters():
                      param.requires_grad = False
                  for i in range(0, self.depth):
                      m = self.blocks[i]
                      m.mlp.eval()
                      m.norm2.eval()
                      for param in m.mlp.parameters():
                          param.requires_grad = False
                      for param in m.norm2.parameters():
                          param.requires_grad = False
    
          def init_weights(self, pretrained=None):
              """Initialize the weights in backbone.
              Args:
                  pretrained (str, optional): Path to pre-trained weights.
                      Defaults to None.
              """
              super().init_weights(pretrained, patch_padding=self.patch_padding)
    
              if pretrained is None:
                  def _init_weights(m):
                      if isinstance(m, nn.Linear):
                          trunc_normal_(m.weight, std=.02)
                          if isinstance(m, nn.Linear) and m.bias is not None:
                              nn.init.constant_(m.bias, 0)
                      elif isinstance(m, nn.LayerNorm):
                          nn.init.constant_(m.bias, 0)
                          nn.init.constant_(m.weight, 1.0)
    
                  self.apply(_init_weights)
    
          def get_num_layers(self):
              return len(self.blocks)
    
          @torch.jit.ignore
          def no_weight_decay(self):
              return {'pos_embed', 'cls_token'}
    
          def forward_features(self, x):
              B, C, H, W = x.shape
              x, (Hp, Wp) = self.patch_embed(x)
    
              if self.pos_embed is not None:
                  # fit for multiple GPU training
                  # since the first element for pos embed (sin-cos manner) is zero, it will cause no difference
                  x = x + self.pos_embed[:, 1:] + self.pos_embed[:, :1]
    
              for blk in self.blocks:
                  if self.use_checkpoint:
                      x = checkpoint.checkpoint(blk, x)
                  else:
                      x = blk(x)
    
              x = self.last_norm(x)
    
              xp = x.permute(0, 2, 1).reshape(B, -1, Hp, Wp).contiguous()
    
              return xp
    
          def forward(self, x):
              x = self.forward_features(x)
              return x
    
          def train(self, mode=True):
              """Convert the model into training mode."""
              super().train(mode)
              self._freeze_stages()

2. Head

  • code

      # Copyright (c) OpenMMLab. All rights reserved.
      import torch
      import torch.nn as nn
      from mmcv.cnn import (build_conv_layer, build_norm_layer, build_upsample_layer,
                            constant_init, normal_init)
    
      from mmpose.core.evaluation import pose_pck_accuracy
      from mmpose.core.post_processing import flip_back
      from mmpose.models.builder import build_loss
      from mmpose.models.utils.ops import resize
      from ..builder import HEADS
      import torch.nn.functional as F
      from .topdown_heatmap_base_head import TopdownHeatmapBaseHead
    
      @HEADS.register_module()
      class TopdownHeatmapSimpleHead(TopdownHeatmapBaseHead):
          """Top-down heatmap simple head. paper ref: Bin Xiao et al. ``Simple
          Baselines for Human Pose Estimation and Tracking``.
    
          TopdownHeatmapSimpleHead is consisted of (>=0) number of deconv layers
          and a simple conv2d layer.
    
          Args:
              in_channels (int): Number of input channels
              out_channels (int): Number of output channels
              num_deconv_layers (int): Number of deconv layers.
                  num_deconv_layers should >= 0. Note that 0 means
                  no deconv layers.
              num_deconv_filters (list|tuple): Number of filters.
                  If num_deconv_layers > 0, the length of
              num_deconv_kernels (list|tuple): Kernel sizes.
              in_index (int|Sequence[int]): Input feature index. Default: 0
              input_transform (str|None): Transformation type of input features.
                  Options: 'resize_concat', 'multiple_select', None.
                  Default: None.
    
                  - 'resize_concat': Multiple feature maps will be resized to the
                      same size as the first one and then concat together.
                      Usually used in FCN head of HRNet.
                  - 'multiple_select': Multiple feature maps will be bundle into
                      a list and passed into decode head.
                  - None: Only one select feature map is allowed.
              align_corners (bool): align_corners argument of F.interpolate.
                  Default: False.
              loss_keypoint (dict): Config for keypoint loss. Default: None.
          """
    
          def __init__(self,
                       in_channels,
                       out_channels,
                       num_deconv_layers=3,
                       num_deconv_filters=(256, 256, 256),
                       num_deconv_kernels=(4, 4, 4),
                       extra=None,
                       in_index=0,
                       input_transform=None,
                       align_corners=False,
                       loss_keypoint=None,
                       train_cfg=None,
                       test_cfg=None,
                       upsample=0,):
              super().__init__()
    
              self.in_channels = in_channels
              self.loss = build_loss(loss_keypoint)
              self.upsample = upsample
    
              self.train_cfg = {} if train_cfg is None else train_cfg
              self.test_cfg = {} if test_cfg is None else test_cfg
              self.target_type = self.test_cfg.get('target_type', 'GaussianHeatmap')
    
              self._init_inputs(in_channels, in_index, input_transform)
              self.in_index = in_index
              self.align_corners = align_corners
    
              if extra is not None and not isinstance(extra, dict):
                  raise TypeError('extra should be dict or None.')
    
              if num_deconv_layers > 0:
                  self.deconv_layers = self._make_deconv_layer(
                      num_deconv_layers,
                      num_deconv_filters,
                      num_deconv_kernels,
                  )
              elif num_deconv_layers == 0:
                  self.deconv_layers = nn.Identity()
              else:
                  raise ValueError(
                      f'num_deconv_layers ({num_deconv_layers}) should >= 0.')
    
              identity_final_layer = False
              if extra is not None and 'final_conv_kernel' in extra:
                  assert extra['final_conv_kernel'] in [0, 1, 3]
                  if extra['final_conv_kernel'] == 3:
                      padding = 1
                  elif extra['final_conv_kernel'] == 1:
                      padding = 0
                  else:
                      # 0 for Identity mapping.
                      identity_final_layer = True
                  kernel_size = extra['final_conv_kernel']
              else:
                  kernel_size = 1
                  padding = 0
    
              if identity_final_layer:
                  self.final_layer = nn.Identity()
              else:
                  conv_channels = num_deconv_filters[
                      -1] if num_deconv_layers > 0 else self.in_channels
    
                  layers = []
                  if extra is not None:
                      num_conv_layers = extra.get('num_conv_layers', 0)
                      num_conv_kernels = extra.get('num_conv_kernels',
                                                   [1] * num_conv_layers)
    
                      for i in range(num_conv_layers):
                          layers.append(
                              build_conv_layer(
                                  dict(type='Conv2d'),
                                  in_channels=conv_channels,
                                  out_channels=conv_channels,
                                  kernel_size=num_conv_kernels[i],
                                  stride=1,
                                  padding=(num_conv_kernels[i] - 1) // 2))
                          layers.append(
                              build_norm_layer(dict(type='BN'), conv_channels)[1])
                          layers.append(nn.ReLU(inplace=True))
    
                  layers.append(
                      build_conv_layer(
                          cfg=dict(type='Conv2d'),
                          in_channels=conv_channels,
                          out_channels=out_channels,
                          kernel_size=kernel_size,
                          stride=1,
                          padding=padding))
    
                  if len(layers) > 1:
                      self.final_layer = nn.Sequential(*layers)
                  else:
                      self.final_layer = layers[0]
    
          def get_loss(self, output, target, target_weight):
              """Calculate top-down keypoint loss.
    
              Note:
                  - batch_size: N
                  - num_keypoints: K
                  - heatmaps height: H
                  - heatmaps weight: W
    
              Args:
                  output (torch.Tensor[N,K,H,W]): Output heatmaps.
                  target (torch.Tensor[N,K,H,W]): Target heatmaps.
                  target_weight (torch.Tensor[N,K,1]):
                      Weights across different joint types.
              """
    
              losses = dict()
    
              assert not isinstance(self.loss, nn.Sequential)
              assert target.dim() == 4 and target_weight.dim() == 3
              losses['heatmap_loss'] = self.loss(output, target, target_weight)
    
              return losses
    
          def get_accuracy(self, output, target, target_weight):
              """Calculate accuracy for top-down keypoint loss.
    
              Note:
                  - batch_size: N
                  - num_keypoints: K
                  - heatmaps height: H
                  - heatmaps weight: W
    
              Args:
                  output (torch.Tensor[N,K,H,W]): Output heatmaps.
                  target (torch.Tensor[N,K,H,W]): Target heatmaps.
                  target_weight (torch.Tensor[N,K,1]):
                      Weights across different joint types.
              """
    
              accuracy = dict()
    
              if self.target_type == 'GaussianHeatmap':
                  _, avg_acc, _ = pose_pck_accuracy(
                      output.detach().cpu().numpy(),
                      target.detach().cpu().numpy(),
                      target_weight.detach().cpu().numpy().squeeze(-1) > 0)
                  accuracy['acc_pose'] = float(avg_acc)
    
              return accuracy
    
          def forward(self, x):
              """Forward function."""
              x = self._transform_inputs(x)
              x = self.deconv_layers(x)
              x = self.final_layer(x)
              return x
    
          def inference_model(self, x, flip_pairs=None):
              """Inference function.
    
              Returns:
                  output_heatmap (np.ndarray): Output heatmaps.
    
              Args:
                  x (torch.Tensor[N,K,H,W]): Input features.
                  flip_pairs (None | list[tuple]):
                      Pairs of keypoints which are mirrored.
              """
              output = self.forward(x)
    
              if flip_pairs is not None:
                  output_heatmap = flip_back(
                      output.detach().cpu().numpy(),
                      flip_pairs,
                      target_type=self.target_type)
                  # feature is not aligned, shift flipped heatmap for higher accuracy
                  if self.test_cfg.get('shift_heatmap', False):
                      output_heatmap[:, :, :, 1:] = output_heatmap[:, :, :, :-1]
              else:
                  output_heatmap = output.detach().cpu().numpy()
              return output_heatmap
    
          def _init_inputs(self, in_channels, in_index, input_transform):
              """Check and initialize input transforms.
    
              The in_channels, in_index and input_transform must match.
              Specifically, when input_transform is None, only single feature map
              will be selected. So in_channels and in_index must be of type int.
              When input_transform is not None, in_channels and in_index must be
              list or tuple, with the same length.
    
              Args:
                  in_channels (int|Sequence[int]): Input channels.
                  in_index (int|Sequence[int]): Input feature index.
                  input_transform (str|None): Transformation type of input features.
                      Options: 'resize_concat', 'multiple_select', None.
    
                      - 'resize_concat': Multiple feature maps will be resize to the
                          same size as first one and than concat together.
                          Usually used in FCN head of HRNet.
                      - 'multiple_select': Multiple feature maps will be bundle into
                          a list and passed into decode head.
                      - None: Only one select feature map is allowed.
              """
    
              if input_transform is not None:
                  assert input_transform in ['resize_concat', 'multiple_select']
              self.input_transform = input_transform
              self.in_index = in_index
              if input_transform is not None:
                  assert isinstance(in_channels, (list, tuple))
                  assert isinstance(in_index, (list, tuple))
                  assert len(in_channels) == len(in_index)
                  if input_transform == 'resize_concat':
                      self.in_channels = sum(in_channels)
                  else:
                      self.in_channels = in_channels
              else:
                  assert isinstance(in_channels, int)
                  assert isinstance(in_index, int)
                  self.in_channels = in_channels
    
          def _transform_inputs(self, inputs):
              """Transform inputs for decoder.
    
              Args:
                  inputs (list[Tensor] | Tensor): multi-level img features.
    
              Returns:
                  Tensor: The transformed inputs
              """
              if not isinstance(inputs, list):
                  if not isinstance(inputs, list):
                      if self.upsample > 0:
                          inputs = resize(
                              input=F.relu(inputs),
                              scale_factor=self.upsample,
                              mode='bilinear',
                              align_corners=self.align_corners
                              )
                  return inputs
    
              if self.input_transform == 'resize_concat':
                  inputs = [inputs[i] for i in self.in_index]
                  upsampled_inputs = [
                      resize(
                          input=x,
                          size=inputs[0].shape[2:],
                          mode='bilinear',
                          align_corners=self.align_corners) for x in inputs
                  ]
                  inputs = torch.cat(upsampled_inputs, dim=1)
              elif self.input_transform == 'multiple_select':
                  inputs = [inputs[i] for i in self.in_index]
              else:
                  inputs = inputs[self.in_index]
    
              return inputs
    
          def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
              """Make deconv layers."""
              if num_layers != len(num_filters):
                  error_msg = f'num_layers({num_layers}) ' \
                              f'!= length of num_filters({len(num_filters)})'
                  raise ValueError(error_msg)
              if num_layers != len(num_kernels):
                  error_msg = f'num_layers({num_layers}) ' \
                              f'!= length of num_kernels({len(num_kernels)})'
                  raise ValueError(error_msg)
    
              layers = []
              for i in range(num_layers):
                  kernel, padding, output_padding = \
                      self._get_deconv_cfg(num_kernels[i])
    
                  planes = num_filters[i]
                  layers.append(
                      build_upsample_layer(
                          dict(type='deconv'),
                          in_channels=self.in_channels,
                          out_channels=planes,
                          kernel_size=kernel,
                          stride=2,
                          padding=padding,
                          output_padding=output_padding,
                          bias=False))
                  layers.append(nn.BatchNorm2d(planes))
                  layers.append(nn.ReLU(inplace=True))
                  self.in_channels = planes
    
              return nn.Sequential(*layers)
    
          def init_weights(self):
              """Initialize model weights."""
              for _, m in self.deconv_layers.named_modules():
                  if isinstance(m, nn.ConvTranspose2d):
                      normal_init(m, std=0.001)
                  elif isinstance(m, nn.BatchNorm2d):
                      constant_init(m, 1)
              for m in self.final_layer.modules():
                  if isinstance(m, nn.Conv2d):
                      normal_init(m, std=0.001, bias=0)
                  elif isinstance(m, nn.BatchNorm2d):
                      constant_init(m, 1)
  • config for classic head

      keypoint_head=dict(
              type='TopdownHeatmapSimpleHead',
              in_channels=768,
              num_deconv_layers=2,
              num_deconv_filters=(256, 256),
              num_deconv_kernels=(4, 4),
              extra=dict(final_conv_kernel=1, ),
              out_channels=channel_cfg['num_output_channels'],
              loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True))
  • config for simple head

      keypoint_head=dict(
              type='TopdownHeatmapSimpleHead',
              in_channels=768,
              num_deconv_layers=0,
              num_deconv_filters=[],
              num_deconv_kernels=[],
              upsample=4,
              extra=dict(final_conv_kernel=3, ),
              out_channels=channel_cfg['num_output_channels'],
              loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),

3. Detector(total model)

  • Top-down method

  • code

      # Copyright (c) OpenMMLab. All rights reserved.
      import warnings
    
      import mmcv
      import numpy as np
      from mmcv.image import imwrite
      from mmcv.utils.misc import deprecated_api_warning
      from mmcv.visualization.image import imshow
    
      from mmpose.core import imshow_bboxes, imshow_keypoints
      from .. import builder
      from ..builder import POSENETS
      from .base import BasePose
    
      try:
          from mmcv.runner import auto_fp16
      except ImportError:
          warnings.warn('auto_fp16 from mmpose will be deprecated from v0.15.0'
                        'Please install mmcv>=1.1.4')
          from mmpose.core import auto_fp16
    
      @POSENETS.register_module()
      class TopDown(BasePose):
          """Top-down pose detectors.
    
          Args:
              backbone (dict): Backbone modules to extract feature.
              keypoint_head (dict): Keypoint head to process feature.
              train_cfg (dict): Config for training. Default: None.
              test_cfg (dict): Config for testing. Default: None.
              pretrained (str): Path to the pretrained models.
              loss_pose (None): Deprecated arguments. Please use
                  `loss_keypoint` for heads instead.
          """
    
          def __init__(self,
                       backbone,
                       neck=None,
                       keypoint_head=None,
                       train_cfg=None,
                       test_cfg=None,
                       pretrained=None,
                       loss_pose=None):
              super().__init__()
              self.fp16_enabled = False
    
              self.backbone = builder.build_backbone(backbone)
    
              self.train_cfg = train_cfg
              self.test_cfg = test_cfg
    
              if neck is not None:
                  self.neck = builder.build_neck(neck)
    
              if keypoint_head is not None:
                  keypoint_head['train_cfg'] = train_cfg
                  keypoint_head['test_cfg'] = test_cfg
    
                  if 'loss_keypoint' not in keypoint_head and loss_pose is not None:
                      warnings.warn(
                          '`loss_pose` for TopDown is deprecated, '
                          'use `loss_keypoint` for heads instead. See '
                          'https://github.com/open-mmlab/mmpose/pull/382'
                          ' for more information.', DeprecationWarning)
                      keypoint_head['loss_keypoint'] = loss_pose
    
                  self.keypoint_head = builder.build_head(keypoint_head)
    
              self.init_weights(pretrained=pretrained)
    
          @property
          def with_neck(self):
              """Check if has neck."""
              return hasattr(self, 'neck')
    
          @property
          def with_keypoint(self):
              """Check if has keypoint_head."""
              return hasattr(self, 'keypoint_head')
    
          def init_weights(self, pretrained=None):
              """Weight initialization for model."""
              self.backbone.init_weights(pretrained)
              if self.with_neck:
                  self.neck.init_weights()
              if self.with_keypoint:
                  self.keypoint_head.init_weights()
    
          @auto_fp16(apply_to=('img', ))
          def forward(self,
                      img,
                      target=None,
                      target_weight=None,
                      img_metas=None,
                      return_loss=True,
                      return_heatmap=False,
                      **kwargs):
              """Calls either forward_train or forward_test depending on whether
              return_loss=True. Note this setting will change the expected inputs.
              When `return_loss=True`, img and img_meta are single-nested (i.e.
              Tensor and List[dict]), and when `resturn_loss=False`, img and img_meta
              should be double nested (i.e.  List[Tensor], List[List[dict]]), with
              the outer list indicating test time augmentations.
    
              Note:
                  - batch_size: N
                  - num_keypoints: K
                  - num_img_channel: C (Default: 3)
                  - img height: imgH
                  - img width: imgW
                  - heatmaps height: H
                  - heatmaps weight: W
    
              Args:
                  img (torch.Tensor[NxCximgHximgW]): Input images.
                  target (torch.Tensor[NxKxHxW]): Target heatmaps.
                  target_weight (torch.Tensor[NxKx1]): Weights across
                      different joint types.
                  img_metas (list(dict)): Information about data augmentation
                      By default this includes:
    
                      - "image_file: path to the image file
                      - "center": center of the bbox
                      - "scale": scale of the bbox
                      - "rotation": rotation of the bbox
                      - "bbox_score": score of bbox
                  return_loss (bool): Option to `return loss`. `return loss=True`
                      for training, `return loss=False` for validation & test.
                  return_heatmap (bool) : Option to return heatmap.
    
              Returns:
                  dict|tuple: if `return loss` is true, then return losses. \
                      Otherwise, return predicted poses, boxes, image paths \
                      and heatmaps.
              """
              if return_loss:
                  return self.forward_train(img, target, target_weight, img_metas,
                                            **kwargs)
              return self.forward_test(
                  img, img_metas, return_heatmap=return_heatmap, **kwargs)
    
          def forward_train(self, img, target, target_weight, img_metas, **kwargs):
              """Defines the computation performed at every call when training."""
              output = self.backbone(img)
              if self.with_neck:
                  output = self.neck(output)
              if self.with_keypoint:
                  output = self.keypoint_head(output)
    
              # if return loss
              losses = dict()
              if self.with_keypoint:
                  keypoint_losses = self.keypoint_head.get_loss(
                      output, target, target_weight)
                  losses.update(keypoint_losses)
                  keypoint_accuracy = self.keypoint_head.get_accuracy(
                      output, target, target_weight)
                  losses.update(keypoint_accuracy)
    
              return losses
    
          def forward_test(self, img, img_metas, return_heatmap=False, **kwargs):
              """Defines the computation performed at every call when testing."""
              assert img.size(0) == len(img_metas)
              batch_size, _, img_height, img_width = img.shape
              if batch_size > 1:
                  assert 'bbox_id' in img_metas[0]
    
              result = {}
    
              features = self.backbone(img)
              if self.with_neck:
                  features = self.neck(features)
              if self.with_keypoint:
                  output_heatmap = self.keypoint_head.inference_model(
                      features, flip_pairs=None)
    
              if self.test_cfg.get('flip_test', True):
                  img_flipped = img.flip(3)
                  features_flipped = self.backbone(img_flipped)
                  if self.with_neck:
                      features_flipped = self.neck(features_flipped)
                  if self.with_keypoint:
                      output_flipped_heatmap = self.keypoint_head.inference_model(
                          features_flipped, img_metas[0]['flip_pairs'])
                      output_heatmap = (output_heatmap +
                                        output_flipped_heatmap) * 0.5
    
              if self.with_keypoint:
                  keypoint_result = self.keypoint_head.decode(
                      img_metas, output_heatmap, img_size=[img_width, img_height])
                  result.update(keypoint_result)
    
                  if not return_heatmap:
                      output_heatmap = None
    
                  result['output_heatmap'] = output_heatmap
    
              return result
    
          def forward_dummy(self, img):
              """Used for computing network FLOPs.
    
              See ``tools/get_flops.py``.
    
              Args:
                  img (torch.Tensor): Input image.
    
              Returns:
                  Tensor: Output heatmaps.
              """
              output = self.backbone(img)
              if self.with_neck:
                  output = self.neck(output)
              if self.with_keypoint:
                  output = self.keypoint_head(output)
              return output
    
          @deprecated_api_warning({'pose_limb_color': 'pose_link_color'},
                                  cls_name='TopDown')
          def show_result(self,
                          img,
                          result,
                          skeleton=None,
                          kpt_score_thr=0.3,
                          bbox_color='green',
                          pose_kpt_color=None,
                          pose_link_color=None,
                          text_color='white',
                          radius=4,
                          thickness=1,
                          font_scale=0.5,
                          bbox_thickness=1,
                          win_name='',
                          show=False,
                          show_keypoint_weight=False,
                          wait_time=0,
                          out_file=None):
              """Draw `result` over `img`.
    
              Args:
                  img (str or Tensor): The image to be displayed.
                  result (list[dict]): The results to draw over `img`
                      (bbox_result, pose_result).
                  skeleton (list[list]): The connection of keypoints.
                      skeleton is 0-based indexing.
                  kpt_score_thr (float, optional): Minimum score of keypoints
                      to be shown. Default: 0.3.
                  bbox_color (str or tuple or :obj:`Color`): Color of bbox lines.
                  pose_kpt_color (np.array[Nx3]`): Color of N keypoints.
                      If None, do not draw keypoints.
                  pose_link_color (np.array[Mx3]): Color of M links.
                      If None, do not draw links.
                  text_color (str or tuple or :obj:`Color`): Color of texts.
                  radius (int): Radius of circles.
                  thickness (int): Thickness of lines.
                  font_scale (float): Font scales of texts.
                  win_name (str): The window name.
                  show (bool): Whether to show the image. Default: False.
                  show_keypoint_weight (bool): Whether to change the transparency
                      using the predicted confidence scores of keypoints.
                  wait_time (int): Value of waitKey param.
                      Default: 0.
                  out_file (str or None): The filename to write the image.
                      Default: None.
    
              Returns:
                  Tensor: Visualized img, only if not `show` or `out_file`.
              """
              img = mmcv.imread(img)
              img = img.copy()
    
              bbox_result = []
              bbox_labels = []
              pose_result = []
              for res in result:
                  if 'bbox' in res:
                      bbox_result.append(res['bbox'])
                      bbox_labels.append(res.get('label', None))
                  pose_result.append(res['keypoints'])
    
              if bbox_result:
                  bboxes = np.vstack(bbox_result)
                  # draw bounding boxes
                  imshow_bboxes(
                      img,
                      bboxes,
                      labels=bbox_labels,
                      colors=bbox_color,
                      text_color=text_color,
                      thickness=bbox_thickness,
                      font_scale=font_scale,
                      show=False)
    
              if pose_result:
                  imshow_keypoints(img, pose_result, skeleton, kpt_score_thr,
                                   pose_kpt_color, pose_link_color, radius,
                                   thickness)
    
              if show:
                  imshow(img, win_name, wait_time)
    
              if out_file is not None:
                  imwrite(img, out_file)
    
              return img

Config

  • cod

      _base_ = [
          '../../../../_base_/default_runtime.py',
          '../../../../_base_/datasets/coco.py'
      ]
      evaluation = dict(interval=10, metric='mAP', save_best='AP')
    
      optimizer = dict(type='AdamW', lr=5e-4, betas=(0.9, 0.999), weight_decay=0.1,
                       constructor='LayerDecayOptimizerConstructor', 
                       paramwise_cfg=dict(
                                          num_layers=32, 
                                          layer_decay_rate=0.85,
                                          custom_keys={
                                                  'bias': dict(decay_multi=0.),
                                                  'pos_embed': dict(decay_mult=0.),
                                                  'relative_position_bias_table': dict(decay_mult=0.),
                                                  'norm': dict(decay_mult=0.)
                                                  }
                                          )
                      )
    
      optimizer_config = dict(grad_clip=dict(max_norm=1., norm_type=2))
    
      # learning policy
      lr_config = dict(
          policy='step',
          warmup='linear',
          warmup_iters=500,
          warmup_ratio=0.001,
          step=[170, 200])
      total_epochs = 210
      target_type = 'GaussianHeatmap'
      channel_cfg = dict(
          num_output_channels=17,
          dataset_joints=17,
          dataset_channel=[
              [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
          ],
          inference_channel=[
              0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
          ])
    
      # model settings
      model = dict(
          type='TopDown',
          pretrained=None,
          backbone=dict(
              type='ViT',
              img_size=(256, 192),
              patch_size=16,
              embed_dim=1280,
              depth=32,
              num_heads=16,
              ratio=1,
              use_checkpoint=False,
              mlp_ratio=4,
              qkv_bias=True,
              drop_path_rate=0.55,
          ),
          keypoint_head=dict(
              type='TopdownHeatmapSimpleHead',
              in_channels=1280,
              num_deconv_layers=2,
              num_deconv_filters=(256, 256),
              num_deconv_kernels=(4, 4),
              extra=dict(final_conv_kernel=1, ),
              out_channels=channel_cfg['num_output_channels'],
              loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),
          train_cfg=dict(),
          test_cfg=dict(
              flip_test=True,
              post_process='default',
              shift_heatmap=False,
              target_type=target_type,
              modulate_kernel=11,
              use_udp=True))
    
      data_cfg = dict(
          image_size=[192, 256],
          heatmap_size=[48, 64],
          num_output_channels=channel_cfg['num_output_channels'],
          num_joints=channel_cfg['dataset_joints'],
          dataset_channel=channel_cfg['dataset_channel'],
          inference_channel=channel_cfg['inference_channel'],
          soft_nms=False,
          nms_thr=1.0,
          oks_thr=0.9,
          vis_thr=0.2,
          use_gt_bbox=False,
          det_bbox_thr=0.0,
          bbox_file='data/coco/person_detection_results/'
          'COCO_val2017_detections_AP_H_56_person.json',
      )
    
      train_pipeline = [
          dict(type='LoadImageFromFile'),
          dict(type='TopDownRandomFlip', flip_prob=0.5),
          dict(
              type='TopDownHalfBodyTransform',
              num_joints_half_body=8,
              prob_half_body=0.3),
          dict(
              type='TopDownGetRandomScaleRotation', rot_factor=40, scale_factor=0.5),
          dict(type='TopDownAffine', use_udp=True),
          dict(type='ToTensor'),
          dict(
              type='NormalizeTensor',
              mean=[0.485, 0.456, 0.406],
              std=[0.229, 0.224, 0.225]),
          dict(
              type='TopDownGenerateTarget',
              sigma=2,
              encoding='UDP',
              target_type=target_type),
          dict(
              type='Collect',
              keys=['img', 'target', 'target_weight'],
              meta_keys=[
                  'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
                  'rotation', 'bbox_score', 'flip_pairs'
              ]),
      ]
    
      val_pipeline = [
          dict(type='LoadImageFromFile'),
          dict(type='TopDownAffine', use_udp=True),
          dict(type='ToTensor'),
          dict(
              type='NormalizeTensor',
              mean=[0.485, 0.456, 0.406],
              std=[0.229, 0.224, 0.225]),
          dict(
              type='Collect',
              keys=['img'],
              meta_keys=[
                  'image_file', 'center', 'scale', 'rotation', 'bbox_score',
                  'flip_pairs'
              ]),
      ]
    
      test_pipeline = val_pipeline
    
      data_root = 'data/coco'
      data = dict(
          samples_per_gpu=64,
          workers_per_gpu=4,
          val_dataloader=dict(samples_per_gpu=32),
          test_dataloader=dict(samples_per_gpu=32),
          train=dict(
              type='TopDownCocoDataset',
              ann_file=f'{data_root}/annotations/person_keypoints_train2017.json',
              img_prefix=f'{data_root}/train2017/',
              data_cfg=data_cfg,
              pipeline=train_pipeline,
              dataset_info={{_base_.dataset_info}}),
          val=dict(
              type='TopDownCocoDataset',
              ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
              img_prefix=f'{data_root}/val2017/',
              data_cfg=data_cfg,
              pipeline=val_pipeline,
              dataset_info={{_base_.dataset_info}}),
          test=dict(
              type='TopDownCocoDataset',
              ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
              img_prefix=f'{data_root}/val2017/',
              data_cfg=data_cfg,
              pipeline=test_pipeline,
              dataset_info={{_base_.dataset_info}}),
      )

Fine tuning

https://github.com/ViTAE-Transformer/ViTPose/tree/main

  • weight들이 .pth 파일로 올라와있으므로, 이를 이용하여 fine-tuning이 가능함.

관련글 더보기

댓글 영역