본문 바로가기
AI Coding/Pytorch

[Pytorch] torch.nn 과 torch.nn.functional

by ga.0_0.ga 2023. 2. 5.
728x90
반응형
반응형

pytorch를 이용해 코딩을 하다 보면 같은 기능에 대해 torch.nn 과 torch.nn.functional 두 방식으로 제공하는 함수들이 여럿 있습니다.

 

▶ torch.nn이 제공하는 기능들

- Parameters

- Conv

- Pooling

- Padding

- Non-linear Activation Function

- Normalization

- Linear

- Dropout

- Loss

- .......

▶ torch.nn.functional이 제공하는 기능들

- Conv

- Pooling

- Non-linear Activation Function

- Normalization

- Dropout

- Loss

- .......

두 방식 모두 같은 결과를 제공해주지만 차이점도 존재합니다. 이번 포스팅에서는 두 개의 차이점을 알아보겠습니다!

 

이름에서도 알 수 있는 것처럼 torch.nn.functional은 함수이고, torch.nn은 클래스입니다. 클래스의 특징 중 하나는 클래스의 속성(attribute, 클래스 내부에 포함된 함수나 변수들)을 이용해 클래스의 상태를 저장하고 활용할 수 있다는 점이죠!

따라서 속성이나 상태를 저장하여 그 값을 중간에 활용해야 한다면 torch.nn을 사용하는 것이 바람직하고, 굳이 인스턴스화 시킬 필요없다면 torch.nn.functional로 바로 사용해주면 되겠습니다.

모델을 선언해줄 때 초기화 부분에 torch.nn 클래스를 이용해 모델을 정의할 것인지, forward를 진행할 때 직접 torch.nn.functional을 이용해 계산해 줄 것인지의 차이라고도 할 수 있습니다.

 

몇 가지 예제를 통해서 살펴보겠습니다!

▶ Cross Entropy Loss를 이용해 살펴보기

- torch.nn

import torch
import torch.nn as nn

loss = nn.CrossEntropyLoss()
input = torch.FloatTensor([[-1.4922, -0.1335,  0.2527,  0.0334,  0.0705],
        [-0.1801, -1.0769,  0.0612, -0.3233,  0.0075],
        [ 0.5383, -0.3063,  0.0163,  0.5453,  0.3191]])
target = torch.LongTensor([2, 3, 1])
output = loss(input, target)
print(output)  ## tensor(1.7135)
output.requires_grad_(True)
output.backward()
 

- torch.nn.functional

import torch
import torch.nn.functional as F

input = torch.FloatTensor([[-1.4922, -0.1335,  0.2527,  0.0334,  0.0705],
        [-0.1801, -1.0769,  0.0612, -0.3233,  0.0075],
        [ 0.5383, -0.3063,  0.0163,  0.5453,  0.3191]])
target = torch.LongTensor([2, 3, 1])
loss = F.cross_entropy(input, target)
loss.requires_grad_(True)
print(loss)  ## tensor(1.7135, requires_grad=True)
loss.backward()

 

동일한 input과 target 으로 실험해보았을 때 두 경우 모두 1.7135라는 값을 loss로 출력합니다.

그럼 cross entropy에서 state는 무엇일까요? 바로 각 클래스에 대한 가중치 정보인 weight입니다. 각 클래스의 데이터 갯수가 불균형하여 각 클래스에 가중치를 주어 특정 클래스에 좀 더 집중하여 학습하고자 할 때, torch.nn 함수에서는 이 파라미터를 처음 CrossEntropyLoss()를 선언할 때 인자로 넣어주게 됩니다. 아래 CrossEntropyLoss의 파라미터들 중 weight가 이를 설정해주는 부분에 해당합니다! 하지만 torch.nn.functional에서는 매번 함수를 호출할 때마다 weight를 곱해줘야 한다는 불편함이 있습니다. 이런 경우에는  torch.nn에서 제공하는 CrossEntropyLoss()를 사용하는 것이 편하겠죠!

이는 Pytorch 공식 문서에서 제공하는 설명을 통해서도 알 수 있습니다.

torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=- 100, reduce=None, 
                          reduction='mean', label_smoothing=0.0)
torch.nn.functional.cross_entropy(input, target, weight=None, size_average=None, 
                                  ignore_index=- 100, reduce=None, reduction='mean', 
                                  label_smoothing=0.0)

 

 

▶ Conv2d를 이용해 살펴보기

- torch.nn

torch.nn에서는 weight를 직접 선언해주지 않습니다. 클래스 내부에서 자체적으로 선언하고 초기화합니다.

import torch 
import torch.nn as nn 
import torch.nn.functional as F
import numpy as np

input = torch.Tensor(np.array([[[ [1,1,1,0,0], 
                                 [0,1,1,1,0], 
                                 [0,0,1,1,1], 
                                 [0,0,1,1,0],
                                 [0,1,1,0,0] ]]]))

cnn = nn.Conv2d(1,1,3)
out= cnn(input)

 

 

- torch.nn.functional

torch.nn.functional에서는 weight를 직접 선언해줍니다.

import torch 
import torch.nn as nn 
import torch.nn.functional as F
import numpy as np

input = torch.Tensor(np.array([[[ [1,1,1,0,0], 
                                 [0,1,1,1,0], 
                                 [0,0,1,1,1], 
                                 [0,0,1,1,0],
                                 [0,1,1,0,0] ]]]))

filter = torch.Tensor(np.array([[[ [1,0,1], 
                                  [0,1,0], 
                                  [1,0,1] ]]]))

output = F.conv2d(input, filter)
print(output)

 

torch.nn에서는 별도로 weight를 선언해주지 않고 nn.Conv2d함수에서 weight를 제공해줍니다. 물론 torch.nn에서는 print(cnn.weight)를 이용해 제공된 weight를 확인할 수 있습니다. 역시 Pytorch 공식 문서에서 제공하는 함수 설명을 통해서도 차이점을 알 수 있습니다.

torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, 
                dilation=1, groups=1, bias=True, padding_mode='zeros', 
                device=None, dtype=None)
torch.nn.functional.conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1)

 

728x90
반응형

댓글