본문 바로가기

코딩/PyTorch

[PyTorch] Transforms

Data는 항상 machine learning 알고리즘에 사용되는 정재된 형태로 존재하지 않는다. 따라서 transforms를 활용해 데이터를 조정하고자 한다.

 

모든 TrochVision 데이터셋은 transformation logic이 포함된 callables를 받는 두 개의 parameter가 존재한다.

  • transform: feature를 바꿀 때 사용
  • target_transform: label을 바꿀 때 사용

torchvision.transforms 모듈은 자주 사용되는 transforms를 제공한다.

 

FashionMNIST 데이터셋의 feature들은 PIL 이미지 형식이고 label들은 integer이다. 학습을 위해서 feature들은 normalized tensor로, label들은 one-hot encoded tensor로 바꿔야 한다.

이런 transformation을 위해 ToTensor와 Lambda를 사용한다.

import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

ds = datasets.FashionMNIST(
	root="data",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float)
.scatter_(0, torch.tensor(y), value=1))
)

ToTensor()

ToTensor는 PIL 이미지나 NumPy ndarray를 FloatTensor로 바꿔준다. 또 이미지의 pixel intensity vlaue를 0과 1 사이 값으로 조정한다.

Lambda Transforms

Lambda transforms는 사용자가 정의한 lambda 함수를 적용시켜준다. 위 예제에서는 integer를 one-hot encoded tensor로 바꾸는 함수를 정의했다. 먼저 size 10의 zero tensor를 만든 후, label y가 표시하는 index에 value=1을 할당하는 scatter_ 함수를 부른다.

target_transform = Lambda(lambda y: torch.zeros(10, dtype=torch.float)
	.scatter_(dim=0, index=torch.tensor(y), value=1))

 

'코딩 > PyTorch' 카테고리의 다른 글