라이브러리/PyTorch
Multi-gpu에서 학습된 모델의 weights를 Single-gpu (or CPU)에서 불러오기
rongxian
2022. 9. 12. 03:38
model = torchvision.models.segmentation.__dict__['deeplabv3_resnet50'](pretrained=False, aux_loss=False)
device = torch.device("cpu")
model.classifier[-1] = torch.nn.Conv2d(model.classifier[-1].in_channels, 1, kernel_size=model.classifier[-1].kernel_size) # change number of outputs to 1
model.to(device)
# checkpoint = torch.load("output/segmentation/deeplabv3_resnet50_random/best.pt", map_location=torch.device('mps'))
checkpoint = torch.load("deeplabv3_resnet50_random.pt", map_location = "cpu")
state_dict_cpu = {k[7:]: v for (k, v) in checkpoint['state_dict'].items()}
model.load_state_dict(state_dict_cpu)
참고
https://discuss.pytorch.org/t/how-to-switch-model-trained-on-2-gpus-to-1-gpu/20039/4