rongxian 2022. 3. 7. 14:00

torch.optim

back-progate 함으로써 NN의 weight를 업데이트 하는 과정을 optimization이라 한다.

torch.optim module은 optimization schedule 에 대한 다양한 툴과 기능을 포함한다.

다음과 같이 torch.optim module을 이용해 optimizer를 정의할 수 있다.

import torch
opt = torch.optim.SGD(model.parameters(), lr=lr)

그리고나서 optimization 수행을 다음과 같이 하면 된다.

opt.step()
opt.zero_grad()


##### 다음과 같이 할 필요 없다.
with torch.no_grad():
	# applying the parameter updates uisng SGD
    for param in model.parameters(): param -= param.grad * lr
    model.zero_grad()