[Pytorch] timm.create_model()通过指定pretrained_cfg从本地加载pretrained模型

问题描述

timm.models.create_model在选择pretrained=True时会默认在本地路径查找是否有相应的pretrained模型参数文件,如果没有则下载到本地指定目录:

windows:C:\Users\用户名\.cache\torch\hub\checkpoints Linux:/home/用户名/.cache/torch/hub/checkpoints

model = timm.models.create_model('swinv2_tiny_window8_256', pretrained=True)

通过设置pretrained_cfg,从file路径去加载本地pretrained模型

方式一

print(timm.models.create_model('swinv2_tiny_window8_256').default_cfg)

'''

{'url': 'https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth',

'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': None, 'crop_pct': 0.9, 'interpolation': 'bicubic', 'fixed_input_size': True,

'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 'first_conv': 'patch_embed.proj', 'classifier': 'head', 'architecture': 'swinv2_tiny_window8_256'}

'''

pretrained_cfg = timm.models.create_model('swinv2_tiny_window8_256').default_cfg

pretrained_cfg['file'] = r'E:\proj\AI\dataset\build_dataset\pretrained\swinv2_tiny_patch4_window8_256.pth'

print(pretrained_cfg)

'''

{'url': 'https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth',

'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': None, 'crop_pct': 0.9, 'interpolation': 'bicubic', 'fixed_input_size': True,

'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 'first_conv': 'patch_embed.proj', 'classifier': 'head', 'architecture': 'swinv2_tiny_window8_256',

'file': 'E:\\proj\\AI\\dataset\\build_dataset\\pretrained\\swinv2_tiny_patch4_window8_256.pth'}

'''

model = timm.models.create_model('swinv2_tiny_window8_256', pretrained=True, pretrained_cfg=pretrained_cfg)

print(model)

方式二

pretrained_cfg = timm.models.create_model('swinv2_tiny_window8_256').default_cfg

pretrained_cfg['file'] = r'E:\proj\AI\dataset\build_dataset\pretrained\swinv2_tiny_patch4_window8_256.pth'

model = timm.models.swinv2_tiny_window8_256(pretrained=True, pretrained_cfg=pretrained_cfg)

Debug记录

进入_create_swin_transformer_v2调用build_model_with_cfg 进入build_model_with_cfg调用load_pretrained 进入load_pretrained调用_resolve_pretrained_source(pretrained_cfg) 首先检查pretrained_cfg文件中是否有file,如果有file则从file的值中加载,如果没有则从url进行下载

好文推荐

评论可见,请评论后查看内容,谢谢!!!
 您阅读本篇文章共花了: