본문 바로가기
AI Coding/Pytorch

[Pytorch] loss nan 해결하기

by ga.0_0.ga 2023. 2. 24.
728x90
반응형

신경망을 학습시키다 보면 학습 도중 loss 가 nan 값이 등장하는 경우가 발생하기도 합니다.

nan loss나 nan output이 발생했을 때 원인을 찾고 해결할 수 있는 방법에 대해 포스팅 하겠습니다.

▶ nan 이 발생한 연산 찾기

먼저 torch. autograd 함수중에서 nan loss 가 발생하면 그 즉시 실행을 멈추고 nan을 유발한 코드 라인을 찾아야합니다. 이를 쉽게 해주는 함수가 바로  출력해주는 함수가 autograd.set_detect_anomaly() 입니다.

autograd.set_detect_anomaly(True)

스크립트 제일 위에 위 코드를 추가해주면, 어느 라인에서 nan이 발생했는지 터미널 창의 문구를 통해 알 수 있습니다.

좀 더 구체적으로 이 함수는 autograd의 모든 엔진에 대해 오류를 감지하는 context manager입니다. 감지를 활성화 한 상태에서 모델이 forward 연산을 진행하고, backward 연산 중 오류가 발생한 forward 연산 부분을 출력해주는 함수입니다.

코드의 특정 부분에만 적용하고 싶다면 아래처럼 with문을 이용해 작성해주면 됩니다!

with torch.autograd.detect_anomaly():
        input = torch.rand(6, 10, requires_grad=True)
        output = my_func(input)
        output.backward()

아래 이미지는 위 함수를 적용했을 때 콘솔에 출력되는 내용입니다.

nan은 단순 연산 뿐만 아니라 forward, backward 연산에서 발생할 가능성이 아주 높습니다. 그렇기 때문에 원인을 직접 찾기 어려운 경우가 많은데 위 코드를 쓰면 비교적 간편하게 원인을 찾아낼 수 있습니다.

주의할 점은! 이 코드를 추가하면 nan을 검출하기 위해 모든 텐서를 확인하게 됩니다. 때문에 속도가 느려지므로, 디버깅을 위해서만 사용하는 것이 좋을 것 같습니다.

▶ 연산 수정

저의 경우에는 두 개의 텐서를 곱하는 과정에서 feature의 값이 과도하게 커져 'inf' 값이 도출되는 부분이 있었습니다. 이로 인해 nan 값이 발생하는 것이었습니다. pytorch에서 제공하는 LayerNorm 함수를 이용해 feature 값을 정규화 하여 해결해주었습니다.

또 다른 경우로는,

- 0으로 나누는 연산

- log(x)연산에서 x값이 0이거나 매우 작은 경우

에 nan loss가 발생하기 쉽다고 합니다. 1e-5와 같이 아주 작은 값을 더해주거나, nan_to_num 함수를 사용해 쉽게 해결 가능합니다.

 x = torch.tensor([float('nan'), float('inf'), -float('inf'), 3.14])
torch.nan_to_num(x) 
#### tensor([ 0.0000e+00,  3.4028e+38, -3.4028e+38,  3.1400e+00])

torch.nan_to_num(x, nan=2.0)  
#### tensor([ 2.0000e+00,  3.4028e+38, -3.4028e+38,  3.1400e+00])

torch.nan_to_num(x, nan=2.0, posinf=1.0) 
#### tensor([ 2.0000e+00,  1.0000e+00, -3.4028e+38,  3.1400e+00])

단, 이 함수는 pytorch 1.8.0 이후의 버전부터 지원되는 함수입니다!

 

이 외에도 ,

  • Gradient exploding, vanishing에 의해서도 발생할 수 있습니다. 이는 원인이 되는 layer의 weight나 grad를 출력해 보면서 해결해야 합니다. 이를 확인하는 방법은 아래와 같습니다.
torch.any(torch.isnan(weight)) # weight에 NaN 존재 여부를 확인하는 함수
model.layer.grad # layer의 gradient를 확인하는 함수
  • 마지막으로, learning rate가 너무 높은 경우에도 발생할 가능성이 있습니다.
728x90
반응형

'AI Coding > Pytorch' 카테고리의 다른 글

[Pytorch] Inference Time Check  (0) 2023.02.24
[Pytorch] Tensor Manipulation  (0) 2023.02.24
[Pytorch] TorchVision Fine Tuning  (0) 2023.02.05
[Pytorch] torch.nn 과 torch.nn.functional  (0) 2023.02.05
[Pytorch] Dataset과 Dataloader 2(Custom)  (0) 2023.02.05

댓글