【start:20231103】

文章目录

思路问题描述:服务器无法下载timm预训练权重官方网址参考资料

下载“huggingface源”的预训练权重首先Ping一下网络运行第三方项目中的包含timm的API代码报错:LocalEntryNotFoundError报错:MaxRetryError; HTTPSConnectionPool

直接运行通用的timm的API代码在huggingface官网下载bin文件(成功)用huggingface资源创建timm模型(成功)用git下载(放弃)

下载“torch源”的预训练权重明确目标用timm的API获取pth权重链接,以便手动下载(成功)用torch资源创建timm模型(成功)特殊情况:timm的API返回的torch的url为空

指定预训练模型在服务器上的保存路径Windows或Linux服务器的保存路径huggingface或torch资源的保存路径

项目实战项目:cellseg_models_pytorch

预训练权重的进阶使用手动权重载入后,自定义加载某些特征层冻结模型的某些层

其他问题报错:pretrained继承checkpoint训练时出错报错:missing features_only argument

思路

问题描述:服务器无法下载timm预训练权重

timm库在huggingface无法联网时,huggingface会一直报网络错误,这时如果要使用预训练权重,需要采用本地读取方法,那么:

如何手动下载timm的pretrained模型?timm的权重默认会下载到哪里?如何让代码兼容手动下载的权重?

官方网址

【huggingface官网链接】https://huggingface.co/timm/tf_efficientnetv2_s.in1k

【timm库github链接】https://github.com/huggingface/pytorch-image-models#introduction

参考资料

【ref】python timm库 python timm库下载 介绍了timm的github项目

【ref】timm-手动下载模型 提到了timm的有新、旧两种下载链接(huggingface、torch)的问题,但未解决用torch链接来下载权重的问题

【ref】Timm预训练权重下载失败的解决方案~ 提到了改pretrained_cfg_overlay这个参数的方法,解决了使用huggingface链接下载的权重的问题

【ref】[Pytorch] timm.create_model()通过指定pretrained_cfg从本地加载pretrained模型 解决了用torch链接下载权重的问题

下载“huggingface源”的预训练权重

首先Ping一下网络

C:\Users\lenovo>ping huggingface.co

正在 Ping huggingface.co [31.13.83.34] 具有 32 字节的数据:

请求超时。

请求超时。

请求超时。

请求超时。

发现huggingface网络确实不太行

运行第三方项目中的包含timm的API代码

在linux服务器上,“timm库在huggingface无法联网时预训练权重无法下载”这个问题是在某一个具体项目中发现的:

# Define the model with the function API.

model = cppnet_base_multiclass(

enc_name="tf_efficientnetv2_s",

n_rays=32, # number of predicted rays

type_classes=len(pannuke_module.type_classes),

)

然后会报以下错:

报错:LocalEntryNotFoundError

开启外网时报的错

TimeoutError: timed out

...

1367 ) from head_call_error

1369 # From now on, etag and commit_hash are not None.

1370 assert etag is not None, "etag must have been retrieved from server"

LocalEntryNotFoundError: An error happened while trying to locate the file on the Hub and we cannot find the requested files in the local cache. Please check your connection and try again or make sure your Internet connection is on.

Output is truncated. View as a scrollable element or open in a text editor. Adjust cell output settings...

如上,开启外网后还是会报错,而且不会返回关键文件的下载地址

报错:MaxRetryError; HTTPSConnectionPool

MaxRetryError("HTTPSConnectionPool(host=‘huggingface.co’, port=443):

不开启外网时报的错

Seed set to 42

**kwargs : {'enc_name': 'tf_efficientnetv2_s'}

kwargs.get("checkpoint_path", None): None

checkpoint_path: None

'(MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /timm/tf_efficientnetv2_s.in21k_ft_in1k/resolve/main/model.safetensors (Caused by ConnectTimeoutError(, 'Connection to huggingface.co timed out. (connect timeout=10)'))"), '(Request ID: 1792c043-b41c-4c0a-9deb-eda1a8438775)')' thrown while requesting HEAD https://huggingface.co/timm/tf_efficientnetv2_s.in21k_ft_in1k/resolve/main/model.safetensors

'(MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /timm/tf_efficientnetv2_s.in21k_ft_in1k/resolve/main/pytorch_model.bin (Caused by ConnectTimeoutError(, 'Connection to huggingface.co timed out. (connect timeout=10)'))"), '(Request ID: ee241300-1909-4254-aaad-e6a073600cdb)')' thrown while requesting HEAD https://huggingface.co/timm/tf_efficientnetv2_s.in21k_ft_in1k/resolve/main/pytorch_model.bin

如上,不开启外网,huggingface更会报错(MaxRetryError); 不过好消息是他返回了两关键个文件(safetensors文件和bin文件)的下载地址;

直接运行通用的timm的API代码

现在我们先不要在第三方代码上折腾,给自己降低一点难度,改为直接运行通用的timm的API代码

比如要下载tf_efficientnetv2_s这个预训练权重,可以执行下述代码:

import timm

print(timm.models.create_model('tf_efficientnetv2_s').default_cfg)

返回:

{'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21ft1k-d7dafa41.pth', 'hf_hub_id': 'timm/tf_efficientnetv2_s.in21k_ft_in1k', 'architecture': 'tf_efficientnetv2_s', 'tag': 'in21k_ft_in1k', 'custom_load': False, 'input_size': (3, 300, 300), 'test_input_size': (3, 384, 384), 'fixed_input_size': False, 'interpolation': 'bicubic', 'crop_pct': 1.0, 'crop_mode': 'center', 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), 'num_classes': 1000, 'pool_size': (10, 10), 'first_conv': 'conv_stem', 'classifier': 'classifier'}

这样我们就得到了权重的路径:‘hf_hub_id’: ‘timm/tf_efficientnetv2_s.in21k_ft_in1k’

和域名组合一下,就能得到完整的地址:https://huggingface.co/timm/tf_efficientnetv2_s.in1k

在huggingface官网下载bin文件(成功)

进入刚才给出的关键文件的地址:

【官网link】https://huggingface.co/timm/tf_efficientnetv2_s.in1k

然后把bin文件下载下来;

(注意,使用API后,不管有没有下载成功,models--timm--tf_efficientnetv2_s.in21k_ft_in1k文件夹都会自动创建;)

根据代码给出的链接,理论上,我们只要下载bin文件,然后把他放到"/home/linxq/.cache/huggingface/hub/models--timm--tf_efficientnetv2_s.in21k_ft_in1k/"路径下就行,具体结果如下:

用huggingface资源创建timm模型(成功)

理论上,只要把HF下载的权重放在相应路径(xxx/.cache/huggingface/hub/)就好了

但是,就算有了权重文件,(可能是因为huggingface要执行固有的联网检测机制)timm仍然会报huggingface的“LocalEntryNotFoundError”网络错误:

LocalEntryNotFoundError: An error happened while trying to locate the file on the Hub and we cannot find the requested files in the local cache. Please check your connection and try again or make sure your Internet connection is on.

这时候,可以为timm添加一个pretrained_cfg_overlay参数,

import timm

print(timm.models.create_model('tf_efficientnetv2_s').default_cfg)

pretrained_cfg_overlay = {'file' : r"/home/linxq/.cache/huggingface/hub/models--timm--tf_efficientnetv2_s.in21k_ft_in1k/pytorch_model.bin"}

model = timm.models.create_model('tf_efficientnetv2_s', pretrained=True, pretrained_cfg_overlay=pretrained_cfg_overlay, num_classes=6)

print(model)

打印出model即代表成功!(具体实战请看后面的实战案例)

pretrained_cfg if pretrained: {'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21ft1k-d7dafa41.pth', 'file': '/home/linxq/.cache/huggingface/hub/models--timm--tf_efficientnetv2_s.in21k_ft_in1k/pytorch_model.bin', 'hf_hub_id': 'timm/tf_efficientnetv2_s.in21k_ft_in1k', 'architecture': 'tf_efficientnetv2_s', 'tag': 'in21k_ft_in1k', 'custom_load': False, 'input_size': (3, 300, 300), 'test_input_size': (3, 384, 384), 'fixed_input_size': False, 'interpolation': 'bicubic', 'crop_pct': 1.0, 'crop_mode': 'center', 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), 'num_classes': 1000, 'pool_size': (10, 10), 'first_conv': 'conv_stem', 'classifier': 'classifier'}

pretrained_cfg getattr: {'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21ft1k-d7dafa41.pth', 'file': '/home/linxq/.cache/huggingface/hub/models--timm--tf_efficientnetv2_s.in21k_ft_in1k/pytorch_model.bin', 'hf_hub_id': 'timm/tf_efficientnetv2_s.in21k_ft_in1k', 'architecture': 'tf_efficientnetv2_s', 'tag': 'in21k_ft_in1k', 'custom_load': False, 'input_size': (3, 300, 300), 'test_input_size': (3, 384, 384), 'fixed_input_size': False, 'interpolation': 'bicubic', 'crop_pct': 1.0, 'crop_mode': 'center', 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), 'num_classes': 1000, 'pool_size': (10, 10), 'first_conv': 'conv_stem', 'classifier': 'classifier'}

EfficientNet(

(conv_stem): Conv2dSame(3, 24, kernel_size=(3, 3), stride=(2, 2), bias=False)

(bn1): BatchNormAct2d(

24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True

(drop): Identity()

(act): SiLU(inplace=True)

)

(blocks): Sequential(

(0): Sequential(

(0): ConvBnAct(

(conv): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)

(bn1): BatchNormAct2d(

24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True

(drop): Identity()

(act): SiLU(inplace=True)

)

(drop_path): Identity()

)

(1): ConvBnAct(

(conv): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)

(bn1): BatchNormAct2d(

24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True

(drop): Identity()

...

)

(global_pool): SelectAdaptivePool2d (pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1))

(classifier): Linear(in_features=1280, out_features=6, bias=True)

)

Output is truncated. View as a scrollable element or open in a text editor. Adjust cell output settings...

用git下载(放弃)

另外,我还尝试了一下git方法,但是一样会因为网络问题报错,就不考虑这个方法了:

C:\Users\lenovo>git clone https://huggingface.co/timm/tf_efficientnetv2_s.in1k.git

Cloning into 'tf_efficientnetv2_s.in1k'...

fatal: unable to access 'https://huggingface.co/timm/tf_efficientnetv2_s.in1k.git/': Failed to connect to huggingface.co port 443 after 21082 ms: Couldn't connect to server

【ref】如何批量**huggingface模型和数据集文件

下载“torch源”的预训练权重

如果自己没有任何终端可以连接上huggingface的官网,是否可以用torch的资源替代huggingface的资源呢?

答案是肯定的,因为其实大多数资源torch都有;

明确目标

放弃huggingface的tf_efficientnetv2_s.in1k文件,改下载torch的tf_efficientnetv2_s_21ft1k-d7dafa41.pth文件

用timm的API获取pth权重链接,以便手动下载(成功)

暂时把pretrained从True设为False

执行

import timm

model = timm.create_model('tf_efficientnetv2_s', pretrained=False, num_classes=6)

print(model.default_cfg) # 打印url!

返回

{'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21ft1k-d7dafa41.pth', 'hf_hub_id': 'timm/tf_efficientnetv2_s.in21k_ft_in1k', 'architecture': 'tf_efficientnetv2_s', 'tag': 'in21k_ft_in1k', 'custom_load': False, 'input_size': (3, 300, 300), 'test_input_size': (3, 384, 384), 'fixed_input_size': False, 'interpolation': 'bicubic', 'crop_pct': 1.0, 'crop_mode': 'center', 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), 'num_classes': 1000, 'pool_size': (10, 10), 'first_conv': 'conv_stem', 'classifier': 'classifier'}

其中,url是旧版本对应的链接;hf_hub_id是新版本加入的HF下载链接,这个hf的优先级更高

url为:https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21ft1k-d7dafa41.pthhf_hub_id为:timm/tf_efficientnetv2_s.in21k_ft_in1k

其中url对应于torch源,现在我们把url中的pth文件下载下来,然后保存到指定路径即可

用torch资源创建timm模型(成功)

为了解决刚才说到的问题(已有timm权重后huggingface还是报网络错误);以及,为了完全避开huggingface的链接(因为就算不用huggingface来下载,timn库还是可能用它来检查,这时又会涉及到网络问题)——

可以在pretrained_cfg中直接插入torch链接下载的pth文件。

对此,我们可以用torch链接下载得到的pth文件作为pretrained_cfg['file']中的权重:

import timm

print(timm.models.create_model('tf_efficientnetv2_s').default_cfg)

pretrained_cfg = timm.models.create_model('tf_efficientnetv2_s').default_cfg

pretrained_cfg['file'] = r"/home/linxq/.cache/torch/hub/checkpoints/tf_efficientnetv2_s_21ft1k-d7dafa41.pth"

print(pretrained_cfg)

model = timm.models.create_model('tf_efficientnetv2_s', pretrained=True, pretrained_cfg=pretrained_cfg)

print(model)

打印出model即代表成功!(具体实战请看后面的实战案例)

{'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21ft1k-d7dafa41.pth', 'hf_hub_id': 'timm/tf_efficientnetv2_s.in21k_ft_in1k', 'architecture': 'tf_efficientnetv2_s', 'tag': 'in21k_ft_in1k', 'custom_load': False, 'input_size': (3, 300, 300), 'test_input_size': (3, 384, 384), 'fixed_input_size': False, 'interpolation': 'bicubic', 'crop_pct': 1.0, 'crop_mode': 'center', 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), 'num_classes': 1000, 'pool_size': (10, 10), 'first_conv': 'conv_stem', 'classifier': 'classifier'}

{'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21ft1k-d7dafa41.pth', 'hf_hub_id': 'timm/tf_efficientnetv2_s.in21k_ft_in1k', 'architecture': 'tf_efficientnetv2_s', 'tag': 'in21k_ft_in1k', 'custom_load': False, 'input_size': (3, 300, 300), 'test_input_size': (3, 384, 384), 'fixed_input_size': False, 'interpolation': 'bicubic', 'crop_pct': 1.0, 'crop_mode': 'center', 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), 'num_classes': 1000, 'pool_size': (10, 10), 'first_conv': 'conv_stem', 'classifier': 'classifier', 'file': '/home/linxq/.cache/torch/hub/checkpoints/tf_efficientnetv2_s_21ft1k-d7dafa41.pth'}

pretrained_cfg if pretrained: {'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21ft1k-d7dafa41.pth', 'file': '/home/linxq/.cache/torch/hub/checkpoints/tf_efficientnetv2_s_21ft1k-d7dafa41.pth', 'hf_hub_id': 'timm/tf_efficientnetv2_s.in21k_ft_in1k', 'architecture': 'tf_efficientnetv2_s', 'tag': 'in21k_ft_in1k', 'custom_load': False, 'input_size': (3, 300, 300), 'test_input_size': (3, 384, 384), 'fixed_input_size': False, 'interpolation': 'bicubic', 'crop_pct': 1.0, 'crop_mode': 'center', 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), 'num_classes': 1000, 'pool_size': (10, 10), 'first_conv': 'conv_stem', 'classifier': 'classifier'}

pretrained_cfg getattr: {'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21ft1k-d7dafa41.pth', 'file': '/home/linxq/.cache/torch/hub/checkpoints/tf_efficientnetv2_s_21ft1k-d7dafa41.pth', 'hf_hub_id': 'timm/tf_efficientnetv2_s.in21k_ft_in1k', 'architecture': 'tf_efficientnetv2_s', 'tag': 'in21k_ft_in1k', 'custom_load': False, 'input_size': (3, 300, 300), 'test_input_size': (3, 384, 384), 'fixed_input_size': False, 'interpolation': 'bicubic', 'crop_pct': 1.0, 'crop_mode': 'center', 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), 'num_classes': 1000, 'pool_size': (10, 10), 'first_conv': 'conv_stem', 'classifier': 'classifier'}

EfficientNet(

(conv_stem): Conv2dSame(3, 24, kernel_size=(3, 3), stride=(2, 2), bias=False)

(bn1): BatchNormAct2d(

24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True

(drop): Identity()

(act): SiLU(inplace=True)

)

(blocks): Sequential(

(0): Sequential(

(0): ConvBnAct(

(conv): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)

(bn1): BatchNormAct2d(

24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True

(drop): Identity()

(act): SiLU(inplace=True)

)

(drop_path): Identity()

)

(1): ConvBnAct(

(conv): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)

(bn1): BatchNormAct2d(

...

)

(global_pool): SelectAdaptivePool2d (pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1))

(classifier): Linear(in_features=1280, out_features=1000, bias=True)

)

Output is truncated. View as a scrollable element or open in a text editor. Adjust cell output settings...

【ref】[Pytorch] timm.create_model()通过指定pretrained_cfg从本地加载pretrained模型

特殊情况:timm的API返回的torch的url为空

执行:

import timm

model = timm.create_model('convnext_small', pretrained=False, num_classes=6)

print(model.default_cfg)

返回:

{'url': '', 'hf_hub_id': 'timm/convnext_small.in12k_ft_in1k', 'architecture': 'convnext_small', 'tag': 'in12k_ft_in1k', 'custom_load': False, 'input_size': (3, 224, 224), 'test_input_size': (3, 288, 288), 'fixed_input_size': False, 'interpolation': 'bicubic', 'crop_pct': 0.95, 'test_crop_pct': 1.0, 'crop_mode': 'center', 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 'num_classes': 1000, 'pool_size': (7, 7), 'first_conv': 'stem.0', 'classifier': 'head.fc'}

如上,返回的url为空,则无法下载pth文件,例如“convnext_small”这个模型的url就为空

这时,就只能去其他地方寻找pth文件了,例如:

https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth

参考:【资源贴】❀资源帖❀ResNet,ConvNeXt,Transformer预训练模型等

指定预训练模型在服务器上的保存路径

再总结一下预训练模型在不同服务器上的保存路径:

Windows或Linux服务器的保存路径

timm.models.create_model在选择pretrained=True时会默认在本地路径查找是否有相应的pretrained模型参数文件,如果没有则下载到本地指定目录:

Windows: C:\Users\用户名\.cache\torch\hub\checkpoints Linux: /home/用户名/.cache/torch/hub/checkpoints

huggingface或torch资源的保存路径

获取服务器上模型权重默认存储路径,以便上传模型,一般来说分两个:

huggingface:xxx/.cache/huggingface/hub/ torch:xxx/.cache/torch/hub/checkpoints/

分别对应:从huggingface上下载的权重、从旧链接上下载的权重

项目实战

项目:cellseg_models_pytorch

【code link】https://github.com/okunator/cellseg_models.pytorch/tree/main

地址:

/home/linxq/code/cell_seg_workflow/src_cite/segment/cellseg_models.pytorch-main/examples/pannuke_nuclei_segmentation_cppnet.ipynb

在最外层的应用代码中设置好enc_name就好,不用做其他参数的修改(要修改的是内层的timm.create_model部分)

import timm

# pretrained_cfg = timm.models.create_model('tf_efficientnetv2_s').default_cfg

# pretrained_cfg['file'] = r"/home/linxq/.cache/torch/hub/checkpoints/tf_efficientnetv2_s_21ft1k-d7dafa41.pth"

# Define the model with the function API.

model = cppnet_base_multiclass(

enc_name="tf_efficientnetv2_s",

n_rays=32, # number of predicted rays

type_classes=len(pannuke_module.type_classes),

pretrained=True,

# pretrained_cfg=pretrained_cfg,

)

/home/linxq/code/cell_seg_workflow/src_cite/segment/cellseg_models.pytorch-main/cellseg_models_pytorch/encoders/timm_encoder.py

找到内层的timm.create_model部分,在其函数中添加包含了file项的pretrained_cfg参数:

# create the timm model

pretrained_cfg = timm.models.create_model('tf_efficientnetv2_s').default_cfg

pretrained_cfg['file'] = r"/home/linxq/.cache/torch/hub/checkpoints/tf_efficientnetv2_s_21ft1k-d7dafa41.pth"

try:

self.backbone = timm.create_model(

name,

pretrained=pretrained,

pretrained_cfg = pretrained_cfg,

checkpoint_path=checkpoint_path,

in_chans=in_channels,

features_only=True,

out_indices=self.out_indices,

**kwargs,

)

except (AttributeError, RuntimeError) as err:

print(err)

raise RuntimeError(

f"timm backbone: {name} is not supported due to missing "

"features_only argument implementation in timm-package."

)

except IndexError as err:

print(err)

raise IndexError(

f"It's possible that the given depth: {depth} is too large for "

f"the given backbone: {name}. Try passing a smaller `depth` argument "

"or a different backbone."

)

图片可视化版:

/home/linxq/code/cell_seg_workflow/src_cite/segment/cellseg_models.pytorch-main/examples/pannuke_nuclei_segmentation_cppnet.ipynb

然后,可以成功运行:

预训练权重的进阶使用

手动权重载入后,自定义加载某些特征层

要自定义加载某些特征层,可以通过修改模型的状态字典(state_dict)来实现

'''

自定义加载某些指定的特征层

可以通过修改模型的状态字典(state_dict)来实现

'''

import torch

import timm

# 打印默认配置

print(timm.models.create_model('tf_efficientnetv2_s').default_cfg)

# 指定预训练权重文件路径

pretrained_cfg = timm.models.create_model('tf_efficientnetv2_s').default_cfg

pretrained_cfg['file'] = r"C:\Users\lenovo\.cache\torch\hub\checkpoints\tf_efficientnetv2_s_21ft1k-d7dafa41.pth"

print(pretrained_cfg)

# 创建模型

model = timm.models.create_model('tf_efficientnetv2_s', pretrained=False)

# 加载预训练权重的状态字典

pretrained_state_dict = torch.load(pretrained_cfg['file'])

# 获取当前模型的状态字典

model_state_dict = model.state_dict()

# 自定义加载特定的特征层(例如,如果只想加载模型的卷积层权重)

for key in list(pretrained_state_dict.keys()):

# 根据需求选择性加载特定的键(特定的层)

if 'conv' in key:

model_state_dict[key] = pretrained_state_dict[key]

# 将修改后的状态字典加载到模型中

model.load_state_dict(model_state_dict)

# 打印模型

print(model)

打印模型结构:

EfficientNet(

(conv_stem): Conv2dSame(3, 24, kernel_size=(3, 3), stride=(2, 2), bias=False)

(bn1): BatchNormAct2d(

24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True

(drop): Identity()

(act): SiLU(inplace=True)

)

(blocks): Sequential(

(0): Sequential(

(0): ConvBnAct(

(conv): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)

(bn1): BatchNormAct2d(

24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True

(drop): Identity()

(act): SiLU(inplace=True)

)

(drop_path): Identity()

)

(1): ConvBnAct(

(conv): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)

(bn1): BatchNormAct2d(

24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True

(drop): Identity()

...

冻结模型的某些层

要冻结模型的某些层,可以通过将这些层的参数的requires_grad属性设置为False来实现

# 冻结模型的某些层,例如冻结所有卷积层

for name, param in model.named_parameters():

if 'conv' in name:

param.requires_grad = False

# 查看哪些层被冻结

for name, param in model.named_parameters():

print(f'{name}: requires_grad={param.requires_grad}')

打印各层的情况:

conv_stem.weight: requires_grad=False

bn1.weight: requires_grad=True

bn1.bias: requires_grad=True

blocks.0.0.conv.weight: requires_grad=False

blocks.0.0.bn1.weight: requires_grad=True

blocks.0.0.bn1.bias: requires_grad=True

blocks.0.1.conv.weight: requires_grad=Fals

...

其他问题

报错:pretrained继承checkpoint训练时出错

解决了pretrained问题后,可以再研究一下checkpoint的问题:

前者是预训练权重,只有部分层的参数是定制的;后者是自己训练后得到的权重,所有层的参数都是定制的;

成功进行第一次训练后,会得到一个checkpoint;然而,第二次训练、定义模型时,加入了checkpoint_path,结果报错:

RuntimeError: timm backbone: tf_efficientnetv2_s is not supported due to missing features_only argument implementation in timm-package.

待解决…

报错:missing features_only argument

RuntimeError: timm backbone: convnext_small is not supported due to missing features_only argument implementation in timm-package.

features_only是一个参数,通常用于控制模型输出。它用于指示模型仅生成特征而不执行最终的分类或回归任务。这在某些情况下可能很有用,特别是当我们只对中间特征表示感兴趣时,而不是最终预测结果。该参数可以让我们提取模型中间层的特征,以便进一步分析或在其他任务中使用。

如果报了以上错误,这可能意味着Timm包中目前不支持ConvNext_Small模型,可能是因为缺少了用于控制输出的features_only参数的实现。这可能是Timm包开发者尚未实现该功能或尚未对ConvNext_Small模型进行适当的集成。

好文阅读

评论可见,请评论后查看内容,谢谢!!!
 您阅读本篇文章共花了: