[Pytorch] cross entropy loss 에 3차원 input 사용하기

2023. 2. 5. 14:11·Pytorch
728x90
반응형
반응형

분류기를 통과한 후 나온 현재 tensor의 구조는 다음과 같습니다.

(batch_size, max_len, num_classes)

 

이를 아래와 같은 순서로 변경해주어야 합니다.

(batch_size, num_classes, max_len)

 

참고 사이트에 따르면 두번째 자리에 항상 클래스 수가 와야 합니다.

​

참고: https://stackoverflow.com/questions/63648735/pytorch-crossentropy-loss-with-3d-input

 

Pytorch crossentropy loss with 3d input

I have a network which outputs a 3D tensor of size (batch_size, max_len, num_classes). My groud truth is in the shape (batch_size, max_len). If I do perform one-hot encoding on the labels, it'll be...

stackoverflow.com

 

728x90
반응형
저작자표시 (새창열림)

'Pytorch' 카테고리의 다른 글

[Pytorch] Tensor Manipulation  (0) 2023.02.24
[Pytorch] TorchVision Fine Tuning  (0) 2023.02.05
[Pytorch] torch.nn 과 torch.nn.functional  (0) 2023.02.05
[Pytorch] Dataset과 Dataloader 2(Custom)  (0) 2023.02.05
[Pytorch] Dataset과 Dataloader 1(Basic)  (0) 2023.02.05
'Pytorch' 카테고리의 다른 글
  • [Pytorch] TorchVision Fine Tuning
  • [Pytorch] torch.nn 과 torch.nn.functional
  • [Pytorch] Dataset과 Dataloader 2(Custom)
  • [Pytorch] Dataset과 Dataloader 1(Basic)
ga.0_0.ga
ga.0_0.ga
    반응형
    250x250
  • ga.0_0.ga
    ##뚝딱뚝딱 딥러닝##
    ga.0_0.ga
  • 전체
    오늘
    어제
    • 분류 전체보기 (181)
      • Paper Review (51)
        • Video Scene Graph Generation (6)
        • Image Scene Graph Generation (18)
        • Graph Model (5)
        • Key Information Extraction (4)
        • Fake Detection (2)
        • Text to Image (1)
        • Diffusion Personalization (4)
        • etc (11)
      • AI Research (49)
        • Deep Learning (30)
        • Artificial Intelligence (15)
        • Data Analysis (4)
      • Pytorch (10)
      • ONNX (5)
      • OpenCV (2)
      • Error Note (34)
      • Linux (2)
      • Docker (3)
      • Etc (7)
      • My Study (16)
        • Algorithm (10)
        • Project (4)
        • Interview (2)
  • 블로그 메뉴

    • 홈
    • 태그
    • 방명록
  • 링크

  • 공지사항

  • 인기 글

  • 태그

    Activation Function
    permute
    torch.nn
    linear regression
    contiguous
    dataloader
    i3d
    활성화 함수
    dataset
    JNI
    RuntimeError
    pandas
    forch.nn.functional
    알고리즘
    그래프신경망
    정규화
    Inductive bias
    나이브 베이즈 분류
    pytorch
    오차 역전파
    3dinput
    transformer
    HRNet
    GCN
    TypeError
    Logistic regression
    fine tuning
    차원의 저주
    ONNX
    tensorflow
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.0
ga.0_0.ga
[Pytorch] cross entropy loss 에 3차원 input 사용하기
상단으로

티스토리툴바