라이브러리/PyTorch
Don't forget to accumulate the Gradient
rongxian
2022. 4. 26. 11:24
ViT와 같은 self-Attention based model 또는 CNN을 Group Normalization + Weight Standardization (Batch Normalization 대신)와 사용중이라면 gradient를 accumulate하는 것을 잊지마세요!
Batch Normalization이 사용되고 있을 땐 gradient accumulation이 효과적이지 않을 지도 모르지만, 위와 같은 모델을 사용중이고 GPU VRAM이 한정적이여서 배치 사이즈를 작게 세팅하는 경우 효과적일 거야.
AMP+Gradient accumulation은 더 효과적!
from torch.cuda import amp
scaler = amp.GradScaler()
n_accumulate = 16 # n_accumulate = 1 means no gradient accumulation is applied
for epoch in epochs:
for step, (input, target) in enumerate(train_dl):
with amp.autocast(enabled=True):
output = model(input)
loss = loss_fn(output, target)
loss = loss/n_accumulate
# Accumulate gradients after scale.
scaler.scale(loss).backward()
if (step+1) % n_accumulate == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
추가적으로 설명한다면 n_accumulate
는 파라미터 업데이트 전(scaler.step(optimizer)
) 반복 횟수를 의미한다. ( scaler.scale(loss).backward()
// forward-backward passes)
예를 들어 batch_size=32
이고 gradient accumulation을 사용하지 않으면 파라미터는 batch_size (32개의 샘플)마다 업데이트 할 것이다.
그런데, 메모리가 한정적이여서 어쩔 수 없이 batch_size=4
로 두고 n_accumulation=8
로 둔다면 파라미터는 8번의 반복 이후 업데이트 즉, 4*8=32개의 샘플마다 업데이트 할 것이다.
참고
https://www.kaggle.com/c/cassava-leaf-disease-classification/discussion/217133