Neural Architecture Search with Reinforcement Learning

Background

arvix原文

神经网络在诸多任务中表现较好,但是设计/调参过程复制。

本文提出一种使用RNN生成模型架构,并且使用强化学习来训练RNN,使其生成的模型在验证集上的准确率最大

论文工作

提出了Neural Architecture Search,一种基于梯度的方法

神经网络的结构structure和连通性connectivity可以用可变长字符串来表示,因此

(1)希望使用循环神经网络RNN(controller)来生成这个网络结构

(2)在数据集上训练生成的子网络child network,获得其准确率

(3)将子网络的在数据集上的准确率作为奖励信号,计算梯度更新控制器RNN

相关工作

本文的Neural Architecture Search和程序合成program synthesis,归纳编程inductive programming有一定相似性

Neural Architecture Search是自回归的,即每预测一个超参数,都以先前的预测为条件

自回归模型通过在每个时间步生成一个元素,依次建立序列。每个元素的生成都依赖于之前生成的元素。在语言模型中,这通常涉及到给定前面的词语来预测下一个词语。(源自ChatGPT)

其他相关工作详见论文

方法 

1. RNN生成网络架构

提高controller的灵活性,使用RNN生成神经网络的超参数

这里预测的神经网络只有卷积层,用控制器生成超参作为a sequence of tokens

每一层包括:filter数量、filter的两个尺寸、stride的两个尺寸

实验里,层数超过阈值就停止生成;随训练进行,作者增加这个阈值

RNN完成了一个模型结构的生成,就构建并训练具有该结构的神经网络,待其收敛时记录准确率

2. 强化训练

这里最好有强化学习的基础,特别是policy gradient,不太了解该内容的建议学习一下李宏毅老师的相关部分。

李宏毅-强化学习

控制器生成的tokens作为一系列actions(a1:T)

子网络在hold out数据集上训练达到精度R,作为奖励信号

最大化期望奖励:(因变量是RNN的参数)

 由于R不可微,所以使用policy gradient策略来更新RNN的参数,这里使用了Williams的REINFORCE rule(1992)

具体策略可参考https://zhuanlan.zhihu.com/p/110881517

经验近似:

m是controller在一个batch里采样网不同结构的数量,T是超参数量,生成的第k个网络的精度为Rk

上述的更新为无偏估计,但是方差高,未来减少方差,使用了基线函数:

 只要b不依赖于当前action,就仍是无偏估计。本文作者使用的是前k-1轮精度的指数移动平均

上述公式的推导:

为什么要加baseline b,参考李宏毅老师的讲解,因为update时是一个sample的过程,如果所有的R都是正数,而有些选项没有被sample到(比如A),那么随着其他选项的概率更新(normalization之后概率之和为1)那么A被选中的概率就会减小,但这是我们不希望的,因为A只是“很不幸”的没有被选中而已

因此我们希望 reward R不要总是正数

所以baseline是为了解决:The probability of actions not sampled will decrease.

使用并行和异步更新加速训练(有钱真好)

每一次训练完子网络才更新controller的参数,使得时间较长,这里作者使用了分布式计算,异步更新

S个parameter server,存储k个控制器副本的共享参数(RNN),每个控制器副本对并行训练的m个子架构将进行采样。控制器将小批量数据的梯度发送给服务器,更新所有控制器副本的权重

3.使用skip connection和不同layer type增加模型复杂度

为了扩大搜索空间,增加skip connection(类比GoogleNet ResNet)

为了能预测这种连接,使用set-selection type attention机制,在每一层都增加一个anchor point锚点,指示前N-1层的内容信息(是否需要connect sigmoid)每个sigmoid都是controller当前状态和前N-1层锚点的隐藏状态的函数。

hj是第j层锚点的状态(0

几种情况的处理:

(1)没有输入的层,把图像作为输入

(2) 在最后一层,把之前所有没有被connect的层的输出连接起来,将隐藏状态给分类器

(3)输入层大小不同,用0作为padding

4.GENERATE RECURRENT CELL ARCHITECTURES

RNN cell:输入xt,ht-1生成ht

控制器RNN需要用组合方法(加,乘等)以及激活函数来标记节点,合并输入,再将两个输出送给下一个节点。

实现上,按顺序索引树的节点,以便逐个访问每个节点需要的超参

参考LSTM,加入ct,ct-1标记记忆状态

例子:叶子节点0 1,内部节点2

 计算步骤

效果

CIFAR10上,test error = 3.65% ,超SOTA0.09%,训练速度1.05倍

Penn Treebank dataset,形成的新的cell性能优于LSTM,test set perplexity = 62.4,超过SOTA3.6

转移到PTB的字符语言建模任务上,达到SOTA 1.214 perplexity

其他结果参考原文,不一一列举

精彩文章

评论可见,请评论后查看内容,谢谢!!!
 您阅读本篇文章共花了: