-
torch.utils.data라이브러리/PyTorch 2022. 3. 7. 14:04
torch.utils.data
torch.utils.data module하에는 torch가 제공하는 여러 데이터셋과 DataLoader class를 제공하는데, 이는 추상화와 유연하게 구현하는데 도움을 준다.
torch.utils.data.DataLoader 예시는 다음과 같다.
form torch.utils.data import TensorDataset, DataLoader train_ds = TensorDataset(X_train, y_train) train_dl = DataLoader(train_ds, batch_size=bs)
for x_batch, y_batch in train_dl: pred = model(x_batch) ###### # 다음과 같이 batch를 iterate할 필요 없다. for i in range((n-1)//bs + 1): x_batch = X_train[start_i:end_i] y_batch = y_train[start_i:end_i] pred = model(x_batch)
'라이브러리 > PyTorch' 카테고리의 다른 글
gc.collect, torch.cuda.empty_cache() (0) 2022.04.25 Tensor moduels (0) 2022.03.07 torch.optim (0) 2022.03.07 torch.nn (0) 2022.03.07 Exager execution (0) 2022.03.07