1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
| import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.types._ import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Aggregator} import org.apache.spark.sql.{Encoders, Encoder}
class MyCountUDAF extends UserDefinedAggregateFunction { override def inputSchema: StructType = { new StructType().add("id", LongType, nullable = true) }
override def bufferSchema: StructType = { new StructType().add("count", LongType, nullable = true) }
override def dataType: DataType = LongType
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = buffer(0) = 0L
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = buffer(0) = buffer.getLong(0) + 1
override def merge(buffer: MutableAggregationBuffer, row: Row): Unit = buffer(0) = buffer.getLong(0) + row.getLong(0)
override def evaluate(buffer: Row): Any = buffer.getLong(0)
}
class MyAverageUDAF extends UserDefinedAggregateFunction{ override def inputSchema: StructType = { new StructType().add("inputColumn", LongType) }
override def bufferSchema: StructType = { new StructType() .add("sum", LongType) .add("count", LongType) }
override def dataType: DataType = DoubleType
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = { buffer(0) = 0L buffer(1) = 0L }
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { buffer(0) = buffer.getLong(0) + input.getLong(0) buffer(1) = buffer.getLong(1) + 1 }
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0) buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1) }
def evaluate(buffer: Row): Double = buffer.getLong(0) / buffer.getLong(1)
}
val myCount = new MyCountUDAF val myAverage = new MyAverageUDAF spark .range(start = 0, end = 4, step = 1, numPartitions = 2) .withColumn("group", $"id" % 2) .groupBy("group") .agg(myCount.distinct($"id") as "count") .show() spark .range(start = 0, end = 4, step = 1, numPartitions = 2) .agg(myAverage($"id")) .show()
|