[Pytorch] loss nan 해결하기

2023. 2. 24. 23:09·Pytorch
728x90
반응형

신경망을 학습시키다 보면 학습 도중 loss 가 nan 값이 등장하는 경우가 발생하기도 합니다.

nan loss나 nan output이 발생했을 때 원인을 찾고 해결할 수 있는 방법에 대해 포스팅 하겠습니다.

​

​

▶ nan 이 발생한 연산 찾기

먼저 torch. autograd 함수중에서 nan loss 가 발생하면 그 즉시 실행을 멈추고 nan을 유발한 코드 라인을 찾아야합니다. 이를 쉽게 해주는 함수가 바로  출력해주는 함수가 autograd.set_detect_anomaly() 입니다.

autograd.set_detect_anomaly(True)

스크립트 제일 위에 위 코드를 추가해주면, 어느 라인에서 nan이 발생했는지 터미널 창의 문구를 통해 알 수 있습니다.

좀 더 구체적으로 이 함수는 autograd의 모든 엔진에 대해 오류를 감지하는 context manager입니다. 감지를 활성화 한 상태에서 모델이 forward 연산을 진행하고, backward 연산 중 오류가 발생한 forward 연산 부분을 출력해주는 함수입니다.

코드의 특정 부분에만 적용하고 싶다면 아래처럼 with문을 이용해 작성해주면 됩니다!

with torch.autograd.detect_anomaly():
        input = torch.rand(6, 10, requires_grad=True)
        output = my_func(input)
        output.backward()

아래 이미지는 위 함수를 적용했을 때 콘솔에 출력되는 내용입니다.

nan은 단순 연산 뿐만 아니라 forward, backward 연산에서 발생할 가능성이 아주 높습니다. 그렇기 때문에 원인을 직접 찾기 어려운 경우가 많은데 위 코드를 쓰면 비교적 간편하게 원인을 찾아낼 수 있습니다.

​

주의할 점은! 이 코드를 추가하면 nan을 검출하기 위해 모든 텐서를 확인하게 됩니다. 때문에 속도가 느려지므로, 디버깅을 위해서만 사용하는 것이 좋을 것 같습니다.

​

​

▶ 연산 수정

저의 경우에는 두 개의 텐서를 곱하는 과정에서 feature의 값이 과도하게 커져 'inf' 값이 도출되는 부분이 있었습니다. 이로 인해 nan 값이 발생하는 것이었습니다. pytorch에서 제공하는 LayerNorm 함수를 이용해 feature 값을 정규화 하여 해결해주었습니다.

​

또 다른 경우로는,

- 0으로 나누는 연산

- log(x)연산에서 x값이 0이거나 매우 작은 경우

에 nan loss가 발생하기 쉽다고 합니다. 1e-5와 같이 아주 작은 값을 더해주거나, nan_to_num 함수를 사용해 쉽게 해결 가능합니다.

 x = torch.tensor([float('nan'), float('inf'), -float('inf'), 3.14])
torch.nan_to_num(x) 
#### tensor([ 0.0000e+00,  3.4028e+38, -3.4028e+38,  3.1400e+00])

torch.nan_to_num(x, nan=2.0)  
#### tensor([ 2.0000e+00,  3.4028e+38, -3.4028e+38,  3.1400e+00])

torch.nan_to_num(x, nan=2.0, posinf=1.0) 
#### tensor([ 2.0000e+00,  1.0000e+00, -3.4028e+38,  3.1400e+00])

단, 이 함수는 pytorch 1.8.0 이후의 버전부터 지원되는 함수입니다!

​

 

이 외에도 ,

  • Gradient exploding, vanishing에 의해서도 발생할 수 있습니다. 이는 원인이 되는 layer의 weight나 grad를 출력해 보면서 해결해야 합니다. 이를 확인하는 방법은 아래와 같습니다.
torch.any(torch.isnan(weight)) # weight에 NaN 존재 여부를 확인하는 함수
model.layer.grad # layer의 gradient를 확인하는 함수
  • 마지막으로, learning rate가 너무 높은 경우에도 발생할 가능성이 있습니다.
728x90
반응형
저작자표시 (새창열림)

'Pytorch' 카테고리의 다른 글

[Pytorch] 메모리 효율적으로 사용하기  (3) 2024.08.30
[Pytorch] Inference Time Check  (0) 2023.02.24
[Pytorch] Tensor Manipulation  (0) 2023.02.24
[Pytorch] TorchVision Fine Tuning  (0) 2023.02.05
[Pytorch] torch.nn 과 torch.nn.functional  (0) 2023.02.05
'Pytorch' 카테고리의 다른 글
  • [Pytorch] 메모리 효율적으로 사용하기
  • [Pytorch] Inference Time Check
  • [Pytorch] Tensor Manipulation
  • [Pytorch] TorchVision Fine Tuning
ga.0_0.ga
ga.0_0.ga
    반응형
    250x250
  • ga.0_0.ga
    ##뚝딱뚝딱 딥러닝##
    ga.0_0.ga
  • 전체
    오늘
    어제
    • 분류 전체보기 (182)
      • Paper Review (52)
        • Video Scene Graph Generation (6)
        • Image Scene Graph Generation (18)
        • Graph Model (5)
        • Key Information Extraction (4)
        • Fake Detection (2)
        • Text to Image (1)
        • Diffusion Personalization (4)
        • etc (12)
      • AI Research (49)
        • Deep Learning (30)
        • Artificial Intelligence (15)
        • Data Analysis (4)
      • Pytorch (10)
      • ONNX (5)
      • OpenCV (2)
      • Error Note (34)
      • Linux (2)
      • Docker (3)
      • Etc (7)
      • My Study (16)
        • Algorithm (10)
        • Project (4)
        • Interview (2)
  • 블로그 메뉴

    • 홈
    • 태그
    • 방명록
  • 링크

  • 공지사항

  • 인기 글

  • 태그

    그래프신경망
    pytorch
    i3d
    permute
    오차 역전파
    GCN
    정규화
    Logistic regression
    dataset
    활성화 함수
    tensorflow
    알고리즘
    Inductive bias
    pandas
    linear regression
    Activation Function
    transformer
    TypeError
    fine tuning
    나이브 베이즈 분류
    dataloader
    contiguous
    차원의 저주
    torch.nn
    forch.nn.functional
    RuntimeError
    3dinput
    ONNX
    JNI
    HRNet
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.0
ga.0_0.ga
[Pytorch] loss nan 해결하기
상단으로

티스토리툴바