mmFewShot

小样本学习、元学习框架mmFewShot,对当下流行的基于深度学习的少样本分类与检测算法,提供了统一的训练、推理、评估的算法框架

github地址:https://github.com/open-mmlab/mmfewshot link 官方文档: https://mmfewshot.readthedocs.io/ link

配置

python3.7 cuda 10.1 torch1.7.0 torchvision 0.8.0 torchaudio 0.7.0

mmcv 1.4.0 mmdet 2.20.0 mmcls 0.15.0 (这三个包互相也有版本依赖)

配置自己数据集

以目标检测VOCsplit1为例 (元学习数据集分为Base和Novel两部分,Base是基类使用比较大的数据集(比如VOC COCO之类的公开数据集),在基类上训练第一阶段,训练出一个”会学习“的模型,之后放在Novel数据集上,Novel很小,只需要几张(1~10)图片,里面的类是与基类不同的类,目标是通过少数几张图片学会泛化性较强的模型。)

在mmfewshot-main/mmfewshot/detection/datasets/voc.py 官方的就是把VOC的20个类分为15Base+5Novel

ALL_CLASSES_SPLIT1=("改成自己的数据集,所有的类Base+Novel"),

NOVEL_CLASSES_SPLIT1=('改成自己的数据集,新类'),

BASE_CLASSES_SPLIT1=('改成自己的数据集,基类'),

注:如果在conda环境中安装了mmfewshot包,那在需要配置数据集的时候就要去找anaconda的环境里面的那个mmfewshot文件夹,修改那里的dataset/voc.py

数据集准备

VOC格式,需要xml,jpg文件和txt文件

configs/detection/base/datasets/nway_kshot(这里要看对应的模型用的是哪种方法)/base_voc.py

data_root='自己的数据集根目录'

底下的每个ann_file换成对应的数据集部分的txt文件

samples_per_gpu='此处修改训练的batchsize'

num_support_ways='基类数量'

第一阶段训练基类的数据集这就配置好了,开始训练,以meta_rcnn模型为例

configs/detection/meta_rcnn/voc/split1/meta-rcnn_xxxxx_split1_base-training.py

num_classes num_meta_classes ='基类数量'

max_iters='训练次数'

然后就可以开始训练第一阶段了, 多GPU的话就是 bash tools/detection/dist_train.sh configs/detection/meta_rcnn/voc/split1/meta-xxxxx_base-training.py 'gpu数量' 单GPU python tools/detection/dist_train.sh configs/detection/meta_rcnn/voc/split1/meta-xxxxx_base-training.py 训练结果自动保存至work_dirs/meta-xxxxxx

注:/mmfewshot-main/configs/detection/meta_rcnn/meta-rcnn_r101_c4.py和meta-rcnn_r150_c4.py这两个文件里面的num_classes改不改有没有用我没试,反正有这么个东西我就直接改了也不费事,这两个文件里的num_classes改成所有类的数量

第二阶段训练

学习小样本类,首先要准备小样本txt文件,还有分到每个类的txt文件

configs/detection/base/datasets/nway_kshot(这里要看对应的模型用的是哪种方法)/few_shot_voc.py 和上面的base_voc一样,改数据集路径和samples_per_gpu,只是这里的num_classes要改成所有类的数量,base+novel

数据集准备与配置

需要/mmfewshot-main/data/few_shot_ann/voc/benchmark_5shot 这样的文件夹,其中包括所有类的txt文档 文档里面放着五张包含有此类的图片的地址 这里可以下载VOC官方数据集的这玩意 https://drive.google.com/file/d/1EQSKo5n2obj7tW8RytYTJ-eEYbXqtUXE/view?usp=sharing 自己数据集的文件可以自己写个脚本或者直接动动小手自己做一下。反正一共也没几张

还需要在VOCdevkit/VOC2007/ImageSets/Main文件夹中放自己的类的train val trainval test的txt文档,里面是所有图片中本类的正负样本情况,具体可以看VOC的官方解释。

第二阶段训练每个类用多少张图片,有1,2,3,5,10几种选择。 用5shot的话 configs/detection/meta_rcnn/voc/split1/meta-rcnn_r101_c4_8xb4_voc-split3_5shot-fine-tuning.py

max_iters 训练次数 load_from 改成第一阶段训练出来的work_dir/metarcnnxxxxx/latest.pth权重文件 也可以直接用官方给出的权重文件: https://mmfewshot.readthedocs.io/en/latest/model_zoo.html link

开始训练

多GPU的话就是 bash tools/detection/dist_train.sh configs/detection/meta_rcnn/voc/split1/meta-rcnn_r101_c4_8xb4_voc-split3_5shot-fine-tuning.py 'gpu数量' 单GPU python tools/detection/dist_train.sh configs/detection/meta_rcnn/voc/split1/meta-rcnn_r101_c4_8xb4_voc-split3_5shot-fine-tuning.py

训练完输出结果,得到权重文件

验证,输出效果图

网上找了一段代码,我看写着mmlab的声明但是在这项目里面没找到。可以直接创建个文件

inference.py

# Copyright (c) OpenMMLab. All rights reserved.

import argparse

import os

import warnings

import mmcv

import torch

from mmcv import Config, DictAction

from mmcv.ops import RoIPool

from mmcv.parallel import MMDataParallel, MMDistributedDataParallel

from mmcv.parallel import collate, scatter

from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,

wrap_fp16_model)

from mmfewshot.detection.datasets import (build_dataloader, build_dataset,

get_copy_dataset_type)

from mmfewshot.detection.models import build_detector, QuerySupportDetector

from mmfewshot.detection.apis import (inference_detector, init_detector,

process_support_images)

def parse_args():

parser = argparse.ArgumentParser(

description='MMFewShot test (and eval) a model')

parser.add_argument('-input', help='directory where source images will be detected')

parser.add_argument('-output', help='directory where painted images will be saved')

parser.add_argument('--config', default='mytools/xyb-rcnn_r50_c4_8xb4_novel-fine-tuning.py')

parser.add_argument('--checkpoint',

default='pth权重',

help='checkpoint file')

parser.add_argument(

'--save-support-heatmap', default=False, action='store_true', help='whether to save the support heat map')

parser.add_argument(

'--save-query-heatmap',default=False, action='store_true', help='whether to save the query heat map')

parser.add_argument(

'--show-score-thr',

type=float,

default=0.3,

help='score threshold (default: 0.3)')

parser.add_argument('--out', help='output result file in pickle format')

parser.add_argument(

'--eval',

type=str,

default=['bbox'],

nargs='+',

help='evaluation metrics, which depends on the dataset, e.g., "bbox",'

' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC')

parser.add_argument('--show', action='store_true', help='show results')

parser.add_argument(

'--show-dir', help='directory where painted images will be saved')

parser.add_argument(

'--gpu-collect',

action='store_true',

help='whether to use gpu to collect results.')

parser.add_argument(

'--tmpdir',

help='tmp directory used for collecting results from multiple '

'workers, available when gpu-collect is not specified')

parser.add_argument(

'--cfg-options',

nargs='+',

action=DictAction,

help='override some settings in the used config, the key-value pair '

'in xxx=yyy format will be merged into config file. If the value to '

'be overwritten is a list, it should be like key="[a,b]" or key=a,b '

'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '

'Note that the quotation marks are necessary and that no white space '

'is allowed.')

parser.add_argument(

'--options',

nargs='+',

action=DictAction,

help='custom options for evaluation, the key-value pair in xxx=yyy '

'format will be kwargs for dataset.evaluate() function (deprecate), '

'change to --eval-options instead.')

parser.add_argument(

'--eval-options',

nargs='+',

action=DictAction,

help='custom options for evaluation, the key-value pair in xxx=yyy '

'format will be kwargs for dataset.evaluate() function')

parser.add_argument(

'--launcher',

choices=['none', 'pytorch', 'slurm', 'mpi'],

default='none',

help='job launcher')

parser.add_argument('--local_rank', type=int, default=0)

args = parser.parse_args()

if 'LOCAL_RANK' not in os.environ:

os.environ['LOCAL_RANK'] = str(args.local_rank)

if args.options and args.eval_options:

raise ValueError(

'--options and --eval-options cannot be both '

'specified, --options is deprecated in favor of --eval-options')

if args.options:

warnings.warn('--options is deprecated in favor of --eval-options')

args.eval_options = args.options

args.cfg_options = args.options

return args

def check_create_dirs(dirs):

if isinstance(dirs, str):

dirs = [dirs]

for dir in dirs:

if not os.path.exists(dir):

os.makedirs(dir)

print(f"\t--create dir*** {dir}")

def write_to_result_txt(file, result, categories, save_dir, score_thr=0.3):

txt_file = os.path.join(save_dir, os.path.basename(file).rsplit('.', 1)[0] + '.txt')

assert len(result) == len(categories)

with open(txt_file, 'w') as f:

for i, category_result in enumerate(result):

category_result = category_result.tolist()

for category_bbox_result in category_result:

if category_bbox_result[-1] >= score_thr:

category_bbox_result = [str(round(x, 3)) for x in category_bbox_result]

f.write(f"{categories[i]}({i}) " + " ".join(category_bbox_result).strip() + "\n")

def main():

args = parse_args()

painted_dir = "输出带框的图片"

heatmap_dir = ""

txt_result_dir = "输出框的txt"

check_create_dirs([painted_dir, heatmap_dir, txt_result_dir])

cfg = Config.fromfile(args.config)

cfg.heatmap_dir = heatmap_dir

cfg.save_support_heatmap = args.save_support_heatmap

cfg.save_query_heatmap = args.save_query_heatmap

if args.cfg_options is not None:

cfg.merge_from_dict(args.cfg_options)

# import modules from string list.

if cfg.get('custom_imports', None):

from mmcv.utils import import_modules_from_strings

import_modules_from_strings(**cfg['custom_imports'])

# set cudnn_benchmark

if cfg.get('cudnn_benchmark', False):

torch.backends.cudnn.benchmark = True

cfg.model.pretrained = None

# currently only support single images testing

samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1)

assert samples_per_gpu == 1, 'currently only support single images testing'

# pop frozen_parameters

cfg.model.pop('frozen_parameters', None)

# build the model and load checkpoint

cfg.model.train_cfg = None

model = build_detector(cfg.model)

fp16_cfg = cfg.get('fp16', None)

if fp16_cfg is not None:

wrap_fp16_model(model)

checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')

# old versions did not save class info in checkpoints, this walkaround is

# for backward compatibility

if 'CLASSES' in checkpoint.get('meta', {}):

model.CLASSES = checkpoint['meta']['CLASSES']

# for meta-learning methods which require support template dataset

# for model initialization.

if cfg.data.get('model_init', None) is not None:

cfg.data.model_init.pop('copy_from_train_dataset')

model_init_samples_per_gpu = cfg.data.model_init.pop(

'samples_per_gpu', 1)

model_init_workers_per_gpu = cfg.data.model_init.pop(

'workers_per_gpu', 1)

if cfg.data.model_init.get('ann_cfg', None) is None:

assert checkpoint['meta'].get('model_init_ann_cfg',

None) is not None

cfg.data.model_init.type = \

get_copy_dataset_type(cfg.data.model_init.type)

cfg.data.model_init.ann_cfg = \

checkpoint['meta']['model_init_ann_cfg']

model_init_dataset = build_dataset(cfg.data.model_init)

# disable dist to make all rank get same data

model_init_dataloader = build_dataloader(

model_init_dataset,

samples_per_gpu=model_init_samples_per_gpu,

workers_per_gpu=model_init_workers_per_gpu,

dist=False,

shuffle=False)

model.cfg = cfg

model = MMDataParallel(model, device_ids=[0])

if cfg.data.get('model_init', None) is not None:

from mmfewshot.detection.apis import (single_gpu_model_init,

single_gpu_test)

single_gpu_model_init(model, model_init_dataloader)

else:

from mmdet.apis.test import single_gpu_test

if hasattr(model, "module"):

model = model.module

files = sorted(os.listdir(args.input))

prog_bar = mmcv.ProgressBar(len(files))

for file in files:

img = os.path.join(args.input, file)

result = inference_detector(model, img)

write_to_result_txt(img, result, model.CLASSES, txt_result_dir, score_thr=args.show_score_thr)

model.show_result(img, result, score_thr=args.show_score_thr, out_file=os.path.join(painted_dir, file))

prog_bar.update(1)

if __name__ == '__main__':

main()

检测效果图

python tools/detection/inference.py -input '待检测的原图文件夹' --config configs/detection/meta_rcnn/voc/split3/meta-rcnn_r101_c4_8xb4_voc-split3_5shot-fine-tuning.py --checkpoint '权重文件' --show-score-thr '阈值' --show-dir '输出的文件夹'

得到新类效果图,over。

精彩文章

评论可见,请评论后查看内容,谢谢!!!
 您阅读本篇文章共花了: