-
torch.optim라이브러리/PyTorch 2022. 3. 7. 14:00
torch.optim
back-progate 함으로써 NN의 weight를 업데이트 하는 과정을 optimization이라 한다.
torch.optim module은 optimization schedule 에 대한 다양한 툴과 기능을 포함한다.
다음과 같이 torch.optim module을 이용해 optimizer를 정의할 수 있다.
import torch opt = torch.optim.SGD(model.parameters(), lr=lr)
그리고나서 optimization 수행을 다음과 같이 하면 된다.
opt.step() opt.zero_grad() ##### 다음과 같이 할 필요 없다. with torch.no_grad(): # applying the parameter updates uisng SGD for param in model.parameters(): param -= param.grad * lr model.zero_grad()
'라이브러리 > PyTorch' 카테고리의 다른 글
gc.collect, torch.cuda.empty_cache() (0) 2022.04.25 Tensor moduels (0) 2022.03.07 torch.utils.data (0) 2022.03.07 torch.nn (0) 2022.03.07 Exager execution (0) 2022.03.07