[Pytorch] timm.create_model()通过指定pretrained_cfg从本地加载pretrained模型
问题描述
timm.models.create_model在选择pretrained=True时会默认在本地路径查找是否有相应的pretrained模型参数文件,如果没有则下载到本地指定目录:
windows:C:\Users\用户名\.cache\torch\hub\checkpoints Linux:/home/用户名/.cache/torch/hub/checkpoints
model = timm.models.create_model('swinv2_tiny_window8_256', pretrained=True)
通过设置pretrained_cfg,从file路径去加载本地pretrained模型
方式一
print(timm.models.create_model('swinv2_tiny_window8_256').default_cfg)
'''
{'url': 'https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth',
'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': None, 'crop_pct': 0.9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 'first_conv': 'patch_embed.proj', 'classifier': 'head', 'architecture': 'swinv2_tiny_window8_256'}
'''
pretrained_cfg = timm.models.create_model('swinv2_tiny_window8_256').default_cfg
pretrained_cfg['file'] = r'E:\proj\AI\dataset\build_dataset\pretrained\swinv2_tiny_patch4_window8_256.pth'
print(pretrained_cfg)
'''
{'url': 'https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth',
'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': None, 'crop_pct': 0.9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 'first_conv': 'patch_embed.proj', 'classifier': 'head', 'architecture': 'swinv2_tiny_window8_256',
'file': 'E:\\proj\\AI\\dataset\\build_dataset\\pretrained\\swinv2_tiny_patch4_window8_256.pth'}
'''
model = timm.models.create_model('swinv2_tiny_window8_256', pretrained=True, pretrained_cfg=pretrained_cfg)
print(model)
方式二
pretrained_cfg = timm.models.create_model('swinv2_tiny_window8_256').default_cfg
pretrained_cfg['file'] = r'E:\proj\AI\dataset\build_dataset\pretrained\swinv2_tiny_patch4_window8_256.pth'
model = timm.models.swinv2_tiny_window8_256(pretrained=True, pretrained_cfg=pretrained_cfg)
Debug记录
进入_create_swin_transformer_v2调用build_model_with_cfg 进入build_model_with_cfg调用load_pretrained 进入load_pretrained调用_resolve_pretrained_source(pretrained_cfg) 首先检查pretrained_cfg文件中是否有file,如果有file则从file的值中加载,如果没有则从url进行下载
好文推荐
发表评论