신경망을 학습시키다 보면 학습 도중 loss 가 nan 값이 등장하는 경우가 발생하기도 합니다.
nan loss나 nan output이 발생했을 때 원인을 찾고 해결할 수 있는 방법에 대해 포스팅 하겠습니다.
▶ nan 이 발생한 연산 찾기
먼저 torch. autograd 함수중에서 nan loss 가 발생하면 그 즉시 실행을 멈추고 nan을 유발한 코드 라인을 찾아야합니다. 이를 쉽게 해주는 함수가 바로 출력해주는 함수가 autograd.set_detect_anomaly() 입니다.
autograd.set_detect_anomaly(True)
스크립트 제일 위에 위 코드를 추가해주면, 어느 라인에서 nan이 발생했는지 터미널 창의 문구를 통해 알 수 있습니다.
좀 더 구체적으로 이 함수는 autograd의 모든 엔진에 대해 오류를 감지하는 context manager입니다. 감지를 활성화 한 상태에서 모델이 forward 연산을 진행하고, backward 연산 중 오류가 발생한 forward 연산 부분을 출력해주는 함수입니다.
코드의 특정 부분에만 적용하고 싶다면 아래처럼 with문을 이용해 작성해주면 됩니다!
with torch.autograd.detect_anomaly():
input = torch.rand(6, 10, requires_grad=True)
output = my_func(input)
output.backward()
아래 이미지는 위 함수를 적용했을 때 콘솔에 출력되는 내용입니다.
nan은 단순 연산 뿐만 아니라 forward, backward 연산에서 발생할 가능성이 아주 높습니다. 그렇기 때문에 원인을 직접 찾기 어려운 경우가 많은데 위 코드를 쓰면 비교적 간편하게 원인을 찾아낼 수 있습니다.
주의할 점은! 이 코드를 추가하면 nan을 검출하기 위해 모든 텐서를 확인하게 됩니다. 때문에 속도가 느려지므로, 디버깅을 위해서만 사용하는 것이 좋을 것 같습니다.
▶ 연산 수정
저의 경우에는 두 개의 텐서를 곱하는 과정에서 feature의 값이 과도하게 커져 'inf' 값이 도출되는 부분이 있었습니다. 이로 인해 nan 값이 발생하는 것이었습니다. pytorch에서 제공하는 LayerNorm 함수를 이용해 feature 값을 정규화 하여 해결해주었습니다.
또 다른 경우로는,
- 0으로 나누는 연산
- log(x)연산에서 x값이 0이거나 매우 작은 경우
에 nan loss가 발생하기 쉽다고 합니다. 1e-5와 같이 아주 작은 값을 더해주거나, nan_to_num 함수를 사용해 쉽게 해결 가능합니다.
x = torch.tensor([float('nan'), float('inf'), -float('inf'), 3.14])
torch.nan_to_num(x)
#### tensor([ 0.0000e+00, 3.4028e+38, -3.4028e+38, 3.1400e+00])
torch.nan_to_num(x, nan=2.0)
#### tensor([ 2.0000e+00, 3.4028e+38, -3.4028e+38, 3.1400e+00])
torch.nan_to_num(x, nan=2.0, posinf=1.0)
#### tensor([ 2.0000e+00, 1.0000e+00, -3.4028e+38, 3.1400e+00])
단, 이 함수는 pytorch 1.8.0 이후의 버전부터 지원되는 함수입니다!
이 외에도 ,
- Gradient exploding, vanishing에 의해서도 발생할 수 있습니다. 이는 원인이 되는 layer의 weight나 grad를 출력해 보면서 해결해야 합니다. 이를 확인하는 방법은 아래와 같습니다.
torch.any(torch.isnan(weight)) # weight에 NaN 존재 여부를 확인하는 함수
model.layer.grad # layer의 gradient를 확인하는 함수
- 마지막으로, learning rate가 너무 높은 경우에도 발생할 가능성이 있습니다.
'Pytorch' 카테고리의 다른 글
[Pytorch] 메모리 효율적으로 사용하기 (3) | 2024.08.30 |
---|---|
[Pytorch] Inference Time Check (0) | 2023.02.24 |
[Pytorch] Tensor Manipulation (0) | 2023.02.24 |
[Pytorch] TorchVision Fine Tuning (0) | 2023.02.05 |
[Pytorch] torch.nn 과 torch.nn.functional (0) | 2023.02.05 |