custom Dataset과 Dataloader에 대해 설명하기 전에 pytorch에서 제공하는 Dataset과 Dataloader의 기본적인 사용법 부터 설명하도록 하겠습니다.
파이토치에서는 데이터를 좀 더 편리하게 다룰 수 있도록 데이터셋(Dataset)과 데이터로더(DataLoader)라는 모듈을 기본적으로 제공합니다. 이를 사용하면 batch size 설정, 데이터 셔플(shuffle, 랜덤하게 데이터를 전달), 병렬 처리까지 파라미터로 간단히 조절하여 수행하는 것이 가능해집니다.
기본적인 사용 방법은 Dataset을 정의하고, 이를 DataLoader에 전달하는 것입니다. 간단한 사용법을 설명하기 위해 Float형 텐서를 입력받아 Dataset의 형태로 변환해주는 TensorDataset을 사용하도록 하겠습니다.
1. 먼저 필요한 pytorch 라이브러리들을 import 합니다.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset ## 텐서형 데이터셋
from torch.utils.data import DataLoader ## 데이터 로더
2. TensorDataset에 데이터를 넣어줍니다. x_train은 훈련 데이터를, y_train 훈련데이터들의 정답 라벨입니다. 이 두 가지를 TensorDataset의 입력으로 주고 dataset이라는 변수에 저장합니다.
x_train = torch.FloatTensor([[70, 77, 72],
[90, 85, 90],
[86, 88, 87],
[93, 95, 97],
[70, 63, 67]])
y_train = torch.FloatTensor([[152],[185],[180],[196],[142]])
dataset = TensorDataset(x_train, y_train)
3. 2번에서 만든 데이터 집합을 가지고 Dataloader를 사용해보겠습니다. Dataloader는 필수적으로 두개의 인자를 받습니다. 하나는 사용할 데이터셋, 하나는 batch size입니다. 저는 batch size를 2로 설정해주었습니다. shuffle은 새로운 epoch을 시작할 때마다 데이터셋의 학습 순서를 바꿀지 결정하는 인지입니다. True로 설정하면 epoch마다 입력으로 들어오는 데이터의 순서가 변경됩니다. shuffle을 사용하는 이유는 모델이 데이터셋의 순서에 익숙해지는 것을 방지하기 위함입니다.(매 epoch마다 같은 순서로 데이터가 들어오면 모델이 데이터의 순서까지 학습해버릴 수 있기 때문입니다.)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
4. 모델과 optimizer를 설정해줍니다. 저는 모델은 간단하게 linear 1 layer만 사용하였고, optimizer는 SGD를 사용했습니다.
model = nn.Linear(3,1)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-5)
5. epoch은 20으로 설정하고, 학습을 진행해보겠습니다.
total_epochs = 20
for epoch in range(total_epochs+1):
for batch_idx, samples in enumerate(dataloader):
x_train, y_train=samples
prediction = model(x_train)
loss=F.mse_loss(prediction, y_train)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Epoch {:4d}/{} Barch {}/{} Cost: {:.6f}'.format(epoch, total_epochs, batch_idx+1,len(dataloader), loss.item()))
batch_idx와 samples를 출력해보면, 다음과 같이 출력됩니다.
Dataloader에서 설정한 batch size 크기인 2만큼 데이터의 크기가 나오는 것을 알 수 있습니다.
'Pytorch' 카테고리의 다른 글
[Pytorch] Tensor Manipulation (0) | 2023.02.24 |
---|---|
[Pytorch] TorchVision Fine Tuning (0) | 2023.02.05 |
[Pytorch] torch.nn 과 torch.nn.functional (0) | 2023.02.05 |
[Pytorch] Dataset과 Dataloader 2(Custom) (0) | 2023.02.05 |
[Pytorch] cross entropy loss 에 3차원 input 사용하기 (0) | 2023.02.05 |