논문 리뷰

[NeurIPS 2021] ABC: Auxiliary Balanced Classifier for Class-imbalanced Semi-supervised Learning

Yejin Kim 2024. 3. 31. 21:51

 

 

이 논문은 NeurIPS 2021에 출판되었으며, 저자는 Hyuck Lee, Seungjae Shin, Heeyoung Kim로 구성되어 있다.

 

Motivation


대부분의 Semi-supervised Learning (SSL) 알고리즘은 데이터 불균등을 가정하고 있는 경우가 많다. 하지만 실제로 데이터의 분포는 불균등한 경우가 많아 이러한 기술들을 적용하기가 어렵다. 불균등한 데이터셋에서 학습한 classifier는 majority class에 편향된 예측을 한다는 것이 잘 알려져 있는데, 학습한 labeled data를 기반으로 unlabeled data를 학습하는 SSL 세팅에서는 '확증 편향' 현상이 발생할 수 있기 때문이다. 이 논문은 이러한 편향 문제를 해결하여 LTSSL 문제를 풀고자 한다.

Method

ABC의 구조는 굉장히 간단하다. 기존의 DNN 기반 SSL 알고리즘을 backbone으로 하고 single layer의 auxiliary balanced classifier (ABC)을 추가한 형태이다. 즉, re-balanced branch와 biased-branch가 dual-branch를 이루고 있는 형태로 이해할 수 있다.

ABC는 rebalancing된 classification loss와 consistency regularization loss를 통해 학습하며, 기존의 DNN 기반 SSL 알고리즘의 backbone을 동시에 활용함으로써 minibatch의 모든 데이터로부터 학습한 좋은 퀄리티의 representation을 활용할 수 있도록 한다.

 

1.  Classification Loss 

Classification loss는 labeled data를 활용한 학습이다. 기존의 classification loss들과 동일하게 one-hot label과의 Cross-entropy를 계산하여 loss를 측정하는데, 0/1 masking을 함으로써 re-balance 효과를 얻는다.

위의 수식에서 \(\mathcal{B}\)는 베르누이 분포를, \(N_L\)은 L번째 class의 sample 수를, \(N_{y_b}\)는 sample \(x_b\)가 속한 class의 sample 수를 의미한다. 참고로 이 논문에서 class는 1부터 L까지 존재하며, \(N_1 \ge N_2 \ge \cdots \ge N_L \) 라고 가정한다.

따라서 sample \(x_b\)에 대한 베르누이 분포의 파라미터를 \(\frac{N_L}{N_{y_b}}\)로 두어 majority class일 수록 파라미터 값이 작아지고 따라서 mask의 값이 0이 나오기 쉬워진다. 반면 minority class일 수록 파라미터의 값이 커지므로 mask의 값은 1이 되기 쉽다. 이를 통해 mini-batch 내에서 minority sample에 대한 학습 반영률을 높인다.

 

2.  Consistency Regularization Loss 

Consistency Regularization은 unlabeled data를 활용한 학습 방법이다. 기본적으로는 특정 데이터에 weak augmentation을 가하든, strong augmentation을 가하든 모델의 예측이 일치해야 한다는 생각으로 고안된 메소드이다.

기존의 consistency regularization과의 차이점은 마찬가지로 0/1 mask를 사용했다는 것이다. classification loss와 달리 여기서는 unlabeled data를 활용하기 때문에 베르누이 분포의 파라미터는 \(u_b\)에 대한 모델의 예측 class를 기반으로 결정한다.

다만 저자들은 학습 초반에는 confidence가 confidence threshold에 대한 hyperparameter \(\tau\)를 넘는 일이 많지 않기 때문에, mask를 1로 둠으로써 confidence score가 높은 데이터를 충분히 활용하도록 하며, 학습을 진행함에 따라 이를 \(\frac{N_L}{N_{\hat{q_b}}}\)에 가까워지도록 점진적으로 줄였다고 한다.

또한 hard pseudo-label이 아닌 soft pseudo-label을 사용하였다. 저자들은 Entropy minimization은 특정 class에 편향된 classification하는 문제를 가속할 수 있기 때문에 soft pseudo-label을 사용했다고 설명하고 있다.

3.  End-to-end Loss 

전체 모델을 학습시키기 위한 loss는 위와 같다. 앞의 두 loss는 ABC에 대한 loss이며, \(L_{back}\)은 backbone 모델을 학습시키는 기존의 loss를 의미한다.

Experiments

실험은 CIFAR10, CIFAR100, SVHN, LSUN에 대해 진행했으며 비교를 진행하는 베이스라인 메소드는 다음과 같다.

  • Deep CNN (Vanilla)
  • BALMS (Class Imbalance Learning; CIL)
  • VAT, ReMixMatch, FixMatch (SSL)
  • FixMatch+CReST+PDA, ReMixMatch+CReST+DARP (Class Imbalance Semi-supervised Learning; CISSL)
  • ReMixMatch+DARP, FixMatch+DARP (CISSL)
    - refine the pseudo-labels
  • ReMixMatch+DARP+cRT, FixMatch+DARP+cRT (CISSL)
    - finetune the classifier using cRT

또한 Long-tail 세팅은 다음의 두 가지를 활용하였다.

(a)는 majority class부터 minority class까지 sample 수가 기하급수적으로 줄어드는 long-tail, (b)는 절반의 majority class가 동일한 sample 수를, 나머지 절반의 minority class가 동일한 sample의 수를 가지는 형태의 long-tail이다.

Main setting으로는 기하급수적으로 클래스의 sample 수가 감소하는 세팅을 고려한다. 여기서 \(\gamma\)는 imbalance ratio를, \(\beta\)는 labeled data의 비율을 의미한다. 위의 세팅에서는 ABC가 overall accuracy와 minority-class-accuracy에 대해 모두 좋은 성능을 보였다.

놀라운 점은 SSL 기법인 VAT는 vanilla 알고리즘과 거의 동일한 성능을 보였다는 것이다. 비슷하게 class imbalance를 고려하지 않은 FixMatch, ReMixMatch를 적용하는 경우에는 minority class에 대해 굉장히 낮은 성능을 보이고 있음을 알 수 있다. BALMS는 class imbalance는 고려하지만 unlabeled data를 활용하지 않아 낮은 성능을 보이고 있는 것으로 생각된다.

저자들은 다른 CISSL 기법들과 다르게 ABC는 0/1 mask를 통해 class imbalance를 완화하면서도 FixMatch (ReMixMatch)로부터 학습한 좋은 퀄리티의 representation을 활용함으로써 좋은 성능을 달성할 수 있었을 것이라 주장한다.

 

Class imbalance ratio와 labeled data의 비율을 변경하여 실험을 한 경우에도 ABC가 다른 CISSL 메소드들의 성능을 능가하고 있는 모습을 확인할 수 있다.

Table 3는 majority class와 minority class의 sample 분포가 step function의 형태를 이루는 step imbalance setting에 대한 실험 결과이다. 저자들은 이러한 형태의 imbalance setting에서는 절반의 minority 클래스가 모두 극심하게 적은 데이터를 갖게 되므로 더욱 모델이 학습하기 어려운 상황일 것이라 추정하였다. 이러한 상황에서도 ABC는 상대적으로 좋은 성능을 보이고 있다.

 

마지막으로는 LSUN 데이터셋이다. LSUN 데이터셋은 256 x 256 이미지 7.5M개를 포함하고 있는 large-scale 데이터셋으로, 데이터셋 자체가 이미 imbalance하다고 한다. Table 4에서는 Table 2, 3과 다르게 CReST를 적용한 방식은 보고하지 않았는데, 저자들은 CReST가 pseudo-label을 업데이트할 때 전체 unlabeled data를 load하는 과정이 필요한데 이는 large-scale 데이터셋이 적합하지 않기 때문이라고 제외하였다고 얘기한다. 대신, 추가적으로 FixMatch+cRT와 ReMixMatch+cRT를 비교대상으로 추가하였다.

주어진 베이스라인들과 비교했을 때 마찬가지로 ABC 방식이 가장 좋은 성능을 보이고 있음을 확인할 수 있다.

 Qualitative analysis 

 

Figure 3는 main setting 하에서 CIFAR-10 test set의 representation을 확인한 t-SNE 결과이다. SSL backbone이 없는 ABC는 class 간 분리가 잘 되는 representation을 배우는데 실패했다. 저자들은 0/1 mask를 사용하여 학습하기 때문에 충분히 데이터를 활용하지 못해 이러한 문제가 발생한다고 설명한다. 반면 ABC를 backbone과 함께 사용하는 경우, 높은 퀄리티의 representation을 배울 수 있음과 동시에 전체 데이터를 사용할 수 있다고 주장했다.

 

Figure 4를 통해서는 ABC가 class imbalance 문제를 해결한다는 것을 보여준다. 실험은 마찬가지로 main setting에서 CIFAR-10에 대해 학습한 모델들을 비교하였다. 여기서 i 행 j 열의 값은 i 번째 class에 속하면서 j 번째 class로 예측된 데이터의 양을 나타낸다. ABC를 활용하는 경우, 다른 경우에 비해 minority class를 majority class로 예측하는 수가 확연히 감소했음을 확인할 수 있다.

 

 Ablation study 

Ablation study 또한 main setting에서 CIFAR-10-LT를 활용하였으며, 실험 결과를 통해 제안된 메소드의 각 요소들이 class imbalance 문제를 해결하여 좋은 성능을 내는데 기여하고 있음을 보여준다.

Conclusion

  • CISSL을 위해 사용할 수 있는 ABC 기법을 소개함. 이는 기존의 SOTA SSL 알고리즘에 쉽게 붙여 사용할 수 있는 기술임.
  • Class-balanced한 예측을 하도록 학습하면서도 backbone에 의해 학습된 높은 퀄리티의 representaiton을 활용할 수 있도록 함.
  • 다양한 세팅 하에서의 실험을 진행하여, 제안된 알고리즘이 기존 알고리즘의 성능을 능가함을 보였음.

제안된 알고리즘은 labeled data와 unlabeled data가 동일한 분포를 따른다는 가정을 하고 있다. 따라서 저자들은 이러한 방향으로 향후 연구가 가능할 것으로 이야기한다.