본문 바로가기

머신러닝

파이토치 모델 로딩, pytorch model loading

반응형

파이토치로 학습시킨 모델의 파라미터들을 파일로 저장해 두었다면,

 

다시 그 파일을 모델에 읽어들여서 사용할 필요가 있습니다.

 

여기서는 파일로부터  모델 파라미터들을 읽어서 모델에 로딩시키는 과정을 설명합니다.


반응형

1단계

 

파일로 저장된 모델을 dictionary로 읽어들여야 한다.

  • 파일 확장자는 .pth, .pt를 사용합니다.

 

모델 읽기를 위해서 파이토치에서는 torch.load( )라는 함수를 제공합니다.

 

_state_dict = torch.load('saved_model.pth')

 

_state_dict는 Python dictionary type으로, keyd와 value 쌍으로 이루어져 있습니다.

 


2단계

 

Dictionary에 읽혀진 파라미터들의 이름을 바꾸어야 할 때가 있습니다. 

 

아래 설명한 경우가 아니라면, 2단계는 건너뛰어도 됩니다.

 

예를 들어, 모델 학습할 때, nn.DataParallel( )을 사용했다면 파라미터 앞에 'module.' 이라는 prefix가 붙습니다.

 

이 파라미터들은 모델로 로딩이 불가하기 때문에 이름에서 'module.' prefix를 제거해 줘야 합니다.

 

바뀐 이름으로 파라미터를 저장해야 하기 때문에 새로운 dictionary에 저장해 주는 것이 좋습니다.

 

새로운 dictionary로는 OrderedDict 를 사용합니다. 

 

OrderedDict는 dictionary와 같지만, 저장되는 (key, value)들의 순서도 같이 저장합니다. 순차 저장하고, 그 순서대로 (key, value)를 뽑아낼 수 있다는 겁니다.

이게 왜 필요하나면, 모델 로딩할 때, 각 파라미터들의 값을 순차적으로 채워나가기 때문입니다.

일반 dictionary를 사용하면, 순서가 바뀔 수 있어서 모델 로딩할 때 문제가 생길 수 있습니다.

 

아래 코드는 저장된 파라미터 이름에서 'module.'을 제외한 새 이름으로 파라미터를 신규 OrderedDict에 저장하는 예입니다.

    from collections import OrderedDict
    
    new_state_dict = OrderedDict()
    
    for k, v in _state_dict.items():
    	new_state_dict[k[7:]] = v

위 코드에서 k[7:]은 이름에서 'module.' 문자열을 제거한 이후라는 뜻입니다.

 

이제 new_state_dict에는 원래 이름들로 저장된 모델 파라미터들이 있습니다.

 


3단계

 

new_state_dict를 실제 모델에 로딩하는 과정입니다.

 

모델을 생성한 후, 모델의 load_state_dict() 함수를 이용해서 로딩하면 됩니다.

 

 # new_state_dict : 파일에서 읽어들인 모델 파라미터들이 OrderedDict에 저장
 
 model = resnet50()
    
 model.load_state_dict(new_state_dict, strict=False)

위 코드에서는 ResNet50 모델을 생성하고, 그 모델에 파라미터들을 로딩합니다.

 

load_state_dict( ) 함수의 argument로 strict=False를 지정한 것을 주의깊게 보세요.

 

모델의 모든 파라미터들과 new_state_dict에 있는 파라미터들이 이름과 개수가 모두 동일하다면, 로딩이 성공합니다.

 

그러나, 이름과 개수가 틀릴 경우, exception이 발생하게 됩니다. 예를 들어, Missing key 같은 겁니다. 모델에 파라미터 A가 있는데, 로딩하려는 new_state_dict에는 없다라는 겁니다.

 

그런데, 이런 exception이 에러 상황이 아니라도 발생할 때가 있습니다. 예를 들어, 저장된 모델이 일부 파라미터들만 저장한 경우도 있거든요. 이럴 때, 에러가 아니라고 알려주는 것이 strict=False 입니다.

 

주의할 점도 있습니다. strict=False로 설정하면, 모델은 매칭되는 파라미터들만 로딩하고, 나머지는 버려두게 됩니다. 그래서 최악의 경우, 매칭되는 파라미터들이 없다면, 모델은 어떤 파라미터도 로딩하지 않게 됩니다. 에러도 발생시키지 않으면서 말이죠.

 


4단계

 

모델 파라미터에 값들이 제대로 로딩되었는지 확인하는 방법입니다.

 

    model = resnet50(num_classes=2)
    
    print('before--->\n', model.state_dict()['layer4.2.conv3.weight'])
    
    model.load_state_dict(new_state_dict, strict=False)
    
    print('after--->\n', model.state_dict()['layer4.2.conv3.weight'])

로딩 전 후에 같은 위치의 파라미터 값을 출력해 보면 됩니다.

 

만약, 로딩이 제대로 되지 않았다면 파라미터 값이 변하지 않았을 겁니다.

 

위 코드 예에서는 layer4.2.conv3.weight 파라미터 값을 확인해 봤습니다.

 

파라미터 이름들을 알아내는 코드는 아래와 같습니다.

 

    for name, param in model.named_parameters():
        if param.requires_grad is True:
            print(name)

named_parameters()를 이용해서 파라미터들의 이름과 파라미터 값을 추출하고, 출력하면 됩니다.

 


5단계

 

파라미터값들이 로딩된 모델을 이용해서 추가 학습을 할 때가 있습니다.

 

모델 전체를 하는 것이 아니라, 모델의 일부분만 하고 싶을 때가 있습니다.

 

이 경우는, 학습을 진행할 파라미터들과 그렇지 않을 파라미터를 찾아서 아래 코드와 같이 지정해 주면 됩니다.

 

    for name, param in model.named_parameters():
        if contains_keyword(name, omit_keywords) is True:
            param.requires_grad = True
        else:
            param.requires_grad = False

학습 진행 여부는 requires_grad 값을 True 혹은 False로 지정하여 결정하는데,

 

위 코드에서는 파라미터 이름을 기준으로 학습할 파라미터들을 결정합니다.

반응형