命令行接口CLI

CLI可以很容易地配置训练(主要是model、data、trainer)时的各个参数,将代码与配置分离,避免直接改动代码。

安装依赖

pip install "pytorch-lightning[extra]"

创建LightningCLI

实例化一个 LightningCLI 对象,类似Trainer对象一样使用,只是不在py文件中直接运行,而是等待命令和参数后运行。

# main.py文件内容

from lightning.pytorch.cli import LightningCLI

# DemoModel, BoringDataModule是任意可用的模型和数据对象

from lightning.pytorch.demos.boring_classes import DemoModel, BoringDataModule

def cli_main():

# 实例LightningCLI,传入model和data,但不fit()

cli = LightningCLI(DemoModel, BoringDataModule)

# note: don't call fit!!

if __name__ == "__main__":

cli_main()

# note: it is good practice to implement the CLI in a function and call it in the main if block

现在可以在终端输入help查看说明。

python main.py --help

usage: main.py [-h] [-c CONFIG] [--print_config [={comments,skip_null,skip_default}+]]

{fit,validate,test,predict,tune} ...

pytorch-lightning trainer command line tool

optional arguments:

-h, --help Show this help message and exit.展示帮助信息

-c CONFIG, --config CONFIG

yaml /json 文件路径

--print_config [={comments,skip_null,skip_default}+]

Print configuration and exit.

subcommands:

For more details of each subcommand add it as argument followed by --help. 子命令即类似调用trainer执行不同的流程代码。

{fit,validate,test,predict,tune}

fit Runs the full optimization routine.

validate Perform one evaluation epoch over the validation set.

test Perform one evaluation epoch over the test set.

predict Run inference on your data.

tune Runs routines to tune hyperparameters before training.

$ python main.py fit

$ python main.py validate

$ python main.py test

$ python main.py predict

$ python main.py tune

子命令中也可以查看说明:

$ python main.py fit --help

usage: main.py [options] fit [-h] [-c CONFIG] # 需要读取的配置文件

[--seed_everything SEED_EVERYTHING] [--trainer CONFIG] # trainer 的配置文件

...

[--ckpt_path CKPT_PATH]# checkpoint 路径

--trainer.logger LOGGER # 日志工具

# 其它参数,以下参数是自定义模型和数据相关的参数,也可以在命令中设置,设置方法是指出其引用路径和数值。

optional arguments:

:

--model.out_dim OUT_DIM

(type: int, default: 10)

--model.learning_rate LEARNING_RATE

(type: float, default: 0.02)

:

--data CONFIG Path to a configuration file.

--data.data_dir DATA_DIR

(type: str, default: ./)

# change the learning_rate 设置学习率

python main.py fit --model.learning_rate 0.1

# change the output dimensions also 设置输出维度和学习率

python main.py fit --model.out_dim 10 --model.learning_rate 0.1

# change trainer and data arguments too 设置data和trainer

python main.py fit --model.out_dim 2 --model.learning_rate 0.1 --data.data_dir '~/' --trainer.logger False

从项目多个模型和数据中选择

当项目变得庞大之后,其中可能包含多个模型model和数据data,也可以通过CLI指定,无需改动过多代码。对于原始的方法先用args工具设置参数,然后增加判断和读取的方法。

# 设置model和data

# Mix and match anything

$ python main.py fit --model=GAN --data=MNIST

$ python main.py fit --model=Transformer --data=MNIST

在代码中判断。

# choose model

if args.model == "gan":

model = GAN(args.feat_dim)

elif args.model == "transformer":

model = Transformer(args.feat_dim)

...

# choose datamodule

if args.data == "MNIST":

datamodule = MNIST()

elif args.data == "imagenet":

datamodule = Imagenet()

...

# mix them!

trainer.fit(model, datamodule)

而lightningCLI则将以上过程封装好,开发者可以用相似的方式使用。

比如多个模型的选择(下面的示例model和data中没有参数设置,后续会说明如何设置):

# main.py

from lightning.pytorch.cli import LightningCLI

from lightning.pytorch.demos.boring_classes import DemoModel, BoringDataModule

class Model1(DemoModel):

def configure_optimizers(self):

print("⚡", "using Model1", "⚡")

return super().configure_optimizers()

class Model2(DemoModel):

def configure_optimizers(self):

print("⚡", "using Model2", "⚡")

return super().configure_optimizers()

# datamodule_class指定一个LightningDataModule的子类,其内部必须返回一个实例。

cli = LightningCLI(datamodule_class=BoringDataModule)

# use Model1

python main.py fit --model Model1

# use Model2

python main.py fit --model Model2

data, optimizers,schedulers的选择与model一致。

使用YAML文件配置参数

当参数增多时,上述方法可能会变得不方便,这里可以将所有配置旋转在YAML文件中,并且可以根据功能分散在不同文件中,比如model的参数放在model.yaml文件中,data的参数放在data.yaml文件中,trainer的放在trainer.yaml中,callback的放置在callback.yaml中,使用时按需组合加载。

例如将所有配置放在config.yaml中:

python main.py fit --config config.yaml

# 将trainer.max_epochs的设置更新

python main.py fit --config config.yaml --trainer.max_epochs 100

训练时,lightning会自动保存实验参数,这样可以加载config.yaml文件重新实验。使用print_config选项可以打印默认参数,只需修改部分参数保存至文件中,避免对着model,data,trainer等实例一个设置参数。

python main.py fit --print_config

seed_everything: null

trainer:

logger: true

...

model:

out_dim: 10

learning_rate: 0.02

data:

data_dir: ./

ckpt_path: null

此外,YAML文件中的参数除了是数据基本类型外,也可以多层引用。例如以下模型的criterion属性在编程时并未指定。可以训练时在YAML中设置,使得实践更加灵活。

# model.py

class MyModel(pl.LightningModule):

def __init__(self, criterion: torch.nn.Module):

self.criterion = criterion

配置是用class_path声明引用路径,init_args给出实例化需要的参数值。可以通过缩进嵌套使用。

model:

class_path: model.MyModel # model 的引用路径

init_args: # 以下是model的参数设置

criterion: # criterion 是model 中的一个属性

class_path: torch.nn.CrossEntropyLoss # criterion 指定为交叉熵损失,它的引用路径为 torch.nn.CrossEntropyLoss,

init_args:

reduction: mean

...

子命令配置

fit/validation/test/predict 这些子命令也可以相同方式配置。

在YAML中使用变量

有些参数是重复的,比如train/val阶段的batch size,可以在命令行中给出,并且不用单独创建不同的yaml文件。

首先安装omegaconf包:

pip install omegaconf

在yaml文件中将decoder_layers的值声明为model.encoder_layers的值,${}是omegaconf工具包的语法。

model:

encoder_layers: 12

decoder_layers:

- ${model.encoder_layers}

- 4

创建CLI时设置omegaconf,

cli = LightningCLI(MyModel, parser_kwargs={"parser_mode": "omegaconf"})

最后,命令行中设置值,encoder_layers,decoder_layers中将自动插值为12。

python main.py --model.encoder_layers=12

其它

callback, data, trainer, 模型子模块,optimizer, scheduler等都可以用相同方法设置。

model:

class_path: mycode.mymodels.MyModel

init_args:

decoder_layers:

- 2

- 4

encoder_layers: 12

data:

class_path: mycode.mydatamodules.MyDataModule

init_args:

...

trainer:

callbacks:

- class_path: lightning.pytorch.callbacks.EarlyStopping

init_args:

patience: 5

...

多个模型和数据

# model必须为MyModelBaseClass的子类,data必须为MyDataModuleBaseClass的子类。

cli = LightningCLI(MyModelBaseClass, MyDataModuleBaseClass, subclass_mode_model=True, subclass_mode_data=True)

直接运行

CLI都是在命令行中给定参数,但在debug的时候这种方法可以不方便,需要代码中设置参数后debug,CLI也提供了这种方式。

首先在cli实例化中增加一个args参数,外层方法增加默认为None的args形参。

from lightning.pytorch.cli import ArgsType, LightningCLI

def cli_main(args: ArgsType = None):

cli = LightningCLI(MyModel, ..., args=args)

...

if __name__ == "__main__":

cli_main()

创建一个新的my_cli.py文件,然后写入参数。执行my_cli.py。

args = {

"trainer": {

"max_epochs": 100,

},

"model": {},

}

args["model"]["encoder_layers"] = 8

cli_main(args)

args["model"]["encoder_layers"] = 12

cli_main(args)

args["trainer"]["max_epochs"] = 200

cli_main(args)

推荐链接

评论可见,请评论后查看内容,谢谢!!!
 您阅读本篇文章共花了: