Vit学习及代码示例(可跑通,帮助理解)

原理:

关键代码实现示例: (一)各关键模块

import torch

import torch.nn as nn

import torch.nn.functional as F

import numpy

from PIL import Image

# test code for image2emb

batch_size, imageChannel, width, height = 1, 3, 8, 8

patch_size = 4

model_dim = 8

max_num_token = 16

patch_depth = patch_size * patch_size * imageChannel # one patch size of the image

image = torch.randn(batch_size, imageChannel, width, height)

weight = torch.randn(patch_depth, model_dim) # Conv2D:model_dim是outputChannel,patch_depth是结果conv size*intputChannel

# 1:convert image to embedding vector sequance

def image2emb_naive(image, patch_size, weight):

# image shape: batch_size*imageChannel*width*height

patch = F.unfold(image, kernel_size=patch_size, stride=patch_size).transpose(-1, -2)

patch_embedding = patch @ weight # patch size(patch depth) -> 'nlp' embedding

return patch_embedding

# patch_embeddding_naive = image2emb_naive(image, patch_size, weight)

# print(patch_embeddding_naive.shape)

kernel = weight.transpose(0, 1).reshape((-1, imageChannel, patch_size, patch_size)) # outputChannel*imageChannel(inputChannel)*height*width

def image2emb_conv(image, kernel, stride):

conv_output = F.conv2d(image, kernel, stride=stride) # batch_size*outputChannel*height*width

batch_size, outChannel, height, width = conv_output.shape

patch_embedding = conv_output.reshape(batch_size, outChannel, height * width).transpose(-1, -2)

return patch_embedding

patch_embeddding_conv = image2emb_conv(image, kernel, patch_size)

print(patch_embeddding_conv.shape)

##########################################

# 2: add prepared CLS token embedding

cls_token_embedding = torch.randn(batch_size,1,model_dim,requires_grad=True)

token_embedding = torch.cat([cls_token_embedding,patch_embeddding_conv],dim=1)

#########################################

# 3: add position embedding

position_embedding_table = torch.randn(max_num_token,model_dim,requires_grad=True)

seq_len = token_embedding.shape[1]

position_embedding = torch.tile(position_embedding_table[:seq_len],[token_embedding.shape[0],1,1])

token_embedding += position_embedding

#########################################

# 4:pass embedding to Transformer Encoder

encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim,nhead=8)

transformer_encoder = nn.TransformerEncoder(encoder_layer,num_layers=6)

encoder_output = transformer_encoder(token_embedding) # of course,there can have mask

# 5: do classfication

cls_token_output = encoder_output[:,0,:] # Batch_size,位置,channel数目

num_classes = 10

label = torch.randint(10,(batch_size,))

linear_layer = nn.Linear(model_dim,num_classes)

logits = linear_layer(cls_token_output) # 未过softmax

loss_fn = nn.CrossEntropyLoss()

loss = loss_fn(logits,label)

print(loss)

结果:

(二)整理成常规训练模式

###################################################################

# 可训练模式(常规训练测试模式,encoder结构中未加mask)

import os

import numpy as np

import torch

from PIL import Image

import matplotlib.pyplot as plt

from torch import nn, optim

from torch.utils.data import DataLoader

from torchvision import models

import torch.nn.functional as F

from pytorch_grad_cam import GradCAM

from pytorch_grad_cam.utils.image import show_cam_on_image

from torchvision import transforms

# 上述的Naive形式为原始的patch构造embedding(NLP)方式

# 此操作为卷积构造,可看成用CNN+Transformer形式(当然更进一步的修改类似)

def image2emb_conv(image, kernel, stride):

conv_output = F.conv2d(image, kernel, stride=stride) # batch_size*outputChannel*height*width

batch_size, outChannel, height, width = conv_output.shape

patch_embedding = conv_output.reshape(batch_size, outChannel, height * width).transpose(-1, -2)

# print(patch_embedding.shape)

return patch_embedding

def make_token_embedding(patch_embeddding_conv):

# print(patch_embeddding_conv.shape)

# 即构造VIT输入[CLS embedding;N 个 patch embedding]or[CLS embedding;height * width embedding]

# N :patch_num,D:model_dim(即Conv的outputChannel)

# height * width == patch_num * patch_size ^ 2

# 2: add prepared CLS token embedding(class embeddding)

# 注意:同样可以不加CLS embedding 而在VIT输出使用average pooling得到最终的image presentation

# 原文ViT是为了尽可能是模型结构接近原始的Transformer,所以采用了类似于BERT的做法,加入特殊字符

cls_token_embedding = torch.randn(batch_size, 1, model_dim, requires_grad=True)

token_embedding = torch.cat([cls_token_embedding, patch_embeddding_conv], dim=1)

# print(token_embedding.shape)

# 3: add position embedding(patch_num+1,model_dim):(N+1,D),代码中为一维位置编码,EG:3x3共9个patch,patch编码为1到9

# 也可用二维位置编码,EG:patch编码为11,12,13,21,22,23,31,32,33

# 或也可用相对位置编码(eg:9相对1距离为8),但各position embedding应展成embedding或用mask方式加入

position_embedding_table = torch.randn(max_num_token, model_dim, requires_grad=True)

seq_len = token_embedding.shape[1]

position_embedding = torch.tile(position_embedding_table[:seq_len], [token_embedding.shape[0], 1, 1])

# print(position_embedding.shape)

token_embedding += position_embedding

return token_embedding

#########################################

# 4:pass embedding to Vit Model

class VitModel(nn.Module):

def __init__(self):

super(VitModel, self).__init__()

self.encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim, nhead=nhead)

self.num_layers = num_layers

self.linear = nn.Linear(model_dim, num_classes)

self.softmax = nn.Softmax(dim=0)

def forward(self, token_embedding):

transformer_encoder = nn.TransformerEncoder(self.encoder_layer, self.num_layers)

x = transformer_encoder(token_embedding)

cls_token_output = x[:, 0, :] # Batch_size,位置,channel数目

y = self.softmax(self.linear(cls_token_output))

return y

# def train(model, dataset, lr, batch_size, num_epochs):

# data_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, drop_last=False)

# criterion = nn.CrossEntropyLoss()

# optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.99)

# for epoch in range(num_epochs):

# losses = 0

# for images in data_loader:

# target = targets

# outputs = model(input)

# # print(outputs.shape)

# loss = criterion(outputs, target) # 训练集、测试集和标签的设定对模型效果影响很大

# optimizer.zero_grad()

# loss.backward()

# optimizer.step()

# losses = losses + loss.item()

# if (epoch + 1) % 5 == 0:

# print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(losses / (data_loader.__len__())))

if __name__ == '__main__':

# 通常在一个很大的数据集上预训练ViT,然后在下游任务相对小的数据集上微调,已有研究表明在分辨率更高的图片上微调比在在分辨率更低的图片上预训练效果更好

image = Image.open("./cat.png").convert('RGB')

image = transforms.Resize((450, 450))(image) # 保持长宽比的resize方法

# img = transforms.Resize((448,448))(img) # 直接resize成正方形的方法

image = transforms.ToTensor()(image)

image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image)

############################################################

# 1: make token_embedding attributes

batch_size = 1

imageChannel, width, height = image.shape[0], image.shape[1], image.shape[2]

# print(imageChannel, width, height)

image = image.unsqueeze(0) # 拓展维度, 拓展batch_size那一维

# print(image.shape)

patch_size = 4

model_dim = 8

patch_depth = patch_size * patch_size * imageChannel # one patch size of the image

weight = torch.randn(model_dim, patch_depth) # Conv2D:model_dim是outputChannel,patch_depth是结果conv size*intputChannel

kernel = weight.reshape((-1, imageChannel, patch_size, patch_size))

patch_embeddding_conv = image2emb_conv(image, kernel, patch_size)

max_num_token = height*width + 1 # height*width == patch_num*patch_size^2, max_num_token >= height*width(patch_num*patch_size^2) + 1(class embedding)

token_embedding = make_token_embedding(patch_embeddding_conv)

##############################

# define model attributes

nhead = 8

num_layers = 6

num_classes = 10

model = VitModel()

label = torch.randint(10, (batch_size,))

criterion = nn.CrossEntropyLoss()

output = model(token_embedding)

loss = criterion(output, label)

print(loss)

输入图片:cat.png 结果:

PLus: 原理细节可以看这两个链接: https://zhuanlan.zhihu.com/p/445122996

https://blog.csdn.net/verse_armour/article/details/128336786?utm_medium=distribute.pc_relevant.none-task-blog-2defaultbaidujs_baidulandingword~default-0-128336786-blog-122799541.pc_relevant_3mothn_strategy_and_data_recovery&spm=1001.2101.3001.4242.1&utm_relevant_index=3 Huggingface工具:https://blog.csdn.net/m0_56722835/article/details/127437259

以上仅为VIT相关学习示例,其本身还是一般人直接训练不起的,但是pre-training用于迁移学习做下游任务,或借鉴VIT的结构做一些任务还是OK的。 例如: 也可用图片patches在transformer的输出,做一个GAP,然后去做最后的分类任务,另外得注意调参。

推荐阅读

评论可见,请评论后查看内容,谢谢!!!
 您阅读本篇文章共花了: