ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • torch.utils.data
    라이브러리/PyTorch 2022. 3. 7. 14:04

    torch.utils.data

    torch.utils.data module하에는 torch가 제공하는 여러 데이터셋과 DataLoader class를 제공하는데, 이는 추상화와 유연하게 구현하는데 도움을 준다.

    torch.utils.data.DataLoader 예시는 다음과 같다.

    form torch.utils.data import TensorDataset, DataLoader
    train_ds = TensorDataset(X_train, y_train)
    train_dl = DataLoader(train_ds, batch_size=bs)
    for x_batch, y_batch in train_dl:
    	pred = model(x_batch)
        
    ######
    # 다음과 같이 batch를 iterate할 필요 없다.
    for i in range((n-1)//bs + 1):
    	x_batch = X_train[start_i:end_i]
        y_batch = y_train[start_i:end_i]
        pred = model(x_batch)

    '라이브러리 > PyTorch' 카테고리의 다른 글

    gc.collect, torch.cuda.empty_cache()  (0) 2022.04.25
    Tensor moduels  (0) 2022.03.07
    torch.optim  (0) 2022.03.07
    torch.nn  (0) 2022.03.07
    Exager execution  (0) 2022.03.07

    댓글

Designed by Tistory.