PR-1 / TENT : Fully test-time adaptation with Entropy Minimization
Introduction
딥 네트워크는 동일한 분포를 테스트 데이터가 들어올 때 높은 정확도를 보인다. 하지만 기존 데이터와는 다른 새로운 데이터에 대한 일반화 성능은 어느 정도 한계가 있다. 훈련 데이터(Source)와 테스트 데이터(Target)이 많이 다를 수록 그 정확도가 낮아지는데, 이러한 가정 상황을 두고 _Domain Shift / Dataset Shift_라고 합니다. 여기서는 이러한 Shift를 _natural variations or corruptions_라고도 부른다.
테스트 단계에서 적응을 위해 저자는 모델 예측 값에 대한 엔트로피를 최소화하기로 했다. 이때 이 엔트로피에 대한 목적을 Test entropy라고 하고, 여기에 적용된 방법론을 tent라고 한다. 그리고 에러와 Shift를 엔트로피와 연관지어 생각을 했는데, 간단하게 정리하면 다음과 같다.
- Entropy = Error = Shift = Corruptions
- Confidence = Accuracy
엔트로피를 최소화하기 위해 TENT는 타겟 데이터에 대한 평균과 분산 통계값을 추정하고, 배치마다 affine 파라미터를 최적화시키는 방식으로 inference를 정규화하고 transform시킨다. 이런 low-dimensional, channel-wise, feature modulation과 같은 방식을 사용하면 테스트 과정에 적응시킬 때 효율적이라고 한다. (특히 온라인으로 진행돼도 좋음)
본 논문은 (Image Classification) 원본 이미지가 손상된 상황에서의 이미지 식별, (Digit Recognition) Domain shift 상황에서의 숫자 구분, (Semantic Segmentation) Simulation-to-Real 상황에서 의미론적 분할, 일련의 3가지 과정에서의 일반화 성능을 평가한다.
본 논문의 Contribution은 다음과 같다.
- 소스 데이터없이 타겟 데이터 만을 이용하는 Fully Test-time adaptation 환경을 구축했다.
- Inference 단계에서 효과적임을 입증하기 위해 Offline과 Online updates 두 상황을 가정하고 benchmark.
- Adaptation objective로서 Entropy를 사용하여 Test-time Entropy Minimization 전략을 실험했다.
- 테스트 데이터에 대한 모델의 예측 엔트로피(에러)가 줄어듦으로서 일반화 성능을 개선할 수 있었다.
- Corruptions에 대한 강인함을 평가했을 때, TENT는 ImageNet-C에서 44.0%의 에러율을 보여줌.
- Domain Adaptation 측면에서, TENT는 Online으로도 가능하며, Classification과 같은 task에서 Source-free 적응이 가능하다. 즉, 소스 데이터를 사용하고 최적화 과정이 더 많은 다른 모델과 견줄 수 있다는 얘기
SETTING: FULLY TEST-TIME ADAPTATION
Adaptation이란 Source에서 Target에 대한 일반화를 하는 것이라 할 수 있다. 기존 방법들과 달리 Fully test-time adaptation은 Inference 중 adaptation 과정에서 모델 `f_\theta`와 unlabeled 타겟 데이터 `x^t`만 요구한다. 기존 적응 방법론들은 저마다 목적이 있겠지만, Source와 Target 또는 Supervision이 동시에 사용 불가능한 모든 실제 상황을 커버하지 못한다고 할 수 있다.
테스트 단계에서 예상못한 타겟 데이터가 들어올 경우 test-time adaptation이 필요하다. 여기선 TTT(Test-time Training)와 TENT를 비교하는데, 두 방법론 모두 테스트 중에 비지도 손실 함수 `L(x^t)`를 최적화시켜 모델 적응을 시도한다는 점에선 동일하나, TTT는 훈련 데이터를 활용하지만 TENT의 Fully test-time adaptation은 훈련 데이터를 사용하지 않고 그에 따라 training loss도 필요없이 test loss만 사용한다는 점에서 큰 차이가 있다.
당연하게도 사용하는 데이터의 수도 줄어들었고, 손실 함수 계산도 필요없어지니 계산 효율성이 매우 개선되었다고 할 수 있다. 결론은 훈련 과정에서 아무런 것도 건드리지 않고 더 적은 데이터와 연산량만을 필요로 하게 되었다.
METHOD: TEST ENTROPY MINIMIZATION VIA FEATURE MODULATION
저자는 테스트 단계에서 예측값의 특징들(features)를 조정함(modulating)으로서 엔트로피를 최소화하여 모델을 최적화한다. 이게 본 방법론의 이름이 TENT(Test-ENTropy)라고 불리게 된 이유다.
들어가기 전 아래와 같은 가정만 유념하자.
- The model must already be trained. : 사전 학습 모델이 필요함
- The model must be probabilistic : Entropy를 추정하려면, 예측값에 대한 분포가 필요하므로 모델은 확률적이어야함
- The model must be differentiable : 빠르게 반복되는 최적화를 위해서는 기울기가 필요하므로 모델을 미분 가능해야함
ENTROPY OBJECTIVE
Test-time objective, `L(x_t)`는 Model prediction, `\hat y=f_\theta(x^t)`에 대한 entropy, `H(\hat y)`를 최소화하는 것이다.
이때 entropy는 `Shannon entropy`를 사용한다.
- `H(\hat y)=-\sum_c p(\hat y_c)log p(\hat y_c)` where the probabilty `\hat y_c` of class `c`
단일 예측을 최적화하고자 할 때, 가장 확률이 높은 클래스에 모든 확률을 할당하는 너무 간단한 방법이 있는데, 여기선 이를 방지하기 위해 배치 전체에서 공유되는 매개변수에 대한 배치 별 predictions을 다같이 최적화하는 방식을 사용한다.
엔트로피가 비지도 학습 목표이기는 하지만, 그것이 예측값의 불확실성을 측정하는 방법으로 작용하기 때문에, 실제로는 supervised task와 밀접한 관련이 있다. 즉, 모델이 어떤 데이터 포인트에 대해 얼마나 "확신"을 가지고 있는지를 나타내므로, 결과적으로는 지도 학습에서이 성능과 관련이 있다.
SSL대비 장점도 얘기하고 있지만, 패스
MODULATION PARAMETERS
모델 파라미터인 `\theta`는 이전 연구들에서 train-time entropy 최소화를 위해 많이 사용되었기 때문에 test-time optimization에서 사용하는 것도 자연스러운 선택이 될 수 있다. 하지만 `\theta`는 본 방법론에서 훈련 데이터에 대한 유일한 표현식이기 때문에 `\theta`를 변경하면 모델이 크게 벗어날 수 있다. 게다가 `f`는 비선형일 수 있고, `\theta`는 고차원일 수 있어서 최적화가 테스트 단계에서 너무 민감하고 비효율적일 수 있다. 고로 안정성과 효율성을 위해 대신 Linear(Scale & Shift) 부분과 low-dimesional(Channel-wise) 기능 부분의 변조만 업데이트하기한다.
_TENT는 정규화 통계값 `\mu, \sigma`를 추정하고 변환 매개변수 `\gamma, \beta`를 최적화하여 테스트 단계에서 Feature를 조정한다. (_이 정규화 및 변환 과정을 통해 feature들에 channel-wise scales 및 shift를 적용함). `\gamma, \beta`는 전체 모델 파라미터의 1% 미만을 차지하므로 이를 조정하는 것이 더 효율적이라고 함
통계값 `\mu, \sigma`는 데이터에서 추정되는 반면, 매개변수 `\gamma, \beta`는 손실에 의해 최적화된다.
구현을 위해선 소스 모델의 정규화 레이어의 용도만 변경하면 된다. 테스트 중에 모든 레이어와 채널에 대한 정규화 통계와 아핀 매개변수를 업데이트한다.
ALGORITHM
INITIALIZATION 옵티마이저는 소스 모델의 각 정규화 레이어 `l`과 채널 `k`에 대해 아핀 변환 매개변수를 수집한다.
- `\theta` \ `\{\gamma_{l,k}, \beta_{l,k}\}`는 남아있고 정규화 통계 `\{\mu_{l,k}, \sigma _{l,k} \}`는 사라진다.
ITERATION 각 스텝마다 데이터 배치에 대한 정규화 통계 및 변환 매개변수를 업데이트한다.
- (Forward pass) 정규화 통계값들은 순방향 전달 중에 각 레이어에 대해 차례로 추정된다.
- (Backward pass) 변환 매개변수 `\gamma, \beta`는 역방향 전달 중에 엔트로피 `∇H(\hat y)`의 기울기에 의해 업데이트된다.
변환 매개변수 업데이트는 현재 배치에 대한 예측값을 따르므로 다음 배치에만 영향을 미친다. 그리고 추가 계산 포인트 당 하나의 Gradient만 필요하므로 효율성을 위해 기본적으로 이 방식을 따른다.
TERMINATION
- Online Adaptation의 경우 테스트 데이터가 있는 한 종료되지 않고 계속 반복된다.
- Offline Adaptation의 경우 모델이 먼저 업데이트된 다음 inference가 반복된다.
EXPERIMENTS
(이어서)
댓글