본문 바로가기
Paper Review/etc

[9] Supervised Contrastive Learning

by ga.0_0.ga 2023. 11. 23.
728x90
반응형

[Paper] https://arxiv.org/pdf/2004.11362.pdf

 

[Github] https://github.com/HobbitLong/SupContrast

 

GitHub - HobbitLong/SupContrast: PyTorch implementation of "Supervised Contrastive Learning" (and SimCLR incidentally)

PyTorch implementation of "Supervised Contrastive Learning" (and SimCLR incidentally) - GitHub - HobbitLong/SupContrast: PyTorch implementation of "Supervised Contrastive Learning&q...

github.com

 

 

 

1. Introduction  

이제까지 딥러닝 분류모델에서 가장 많이 사용되던 loss함수는 cross-entropy 였습니다. 하지만 cross-entropy는 noisy label에 대한 robustness 가 부족하고 학습 시 margin을 추가할 수 없어 성능 저하를 일으킨다는 문제가 있습니다. 이를 해결하기 위해 다양한 새로운 loss 함수들이 제안되었지만, large-scale dataset에서는 잘 동작하지 않는 점은 여전히 문제로 남아 있습니다. 

이러한 단점을 해결하기 위해 등장한 것이 바로 self-supervised contrastive 입니다.  이 논문에서는 supervised contrastive 를 제안하고 있는데요, 둘의 차이점을 아래 그림을 통해 간략히 설명하겠습니다.

 

먼저 self-supervised contrastive learning 입니다. 이 방법은 label이 없는 large-scale dataset을 잘 학습시키위해 등장한 방법입니다.  label 없이 의미 있는 표현을 학습하도록 합니다. representation learning의 한 종류이죠. 그 과정은 아래와 같습니다.

step 1) 학습 데이터(anchor) 하나를 설정하고,  data augmentation을 진행합니다. 이 데이터들은 positive data가 됩니다.

step 2) 나머지 이미지들을 negative data로 설정하고, 학습을 진행합니다. 

step 3) positive와 negative가 embedding vector space에서 분리되며 학습이 진행됩니다.

하지만 이 과정에는 한가지 문제점이 존재합니다. negative를 설정하는 기준이 "다른 클래스"가 아니라 "다른 이미지"이기 때문에 같은  class의 사진도 negative로 분류되어 버립니다. 그렇기 때문에 pretrain후 fine-tuning이 어렵고 추가학습 진행이 어렵습니다. 이는  label이 없는 self-supervised 방식이기 때문에 발생하는 문제입니다.

이를 해결하기 위해 등장한 것이 본 논문에서 제안하는 supervised contrastive learning 입니다. 정답 label을 알 수 있기 때문에 앞서 발생한 문제들을 해결할 수 있습니다. step 1의 data augmentation은 유지하고, 같은 label들끼리는 유사한 representation을 얻도록 학습하게 됩니다. 즉, positive 들 끼리는 "pull together", negative들 끼리는 "push apart"하게 학습하는 것이죠. 또한,  cross-entropy를 사용하는 경우에는 representation과 decision boundary를 동시에 학습했지만 본 논문에서는 따로 학습하는 방법을 사용했습니다. 이러한 과정들을 논문에서는 SupCon이라고  이름 지었네요. 이제 SupCon의 구체적인 방법에 대해 알아보겠습니다.

 

 

 

2.Method  

2.1 Representation Learning Framework  

- Augmentation Module(Aug(x)) :  데이터를 augmemtation하는 모듈입니다. 데이터의 다양한 패턴을 반영할 수 있습니다.

- Encoder Module(Enc(x)) : CNN을 적용해 feature를 추출합니다.

- Projection head Module(proj(x)) : MLP 로 구성되어 있고, L2 normalization하여 feature 를 추출합니다. 

 

 

2.2 Contrastive Loss Function  

먼저 self-supervised contrastive loss function 의 수식을 살펴보겠습니다. 논문에 “multiviewed batch”라는 단어가 등장하는데, 이는 augmentation된 2N개의 샘플들을 말합니다.

 

2.2.1 Self-Supervised Contrastive Loss

각 기호가 의미하는 바는 아래와 같습니다.

또한, 학습시에는 위 식의 분자는 최대, 분모는 최소가 되도록 합니다.

 

 

2.2.2 Supervised Contrastive Loss

Self-Supervised Contrastive Loss와 비교해서 달라진 점은 파란 네모 박스 부분입니다.  각 기호가 의미하는 바는 아래와 같습니다.

 

 

2.2.3 Connection to Triplet Loss and N-pairs Loss

Supervised contrastive learning은 triplet loss와 밀접한 관련이 있습니다. contrastive loss에서 하나의 negative와 하나의 positive가 사용될 때 triplet loss가 됩니다. 또한 하나 이상의 negative가 사용되면 N-pairs loss와 동일해집니다.

 

 

 

3. Experiments  

다양한 loss 함수들과 비교했을 때 제안하는 loss 함수가 높은 성능을 보이고 있습니다. 

 

 

 

 

 

 

 

 

 

728x90
반응형

댓글