Processing math: 100%

논문 리뷰

[NeurIPS 2020] FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence

Yejin Kim 2024. 4. 18. 20:30

 

 

이 논문은 NeurIPS 2020에 출판되었으며, 저자는 Kihyuk Shon*, David Berthelot*, Chun-Liang Li,  Zizhao Zhang, Nicholas Carlini, Ekin D. Cubuk, Alex Kurakin, Han Zhang, Colin Raffel로 구성되어 있다.

 

Motivation

Deep neural network는 큰 데이터셋을 활용할 수록 더 높은 성능을 달성할 수 있다는 것이 잘 알려져 있지만, data를 labeling하는 것의 cost가 굉장히 비싸기 때문에 큰 labeled 데이터셋을 만들기가 어렵다는 문제가 있다. 따라서 소수의 labeled data와 많은 양의 unlabeled data를 활용하여 학습을 진행하는 semi-supervised learning (SSL) 이 제안되었다.

 

Labeled data는 label이 있기 때문에 모델에게 줄 수 있는 supervision이 비교적 명확한 반면, unlabeled data는 그렇지 않다. 따라서 이 논문에서는 어떻게 unlabeled data에 대해 supervision을 줄 수 있을지를 연구하고, 이를 활용한 새로운 SSL 알고리즘을 제안한다.

 

Background

FixMatch는 기존의 두 가지 기법을 적절히 결합하여 SSL 문제를 해결한다. 사용한 기법 중 첫번째 기법은 pseudo-labeling이다.

Pseudo-labeling 기법

Pseudo-labeling은 다음과 같은 방식으로 모델을 학습시킨다.

  • 우선 model을 labeled data로 학습시킨 뒤, unlabeled data에 대해 label을 예측하도록 한다. 이렇게 나온 prediction을 가공 (e.g. sharpening 등) 하고 이를 pseudo-label로 삼는다. 이렇게 만들어진 unlabeled data에 대한 label과 기존의 labeled data를 활용하여 model을 다시 학습시킨다.

Prediction을 어떻게 가공할 것인가는 하나의 design choice라고 볼 수 있으며, FixMatch에서는 가장 높은 score를 가지는 class에 대해 one-hot label을 만드는 방향으로 가공을 진행한다.

 

두번째 기법은 consistency regularization이다. 

 

Consistency regularization의 철학은 모델이 특정 이미지를 다양한 방식으로 변형한 이미지들에 대해서 모두 비슷한 예측을 해야한다는 것이다. 이를 어떻게 활용했는가에 대해서는 메소드 섹션에서 더 자세히 설명하도록 하겠다.

 

Method

아래 그림은 FixMatch의 overview로, 위에서 설명한 두 기법 (pseudo-labeling, consistency regularization)을 결합하여 unlabeled data를 학습시키는 방법을 보여주고 있다.

Unlabeled image 학습 방법

학습 방법에 대해서는 loss function을 통해 더 자세히 설명하도록 하겠다.

 

 Loss function 1: supervised loss 

우선 labeled data를 통해 학습을 진행하는 supervised loss는 위와 같다.

 

이는 기존의 supervised learning에서 사용하는 loss function과 동일한 형태이다. α()는 weak augmentation을 의미하며, augmented된 이미지에 대한 model의 prediction과 ground truth (GT)인 pb 사이의 cross entropy를 계산하여 해당 sample에 대한 loss를 계산한다. 이렇게 계산한 sample들의 loss를 평균냄으로써 최종 loss를 계산한다.

 

참고로 weak augmentation을 한 image에 대해 supervision을 주는 이유는, 모델이 original image를 계속해서 반복해서 보게되면 해당 sample들을 외워버리는 overfitting 문제가 발생할 수 있기 때문이다. 매 epoch마다 해당 image에 약간의 변형을 줌으로써 overfitting 문제를 완화한다.

 

 Loss function 2: unsupervised loss 

Unlabeled data에 대해 학습을 진행하는 unsupervised loss는 위와 같다. 이러한 형태의 loss는 weakly augmented version과 strongly augmented version에 대한 모델의 prediction이 유사해야 한다는 consistency regularization의 철학을 기반으로 하고 있다.

 

FixMatch에서는 artificial label qb를 얻기 위해, 먼저 unlabeled data ub의 weakly augmented version에 대한 모델의 prediction을 계산한다. 이 과정은 qb=pm(y|α(xb))로 나타낼 수 있다. 그리고 이를 one-label로 가공한 ˆqb=argmax(qb)pseudo-label로 활용한다. 이렇게 만든 pseudo-label을 strongly augmented version에 대한 model의 prediction에 대한 supervision으로 주어 unsupervised loss를 계산한다.

 

이 때, 학습 초반에 낮은 quality의 pseudo-label이 잘못된 supervision을 제공할 수 있다는 문제를 방지하기 위해 confidence thresholding을 진행한다. 여기서 τ는 confidence threshold로, 논문에서는 이 값을 0.95로 활용하고 있다. 이를 통해 모델이 해당 sample의 weakly augmented version에 대해 0.95 이상의 confidence를 갖지 않는 경우, 해당 sample을 학습에서 배제할 수 있도록 한다.

 

FixMatch의 알고리즘은 다음과 같다.

Experiments

실험은 CIFAR-10, CIFAR-100, SVHN, STL-10에 대해 진행했다. CIFAR-10과 SVHN에 대해서는 WRN-28-2를, CIFAR-100에 대해서는 WRN-28-8를, STL-10에 대해서는 WRN-37-2를 backbone 모델로 사용하였다.

 

FixMatch는 대부분의 세팅에서 SOTA 성능을 달성하고 있다. 하지만 ReMixMatch가 CIFAR-100에서는 조금 더 나은 성능을 보이고 있다. 저자들은 이러한 결과가 ReMixMatch의 Distribution Alignment (DA) 기법 때문이라고 추측한다. DA는 unlabeled data에 대한 pseudo-label이 labeled data와 비슷한 분포를 따를 수 있도록 logit을 가공하는 기법으로, 여기에는 unlabeled data가 labeled data와 비슷한 distribution을 가진다는 assumption이 사용되었다. 저자들은 실험의 세팅이 모두 이러한 assumption과 일치하고 있기 때문에 DA가 효과적이었을 것이라 생각하였고, FixMatch에 DA를 추가한 형태로도 실험을 진행하였다. 실제로 FixMatch에 DA를 적용하면 CIFAR-100 (400 labels case) 에 대해 40.14%의 error rate을 보이며 ReMixMatch보다 향상된 성능을 보인다고 한다.

 

그 다음으로 저자들은 Barely Supervised Learning에 대한 실험을 진행하였다. 이는 CIFAR-10에 대한 실험이며, 한 class 당 labeled data가 오직 1장만 있는 극한의 세팅에서 FixMatch 알고리즘이 어떻게 작동하는지를 테스트하였다.

랜덤하게 각 class 당 1개의 sample을 선택하는 형태로, 총 4개의 dataset을 만들고 총 4번의 실험을 진행하였다. 이 때 test accuracy는 최저 48.58%, 최대 85.32% 였으며, 중앙값은 64.28% 이었다고 한다. 저자들은 이렇게 성능이 천차만별인 이유가 10개의 labeled example의 quality 때문일 것이라 가설을 세웠다. 이를 구체적으로 확인하기 위해 "prototypicality"를 기반으로 데이터셋을 다시 구축하였다.

"prototypicality"를 기반으로 구축한 dataset

이러한 ordering은 논문 "Distribution Density, Tails, and Outliers in Machine Learning: Metrics and Applications"에서 CIFAR-10에 대해 ordering한 prototypicality를 기반으로 정하였다. 자세한 내용이 궁금하다면 위의 논문을 참고하는 것을 추천한다. 위의 그림은 prototypicality를 기반으로 나눈 데이터셋의 모습이다. 첫번째 row는 가장 representative한 sample들이 모인 set으로 이를 Dataset 0으로, 가장 마지막인 8번째 row는 가장 outlier에 가까운 sample들이 모인 set으로 이를 Dataset 7로 둔다. 각 set을 labeled set으로 하여 학습을 진행했을 때의 결과는 아래의 그래프를 통해 확인할 수 있다.

 

Most prototypical한 example들을 labeled data로 가진 경우 최대 84%, 중앙값 78% 정도의 성능에 도달했으며, distribution의 중간 즈음 위치한 example들을 활용한 경우 약 65%의 accuracy를, outlier에 대해서는 약 10% 정도로 거의 수렴하지 못하는 모습을 보였다고 한다.

 Ablation Study 

1. Sharpening and Thresholding

왼쪽의 그래프는 confidence threshold τ의 값에 따른 모델의 error rate을 보이고 있다. τ가 0.95 정도일 때 가장 좋은 성능을 보였으며, 이를 통해 높은 성능을 달성하기 위해서는 pseudo-label의 양보다 질이 더 중요하다는 것을 짐작할 수 있다. 또한 sharpening은 pseudo-label을 one-hot으로 만들지 않고 Temperature T를 활용하여 sharpened된 버전으로 만드는 방식으로, one-hot label의 soft version이라고 볼 수 있다. 이 때, T가 0이면 fixmatch의 one-hot label과 같은 pseudo-label이 만들어진다. Sharpening은 confidence threshold를 사용했을 때, 성능에 큰 영향을 미치지 않는다는 것을 실험을 통해 확인할 수 있고 따라서 저자들은 이를 메소드에 활용하지 않았다.

 

2. Augmentation Strategy

저자들은 strong augmentation policy에 대해 RandAugment와 CTAugment를 모두 활용하여 실험을 진행해 보았다. 실험결과 CTAugment에 Cutout을 모두 활용하는 것이 가장 좋은 성능을 보였다고 한다.

 

또한 pseudo-label 생성과 prediction에 대해 weak / strong augmentation의 다양한 조합을 적용하여 실험을 진행했다.

  • Strongly augmented version으로 label guessing을 진행하는 경우: 학습 초반에 발산함
  • No augmented version으로 label guessing을 진행하는 경우: 추정한 unlabeled label에 overfit함
  • Strongly augmentation 대신 weak augmentation으로 model의 prediction을 생성하는 경우: 최대 45%까지 성능이 향상되지만 성능이 안정화되지 않고 점차 12%까지 하락함

Conclusion

  • 간단한 SSL 알고리즘을 통해 많은 데이터셋에 대해 SOTA 성능을 달성함
  • 메소드의 간단함 덕분에 FixMatch의 동작을 보다 면밀히 조사할 수 있음
  • 간단하면서도 높은 성능을 내는 SSL 알고리즘 FixMatch는 label을 구하기 어려운 real-world의 여러 도메인에서 많이 활용될 수 있을 것으로 기대함