背景

本文基于 SPARK 3.3.0 从一个unit test来探究SPARK Codegen的逻辑,

test("SortAggregate should be included in WholeStageCodegen") {

val df = spark.range(10).agg(max(col("id")), avg(col("id")))

withSQLConf("spark.sql.test.forceApplySortAggregate" -> "true") {

val plan = df.queryExecution.executedPlan

assert(plan.exists(p =>

p.isInstanceOf[WholeStageCodegenExec] &&

p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortAggregateExec]))

assert(df.collect() === Array(Row(9, 4.5)))

}

}

该sql形成的执行计划第二部分的全代码生成部分如下:

WholeStageCodegen

*(2) SortAggregate(key=[], functions=[max(id#0L), avg(id#0L)], output=[max(id)#5L, avg(id)#6])

InputAdapter

+- Exchange SinglePartition, ENSURE_REQUIREMENTS, [id=#13]

分析

第二阶段wholeStageCodegen

第二阶段的代码生成涉及到SortAggregateExec和ShuffleExchangeExec以及InputAdapter的produce和consume方法,这里一一来分析: 第二阶段wholeStageCodegen数据流如下:

WholeStageCodegenExec SortAggregateExec(Final) InputAdapter ShuffleExchangeExec

====================================================================================

-> execute()

|

doExecute() ---------> inputRDDs() -----------------> inputRDDs() -------> execute()

| |

doCodeGen() doExecute()

| |

+-----------------> produce() ShuffledRowRDD

|

doProduce()

|

doProduceWithoutKeys() -------> produce()

|

doProduce()

|

doConsume() <------------------- consume()

|

doConsumeWithoutKeys()

|并不是doConsumeWithoutKeys调用consume,而是由doProduceWithoutKeys调用

doConsume() <-------- consume()

SortAggregateExec(Final) 的doProduce

这里只列出和SortAggregateExec(Partial)的不同的部分:

val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete)) {

// evaluate aggregate results

ctx.currentVars = flatBufVars

val aggResults = bindReferences(

functions.map(_.evaluateExpression),

aggregateBufferAttributes).map(_.genCode(ctx))

val evaluateAggResults = evaluateVariables(aggResults)

// evaluate result expressions

ctx.currentVars = aggResults

val resultVars = bindReferences(resultExpressions, aggregateAttributes).map(_.genCode(ctx))

(resultVars,

s"""

|$evaluateAggResults

|${evaluateVariables(resultVars)}

""".stripMargin)

因为我们这里是Final部分,所以我们的数据流和Partial是不同的ctx.currentVars = flatBufVars 赋值currentVars为当前buffer变量,便于下面进行数据绑定,该buffer变量是全局变量val aggResults = bindReferences

functions.map(_.evaluateExpression) 这是对最终输出结果的计算,对于SUM来说是Divide(sum.cast(resultType), count.cast(resultType), failOnError = false) ,生成的代码如下: boolean sortAgg_isNull_6 = sortAgg_bufIsNull_2;

double sortAgg_value_6 = -1.0;

if (!sortAgg_bufIsNull_2) {

sortAgg_value_6 = (double) sortAgg_bufValue_2;

}

boolean sortAgg_isNull_4 = false;

double sortAgg_value_4 = -1.0;

if (sortAgg_isNull_6 || sortAgg_value_6 == 0) {

sortAgg_isNull_4 = true;

} else {

if (sortAgg_bufIsNull_1) {

sortAgg_isNull_4 = true;

} else {

sortAgg_value_4 = (double)(sortAgg_bufValue_1 / sortAgg_value_6);

}

}

aggregateBufferAttributes 聚合函数的buffer属性值 sum :: count :: Nil 这样在绑定数据的变量数据的时候和currentVars是一一对应的 val evaluateAggResults = evaluateVariables(aggResults) 对聚合的结果进行最终的计算ctx.currentVars = aggResults 把最终结果的变量赋值给currentVars,便于后面的数据绑定val resultVars = bindReferences(resultExpressions, aggregateAttributes).map(_.genCode(ctx)) 这一步是把聚合结果的变量绑定到聚合表达式中, 其中resultExpressions为List( avg(id#0L)#3 AS avg(id)#6) (这里我们只考虑AVG) aggregateAttributes是resultExpression的AttributeReference的一种表达,便于在BoundReference的时候进行映射绑定 对应的ExprCode为ExprCode(,sortAgg_isNull_4,sortAgg_value_4))

InputAdaptor的 doProduce

InputAdaptor的主要作用是承上启下,用来适配不支持Codegen的物理计划,sql如下:

override def doProduce(ctx: CodegenContext): String = {

// Inline mutable state since an InputRDDCodegen is used once in a task for WholeStageCodegen

val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];",

forceInline = true)

val row = ctx.freshName("row")

val outputVars = if (createUnsafeProjection) {

// creating the vars will make the parent consume add an unsafe projection.

ctx.INPUT_ROW = row

ctx.currentVars = null

output.zipWithIndex.map { case (a, i) =>

BoundReference(i, a.dataType, a.nullable).genCode(ctx)

}

} else {

null

}

val updateNumOutputRowsMetrics = if (metrics.contains("numOutputRows")) {

val numOutputRows = metricTerm(ctx, "numOutputRows")

s"$numOutputRows.add(1);"

} else {

""

}

s"""

| while ($limitNotReachedCond $input.hasNext()) {

| InternalRow $row = (InternalRow) $input.next();

| ${updateNumOutputRowsMetrics}

| ${consume(ctx, outputVars, if (createUnsafeProjection) null else row).trim}

| ${shouldStopCheckCode}

| }

""".stripMargin

}

val input = ctx.addMutableState(“scala.collection.Iterator”, “input”, v => s"$v = inputs[0];" 定义一个input变量用来接受sortaggregate(partial)的输出的InteralRow(unsafeRow),对应的初始化方法会在init方法中调用val row = ctx.freshName(“row”) 定义一个临时变量用来接受input中的unsafe类型的InteralRow,便于进行迭代操作val outputVars = if (createUnsafeProjection) 对于InputAdaptor来说createUnsafeProjection是 false, 所以这块返回的是nullval updateNumOutputRowsMetrics = 因为metrics不满足条件,所以这里也是返回空字符串代码组装 s"""

| while ($limitNotReachedCond $input.hasNext()) {

| InternalRow $row = (InternalRow) $input.next();

| ${updateNumOutputRowsMetrics}

| ${consume(ctx, outputVars, if (createUnsafeProjection) null else row).trim}

| ${shouldStopCheckCode}

| }

""".stripMargin

对输入的每一行数据进行迭代操作, 之后再调用consume方法, 注意: 这里的consume传入的是row,是InteralRow类型,而不是在RangeExec中的Long类型的变量

InputAdaptor的 consume

我们这里只说明和之前不一样的部分,对应的sql如下:

final def consume(ctx: CodegenContext, outputVars: Seq[ExprCode], row: String = null): String =

注意这里的参数 outputVars为null row 为InteralRow类型的变量

val inputVarsCandidate =

val inputVarsCandidate =

if (outputVars != null) {

assert(outputVars.length == output.length)

// outputVars will be used to generate the code for UnsafeRow, so we should copy them

outputVars.map(_.copy())

} else {

assert(row != null, "outputVars and row cannot both be null.")

ctx.currentVars = null

ctx.INPUT_ROW = row

output.zipWithIndex.map { case (attr, i) =>

BoundReference(i, attr.dataType, attr.nullable).genCode(ctx)

}

}

这里的数据流向了 else :

ctx.INPUT_ROW = row 设置当前的INPUT_ROW为row BoundReference的doGenCode方法也是走向了另一个分支:

assert(ctx.INPUT_ROW != null, "INPUT_ROW and currentVars cannot both be null.")

val javaType = JavaCode.javaType(dataType)

val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)

if (nullable) {

ev.copy(code =

code"""

|boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);

|$javaType ${ev.value} = ${ev.isNull} ?

| ${CodeGenerator.defaultValue(dataType)} : ($value);

""".stripMargin)

} else {

ev.copy(code = code"$javaType ${ev.value} = $value;", isNull = FalseLiteral)

}

分析

val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType,ordinal.toString) 根据数据类型的不同,调用UnsafeRow的不同方法 if (nullable) 因为AttributeReference("sum", sumDataType)()和AttributeReference("count", LongType)()表达式 nullable 为 TRUE,所以生成的代码为: boolean inputadapter_isNull_0 = inputadapter_row_0.isNullAt(0);

long inputadapter_value_0 = inputadapter_isNull_0 ?

-1L : (inputadapter_row_0.getLong(0));

boolean inputadapter_isNull_1 = inputadapter_row_0.isNullAt(1);

double inputadapter_value_1 = inputadapter_isNull_1 ?

-1.0 : (inputadapter_row_0.getDouble(1));

boolean inputadapter_isNull_2 = inputadapter_row_0.isNullAt(2);

long inputadapter_value_2 = inputadapter_isNull_2 ?

-1L : (inputadapter_row_0.getLong(2));

constructDoConsumeFunction方法中inputVarsInFunc 这里会多一个名为inputadapter_row_0的InternalRow类型的实参

相关文章

评论可见,请评论后查看内容,谢谢!!!
 您阅读本篇文章共花了: