코딩/PyTorch

[PyTorch] Datasets & DataLoaders

guungyul 2025. 1. 10. 00:15

PyTorch는 모델 학습 코드와 데이터셋 코드를 분리시켜 더 읽기 쉽고 모듈성이 뛰어나게 만들고자 한다.

PyTorch에서 이미 로드된 데이터셋이나 나의 데이터셋을 사용하기 위해 두가지 데이터 primitives를 제공한다:

torch.utils.data.DataLoader 와 torch.utils.data.Dataset 이다.

 

Dataset은 샘플들과 대응하는 라벨들을 저장하고 DataLoader는 Dataset을 iterable로 감싸 쉽게 샘플들에 접근할 수 있도록 해준다.

 

Loading a Dataset

아래는 Fashion-MNIST 데이터셋을 불러오는 과정이다.

FashionMNIST 데이터셋은 다음과 같은 parameter로 불러진다:

  • root: train/test 데이터가 저장된 위치
  • train: training 또는 test 데이터셋 특정
  • download=True: root에 데이터셋이 없다면 인터넷에서 다운로드
  • transform, target_transform: feature나 label 변환
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt


training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

Iterating and Visualizing the Dataset

Dataset은 list처럼 index되어 사용될 수 있다.

예를 들어 training_data[index] 형식으로 한 element에 접근할 수 있다.

labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()


Creating a Custom Dataset for your files

사용자 정의 Dataset class는 세 가지 함수를 반드시 구현해야 한다:

__init__, __len__, __getitem__

 

다음 예제에서는 FashionMNIST 구현 함수를 보여준다. 이미지들은 img_dir에 저장되어있고 라벨은 다른 CSV 파일인 annotations_file에 저장되어 있다.

import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

__init__

__init__ 함수에서는 Dataset 객체를 만들 때 실행된다.

이미지가 저장된 폴더, 주석 파일, transform 두 개가 모두 초기화된다.

 

labels.csv 파일은 다음과 같은 형식을 가지고 있다:

tshirt1.jpg, 0
tshirt2.jpg, 0
......
ankleboot999.jpg, 9

 

__len__

__len__함수는 데이터셋에 있는 샘플 개수를 반환한다.

 

__getitem__

__getitem__함수는 주어진 index idx에 위치한 샘플을 반환한다.

  • 주어진 index에 기반해 이미지 위치를 찾고
  • read_image 함수를 이용해 tensor로 변환하고
  • 대응하는 label을 csv 데이터인 self.img_labels에서 찾고
  • transform 함수를 부르고 (필요하다면)
  • 마지막으로 tensor 이미지와 대응하는 label tuple을 반환한다

Preparing your data for training with DataLoaders

Dataset은 데이터셋의 feature와 label을 한번에 하나씩 찾는다. 하지만 보통 학습 과정에서는 샘플들을 "minibatches" 단위로 넘기고, 데이터를 epoch마다 다시 섞고, Python의 multiprocessing을 이용해 데이터 검색을 빠르게 진행한다.

 

DataLoader는 이런 복잡한 과정을 추상화한 API를 제공하는 iterable이다.

from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

 

Iterate through the DataLoader

위 예제에서 데이터셋을 DataLoader를 활용해 불러왔고 필요에 따라 순회 할 수 있다.

아래 예제에서 하나의 iteration은 train_features와 train_labels의 batch를 반환한다 (feature와 label각 64개).

Suffle=True이기 때문에 모든 batch를 순회한 후 데이터 순서가 섞이게 된다.

# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
# Feature batch shape: torch.Size([64, 1, 28, 28])
print(f"Labels batch shape: {train_labels.size()}")
# Labels batch shape: torch.Size([64])
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")
# Label: 5