[9] Supervised Contrastive Learning

2023. 11. 23. 21:51·Paper Review/etc
728x90
반응형

[Paper] https://arxiv.org/pdf/2004.11362.pdf

 

[Github] https://github.com/HobbitLong/SupContrast

 

GitHub - HobbitLong/SupContrast: PyTorch implementation of "Supervised Contrastive Learning" (and SimCLR incidentally)

PyTorch implementation of "Supervised Contrastive Learning" (and SimCLR incidentally) - GitHub - HobbitLong/SupContrast: PyTorch implementation of "Supervised Contrastive Learning&q...

github.com

 

 

 

1. Introduction  

이제까지 딥러닝 분류모델에서 가장 많이 사용되던 loss함수는 cross-entropy 였습니다. 하지만 cross-entropy는 noisy label에 대한 robustness 가 부족하고 학습 시 margin을 추가할 수 없어 성능 저하를 일으킨다는 문제가 있습니다. 이를 해결하기 위해 다양한 새로운 loss 함수들이 제안되었지만, large-scale dataset에서는 잘 동작하지 않는 점은 여전히 문제로 남아 있습니다. 

이러한 단점을 해결하기 위해 등장한 것이 바로 self-supervised contrastive 입니다.  이 논문에서는 supervised contrastive 를 제안하고 있는데요, 둘의 차이점을 아래 그림을 통해 간략히 설명하겠습니다.

 

먼저 self-supervised contrastive learning 입니다. 이 방법은 label이 없는 large-scale dataset을 잘 학습시키위해 등장한 방법입니다.  label 없이 의미 있는 표현을 학습하도록 합니다. representation learning의 한 종류이죠. 그 과정은 아래와 같습니다.

step 1) 학습 데이터(anchor) 하나를 설정하고,  data augmentation을 진행합니다. 이 데이터들은 positive data가 됩니다.

step 2) 나머지 이미지들을 negative data로 설정하고, 학습을 진행합니다. 

step 3) positive와 negative가 embedding vector space에서 분리되며 학습이 진행됩니다.

하지만 이 과정에는 한가지 문제점이 존재합니다. negative를 설정하는 기준이 "다른 클래스"가 아니라 "다른 이미지"이기 때문에 같은  class의 사진도 negative로 분류되어 버립니다. 그렇기 때문에 pretrain후 fine-tuning이 어렵고 추가학습 진행이 어렵습니다. 이는  label이 없는 self-supervised 방식이기 때문에 발생하는 문제입니다.

이를 해결하기 위해 등장한 것이 본 논문에서 제안하는 supervised contrastive learning 입니다. 정답 label을 알 수 있기 때문에 앞서 발생한 문제들을 해결할 수 있습니다. step 1의 data augmentation은 유지하고, 같은 label들끼리는 유사한 representation을 얻도록 학습하게 됩니다. 즉, positive 들 끼리는 "pull together", negative들 끼리는 "push apart"하게 학습하는 것이죠. 또한,  cross-entropy를 사용하는 경우에는 representation과 decision boundary를 동시에 학습했지만 본 논문에서는 따로 학습하는 방법을 사용했습니다. 이러한 과정들을 논문에서는 SupCon이라고  이름 지었네요. 이제 SupCon의 구체적인 방법에 대해 알아보겠습니다.

 

 

 

2.Method  

2.1 Representation Learning Framework  

- Augmentation Module(Aug(x)) :  데이터를 augmemtation하는 모듈입니다. 데이터의 다양한 패턴을 반영할 수 있습니다.

- Encoder Module(Enc(x)) : CNN을 적용해 feature를 추출합니다.

- Projection head Module(proj(x)) : MLP 로 구성되어 있고, L2 normalization하여 feature 를 추출합니다. 

 

 

2.2 Contrastive Loss Function  

먼저 self-supervised contrastive loss function 의 수식을 살펴보겠습니다. 논문에 “multiviewed batch”라는 단어가 등장하는데, 이는 augmentation된 2N개의 샘플들을 말합니다.

 

2.2.1 Self-Supervised Contrastive Loss

각 기호가 의미하는 바는 아래와 같습니다.

또한, 학습시에는 위 식의 분자는 최대, 분모는 최소가 되도록 합니다.

 

 

2.2.2 Supervised Contrastive Loss

Self-Supervised Contrastive Loss와 비교해서 달라진 점은 파란 네모 박스 부분입니다.  각 기호가 의미하는 바는 아래와 같습니다.

 

 

2.2.3 Connection to Triplet Loss and N-pairs Loss

Supervised contrastive learning은 triplet loss와 밀접한 관련이 있습니다. contrastive loss에서 하나의 negative와 하나의 positive가 사용될 때 triplet loss가 됩니다. 또한 하나 이상의 negative가 사용되면 N-pairs loss와 동일해집니다.

 

 

 

3. Experiments  

다양한 loss 함수들과 비교했을 때 제안하는 loss 함수가 높은 성능을 보이고 있습니다. 

 

 

 

 

 

 

 

 

 

728x90
반응형
저작자표시 (새창열림)

'Paper Review > etc' 카테고리의 다른 글

[11] Visual Instruction Tuning (LLaVA: Large Language and Vision Assistant)  (0) 2025.05.11
[10] CrossViT: Cross-Attention Multi-Scale Vision Transformer for ImageClassification  (2) 2023.12.26
[8] MOBILEVIT: LIGHT-WEIGHT, GENERAL-PURPOSE,AND MOBILE-FRIENDLY VISION TRANSFORMER  (2) 2023.10.15
[7] An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale  (0) 2023.09.12
[6] MobileOne: An Improved One millisecond Mobile Backbone  (0) 2023.08.06
'Paper Review/etc' 카테고리의 다른 글
  • [11] Visual Instruction Tuning (LLaVA: Large Language and Vision Assistant)
  • [10] CrossViT: Cross-Attention Multi-Scale Vision Transformer for ImageClassification
  • [8] MOBILEVIT: LIGHT-WEIGHT, GENERAL-PURPOSE,AND MOBILE-FRIENDLY VISION TRANSFORMER
  • [7] An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
ga.0_0.ga
ga.0_0.ga
    반응형
    250x250
  • ga.0_0.ga
    ##뚝딱뚝딱 딥러닝##
    ga.0_0.ga
  • 전체
    오늘
    어제
    • 분류 전체보기 (181)
      • Paper Review (51)
        • Video Scene Graph Generation (6)
        • Image Scene Graph Generation (18)
        • Graph Model (5)
        • Key Information Extraction (4)
        • Fake Detection (2)
        • Text to Image (1)
        • Diffusion Personalization (4)
        • etc (11)
      • AI Research (49)
        • Deep Learning (30)
        • Artificial Intelligence (15)
        • Data Analysis (4)
      • Pytorch (10)
      • ONNX (5)
      • OpenCV (2)
      • Error Note (34)
      • Linux (2)
      • Docker (3)
      • Etc (7)
      • My Study (16)
        • Algorithm (10)
        • Project (4)
        • Interview (2)
  • 블로그 메뉴

    • 홈
    • 태그
    • 방명록
  • 링크

  • 공지사항

  • 인기 글

  • 태그

    나이브 베이즈 분류
    JNI
    forch.nn.functional
    활성화 함수
    3dinput
    i3d
    linear regression
    pytorch
    TypeError
    torch.nn
    HRNet
    dataset
    오차 역전파
    contiguous
    pandas
    그래프신경망
    fine tuning
    차원의 저주
    transformer
    알고리즘
    ONNX
    정규화
    GCN
    Inductive bias
    Logistic regression
    Activation Function
    permute
    tensorflow
    dataloader
    RuntimeError
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.0
ga.0_0.ga
[9] Supervised Contrastive Learning
상단으로

티스토리툴바