본문 바로가기

머신러닝

파이토치 모델 저장; pytorch model save

반응형
SMALL

PyTorch에서 학습된 모델을 저장하는 방법은 크게 state dictionary를 저장하는 방법모델 전체를 저장하는 방법이 있습니다. 두 방법 모두 모델을 저장하고 불러오는 데 사용되지만, 저장 형태와 유연성에서 차이가 있습니다.


1. 모델의 State Dictionary 저장

  • 설명:
    • 모델의 state dictionary는 모델의 **가중치(weight)**와 바이어스(bias) 같은 학습된 파라미터들을 저장한 Python 사전(dictionary)입니다.
    • 모델 구조는 저장하지 않고, 학습된 파라미터만 저장합니다.
    • 다른 모델 구조에 파라미터를 재사용하거나, 코드 내에서 모델 클래스를 따로 정의해야 하는 경우에 적합합니다.
import torch

# 예제 모델 정의
model = MyModel()  # MyModel은 사용자 정의 모델 클래스
torch.save(model.state_dict(), 'model_state.pth')

# 모델 불러오기
model = MyModel()  # 동일한 모델 클래스를 다시 선언
model.load_state_dict(torch.load('model_state.pth'))
model.eval()  # 평가 모드로 전환

 

  • 장점:
    • 저장 파일이 상대적으로 작음.
    • 모델 구조를 유연하게 변경하거나 수정 가능.
    • 다른 모델 구조에 일부 파라미터만 로드 가능.
  • 단점:
    • 저장된 파일만으로는 모델 구조를 알 수 없으므로, 모델 클래스를 다시 정의해야 함.

 

 

2. 모델을 통째로 저장

  • 설명:
    • 모델의 state dictionary뿐만 아니라 모델의 구조(클래스 정의)까지 모두 저장합니다.
    • 저장된 파일만으로 모델을 불러오고 바로 사용할 수 있습니다.
import torch

# 예제 모델 정의
model = MyModel()  # MyModel은 사용자 정의 모델 클래스
torch.save(model, 'model_complete.pth')

# 모델 불러오기
model = torch.load('model_complete.pth')
model.eval()  # 평가 모드로 전환

 

  • 장점:
    • 저장된 파일만으로 모델 구조와 파라미터를 함께 로드할 수 있음.
    • 사용이 간단하고 직관적.
  • 단점:
    • 저장 파일 크기가 더 큼.
    • 모델 구조를 저장하므로, PyTorch 버전 간 호환성이 떨어질 수 있음.
    • 모델 클래스 코드가 변경되면 불러오기가 어려울 수 있음.

 


 

반응형
LIST