코딩/PyTorch

[PyTorch] Save and Load the Model

guungyul 2025. 1. 11. 17:43

이 글에서는 모델의 상태를 저장하고 불러오는 방법을 소개한다.

import torch
import torchvision.models as models

 

Save and Loading Model Weights

PyTorch 모델은 학습된 parameter들을 state_dict라 불리는 내부 dictionary에 저장한다. 이 dictionary는 torch.save 함수로 저장될 수 있다.

model = models.vgg16(weights='IMAGENET1K_V1')
torch.save(model.state_dict(), 'model_weights.pth')

 

저장된 모델 weight들을 불러오기 위해서는 같은 모델의 객체를 만든 후 load_state_dict() 함수를 사용해 parameter들을 불러온다.

 

아래 예제에서는 weights_only=True로 설정해 오직 weight들을 불러오는데에 필요한 함수들만 실행시킨다. 모델의 weights들을 불러올 때 weights_only=True로 설정하는 것은 모범 사례이다.

model = models.vgg16()  # 'weights'를 명시하지 않음
model.load_state_dict(torch.load('model_weights.pth', weights_only=True))
model.eval()

 

Note

  • model.eval() 함수는 dropout과 batch 정규화 layer들을 evaluation mode로 설정하기 전에 불러야 한다.

 

Saving and Loading Models with Shapes

모델 weights들을 불러올 때 모델 class의 객체를 먼저 만들어야 하는데 이는 class가 network의 구조를 정의하기 때문이다. 모델과 함께 class의 구조도 함께 저장하고 싶을 땐 model을 인자로 넘겨준다. (model.state_dict()가 아님)

torch.save(model, 'model.pth')

 

그 후 아래 예제처럼 모델을 불러올 수 있다.

weights_only=False로 설정된 이유는 모델을 불러와야 하기 때문이다.

model = torch.load('model.pth', weights_only=False)

 

Note

  • 위 방법은 모델을 직렬화 할 때 Python pickle 모듈을 이용한다. 따라서 모델을 불러올 때 실제 class 정의가 접근 가능해야 한다.