Pytorch는 텐서의 형을 변환해주는 다양한 함수들을 제공해줍니다. 이번 포스팅에서는 텐서의 형변화를 위한 아래 4가지 함수의 사용법과 차이점에 대해 설명해보겠습니다.
- view()
- reshape()
- transpose()
- permute()
▶ view( ) 와 reshape( )
두 함수 모두 numpy의 reshape( ) 함수를 기반으로 하고 있습니다.
먼저 두 함수의 간단한 사용법 부터 설명하겠습니다. view()와 reshape() 모두 입력으로 shape을 받습니다. 원하는 차원의 shape을 적어주면 바로 형변환하여 리턴해줍니다.
import torch
x = torch.arange(12)
print(x) # tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
print(x.shape) # torch.Size([12])
x=x.reshape(3,4)
print(x)
'''
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
'''
print(x.shape) # torch.Size([3, 4])
x=x.view(2, 6)
print(x)
'''
tensor([[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11]])
'''
print(x.shape) # torch.Size([2, 6])
특이한 점은, view()와 reshape()의 인자 중 1개의 차원 값에 한하여-1로 값 지정이 가능합니다. 이 경우 -1자리에 자동으로 알맞은 차원을 계산하여 형변환해줍니다.
import torch
x = torch.arange(12)
print(x) # tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
print(x.shape) # torch.Size([12])
x=x.reshape(2,3,-1) # (2, 3, 2) 차원으로 자동 지정됩니다.
print(x)
'''
tensor([[[ 0, 1],
[ 2, 3],
[ 4, 5]],
[[ 6, 7],
[ 8, 9],
[10, 11]]])
'''
print(x.shape) # torch.Size([2, 3, 2])
x=x.view(2, 2,-1) # (2, 2, 3) 차원으로 자동 지정됩니다.
print(x)
'''
tensor([[[ 0, 1, 2],
[ 3, 4, 5]],
[[ 6, 7, 8],
[ 9, 10, 11]]])
'''
print(x.shape) # torch.Size([2, 2, 3])
만약, -1이 두개 이상 포함되거나 원래 텐서의 차원에서 나올 수 없는 차원을 인자로 전달한다면 에러가 발생합니다!
두 함수는 겉으로 보기에는 큰 차이가 없어 보입니다. 하지만 보이지 않는 큰 차이점이 존재합니다.
바로, contiguous 속성을 만족하지 않는 텐서에 적용이 가능한지 여부(contiguous에 대한 설명은 가장 아래쪽을 참고해주세요!)입니다. view()는 contiguous 속성을 만족하지 않는 경우에는 사용이 제한됩니다. 두 함수의 차이를 보기 위해 임의의 텐서를 만들어 함수에 적용해 본 후 확인해 보겠습니다.
import torch
x = torch.ones(3, 4)
x.transpose_(0, 1)
x.is_contiguous() # False
(3,4) shape의 텐서를 만든 후 transpose함수로 형변환을 하면 contiguous 속성이 깨지게 됩니다. 이때 reshape 함수와 view함수를 적용해보면 둘은 다른 결과를 보입니다. reshape() 함수는 에러없이 기대하던 값을 리턴하지만, view()함수는 에러가 발생합니다.
![](https://blog.kakaocdn.net/dn/cvBzq8/btr0AMV9537/l4oIXThnBw84sksImWOylK/img.png)
따라서, 위와 같은 이유로 인해 텐서의 차원을 변경하고자 할 때는, 텐서의 상태를 정확히 파악하기 힘든 경우가 있으므로 reshape() 함수를 사용하는 것이 더 안전하겠습니다!
또한 view()함수의 결과값인 새로운 텐서는 항상 데이터를 원본 텐서와 공유한다고 합니다. 원본 텐서를 변경하면 새로운 텐서가 변경되는 것입니다. 그 반대의 경우도 동일합니다!
반대로, reshape()함수는 새로운 텐서와 원본 텐서의 데이터 공유를 보장하지 않습니다.
▶ transpose( ) 와 permute( )
두 함수도 비슷한 방식으로 동작합니다. transpose()는 딱 두 개의 차원을 교환할 때만 사용가능하지만, permute()는 모든 차원을 맞교환할 수 있습니다. view()와 reshape() 함수와는 다르게 입력으로 shape의 인덱스를 넘겨줍니다. 아래 예시를 보면 쉽게 이해할 수 있습니다.
import torch
x = torch.rand(16, 32, 3)
print(x.shape) # [16, 32, 3]
y = x.transpose(0, 2)
print(y.shape) # [3, 32, 16]
z = x.permute(2, 1, 0)
print(z.shape) # [3, 32, 16]
▶ view() 와 permute(), transpose()
역시나, 형변환을 해주는 세 함수 view()와 permute() / transpose() 연산 자체에도 차이가 존재합니다. 아래 예시를 보시면 view()는 순서를 유지하면서 다음 차원으로 넘어가지만 permute()는 내부적으로 transpose() 연산을 진행합니다.
import torch
a = [[1,2,3,4],
[1,2,3,4],
[1,2,3,4]]
a = torch.Tensor(a) # [3, 4]
b = a.view(4, 3) # [4, 3]
print("using view: \n", b)
c = a.permute(1, 0) # [4, 3]
print("using permute: \n", c)
![](https://blog.kakaocdn.net/dn/bmuc6t/btr0HR2Y40r/wzyaR6uj5mOV9HSKct1s4K/img.png)
또한, view 함수는 앞서 설명했듯이 contiguous한 텐서에만 작동하고 리턴하는 텐서 역시 contiguous합니다. 반면에 transpose와 permute는 non-contiguous와 contiguous 텐서 모두에 동작하지만 리턴하는 텐서는 contiguous하지 않다고 합니다. permute나 transpose를 사용한 후 non-continuous 텐서를 continuous하게 만들고싶다면 transpose().contiguous()를 꼭 해주어야 합니다.
요약하자면! view와 reshape은 연산 자체는 동일하며 기존 텐서가 변경되냐 아니냐의 차이가 존재하고
view, reshape / transpose, permute는 연산 자체도 다르고 출력의 contiguous 의 차이도 존재한다고 할 수 있겠군요!.
※ contiguous 속성이란?
파이토치에서 메모리 내의 자료형 저장 상태와 관련된 속성입니다. 간단한 예시를 들어 설명해보겠습니다.
import torch
a = torch.randn(3, 4)
a.transpose_(0, 1)
print(a)
'''
tensor([[ 0.0947, -0.7289, 0.2453],
[ 0.1352, 0.4020, -0.2195],
[-1.7819, -0.2174, -0.0530],
[ 0.7038, 1.1439, 0.1813]])
'''
b = torch.randn(4, 3)
print(b)
'''
tensor([[ 0.3413, 1.4892, 0.1824],
[-1.1999, 0.2948, 0.3660],
[-0.9538, 0.5553, -0.1672],
[ 0.0525, 0.2478, -1.5958]])
'''
텐서 a는 (3,4) 크기로 선언 후 transpose_함수를 이용해 (4,3) 크기로 변형해 주었고, 텐서 b는 처음부터 (4,3) 크기로 선언해주었습니다. 이제 각 텐서에 저장된 값들의 메모리 주소를 확인해보겠습니다.
print([a[i][j].data_ptr() for j in range(3) for i in range(4)])
'''
[94379339310976, 94379339310980, 94379339310984,
94379339310988, 94379339310992, 94379339310996,
94379339311000, 94379339311004, 94379339311008,
94379339311012, 94379339311016, 94379339311020]
'''
print([b[i][j].data_ptr() for j in range(3) for i in range(4)])
'''
[94379338807744, 94379338807756, 94379338807768,
94379338807780, 94379338807748, 94379338807760,
94379338807772, 94379338807784, 94379338807752,
94379338807764, 94379338807776, 94379338807788]
위 주소를 자세히 보면 텐서 b는 한줄에 4씩 값이 증가하지만, a는 그렇지 않습니다. 메모리 주소가 할당된 상황을 그림으로 표현해보면 아래처럼 나타낼 수 있습니다.
![](https://blog.kakaocdn.net/dn/Ba9Sk/btr0HtBkUuE/n8LtxHypFcqSzJERkkL4jK/img.png)
즉, b는 axis=0인 오른쪽 방향순으로 자료가 저장되지만, a는 transpose연산을 거치면서 axis=1인 아래 방향으로 자료가 저장되게 되는 것입니다. b처럼 axis=0 순서대로 자료가 저장된 상태를 "contiguous=True 상태"라고 하고, a처럼 순서가 원래 방향과 달라진 것을 False 상태라고 합니다. 이를 확인하는 방법은 stride()함수와 is_contiguous()함수가 있습니다. 각 함수의 설명과 사용법은 아래를 참고하시면 됩니다.
a.stride() # (1, 4), a[0][0]=>a[1][0]으로 증가할때는 자료 1개만큼의 메모리 주소가 이동되고
# a[0][0]=>a[0][1]로 증가할 때는 자료 4개만큼의 메모리 주소가 바뀐다는 의미
b.stride() # (3, 1)
a.is_contiguous() # False
b.is_contiguous() # True
텐서의 shape를 변형하는 과정에서 메모리 저장상태가 변하기도 합니다.
'AI Coding > Pytorch' 카테고리의 다른 글
[Pytorch] Inference Time Check (0) | 2023.02.24 |
---|---|
[Pytorch] loss nan 해결하기 (0) | 2023.02.24 |
[Pytorch] TorchVision Fine Tuning (0) | 2023.02.05 |
[Pytorch] torch.nn 과 torch.nn.functional (0) | 2023.02.05 |
[Pytorch] Dataset과 Dataloader 2(Custom) (0) | 2023.02.05 |
댓글