본문 바로가기
AI Coding/Pytorch

[Pytorch] Tensor Manipulation

by ga.0_0.ga 2023. 2. 24.
728x90
반응형

Pytorch는 텐서의 형을 변환해주는 다양한 함수들을 제공해줍니다. 이번 포스팅에서는 텐서의 형변화를 위한 아래 4가지 함수의 사용법과 차이점에 대해 설명해보겠습니다.

- view()

- reshape()

- transpose()

- permute()

▶ view( ) 와 reshape( )

두 함수 모두 numpy의 reshape( ) 함수를 기반으로 하고 있습니다.

먼저 두 함수의 간단한 사용법 부터 설명하겠습니다. view()와 reshape() 모두 입력으로 shape을 받습니다. 원하는 차원의 shape을 적어주면 바로 형변환하여 리턴해줍니다.

import torch

x = torch.arange(12)
print(x)  # tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])
print(x.shape)  # torch.Size([12])

x=x.reshape(3,4)
print(x) 
'''
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
'''
print(x.shape)  # torch.Size([3, 4])

x=x.view(2, 6)
print(x)
'''
tensor([[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11]])
'''
print(x.shape)  # torch.Size([2, 6])

특이한 점은, view()와 reshape()의 인자 중 1개의 차원 값에 한하여-1로 값 지정이 가능합니다. 이 경우 -1자리에 자동으로 알맞은 차원을 계산하여 형변환해줍니다.

import torch

x = torch.arange(12)
print(x)  # tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])
print(x.shape)  # torch.Size([12])

x=x.reshape(2,3,-1)  # (2, 3, 2) 차원으로 자동 지정됩니다.
print(x) 
'''
tensor([[[ 0,  1],
         [ 2,  3],
         [ 4,  5]],
        [[ 6,  7],
         [ 8,  9],
         [10, 11]]])
'''
print(x.shape)  # torch.Size([2, 3, 2])

x=x.view(2, 2,-1)  # (2, 2, 3) 차원으로 자동 지정됩니다.
print(x)
'''
tensor([[[ 0,  1,  2],
         [ 3,  4,  5]],
        [[ 6,  7,  8],
         [ 9, 10, 11]]])
'''
print(x.shape)  # torch.Size([2, 2, 3])

만약, -1이 두개 이상 포함되거나 원래 텐서의 차원에서 나올 수 없는 차원을 인자로 전달한다면 에러가 발생합니다!

두 함수는 겉으로 보기에는 큰 차이가 없어 보입니다. 하지만 보이지 않는 큰 차이점이 존재합니다. 

바로,  contiguous 속성을 만족하지 않는 텐서에 적용이 가능한지 여부(contiguous에 대한 설명은 가장 아래쪽을 참고해주세요!)입니다. view()는 contiguous 속성을 만족하지 않는 경우에는 사용이 제한됩니다. 두 함수의 차이를 보기 위해 임의의 텐서를 만들어 함수에 적용해 본 후 확인해 보겠습니다.

import torch

x = torch.ones(3, 4)
x.transpose_(0, 1)
x.is_contiguous() # False

(3,4) shape의 텐서를 만든 후 transpose함수로 형변환을 하면 contiguous 속성이 깨지게 됩니다. 이때 reshape 함수와 view함수를 적용해보면 둘은 다른 결과를 보입니다. reshape() 함수는 에러없이 기대하던 값을 리턴하지만, view()함수는 에러가 발생합니다.

따라서, 위와 같은 이유로 인해 텐서의 차원을 변경하고자 할 때는, 텐서의 상태를 정확히 파악하기 힘든 경우가 있으므로 reshape() 함수를 사용하는 것이 더 안전하겠습니다!

또한 view()함수의 결과값인 새로운 텐서는 항상 데이터를 원본 텐서와 공유한다고 합니다. 원본 텐서를 변경하면 새로운 텐서가 변경되는 것입니다. 그 반대의 경우도 동일합니다!
반대로, reshape()함수는 새로운 텐서와 원본 텐서의 데이터 공유를 보장하지 않습니다.

▶ transpose( ) 와 permute( )

두 함수도 비슷한 방식으로 동작합니다. transpose()는 딱 두 개의 차원을 교환할 때만 사용가능하지만, permute()는 모든 차원을 맞교환할 수 있습니다. view()와 reshape() 함수와는 다르게 입력으로 shape의 인덱스를 넘겨줍니다. 아래 예시를 보면 쉽게 이해할 수 있습니다.

import torch

x = torch.rand(16, 32, 3)
print(x.shape)  # [16, 32, 3]

y = x.transpose(0, 2)
print(y.shape)  # [3, 32, 16]

z = x.permute(2, 1, 0)
print(z.shape)  # [3, 32, 16]

 

 

view() 와 permute(), transpose()

역시나, 형변환을 해주는 세 함수 view()와 permute() / transpose() 연산 자체에도 차이가 존재합니다. 아래 예시를 보시면 view()는 순서를 유지하면서 다음 차원으로 넘어가지만 permute()는 내부적으로 transpose() 연산을 진행합니다.

import torch

a = [[1,2,3,4],
     [1,2,3,4],
     [1,2,3,4]]

a = torch.Tensor(a) # [3, 4]

b = a.view(4, 3) # [4, 3]
print("using view: \n", b)

c = a.permute(1, 0) # [4, 3]
print("using permute: \n", c)

또한, view 함수는 앞서 설명했듯이 contiguous한 텐서에만 작동하고 리턴하는 텐서 역시 contiguous합니다. 반면에 transpose와 permute는 non-contiguous와 contiguous 텐서 모두에 동작하지만 리턴하는 텐서는 contiguous하지 않다고 합니다. permute나 transpose를 사용한 후 non-continuous 텐서를 continuous하게 만들고싶다면 transpose().contiguous()를 꼭 해주어야 합니다.

요약하자면! view와 reshape은 연산 자체는 동일하며 기존 텐서가 변경되냐 아니냐의 차이가 존재하고

view, reshape / transpose, permute는 연산 자체도 다르고 출력의 contiguous 의 차이도 존재한다고 할 수 있겠군요!.

※ contiguous 속성이란?

파이토치에서 메모리 내의 자료형 저장 상태와 관련된 속성입니다. 간단한 예시를 들어 설명해보겠습니다.

import torch

a = torch.randn(3, 4)
a.transpose_(0, 1)
print(a)
'''
tensor([[ 0.0947, -0.7289,  0.2453],
        [ 0.1352,  0.4020, -0.2195],
        [-1.7819, -0.2174, -0.0530],
        [ 0.7038,  1.1439,  0.1813]])
'''

b = torch.randn(4, 3)
print(b)
'''
tensor([[ 0.3413,  1.4892,  0.1824],
        [-1.1999,  0.2948,  0.3660],
        [-0.9538,  0.5553, -0.1672],
        [ 0.0525,  0.2478, -1.5958]])
'''

텐서 a는 (3,4) 크기로 선언 후 transpose_함수를 이용해 (4,3) 크기로 변형해 주었고, 텐서 b는 처음부터 (4,3) 크기로 선언해주었습니다. 이제 각 텐서에 저장된 값들의 메모리 주소를 확인해보겠습니다.

print([a[i][j].data_ptr() for j in range(3) for i in range(4)])
'''
[94379339310976, 94379339310980, 94379339310984, 
94379339310988, 94379339310992, 94379339310996, 
94379339311000, 94379339311004, 94379339311008, 
94379339311012, 94379339311016, 94379339311020]
'''
print([b[i][j].data_ptr() for j in range(3) for i in range(4)])
'''
[94379338807744, 94379338807756, 94379338807768, 
94379338807780, 94379338807748, 94379338807760, 
94379338807772, 94379338807784, 94379338807752, 
94379338807764, 94379338807776, 94379338807788]

위 주소를 자세히 보면 텐서 b는 한줄에 4씩 값이 증가하지만, a는 그렇지 않습니다. 메모리 주소가 할당된 상황을 그림으로 표현해보면 아래처럼 나타낼 수 있습니다.

https://jimmy-ai.tistory.com/122

 

즉, b는 axis=0인 오른쪽 방향순으로 자료가 저장되지만, a는 transpose연산을 거치면서 axis=1인 아래 방향으로 자료가 저장되게 되는 것입니다. b처럼 axis=0 순서대로 자료가 저장된 상태를 "contiguous=True 상태"라고 하고, a처럼 순서가 원래 방향과 달라진 것을 False 상태라고 합니다. 이를 확인하는 방법은 stride()함수와 is_contiguous()함수가 있습니다. 각 함수의 설명과 사용법은 아래를 참고하시면 됩니다.

a.stride() # (1, 4), a[0][0]=>a[1][0]으로 증가할때는 자료 1개만큼의 메모리 주소가 이동되고
           #         a[0][0]=>a[0][1]로 증가할 때는 자료 4개만큼의 메모리 주소가 바뀐다는 의미
b.stride() # (3, 1)

a.is_contiguous() # False
b.is_contiguous() # True

텐서의 shape를 변형하는 과정에서 메모리 저장상태가 변하기도 합니다.

728x90
반응형

'AI Coding > Pytorch' 카테고리의 다른 글

[Pytorch] Inference Time Check  (0) 2023.02.24
[Pytorch] loss nan 해결하기  (0) 2023.02.24
[Pytorch] TorchVision Fine Tuning  (0) 2023.02.05
[Pytorch] torch.nn 과 torch.nn.functional  (0) 2023.02.05
[Pytorch] Dataset과 Dataloader 2(Custom)  (0) 2023.02.05

댓글