ViT는 1) 일련의 visual token에 이미지를 임베팅하고 2) 누적된 transformer block을 사용하여 그들간의 글로벌 종속성을 모델링하는 최초의 모델이라는 의의가 있다.
반면 CNN은 태생적으로 scale-invariance 및 지역성을 가진 intrinsic IB를 가지기 때문에, 여전히 vision task에서 널리 사용되는 backbone 역할을 한다.
하지만 CNN은 long-range 종속성을 모델링하는데 적합하지 않는다.
DeiT는 training을 줄이고 성능을 향상시키기 위해, knowledge distillation을 통해 teacher model에 해당하는 CNN에서 얻은 IB를 student model의 ViT로 Transfer하는 모델이다. 이를 통해 Transformer의 장점(long-range 종속성 모델링에 적합)과 CNN의 장점(intrinsic IB 풍부함)을 모두 가진다.
하지만 추가적인 traning 비용을 필요로 한다는 한계가 여전히 존재한다.
ViTAE는 CNN의 intrinsic IB를 ViT에 알려주는 것을 목표로 하며, 이를 위해 크게 Reduction Cell(RC)와 Normal Cell(NC)로 구성되어있다.
RC는 input image를 downsampling하고, 풍부한 멀티스케일 컨텍스트를 토큰에 삽입하는데 사용된다.
input image $x\in R^{H\times W\times C}$ 는 3번의 RC를 통과하여 서서히 downsampling(x4, x2, x2)되어 최종적으로는 $R^{(HW/256)\times D}$로 flatten된다.
기술적으로 RC는 두 병렬적인 가지를 가지고 있는데, 각각은 지역성(locality)와 장거리 종속성(long-range dependency)를 모델링하기 위함이다.
i번째 RC의 input feature를 $f_i\in R^{H_i\times W_i\times D_i}$ 라고 하고, RC의 첫번째 input image가 $x$라고 하자.
우선 $f_i$는 Ryramid Reduction Module(RPM)으로 들어가서 멀티스케일 컨텍스트를 추출한다.
$Conv_{ij}$는 i번째 PRM에서 j번째 convolution layer를 의미하며, i번째 RPM에서 각각의 convolution layer에서 추출한 fature map들을 channel 차원으로 모두 concat한 결과가 $f_i^{ms}$ 이다. ⇒ $f_i^{ms}\in R^{(W_i/p)\times(H_i/p)\times(|S_i|D)}$
그리고 나서 MHSA 모듈을 거쳐 장거리 종속성을 모델링한다. 여기서 Img2Seq는 feature map을 1D로 간단히 flatten하는 방법이다. MHSA을 통해 그 결과인 f_i^g는 각각 토큰의 멀티스케일 콘텍스트를 내제할 수 있다.
${PCM}_i(f_i)$는 3개의 conv layer를 거치고 Img2Seq으로 flattening된 vector를 뜻하며, 이를 앞서 구한 $f_i^g$에 더해준다.
Parallel Convolution Module(PCM)을 더해줌으로서, $f_i^g$에 지역(local) 콘텍스트 추가적으로 내제할 수 있다. 결론적으로 RC는 지역성과 스케일 불변성 IB 모두를 가지는 토큰을 만들 수 있다.
그런 다음 융합된 토큰은 FFN에 의해 처리되고 Seq2Img를 통해 token sequence(1차원)이 feature map으로 다시 변환되며, 다음 RC또는 NC의 input으로 들어가게 된다.
의 output에 우선 class token($t_{cls}$)를 붙이고, position encoding을 추가해야 첫번째 NC의 Input token으로 들어올 수 있다.
$t_{cls}$ 토큰은 training 과정에서는 무작위로 초기화된 후 weight를 update하며, inference 과정에서는 weight가 고정된다.
NC는 토큰 시퀀스 안에 지역성과 장기간 종속성을 모두 가지는 모델을 만드는데 사용된다.
NC는 PRM이 없다는 것을 제외하면 RC와 구조가 같다.
RC와 마찬가지로 MHSA와 FFN을 차례로 거친다. 한가지 주목할 점은 class token은 다른 visual token들과 공간적 연결성이 없기 때문에 PCM에서 제거된채로 MHSA와 결합되어 FFN로 전달된다는 사실이다.
NC에서는 PRM이 없기 때문에 토큰의 길이에 변함이 생기지 않으며, ViT와 마찬가지로 최종 Normal Cell의 output으로부터 추출한 class token에 대해 선형 분류 layer를 통과하여 예측 확률을 얻어 최종 분류 결과를 얻게된다.
댓글 영역