라이브러리/PyTorch
torch.utils.data
rongxian
2022. 3. 7. 14:04
torch.utils.data
torch.utils.data module하에는 torch가 제공하는 여러 데이터셋과 DataLoader class를 제공하는데, 이는 추상화와 유연하게 구현하는데 도움을 준다.
torch.utils.data.DataLoader 예시는 다음과 같다.
form torch.utils.data import TensorDataset, DataLoader
train_ds = TensorDataset(X_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs)
for x_batch, y_batch in train_dl:
pred = model(x_batch)
######
# 다음과 같이 batch를 iterate할 필요 없다.
for i in range((n-1)//bs + 1):
x_batch = X_train[start_i:end_i]
y_batch = y_train[start_i:end_i]
pred = model(x_batch)