Computer Vision/Object Detection, Segmentation

Meta AI의 SAM(Segment Anything Model) 리뷰

Meta AI에서 개발한 SAM(Segment Anything Model)은 자연어처리 분야의 ChatGPT 같은 *Foundation Model을 컴퓨터비전 분야에서 만들어보려고 한 시도이다.

* Foundation Model : 대규모의 광범위한 데이터에 대해 훈련되고 광범위한 다운스트림 작업에 적응할 수 있는 모델. 즉, 하나의 task로 학습시킨 모델이 학습하지도 않은 다야한 분야에 적용될 수 있는 general한 모델을 의미함.

 

SAM 논문 3줄 요약

  • Next token prediction으로 학습한 GPT가 온갖 task를 잘한다.
  • Computer Vision에서도 이런 만능 모델을 만들고 싶어서 새로운 task, model, data를 개발했다.
  • Segmentation은 당연히 잘하고, 다른 테스크들에서도 어느 정도 성능을 보여줬다.

 

Image Segmentation이란?

자율주행차나 의료 AI 등의 분야에서 많이 활용되며 난이도가 매우 높음

 

Intro

  • Task: 어떤 task로 모델을 학습시켜야 GPT처럼 general 한 vision 모델을 만들 수 있을까? 

       → Promptable Segmentation 이라는 새로운 vision task 정의함

  • Data: 이 모델을 학습시키려면 어떤 데이터들이 필요할까?

       → SA-1B (1100만장 이미지에 대한 10억개의 마스크)

  • Model: 이 task를 잘 수행하면서도 general하려면 어떤 모델 구조여야 할까?

       → Image Encoder, Prompt Encoder, Mask Decoder

 

 

Task

GPT는 다음에 올 단어를 예측하는 방식으로 학습되었음(next token prediction). 한번도 학습한적 없는 테스크들을 잘 수행하면서 NLP에 혁신을 가져옴 

 

저자들은 Computer Vision에서는 어떤 단일 task로 모델을 학습시키면 GPT처럼 여러 테스크를 잘하는 모델을 만들 수 있을까 고민. 그 결과 Promptable Segmentation이라는 새로운 테스크를 정의함. 

 

 

 

 

Promptable Segmentation

Promptable Segmentation은 마스크를 생성하고자 하는 대상을 유연하게 prompt로 지정할 수 있는 테스크이다. Prompt로는 single point, multiple point, bounding box, bounding box + test 등이 사용될 수 있다. 이처럼 프롬프트를 유연하게 지정할 수 있게끔 해주는 게 핵심

 

인풋 아웃풋 관점에서 보면 이미지와 프롬프트를 입력받아서 물체에 해당하는 마스크를 출력하는 테스크이다.

 

Data

유연하게 동작하는 general한 AI → 상상을 초월하는 데이터가 필요GPT의 경우, 웹 상의 텍스트들을 크롤링한 Common Crawl을 주요 데이터셋으로 사용. GPT 3 모델을 학습시킨 데이터는 570GB 정도..그런데 segmentation 모델을 학습시키기 위해선 마스크 라벨이 붙어있는 데이터가 필요함. 이는 단순 크롤링으로 해결이 안될 뿐더러 제작에 엄청난 수고가 들어감. 때문에 자체 data engine을 만들어서 전례없는 규모의 데이터셋을 직접 만듦.그 과정에서 수작업과 AI를 적절히 사용했으며, 공개한 SA-1B 데이터셋은 순수 AI가 만든 데이터셋

 

SA-1B 데이터 셋은 3가지 단계로 나눠서 구축. 처음에는 사람이 라벨링하고, AI가 보조하다가 점점 AI가 스스로 라벨링을 한다. 각 단계별 디테일은 다음과 같다.

 

 

1. Assisted Manual 단계

기존에 공개되어 있는 segmentation 데이터 셋으로 먼저 SAM 모델을 학습시켰다고 함. 그리고 새로운 데이터에 대해서 AI가 먼저 segmentation을 해놓으면 사람이 이를 수정했다고 함. 이런 방식으로 430만개의 마스크를 라벨링하였으며, 라벨링 하는 와중에도 데이터 쌓이면 모델을 꾸준히 재학습 시켰다고 함.

 

 

2. Semi-automated 단계

기존의 segmentation 데이터 셋은 배제하고 1단계에서 모은 데이터 셋 만으로 SAM 모델 학습시켰다고 함. 그 다음, AI가 먼저 segmentation을 해놓으면 사람이 빠진 것들만 채워넣음. 이미 segmentation 성능이 뛰어나서 사람이 일일이 수정하지 않고, 마스크가 아예 빠져있을 경우에만 추가해줌. 이런 방식으로 마스크 590만개 추가 라벨링 (도합 1020만개)

 

 

3. Fully-automated 단계

1, 2 단계에서 모은 마스크 1020만개를 가지로 SAM 모델을 학습시킴. 이걸 가지고 이미지 1100만장에 대해 11억개의 마스크 라벨을 생성하도록 계속 inference를 한게 SA-1B 데이터 셋이다.

 

Model

SAM 모델은 Image Encoder, Prompt Encoder, Mask Decoder로 구성되어 있다. 이미지가 있으면 인코더를 통해 embedding을 뽑고, prompt도 embedding을 뽑는다. 그 다음에 image embedding과 prompt embedding을 가지고 mask를 생성하는 mask decoder를 돌려서 mask 3개를 생성함.

Image Encoder

먼저 Image Encoder는 MAE(Masked auto-encoder) 방식으로 학습시킨 Vision transformer를 사용한다. MAE는 이미지를 일정한 크기의 그리드로 나누고 랜덤하게 가린 뒤, 복원하도록 모델을 학습시키는 기법이다. 아래 모델에서 decoder는 제외하고 encoder만 가져왔다고 합다. 아주 강력한 임베딩 모델로 사용할 수 있다고 함.

 

Prompt Encoder

그 다음 Prompt Encoder는 각 prompt 타입에 맞는 인코딩 방식을 적용했다고 한다. 점이나 바운딩 박스는 positional encoding을 사용했다고 하고, text는 CLIP 멀티모달 임베딩을 가져왔다고 한다.

  • Mask: 컨볼루션으로 차원 맞춰준 뒤, image embedding에 pixel wise sum
  • Point & Bounding Box: positional encoding 으로 표현
  • Text: CLIP 모델의 text encoder를 가져와 임베딩

Mask Decoder

이미지 임베딩과 프롬프트 임베딩 간의 cross attention 메커니즘을 적용해준 뒤, 마스크와 마스크를 얼만큼 신뢰할 수 있는지 나타내는 IoU scores를 리턴한다.

 

Ambiguity

promptable segmentation 테스크의 경우 고려해야할 상황이 하나 더 있다. 바로 하나의 prompt에 정답이 여러개가 될 수 있는 ambiguity이다.

 

예를 들어 아래 예시 이미지에서 첫번째 열의 타조 이미지를 보면 머리 부분에 점이 찍혀있다. 이는 전체 타조를 세그맨테이션 해달라는 요청일 수도 있고, 타조의 머리 부분만 세그맨테이션을 해달라는 요청일 수도 있다. 

 

이런 애매한 상황을 해결하기 위해서 SAM 모델은 애초에 하나의 프롬프트에 대해 3개의 마스크를 리턴한다. 그리고 3개 중에 가장 loss 값이 작은 것만 역전파 시키는 방식으로 해결한다.

 

Zeroshot Experiments

이렇게 학습시킨 SAM 모델이 과연 GPT 처럼 한번도 학습한 적 없는 테스크들을 잘 수행하는지 실험을 진행했다. 논문에서는 총 5가지의 테스크에 대해 실험을 했고, 그 결과 상당히 준수한 성능을 보여주었다.

1.  Single point Segmentation

먼저 SAM 모델이 학습한 적 없는 기존 23개의 segmentation 데이터 셋들에 대해서
점 찍었을 때 마스크를 얼마나 잘 생성하는지를 기존 SOTA 모델인 RITM과 비교했다. 주황색과 파란색 막대로 표시된 부분은 각각 SAM과 RITM이 더 뛰어난 성능을 보여준 지표들이고, 점선과 함께 연한 점으로 표시된 지표는 Ambiguity를 고려했을 때 SAM 모델이 얼만큼 더 뛰어난 성능을 보여주었는지 나타낸다.

생성된 마스크의 퀄리티를 사람 눈으로 비교했을 때에도 모든 부문에서 SAM이 더 뛰어났다는걸 확인할 수 있다.

2. Edge Detection

이미지가 주어졌을 때, 테두리를 추출하는 테스크로 기존 edge detection 데이터 셋에 대해서 SOTA 모델들과 비교해보았다.

segmentation 모델을 edge detection에 사용하기 위해서 인퍼런스를 살짝 변형함.

그 결과 기존 edge detection SOTA 모델을 뛰어넘지는 못했지만, 비빌만한 성능올 보여주었다고 함.

3. Object Proposal

object detection 분야에서 많이 연구된 분야로 물체가 있을만한 후보 영역을 찾는 테스크이다.

 

마찬가지로 object proposal을 수행하기 위해서 인퍼런스 과정을 살짝 변형함.

그 결과 작은 물체를 제외하고는 SOTA 시스템과 비슷한 성능을 보였다고 한다.
그리고 rare한 물체는 더 잘 찾는 것으로 보아 더 general한 특징이 있었다고 함.

 

4. Instance Segmentation

Object Detection 결과로 출력된 물체에 대해서 Segmentation을 하는 테스크이다.

정량적인 metric에서는 다소 밀리나, 정성적인 마스크 퀄리티 평가에서는 압도했다. 특히 데이터 셋의 품질이 안좋을수록 SAM 모델이 낮은 평가를 받았다고 한다.

 

5. Text to mask

마지막으로 가장 모험적인 시도인 프롬프트에 텍스트 정보를 포함해 세그멘테이션 하는 것이다. SAM 모델이 text-aware하게 만들기 위해서 training 과정을 살짝 수정하였다. CLIP 모델을 이용해서 별도의 텍스트 라벨 없이도 텍스트를 인식하도록 학습하였다. 

어느 정도 텍스트 기반의 세그멘테이션 구현하였으며, 다른 프롬프트와도 결합 가능해 보였다. 그러나 정량적인 평가가 논문에서 빠졌고, 데모에서도 제외된 것으로 보아 성능이 뛰어나진 않았을 것으로 추측함.

정리

정리해보면 메타의 연구진들은 Computer Vision계의 Foundation Model을 만들자는 문제 의식에서 출발해서 Task, Data, Model을 직접 고안해냈다. 그리고 그 결과로 나온 것이 SAM이다. 하나의 모델이 다양한 테스크를 수행했다기 보단, 많은 테스크를 포괄하는 테스크로 모델을 학습시키고 subtask들로 실험한 느낌이라 범용성 측면에서 아쉽다. 하지만 대규모 데이터셋을 만들어서 공개했고, general 한 모델을 만들려고 시도한 것 자체는 의미가 있다.