본문 바로가기
AI Coding/Pytorch

[Pytorch] cross entropy loss 에 3차원 input 사용하기

by ga.0_0.ga 2023. 2. 5.
728x90
반응형
반응형

분류기를 통과한 후 나온 현재 tensor의 구조는 다음과 같습니다.

(batch_size, max_len, num_classes)

 

이를 아래와 같은 순서로 변경해주어야 합니다.

(batch_size, num_classes, max_len)

 

참고 사이트에 따르면 두번째 자리에 항상 클래스 수가 와야 합니다.

참고: https://stackoverflow.com/questions/63648735/pytorch-crossentropy-loss-with-3d-input

 

Pytorch crossentropy loss with 3d input

I have a network which outputs a 3D tensor of size (batch_size, max_len, num_classes). My groud truth is in the shape (batch_size, max_len). If I do perform one-hot encoding on the labels, it'll be...

stackoverflow.com

 

728x90
반응형

댓글