논문 리뷰

[ICLR 2021] Long-tail Learning via Logit Adjustment

Yejin Kim 2024. 3. 23. 23:51

 

이 논문은 ICLR 2021에 출판되었으며, 저자는 Aditya Krishna Menon, Sadeep Jayasumana, Ankit Singh Rawat, Himanshu Jain, Andreas Veit, Sanjiv Kumar 로 구성되어 있다.

 

Motivation

기존 Long-tail Learning에서는 다음의 두 가지 방식으로 문제를 해결하고 있었다.

  1. Classification weight의 post-hoc normalization
  2. Class별 페널티가 적용된 loss를 통한 학습

위의 두 가지 방식의 문제점을 먼저 살펴보도록 하자.

 

1.  Classification weight의 post-hoc normalization 

\( w_y^T\Phi(x) \)에서 \( w_y \)는 모델의 마지막 layer의 weight이며, \( \Phi(x) \)는 마지막 layer의 input으로 들어오는 x의 feature다.

이 공식에서는 \( \nu_y \)를 어떻게 설정하느냐에 따라 모델의 logit 가공방식이 달라진다.

\( \tau \)는 양수의 값을 갖는 hyperparameter이며 \( \nu_y \)는 \( \mathbb{P}(y) \), \( ||w_y||_2\)와 같이 세팅을 한다. 후자의 경우 \( ||w_y||_2\)가 \( \mathbb{P}(y) \)와 correlate 되어 있다는 관찰을 통해 고안된 방식이다. 하지만 이러한 assumption은 optimizer를 무엇으로 선택하느냐에 따라 쉽게 깨진다는 문제가 있다.

Figure 1에 주어진 실험을 살펴보자. Class 0은 가장 Major한 class를 의미하며 Class number가 커질수록 minor한 class라고 이해하면 된다. SGD with momentum을 사용하는 경우 \( ||w_y||_2\)이 class 내의 sample 수와 굉장히 correlate되어 있지만, Adam을 사용하는 경우에는 correlate 되어 있지 않음을 확인할 수 있다.

 

2.  Class별 페널티가 적용된 loss를 통한 학습 

이러한 방식의 예시는 다음과 같은 것들이 있다.

Katharina Morik, Peter Brockhausen, and Thorsten Joachims. Combining statistical learning with a knowledge-based approach - a case study in intensive care monitoring. In Proceedings of the Sixteenth International Conference on Machine Learning (ICML), pages 268–277, San Francisco, CA, USA, 1999. Morgan Kaufmann Publishers Inc. ISBN 1-55860-612-2.
Kaidi Cao, Colin Wei, Adrien Gaidon, Nikos Arechiga, and Tengyu Ma. Learning imbalanced datasets with label-distribution-aware margin loss. In Advances in Neural Information Processing Systems, 2019.

 

이러한 loss들은 모델의 decision boundary가 rare class로부터 멀어지도록 만들 수 있다. Rare class의 경우 majority class보다 uncertainty가 높기 때문에 decision boundary에 margin을 줌으로써 문제를 해결하는 것이다.

하지만 이러한 loss를 최소화하는 포인트가 minimal balanced error를 도출해내지 않는다는 문제가 있다. 즉, Fisher consistent 관점에서 consistent 하지 않다는 것이다.

더보기

Fisher consistency는 estimator나 loss function의 설계 및 평가에 관한 개념으로, 특정 조건 하에서 추정 방법이나 학습 알고리즘이 전체 모집단에 적용될 때 정확한 모델이나 파라미터 값을 생성할 것이 보장되는 성질을 말한다.

 

Estimator나 model이 특정 문제에 대해 fisher consistent하다는 것은 데이터 양이 무한히 많아질 때(샘플의 크기가 전체 모집단 크기에 접근할 때), 해당 방법이 데이터를 생성한 true model이나 파라미터를 복구할 수 있다는 것을 의미한다. 이러한 개념은 확률론의 대수의 법칙과 밀접한 관련이 있으며, 샘플 크기가 증가함에 따라 샘플의 평균이 전체 모집단의 평균에 가까워진다는 아이디어에 기반한다고 한다.

 

Loss function이 fisher consistent하다는 것은, 이 loss를 true data distribution에 대해 최소화함으로써 얻는 model이 true model과 일치하거나 true distribution에 따라 평균적으로 정확한 예측을 한다는 것이다.

Method

이 논문에서는 위에서 언급한 post-hoc 방식과 loss 수정 방식에 대해 모두 방법을 제안한다. 

 

1.  Post-hoc logit adjustment 

Weight normalization 방식이 logit을 multiplicative하게 업데이트하는 반면, 위의 공식을 통해 제안한 메소드는 additive하게 업데이트를 진행한다. 만약 rare label \( y\)가 biased predictor에 의해 negative score \( w_y^T\Phi(x) < 0 \)를 가지는 경우, weight normalization으로는 \(y\)가 highest score를 가지도록 logit을 수정할 수가 없다. 하지만 logit adjustment \( w_y^T\Phi(x) - ln\pi_y \)는 additive하게 update하므로 기존 logit의 부호와 관계없이 dominant class의 score를 적절히 낮출 수 있다.

2.  The logit adjusted softmax cross-entropy 

Logit adjusted softmax CE 방식은 위와 같이 제안한다. 기존의 standard softmax CE와 달리, 이 loss는 각 logit에 label-dependent offset을 추가하였다.

 

저자들은 기존의 pairwise margin을 활용하여 구성한 loss들을 general하게 나타낼 수 있는 공식을 다음과 같이 나타내었다.

그리고 이러한 loss 형태에 대해 다음의 조건을 충족시키는 경우에 pairwise loss가 Fisher consistent 하다고 한다. 위에서 제안한 loss는 아래의 조건에서 \( \delta_y = \pi_y \)인 경우에 해당하므로 consistent 함을 알 수 있다.

Experiments

실험은 크게 두 가지 형태의 데이터셋에 대해 진행했다.

  1. Simple binary problem
    • Class는 +1, -1로 구성되어 있으며 isotropic covariance와 \(\mu_y = y \cdot (+1, +1)\)의 평균값을 가지는 2D 가우시안으로부터 sample을 추출
    • 10,000개의 sample로 구성된 test set에 대해 100번의 독립시행을 하여 balanced error를 측정
  2. Real-world dataset
    • CIFAR-10, CIFAR-100, ImageNet, iNaturalist에 대해 실험을 진행
    • CIFAR에 대해서는 ResNet-32를, ImageNet과 iNaturalist에 대해서는 ResNet-50을 사용
    • 모든 모델은 SGD with momentum을 사용하여 학습을 진행

1.  Simple binary problem 

Bayes optimization을 통해 bayes-optimal solution이 우리가 달성하고자 하는 목표라고 볼 수 있다. Figure 2의 왼쪽 그림과 가운데 그림을 보면, logit adjusted 방식이 수치적으로도 기하적으로도 bayes 방식과 가장 유사한 형태를 보이고 있음을 확인할 수 있다. 또한 오른쪽 그림을 보면 weight normalization과 달리 logit adjustment는 적절히 \(\tau\)를 scaling하면 bayes-optimal solution에 달성할 수 있다. 이 말은 즉, logit adjustment가 fisher consistent하다는 것을 의미한다.

2.  Real-world dataset 

Table 3는 CIFAR-10, CIFAR-100, ImageNet, iNaturalist에 대한 실험 결과이다. Logit adjustment loss를 했을 때, 다른 방식보다 낮은 balanced error를 보이고 있음을 확인할 수 있다.

또한 Figure 3을 통해, \(\tau\) scaling을 follow-up 하더라도, post-hoc adjustment가 일관적으로 weight normalization을 outperform 하고 있는 것을 확인할 수 있다.

Figure 4는 class별 error를 측정한 결과 이다. CIFAR-100과 iNaturalist에 대해서는 보기 쉽게 시각화를 하기 위해 frequency-sorted order(sample이 많은 순으로 정렬한 순서)를 기반으로 10개의 class로 grouping한 뒤 결과를 보고하였다. 그림에서 알 수 있듯이 logit adjusted 방식이 minority class에서 가장 낮은 error를 보이고 있다.

Conclusion

  • Long-tail learning에서 post-hoc 단계와 training 단계에서 적용할 수 있는 logit adjustment 방식을 제안함
  • 기존의 방식들이 fisher consistent하지 않다는 한계점을 해결한 logit adjustment를 만듦
  • 제안된 기술을 real-world dataset에 적용하고, 그 효과성을 입증함