[ONNX] Pytorch 모델을 ONNX 모델로 변환하기

2023. 3. 30. 23:23·ONNX
728x90
반응형

이번 포스팅에서는 ONNX에 대한 간단한 설명과 pytorch를 ONNX로 변환하는 방법에 대해 설명하겠습니다!

 

▶ ONNX란?

ONNX는 Open Neural Network Exchange의 줄임말입니다. 말 그대로 서로 다른 딥러닝 프레임워크 환경인(Tensorflow, Pytorch 등등...) 에서 만들어진 모델들을 서로 호환하여 사용할 수 있도록 만들어진 플랫폼입니다. ONNX는 아래 두 가지의 장점을 갖습니다.

1. Framework Interoperability: 특정 프레임워크에서 생성된 모델을 다른 환경에서 import하여 자유롭게 사용할 수 있습니다. (모바일, PC의 구분없이 사용 할 수 있습니다.)

2. Shared Optimization: 하드웨어(linux, window, mac, CPU, GPU 등등...) 설계시에 ONNX representation을 기준으로 최적화 할 수 있습니다. 때문에 매우 효율적입니다.

 

 

▶ Pytorch 모델을 ONNX 모델로 변환하기

먼저 필요한 라이브러리들을 import 해줍니다.

import io
import numpy as np

from torch import nn
import torch.onnx

from efficientnet_pytorch import EfficientNet

 

저는 변환할 모델로 EfficientNet을 사용했습니다.

https://github.com/lukemelas/EfficientNet-PyTorch

 

GitHub - lukemelas/EfficientNet-PyTorch: A PyTorch implementation of EfficientNet and EfficientNetV2 (coming soon!)

A PyTorch implementation of EfficientNet and EfficientNetV2 (coming soon!) - GitHub - lukemelas/EfficientNet-PyTorch: A PyTorch implementation of EfficientNet and EfficientNetV2 (coming soon!)

github.com

먼저 변환하고자 하는 pytorch모델 파일을 준비하고, 모델의 파라미터를 학습된 가중치로 초기화합니다. 그리고 아래와 같이 네트워크를 eval 모드로 전환해줍니다.

torch_model = EfficientNet()

model_path = 'efficientnet-b0.pth'

map_location = lambda storage, loc: storage
if torch.cuda.is_available():
    map_location = None
state_dict = torch.load(model_path)
torch_model.load_state_dict(state_dict)
torch_model.eval()

 

 

다음으로, torch.onnx.export를 이용하여 변환을 수행합니다. 변환할 떄는 랜덤한 값으로 채워진 텐서를 입력 값으로 줘야합니다. (input 값과 맞는 자료형과 shape이라면 어떤 값이던 상관없습니다!)

export() 함수의 가장 마지막 파라미터인 dynamic_axes는 가변적인 길이를 가진 차원을 지정할 때 사용하는 파라미터입니다. 특정 차원을 동적 차원으로 지정하지 않는 한, ONNX 그래프는 입력 값의 사이즈를 모든 차원에 대해 고정하게됩니다.

x = torch.randn(batch_size, 1, 224, 224, requires_grad=True)
torch_out = torch_model(x)

# 모델 변환
torch.onnx.export(torch_model,               # 실행될 모델
                  x,                         # 모델 입력값 (튜플 또는 여러 입력값들도 가능합니다.)
                  "effi_onnx_covert.onnx",   # 모델 저장 경로 (파일 또는 파일과 유사한 객체 모두 가능합니다.)
                  export_params=True,        # 모델 파일 안에 학습된 모델 가중치를 저장할지의 여부
                  opset_version=13,          # 모델을 변환할 때 사용할 연산자들의 ONNX 버전
                  do_constant_folding=True,  # 최적화 작업시 상수폴딩을 사용할지의 여부
                  input_names = ['input'],   # 모델의 입력값을 가리키는 이름
                  output_names = ['output'], # 모델의 출력값을 가리키는 이름
                  dynamic_axes={'input' : {0 : 'batch_size'},    # 가변적인 길이를 가진 차원
                                'output' : {0 : 'batch_size'}})

 

변환이 완료되면 아래와 같이 모델을 로드하고, check_model을 이용해 확인해볼 수 있습니다.

import onnx

onnx_model = onnx.load("super_resolution.onnx")
onnx.checker.check_model(onnx_model)

아래 그림은 변환된 efficientNet의 ONNX 파일을 Neutron을 이용해 시각화한 것입니다.

 
728x90
반응형
저작자표시

'ONNX' 카테고리의 다른 글

[ONNX] ONNX Model Quantization  (1) 2023.09.05
[ONNX] ONNX Model Visualization(Netron)  (0) 2023.04.04
[ONNX] ONNX 변환모델에 메타데이터 추가하기  (0) 2023.04.01
[ONNX] ONNX Runtime에서 실행하기  (0) 2023.03.31
'ONNX' 카테고리의 다른 글
  • [ONNX] ONNX Model Quantization
  • [ONNX] ONNX Model Visualization(Netron)
  • [ONNX] ONNX 변환모델에 메타데이터 추가하기
  • [ONNX] ONNX Runtime에서 실행하기
ga.0_0.ga
ga.0_0.ga
    반응형
    250x250
  • ga.0_0.ga
    ##뚝딱뚝딱 딥러닝##
    ga.0_0.ga
  • 전체
    오늘
    어제
    • 분류 전체보기 (180)
      • Paper Review (50)
        • Video Scene Graph Generation (6)
        • Image Scene Graph Generation (18)
        • Graph Model (5)
        • Key Information Extraction (4)
        • Fake Detection (2)
        • Text to Image (1)
        • Diffusion Personalization (4)
        • etc (10)
      • AI Research (49)
        • Deep Learning (30)
        • Artificial Intelligence (15)
        • Data Analysis (4)
      • Pytorch (10)
      • ONNX (5)
      • OpenCV (2)
      • Error Note (34)
      • Linux (2)
      • Docker (3)
      • Etc (7)
      • My Study (16)
        • Algorithm (10)
        • Project (4)
        • Interview (2)
  • 블로그 메뉴

    • 홈
    • 태그
    • 방명록
  • 링크

  • 공지사항

  • 인기 글

  • 태그

    GCN
    차원의 저주
    오차 역전파
    그래프신경망
    HRNet
    pytorch
    활성화 함수
    ONNX
    transformer
    permute
    Activation Function
    linear regression
    RuntimeError
    나이브 베이즈 분류
    fine tuning
    torch.nn
    3dinput
    pandas
    tensorflow
    JNI
    contiguous
    Inductive bias
    TypeError
    dataloader
    i3d
    Logistic regression
    dataset
    정규화
    forch.nn.functional
    알고리즘
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.0
ga.0_0.ga
[ONNX] Pytorch 모델을 ONNX 모델로 변환하기
상단으로

티스토리툴바