[Pytorch] Dataset과 Dataloader 2(Custom)

2023. 2. 5. 15:00·Pytorch
728x90
반응형
반응형

이전 게시물에 이어 Dataset과 Dataloader를 커스텀하는 방법에 대해 작성하겠습니다.

​

▶ 먼저! Custom Dataset과 Dataloader가 필요한 이유가 무엇일까요?

딥러닝을 이용하는 거의 모든 작업들에서는 상당히 많은 양의 데이터를 이용하여 학습을 진행합니다. 그런데, 이 엄청난 양의 데이터를 한번에 불러오려면 시간도 오래 걸리고, 메모리 부족 현상이 발생할 수 있습니다. 데이터를 한번에 다 로드하지 않고 조금씩만 불러다 쓰면 이런 문제를 해결할 수 있겠죠!

이전 게시물에서 설명했던 Dataset은 모든 데이터를 한번에 불러오는 방식이었습니다. 데이터를 조금씩 불러오기 위해서는 custom Dataset을 만들어야 합니다. 또한 데이터 속에는 서로 다른 길이의 input이 있을 수 있기 때문에 batch를 만들어 주기 위해서는 Dataloader에서 batch를 만드는 부분을 수정해 custom Dataloader를 사용해야 합니다. (input의 길이가 다르면 하나의 batch로 묶는 것이 불가능하기 때문입니다.)

​

 

​

▶ Custom Dataset과 Dataloader 사용하기

Dataset을 커스텀 하려면 먼저 데이터 클래스가 pytorch에서 데이터셋을 제공하는 추상 클래스인 torch.utils.data.Dataset 를 상속받아야 합니다. Dataset을 상속받아 "__init__", "__len__", "__getitem__" 함수들을 오버라이드 해야합니다. 이 세 함수가 가장 기본적인 뼈대입니다.

새로 만들 Dataset 클래스를 CustomDataset이라고 하겠습니다. 각 함수의 역할은 다음과 같습니다.

이전 게시물에 TensorDataset을 이용하여 만들었던 데이터셋을 별도의 클래스로 정의하여 만들어보겠습니다.

​

1. CustomDataset 클래스 작성하기

 
class CustomDataset(Dataset): ## Dataset 클래스 상속
	def __init__(self): 
        # 데이터셋의 전처리를 해주는 부분
		self.x_train = [[70, 77, 72],
						[90, 85, 90],
						[86, 88, 87],
						[93, 95, 97],
						[70, 63, 67]]
		self.y_train = [[152],[185],[180],[196],[142]]
	
	def __len__(self):
        # 데이터셋의 길이. 총 데이터의 갯수를 반환하는 함수, len(dataset)으로 사용 가능
		return len(self.x_train)
	
	def __getitem__(self, idx):
        # 데이터셋에서 특정 1개의 샘플을 가져와 반환하는 함수, 
        # dataset[i]를 했을때 i번째 데이터를 가져올 수 있도록 함

		x = torch.FloatTensor(self.x_train[idx])
		y = torch.FloatTensor(self.y_train[idx])
		return x, y

__init__ 함수에는 전체 데이터를 로드하거나 파일 목록을 로드하여 학습데이터와 정답 라벨을 변수에 넣어주고 필요한 전처리를 해줍니다. 별도의 전처리 과정이 필요없다면 변수에만 넣어주면 됩니다. 
__len__ 함수에는 전체 데이터의 길이를 반환하도록 합니다.

__getitem__함수는 인자로 index를 받고 index번째 데이터를 반환하도록 작성해주면 됩니다.

dataset = CustomDataset()
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

CustomDataset 클래스를 생성하고 생성된 클래스는 DataLoader에 전달해주면 됩니다! 그리고 batch size와 data shuffle의 정보도 함께 설정합니다.

​

- 나머지 학습코드는 이전 게시물과 동일합니다!

전체 코드는 github을 참고해주세요!

​

​

 

▶ Custom Dataloader로 길이가 변하는 input 사용하기

DataLoader에는 다양한 파라미터들이 존재합니다.

dataset = CustomDataset()
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

이 중, 다양한 길이의 input을 에러없이 처리하기 위해서는 collate_fn의 인자로 주어질 함수를 설계해야 합니다. 간단한 예시를 통해 설명하겠습니다.

​

1. 다양한 길이를 갖는 데이터를 만들어줍니다.

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset
from torch.utils.data import DataLoader

class MyData(Dataset):
	def __init__(self): 
		self.x_train = [[i]*(i+1) for i in range(10)]
		self.y_train = [[i] for i in range(10)]
		
	def __len__(self):
		return 10
    
	def __getitem__(self, idx):
		x = torch.FloatTensor(self.x_train[idx])
		y = torch.FloatTensor(self.y_train[idx])
		return x, y

my_dataset = MyData()
dataloader = DataLoader(my_dataset)
for data in dataloader:
    print(data[0])

생성된 데이터는 아래와 같습니다. 모든 데이터들이 다른 길이를 같습니다. 

이때 별도의 처리없이 batch size를 2로 설정하고 학습을 진행하려고 한다면 , 아래와 같은 오류가 발생하게 됩니다.

​

따라서! batch로 묶이게 될 모든 데이터들을 길이에 상관없이 하나로 잘 묶어주기 위해 collate_fn 함수를 구현해주어야 합니다.

아래는 위의 데이터셋에 맞게 구현한 collate_fn 함수입니다.

def my_collate_fn(data):
	input = [d[0] for d in data]
	label = [d[1] for d in data]
	padding = torch.nn.utils.rnn.pad_sequence(input, batch_first=True)
    return padding, label

가장 긴 길이에 맞춰 짧은 것들의 나머지 자리들을 0으로 채우도록(패딩하도록) 구현하였습니다.

 

구현한 collate_fn 함수를 사용하려면 DataLoader의 인자로 전달하면 됩니다. 그럼, batch size를 3으로 설정하고 collate함수 사용결과를 출력해보겠습니다.

my_dataset = MyData()
dataloader = DataLoader(my_dataset, collate_fn=my_collate_fn, batch_size=3)
for data in dataloader:
    print(data[0], data[1])

에러없이 잘 진행됩니다!

 

 

전체 코드는 저의 Github에서 확인할 수 있습니다.

 

728x90
반응형
저작자표시 (새창열림)

'Pytorch' 카테고리의 다른 글

[Pytorch] Tensor Manipulation  (0) 2023.02.24
[Pytorch] TorchVision Fine Tuning  (0) 2023.02.05
[Pytorch] torch.nn 과 torch.nn.functional  (0) 2023.02.05
[Pytorch] Dataset과 Dataloader 1(Basic)  (0) 2023.02.05
[Pytorch] cross entropy loss 에 3차원 input 사용하기  (0) 2023.02.05
'Pytorch' 카테고리의 다른 글
  • [Pytorch] TorchVision Fine Tuning
  • [Pytorch] torch.nn 과 torch.nn.functional
  • [Pytorch] Dataset과 Dataloader 1(Basic)
  • [Pytorch] cross entropy loss 에 3차원 input 사용하기
ga.0_0.ga
ga.0_0.ga
    반응형
    250x250
  • ga.0_0.ga
    ##뚝딱뚝딱 딥러닝##
    ga.0_0.ga
  • 전체
    오늘
    어제
    • 분류 전체보기 (181)
      • Paper Review (51)
        • Video Scene Graph Generation (6)
        • Image Scene Graph Generation (18)
        • Graph Model (5)
        • Key Information Extraction (4)
        • Fake Detection (2)
        • Text to Image (1)
        • Diffusion Personalization (4)
        • etc (11)
      • AI Research (49)
        • Deep Learning (30)
        • Artificial Intelligence (15)
        • Data Analysis (4)
      • Pytorch (10)
      • ONNX (5)
      • OpenCV (2)
      • Error Note (34)
      • Linux (2)
      • Docker (3)
      • Etc (7)
      • My Study (16)
        • Algorithm (10)
        • Project (4)
        • Interview (2)
  • 블로그 메뉴

    • 홈
    • 태그
    • 방명록
  • 링크

  • 공지사항

  • 인기 글

  • 태그

    permute
    forch.nn.functional
    fine tuning
    GCN
    tensorflow
    차원의 저주
    JNI
    contiguous
    TypeError
    Inductive bias
    pandas
    torch.nn
    나이브 베이즈 분류
    i3d
    dataset
    3dinput
    Activation Function
    정규화
    pytorch
    HRNet
    Logistic regression
    활성화 함수
    ONNX
    linear regression
    dataloader
    RuntimeError
    알고리즘
    오차 역전파
    그래프신경망
    transformer
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.0
ga.0_0.ga
[Pytorch] Dataset과 Dataloader 2(Custom)
상단으로

티스토리툴바