目录

动机 

  MMEngine.runner 设置config参数举例

MMEngine.runner源码

IterBasedTrainLoop说明

输入

输出

 IterBasedTrainLoop源码

EpochBasedTrainLoop说明

输入

输出

EpochBasedTrainLoop源码

总结

基于迭代次数训练

❤️config

❤️参数说明

基于轮数训练

❤️config

❤️参数说明

✌️✌️启发

整理不易,欢迎一键三连!!!

送你们一条美丽的--分割线--

动机 

        基于MMEngine做模型训练,设置各种hook时,总是看不到源码,只能按照既定模式进行网络训练,要修改就得自己试参数,索性咱们就一次深挖到底,看看最底层的代码是如何写的,就不用每次猜参数了。

        MMEngine 支持两种训练模式:

基于轮次的 EpochBased 方式基于迭代次数的 IterBased 方式

        这两种方式在下游算法库均有使用,例如MMDetection 默认使用 EpochBased 方式,MMSegmentation默认使用 IterBased 方式。如何修改二者的模式,看这一篇就够了。

  MMEngine.runner 设置config参数举例

from mmengine.runner import Runner

cfg = dict(

model=dict(type='ToyModel'),

work_dir='path/of/work_dir',

train_dataloader=dict(

dataset=dict(type='ToyDataset'),

sampler=dict(type='DefaultSampler', shuffle=True),

batch_size=1,

num_workers=0),

val_dataloader=dict(

dataset=dict(type='ToyDataset'),

sampler=dict(type='DefaultSampler', shuffle=False),

batch_size=1,

num_workers=0),

test_dataloader=dict(

dataset=dict(type='ToyDataset'),

sampler=dict(type='DefaultSampler', shuffle=False),

batch_size=1,

num_workers=0),

auto_scale_lr=dict(base_batch_size=16, enable=False),

optim_wrapper=dict(type='OptimizerWrapper', optimizer=dict(

type='SGD', lr=0.01)),

param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]),

val_evaluator=dict(type='ToyEvaluator'),

test_evaluator=dict(type='ToyEvaluator'),

train_cfg=dict(by_epoch=True, max_epochs=3, val_interval=1),

val_cfg=dict(),

test_cfg=dict(),

custom_hooks=[],

default_hooks=dict(

timer=dict(type='IterTimerHook'),

checkpoint=dict(type='CheckpointHook', interval=1),

logger=dict(type='LoggerHook'),

optimizer=dict(type='OptimizerHook', grad_clip=False),

param_scheduler=dict(type='ParamSchedulerHook')),

launcher='none',

env_cfg=dict(dist_cfg=dict(backend='nccl')),

log_processor=dict(window_size=20),

visualizer=dict(type='Visualizer',

vis_backends=[dict(type='LocalVisBackend',

save_dir='temp_dir')])

)

runner = Runner.from_cfg(cfg)

runner.train()

runner.test()

        今天咱们主要研究train_cfg参数设置。官方给出的train_cfg参数定义为:

train_cfg (dict, optional): A dict to build a training loop. If it does not

provide "type" key, it should contain "by_epoch" to decide which type of training

loop :class:`EpochBasedTrainLoop` or :class:`IterBasedTrainLoop` should be used.

If ``train_cfg`` specified, :attr:`train_dataloader` should also be specified.

Defaults to None. See :meth:`build_train_loop` for more details.

        可以看到,train_cfg 包含两种类:

`EpochBasedTrainLoop` `IterBasedTrainLoop`

MMEngine.runner源码

class Runner:

cfg: Config

_train_loop: Optional[Union[BaseLoop, Dict]]

_val_loop: Optional[Union[BaseLoop, Dict]]

_test_loop: Optional[Union[BaseLoop, Dict]]

def __init__(

self,

model: Union[nn.Module, Dict],

work_dir: str,

train_dataloader: Optional[Union[DataLoader, Dict]] = None,

val_dataloader: Optional[Union[DataLoader, Dict]] = None,

test_dataloader: Optional[Union[DataLoader, Dict]] = None,

train_cfg: Optional[Dict] = None,

val_cfg: Optional[Dict] = None,

test_cfg: Optional[Dict] = None,

auto_scale_lr: Optional[Dict] = None,

optim_wrapper: Optional[Union[OptimWrapper, Dict]] = None,

param_scheduler: Optional[Union[_ParamScheduler, Dict, List]] = None,

val_evaluator: Optional[Union[Evaluator, Dict, List]] = None,

test_evaluator: Optional[Union[Evaluator, Dict, List]] = None,

default_hooks: Optional[Dict[str, Union[Hook, Dict]]] = None,

custom_hooks: Optional[List[Union[Hook, Dict]]] = None,

data_preprocessor: Union[nn.Module, Dict, None] = None,

load_from: Optional[str] = None,

resume: bool = False,

launcher: str = 'none',

env_cfg: Dict = dict(dist_cfg=dict(backend='nccl')),

log_processor: Optional[Dict] = None,

log_level: str = 'INFO',

visualizer: Optional[Union[Visualizer, Dict]] = None,

default_scope: str = 'mmengine',

randomness: Dict = dict(seed=None),

experiment_name: Optional[str] = None,

cfg: Optional[ConfigType] = None,

):

self._work_dir = osp.ab

......

def from_cfg(cls, cfg: ConfigType) -> 'Runner':

"""Build a runner from config.

Args:

cfg (ConfigType): A config used for building runner. Keys of

``cfg`` can see :meth:`__init__`.

Returns:

Runner: A runner build from ``cfg``.

"""

cfg = copy.deepcopy(cfg)

runner = cls(

model=cfg['model'],

work_dir=cfg['work_dir'],

train_dataloader=cfg.get('train_dataloader'),

val_dataloader=cfg.get('val_dataloader'),

test_dataloader=cfg.get('test_dataloader'),

train_cfg=cfg.get('train_cfg'),

val_cfg=cfg.get('val_cfg'),

test_cfg=cfg.get('test_cfg'),

auto_scale_lr=cfg.get('auto_scale_lr'),

optim_wrapper=cfg.get('optim_wrapper'),

param_scheduler=cfg.get('param_scheduler'),

val_evaluator=cfg.get('val_evaluator'),

test_evaluator=cfg.get('test_evaluator'),

default_hooks=cfg.get('default_hooks'),

custom_hooks=cfg.get('custom_hooks'),

data_preprocessor=cfg.get('data_preprocessor'),

load_from=cfg.get('load_from'),

resume=cfg.get('resume', False),

launcher=cfg.get('launcher', 'none'),

env_cfg=cfg.get('env_cfg', dict(dist_cfg=dict(backend='nccl'))),

log_processor=cfg.get('log_processor'),

log_level=cfg.get('log_level', 'INFO'),

visualizer=cfg.get('visualizer'),

default_scope=cfg.get('default_scope', 'mmengine'),

randomness=cfg.get('randomness', dict(seed=None)),

experiment_name=cfg.get('experiment_name'),

cfg=cfg,

)

return runner

......

......

@property

def train_loop(self):

""":obj:`BaseLoop`: A loop to run training."""

if isinstance(self._train_loop, BaseLoop) or self._train_loop is None:

return self._train_loop

else:

self._train_loop = self.build_train_loop(self._train_loop)

return self._train_loop

        可以看到train_loop函数的的主要参数是根据BaseLoop进行设置的,那么我们就找BaseLoop就行了。其中BaseLoop包含IterBasedTrainLoop和EpochBasedTrainLoop两种格式,也就是我们在config中传入的type参数。

​​​​​​​IterBasedTrainLoop说明

输入

runner (Runner) – A reference of runner. dataloader (Dataloader or dict) – A dataloader object or a dict to build a dataloader. max_iters (int) – Total training iterations. val_begin (int) – The iteration that begins validating. Defaults to 1. val_interval (int) – Validation interval. Defaults to 1000. dynamic_intervals (List[Tuple[int, int]], optional) – The first element in the tuple is a milestone and the second element is a interval. The interval is used after the corresponding milestone. Defaults to None.

输出

None

 IterBasedTrainLoop源码

@LOOPS.register_module()

class IterBasedTrainLoop(BaseLoop):

"""Loop for iter-based training.

Args:

runner (Runner): A reference of runner.

dataloader (Dataloader or dict): A dataloader object or a dict to

build a dataloader.

max_iters (int): Total training iterations.

val_begin (int): The iteration that begins validating.

Defaults to 1.

val_interval (int): Validation interval. Defaults to 1000.

dynamic_intervals (List[Tuple[int, int]], optional): The

first element in the tuple is a milestone and the second

element is a interval. The interval is used after the

corresponding milestone. Defaults to None.

"""

def __init__(

self,

runner,

dataloader: Union[DataLoader, Dict],

max_iters: int,

val_begin: int = 1,

val_interval: int = 1000,

dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None:

super().__init__(runner, dataloader)

self._max_iters = int(max_iters)

assert self._max_iters == max_iters, \

f'`max_iters` should be a integer number, but get {max_iters}'

self._max_epochs = 1 # for compatibility with EpochBasedTrainLoop

self._epoch = 0

self._iter = 0

self.val_begin = val_begin

self.val_interval = val_interval

# This attribute will be updated by `EarlyStoppingHook`

# when it is enabled.

self.stop_training = False

if hasattr(self.dataloader.dataset, 'metainfo'):

self.runner.visualizer.dataset_meta = \

self.dataloader.dataset.metainfo

else:

print_log(

f'Dataset {self.dataloader.dataset.__class__.__name__} has no '

'metainfo. ``dataset_meta`` in visualizer will be '

'None.',

logger='current',

level=logging.WARNING)

# get the iterator of the dataloader

self.dataloader_iterator = _InfiniteDataloaderIterator(self.dataloader)

self.dynamic_milestones, self.dynamic_intervals = \

calc_dynamic_intervals(

self.val_interval, dynamic_intervals)

@property

def max_epochs(self):

"""int: Total epochs to train model."""

return self._max_epochs

@property

def max_iters(self):

"""int: Total iterations to train model."""

return self._max_iters

@property

def epoch(self):

"""int: Current epoch."""

return self._epoch

@property

def iter(self):

"""int: Current iteration."""

return self._iter

def run(self) -> None:

"""Launch training."""

self.runner.call_hook('before_train')

# In iteration-based training loop, we treat the whole training process

# as a big epoch and execute the corresponding hook.

self.runner.call_hook('before_train_epoch')

while self._iter < self._max_iters and not self.stop_training:

self.runner.model.train()

data_batch = next(self.dataloader_iterator)

self.run_iter(data_batch)

self._decide_current_val_interval()

if (self.runner.val_loop is not None

and self._iter >= self.val_begin

and self._iter % self.val_interval == 0):

self.runner.val_loop.run()

self.runner.call_hook('after_train_epoch')

self.runner.call_hook('after_train')

return self.runner.model

def run_iter(self, data_batch: Sequence[dict]) -> None:

"""Iterate one mini-batch.

Args:

data_batch (Sequence[dict]): Batch of data from dataloader.

"""

self.runner.call_hook(

'before_train_iter', batch_idx=self._iter, data_batch=data_batch)

# Enable gradient accumulation mode and avoid unnecessary gradient

# synchronization during gradient accumulation process.

# outputs should be a dict of loss.

outputs = self.runner.model.train_step(

data_batch, optim_wrapper=self.runner.optim_wrapper)

self.runner.call_hook(

'after_train_iter',

batch_idx=self._iter,

data_batch=data_batch,

outputs=outputs)

self._iter += 1

def _decide_current_val_interval(self) -> None:

"""Dynamically modify the ``val_interval``."""

step = bisect.bisect(self.dynamic_milestones, (self._iter + 1))

self.val_interval = self.dynamic_intervals[step - 1]

​​​​​​​EpochBasedTrainLoop说明

输入

runner (Runner) – A reference of runner. dataloader (Dataloader or dict) – A dataloader object or a dict to build a dataloader. max_epochs (int) – Total training epochs. val_begin (int) – The epoch that begins validating. Defaults to 1. val_interval (int) – Validation interval. Defaults to 1. dynamic_intervals (List[Tuple[int, int]], optional) – The first element in the tuple is a milestone and the second element is a interval. The interval is used after the corresponding milestone. Defaults to None.

输出

None

EpochBasedTrainLoop源码

class EpochBasedTrainLoop(BaseLoop):

"""Loop for epoch-based training.

Args:

runner (Runner): A reference of runner.

dataloader (Dataloader or dict): A dataloader object or a dict to

build a dataloader.

max_epochs (int): Total training epochs.

val_begin (int): The epoch that begins validating.

Defaults to 1.

val_interval (int): Validation interval. Defaults to 1.

dynamic_intervals (List[Tuple[int, int]], optional): The

first element in the tuple is a milestone and the second

element is a interval. The interval is used after the

corresponding milestone. Defaults to None.

"""

def __init__(

self,

runner,

dataloader: Union[DataLoader, Dict],

max_epochs: int,

val_begin: int = 1,

val_interval: int = 1,

dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None:

super().__init__(runner, dataloader)

self._max_epochs = int(max_epochs)

assert self._max_epochs == max_epochs, \

f'`max_epochs` should be a integer number, but get {max_epochs}.'

self._max_iters = self._max_epochs * len(self.dataloader)

self._epoch = 0

self._iter = 0

self.val_begin = val_begin

self.val_interval = val_interval

# This attribute will be updated by `EarlyStoppingHook`

# when it is enabled.

self.stop_training = False

if hasattr(self.dataloader.dataset, 'metainfo'):

self.runner.visualizer.dataset_meta = \

self.dataloader.dataset.metainfo

else:

print_log(

f'Dataset {self.dataloader.dataset.__class__.__name__} has no '

'metainfo. ``dataset_meta`` in visualizer will be '

'None.',

logger='current',

level=logging.WARNING)

self.dynamic_milestones, self.dynamic_intervals = \

calc_dynamic_intervals(

self.val_interval, dynamic_intervals)

@property

def max_epochs(self):

"""int: Total epochs to train model."""

return self._max_epochs

@property

def max_iters(self):

"""int: Total iterations to train model."""

return self._max_iters

@property

def epoch(self):

"""int: Current epoch."""

return self._epoch

@property

def iter(self):

"""int: Current iteration."""

return self._iter

def run(self) -> torch.nn.Module:

"""Launch training."""

self.runner.call_hook('before_train')

while self._epoch < self._max_epochs and not self.stop_training:

self.run_epoch()

self._decide_current_val_interval()

if (self.runner.val_loop is not None

and self._epoch >= self.val_begin

and self._epoch % self.val_interval == 0):

self.runner.val_loop.run()

self.runner.call_hook('after_train')

return self.runner.model

def run_epoch(self) -> None:

"""Iterate one epoch."""

self.runner.call_hook('before_train_epoch')

self.runner.model.train()

for idx, data_batch in enumerate(self.dataloader):

self.run_iter(idx, data_batch)

self.runner.call_hook('after_train_epoch')

self._epoch += 1

def run_iter(self, idx, data_batch: Sequence[dict]) -> None:

"""Iterate one min-batch.

Args:

data_batch (Sequence[dict]): Batch of data from dataloader.

"""

self.runner.call_hook(

'before_train_iter', batch_idx=idx, data_batch=data_batch)

# Enable gradient accumulation mode and avoid unnecessary gradient

# synchronization during gradient accumulation process.

# outputs should be a dict of loss.

outputs = self.runner.model.train_step(

data_batch, optim_wrapper=self.runner.optim_wrapper)

self.runner.call_hook(

'after_train_iter',

batch_idx=idx,

data_batch=data_batch,

outputs=outputs)

self._iter += 1

def _decide_current_val_interval(self) -> None:

"""Dynamically modify the ``val_interval``."""

step = bisect.bisect(self.dynamic_milestones, (self.epoch + 1))

self.val_interval = self.dynamic_intervals[step - 1]

总结

        config文件的train_cfg只有两种训练模式,一种是基于迭代次数,另一种是基于轮数,其中设置参数为一下两种方式。

基于迭代次数训练

❤️config

train_cfg = dict(type='IterBasedTrainLoop', max_iters=80000, val_interval=4000)

❤️参数说明

type:训练类型。max_iters:最大训练迭代次数,即达到80000次迭代结束训练。val_interval:验证迭代次数,即每4000次迭代计算一次验证。

基于轮数训练

❤️config

train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=200, val_interval=1)

❤️参数说明

type:训练类型。max_iters:最大训练轮数,即达到200轮结束训练。val_interval:验证轮数,即每1轮计算一次验证。

✌️✌️启发

        虽然得到的结论很简单,只有两种不同训练方式的参数设置说明,但是中间的巧妙训练设计源码,没事看看也是一种“巧夺天工的美文”

完整mmengine源码:链接

整理不易,欢迎一键三连!!!

送你们一条美丽的--分割线--

⛵⛵⭐⭐

推荐链接

评论可见,请评论后查看内容,谢谢!!!
 您阅读本篇文章共花了: