用Scala的高阶函数让Chisel代码更优雅

动机

前面模块中讨厌的for循环很冗长,破坏了函数式编程的目的,这一篇将会介绍高阶函数,让我们写生成器的时候更快。

FIR的卷积操作居然一行代码就能实现?

前面我们写了FIR过滤器的实现,其中的卷积部分是这么写的:

val muls = Wire(Vec(length, UInt(8.W)))

for(i <- 0 until length) {

if(i == 0) muls(i) := io.in * io.consts(i)

else muls(i) := regs(i - 1) * io.consts(i)

}

val scan = Wire(Vec(length, UInt(8.W)))

for(i <- 0 until length) {

if(i == 0) scan(i) := muls(i)

else scan(i) := muls(i) + scan(i - 1)

}

io.out := scan(length - 1)

简单说一下它的基本思想,首先把io.in的每个元素与对应的io.const相乘,然后存放到muls中,然后muls中的元素累加到scan,其中scan(0) = muls(0),scan(1) = scan(0) + muls(1) = muls(0) + muls(1),一般化就是scan(n) = scan(n-1) + muls(n) = muls(0) + ... + muls(n-1) + muls(n),最后scan的最后一个元素(等于muls所有元素的值)被赋值到io.out。

但是啊,真的很繁琐,罗里吧嗦半天就为了实现一个简单的操作。实际上呢,上面的操作可以写成一行代码:

io.out := (taps zip io.consts).map { case (a, b) => a * b }.reduce(_ + _)

它是怎么办到的呢?下面一点点来分析:

假设taps是所有采样的列表,也就是taps(0) = io.in,taps(1) = regs(0)等;(taps zip io.consts)接受两个列表,taps和io.const,然后组合他们为一个列表,其每个元素都是输入的相应位置的元素的元组,具体来说,它的值就像这样:[(taps(0), io.consts(0)), (taps(1), io.consts(1)), ..., (taps(n), io.consts(n))]。注意,因为.是可以省略的,所以这个其实等价于(taps.zip(io.const));.map {case (a, b) => a*b}应用了一个匿名函数,这个函数接受一个二元素的元组然后返回他们的乘积,这个匿名函数应用在了列表的元素上并返回结果也是列表,结果为[taps(0) * io.consts(0), taps(1) * io.consts(1), ..., taps(n) * io.consts(n)],和之前的muls是等价的。现在先简单了解写匿名函数的语法,后面还会详细介绍;最后,.reduce(_ + _)在列表的元素上应用了函数(元素之间相加)。然而,这里同样接受两个参数,第一个是当前的累加结果,第二个是列表元素(第一次迭代中两个参数都是列表元素)。这是通过圆括号里面的两个下划线给定的。假设从左到右遍历,结果会是(((muls(0) + muls(1)) + muls(2)) + ...) + muls(n),最先计算的有更深的括号嵌套。这个结果就是卷积的输出。

作为参数的函数

正式地来讲,像map和reduce这样的函数就是高阶函数,因为它们是以函数为参数的函数。结果也证明,使用高阶函数能够省不少事,可以封装一个一般的计算模式,允许你在写代码是专注于应用整体的逻辑而不是控制流,能写出非常简洁的代码。

指定函数的不同方法

上面已经提到了两种了,这里总结一下:

对于每个元素只会引用一次的函数,可能可以使用下划线_来指代每个元素,在上面的例子中,reduce的参数函数接受两个元素就可以被表示为_ + _。虽然方便,但这受限于极少数满足一些复杂规则的情况,所以如果不行的话,可以试试下面的方法:显式指定输入参数列表。上面的规约操作可以被显式写成(a, b) => a + b,用的是把参数列表放在括号里的一般形式,然后跟着=>符号,再跟着引用这些参数的函数体;如果需要解包元组的时候,就用case语句,就像case (a, b) => a + b这里的用法一样。这种情况是接受了单个参数,即一个两元素的元组,然后将其解包为变量a和b,然后用于后面的函数体。

Scala中的实践

上上一篇文章中,我们介绍了Scala集合类API中的主要类,比如List这种。这些高阶函数其实也是这些API的一部分,比如上面的map和reduce都是List上的API。这一小节我们通过一些例子来熟悉这些方法。在例子中,简洁起见我们都在Scala的数字(Int)上进行操作,但是因为Chisel的操作符也是类似的,所以这些概念可以泛化。

Map

List[A].map有类型签名map[B](f: (A) ⇒ B): List[B]。后面会有专门的一篇讲解关于类型的知识,现在就把这个类型A和B看成是Int或者SInt,意味着它们可以是软件类型也可以是硬件类型。

说人话就是,它接受一个类型为(f: (A) ⇒ B)的参数,或者是一个接受两个参数的函数,第一个参数类型为A,即和输入列表的元素类型一致,第二个参数类型就随便了,什么类型都可以,然后map会返回一个类型B的列表,即参数函数的返回值类型。

因为我们已经解释过了FIR例子中列表的行为,现在就直接看例子吧:

println(List(1, 2, 3, 4).map(x => x + 1)) // 函数中的显式参数列表

println(List(1, 2, 3, 4).map(_ + 1)) // 和上面的等价,但是隐式的

println(List(1, 2, 3, 4).map(_.toString + "a")) // 输出元素类型可和输入不同

println(List((1, 5), (2, 6), (3, 7), (4, 8)).map { case (x, y) => x*y }) // 用case解包元组,注意这里用的是大括号

// 提一嘴,Scala中有构造连续数字列表的语法

println(0 to 10) // to是inclusive的, 这里的10是包括在内的

println(0 until 10) // until是exclusive的,这里的10不包括在内

// 上面生成的和列表的行为基本一致,生成索引的时候很有用

val myList = List("a", "b", "c", "d")

println((0 until 4).map(myList(_)))

输出如下:

List(2, 3, 4, 5)

List(2, 3, 4, 5)

List(1a, 2a, 3a, 4a)

List(5, 12, 21, 32)

Range 0 to 10

Range 0 until 10

Vector(a, b, c, d)

再来个简单的练习,想要让列表中的每个元素的翻倍,???处填什么代码呢?

println(List(1, 2, 3, 4).map(???))

显然,填_ * 2就行了。

zipWithIndex

zipWithIndex的类型签名是zipWithIndex: List[(A, Int)]。

也就是说不接受任何参数,但是返回一个列表,其每个元素都是源数据和它的索引(第一个元素索引为0)。所以说,List("a", "b", "c", "d").zipWithIndex会返回List(("a", 0), ("b", 1), ("c", 2), ("d", 3))。

这在某些操作中,需要元素索引的场合特别有用。

这个也很简单,直接上例子:

println(List(1, 2, 3, 4).zipWithIndex) // 注意索引从0开始

println(List("a", "b", "c", "d").zipWithIndex)

println(List(("a", "b"), ("c", "d"), ("e", "f"), ("g", "h")).zipWithIndex) // 嵌套元组

输出如下:

List((1,0), (2,1), (3,2), (4,3))

List((a,0), (b,1), (c,2), (d,3))

List(((a,b),0), ((c,d),1), ((e,f),2), ((g,h),3))

Reduce

List[A].reduce的类型签名和List[A].map差不多,为reduce (op: (A, A) ⇒ A),这里就很宽松了,A只需要是List类型的超类就行了,但是这里不讨论这些语法。

直接上例子:

println(List(1, 2, 3, 4).reduce((a, b) => a + b)) // 返回所有元素的和

println(List(1, 2, 3, 4).reduce(_ * _)) // 返回所有元素的积

println(List(1, 2, 3, 4).map(_ + 1).reduce(_ + _)) // 可以把reduce放在map后

输出为:

10

24

14

需要注意的是,在空列表上使用reduce是不行的:

println(List[Int]().reduce(_ * _))

会报错:

java.lang.UnsupportedOperationException: empty.reduceLeft

现在稍微练习以下,在???处填入代码,使得列表的元素先翻倍再累乘:

println(List(1, 2, 3, 4).map(???).reduce(???))

很简单,这么写就行:

println(List(1, 2, 3, 4).map(_ * 2).reduce(_ * _))

Fold

List[A].fold和reduce类型,除了以不能指定规约运算的初始值。类型签名和reduce是类似的:fold(z: A)(op: (A, A) ⇒ A): A。

注意,它有两个参数,第一个参数z是初始值,第二个参数是规约的函数。和reduce不同,对于空列表它不会失效,而是会直接返回初始值。

例子来了:

println(List(1, 2, 3, 4).fold(0)(_ + _)) // 等价于用reduce的累加

println(List(1, 2, 3, 4).fold(1)(_ + _)) // 和上面的差不多,但是从1开始累加

println(List().fold(1)(_ + _)) // 和reduce不一样,fold可以在空列表上执行

输出为:

10

11

1

小小练习一下,现在要用fold返回一个列表的累乘值的两倍,???处怎么写:

println(List(1, 2, 3, 4).fold(???)(???))

这还用想?

println(List(1, 2, 3, 4).fold(2)(_ * _))

不过需要注意的是,除非需要容忍空列表,不然还是用reduce更好。

Chisel中的实践——Decoupled Arbiter

现在结合上面所学,实现一个Decoupled的仲裁器,要求有n个Decoupled输入和一个Decoupled输出。仲裁器选择有效通道中索引最低的转发到输出。

几点提示:

如果有任何输入有效的话,io.out.valid就为真;可以考虑在模块内部给被选择的通道整个Wire;如果输出ready为真的话,且某个通道被选择,则对应的输入的ready为真,(注意这里把ready和valid耦合到一起去了,但是这里先忽略);可能会用到map,尤其是用来返回子元素的Vec时,比如io.in.map(_.valid)就会返回输入Bundle的有效信号的列表;可能用到PriorityMux(List[Bool, Bits]),接受一个列表的有效信号和数据,返回第一个有效的元素;可能用到Vec的动态索引,通过一个UInt数来索引,比如io.in(0.U)。

一样的,在???处填上自己的代码:

import chisel3._

import chisel3.util._

import chisel3.tester._

import chisel3.tester.RawTester.test

object MyModule extends App {

class MyRoutingArbiter(numChannels: Int) extends Module {

val io = IO(new Bundle {

val in = Vec(numChannels, Flipped(Decoupled(UInt(8.W))))

val out = Decoupled(UInt(8.W))

} )

// 在这里填上自己的代码

???

}

test(new MyRoutingArbiter(4)) { c =>

// 设置初始值

for(i <- 0 until 4) {

c.io.in(i).valid.poke(false.B)

c.io.in(i).bits.poke(i.U)

c.io.out.ready.poke(true.B)

}

c.io.out.valid.expect(false.B)

// 测试有背压的单输入有效的行为

for (i <- 0 until 4) {

c.io.in(i).valid.poke(true.B)

c.io.out.valid.expect(true.B)

c.io.out.bits.expect(i.U)

c.io.out.ready.poke(false.B)

c.io.in(i).ready.expect(false.B)

c.io.out.ready.poke(true.B)

c.io.in(i).valid.poke(false.B)

}

// 测试有背压的多输入有效的行为

c.io.in(1).valid.poke(true.B)

c.io.in(2).valid.poke(true.B)

c.io.out.bits.expect(1.U)

c.io.in(1).ready.expect(true.B)

c.io.in(0).ready.expect(false.B)

c.io.out.ready.poke(false.B)

c.io.in(1).ready.expect(false.B)

}

println("SUCCESS!!") // Scala Code: if we get here, our tests passed!

}

问号处代码应为:

io.out.valid := io.in.map(_.valid).reduce(_ || _)

val channel = PriorityMux(

io.in.map(_.valid).zipWithIndex.map { case (valid, index) => (valid, index.U) }

)

io.out.bits := io.in(channel).bits

io.in.map(_.ready).zipWithIndex.foreach { case (ready, index) =>

ready := io.out.ready && channel === index.U

}

测试通过,下面解释一下:

用map取出io.in的valid信号的列表,再规约进行或运算,就知道其中是否至少有一个有效;构造一个优先级多路选择器,输入是valid信号和索引才行,所以对valid列表进行了一个zipWithIndex,然后继续应用map,将索引转为硬件格式,多路选择器的输出就是有效且最低索引的索引;输出的bits就是channel索引对应的输入bits;要对输入的每个ready信号进行操作,需要用到一个foreach,它也是List上的函数,虽然前面没提到但是很好懂,就是对于每个元素进行操作,无需返回值,这里用同样的方法提取除了ready和索引,然后根据输出的ready信号和被通道的索引来判断是否应该将这个输入的ready信号置为有效。

有一点需要提一下,为什么不直接用PriorityMux输出io.in.bits呢?因为需要设置io.in.ready位,所以必须要知道被选择的输入的索引,所以PriorityMux用来找索引了。

文章链接

评论可见,请评论后查看内容,谢谢!!!
 您阅读本篇文章共花了: