라이브러리/PyTorch

Don't forget to accumulate the Gradient

rongxian 2022. 4. 26. 11:24

ViT와 같은 self-Attention based model 또는 CNN을 Group Normalization + Weight Standardization (Batch Normalization 대신)와 사용중이라면 gradient를 accumulate하는 것을 잊지마세요!

Batch Normalization이 사용되고 있을 땐 gradient accumulation이 효과적이지 않을 지도 모르지만, 위와 같은 모델을 사용중이고 GPU VRAM이 한정적이여서 배치 사이즈를 작게 세팅하는 경우 효과적일 거야.

AMP+Gradient accumulation은 더 효과적!

from torch.cuda import amp

scaler = amp.GradScaler()
n_accumulate = 16 # n_accumulate = 1 means no gradient accumulation is applied

for epoch in epochs:
    for step, (input, target) in enumerate(train_dl):
        with amp.autocast(enabled=True):
            output = model(input)
            loss = loss_fn(output, target)
            loss = loss/n_accumulate

           # Accumulate gradients after scale.
           scaler.scale(loss).backward()

           if (step+1) % n_accumulate == 0:
                   scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

추가적으로 설명한다면 n_accumulate는 파라미터 업데이트 전(scaler.step(optimizer)) 반복 횟수를 의미한다. ( scaler.scale(loss).backward() // forward-backward passes)

예를 들어 batch_size=32 이고 gradient accumulation을 사용하지 않으면 파라미터는 batch_size (32개의 샘플)마다 업데이트 할 것이다.

그런데, 메모리가 한정적이여서 어쩔 수 없이 batch_size=4로 두고 n_accumulation=8로 둔다면 파라미터는 8번의 반복 이후 업데이트 즉, 4*8=32개의 샘플마다 업데이트 할 것이다.

참고

https://www.kaggle.com/c/cassava-leaf-disease-classification/discussion/217133

https://arxiv.org/abs/1804.07612

https://arxiv.org/abs/1711.00489