본문 바로가기
Error Note

[Pytorch] RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #

by ga.0_0.ga 2023. 1. 29.
728x90
반응형
반응형

- 전체 에러 문구

RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2

cross entropy loss를 사용할 때 발생하는 에러입니다. target(정답 라벨) 자리에 잘못된 데이터 타입이 왔을 때 발생합니다.

- 해결 방법

기존의 F.cross_entropy(logits, targets)를 아래처럼 변경해주면 됩니다.

F.cross_entropy(logits, targets.to(device='cuda', dtype=torch.int64))

 

728x90
반응형

댓글