코딩/PyTorch
[PyTorch] Save and Load the Model
guungyul
2025. 1. 11. 17:43
이 글에서는 모델의 상태를 저장하고 불러오는 방법을 소개한다.
import torch
import torchvision.models as models
Save and Loading Model Weights
PyTorch 모델은 학습된 parameter들을 state_dict라 불리는 내부 dictionary에 저장한다. 이 dictionary는 torch.save 함수로 저장될 수 있다.
model = models.vgg16(weights='IMAGENET1K_V1')
torch.save(model.state_dict(), 'model_weights.pth')
저장된 모델 weight들을 불러오기 위해서는 같은 모델의 객체를 만든 후 load_state_dict() 함수를 사용해 parameter들을 불러온다.
아래 예제에서는 weights_only=True로 설정해 오직 weight들을 불러오는데에 필요한 함수들만 실행시킨다. 모델의 weights들을 불러올 때 weights_only=True로 설정하는 것은 모범 사례이다.
model = models.vgg16() # 'weights'를 명시하지 않음
model.load_state_dict(torch.load('model_weights.pth', weights_only=True))
model.eval()
Note
- model.eval() 함수는 dropout과 batch 정규화 layer들을 evaluation mode로 설정하기 전에 불러야 한다.
Saving and Loading Models with Shapes
모델 weights들을 불러올 때 모델 class의 객체를 먼저 만들어야 하는데 이는 class가 network의 구조를 정의하기 때문이다. 모델과 함께 class의 구조도 함께 저장하고 싶을 땐 model을 인자로 넘겨준다. (model.state_dict()가 아님)
torch.save(model, 'model.pth')
그 후 아래 예제처럼 모델을 불러올 수 있다.
weights_only=False로 설정된 이유는 모델을 불러와야 하기 때문이다.
model = torch.load('model.pth', weights_only=False)
Note
- 위 방법은 모델을 직렬화 할 때 Python pickle 모듈을 이용한다. 따라서 모델을 불러올 때 실제 class 정의가 접근 가능해야 한다.