最近对Generic UDAF思索了一下,感觉最关键的是理解UDAF执行的每一步过程的输入与输出,其实步骤根据说明来编写相关代码就基本没问题,但是需要注意的是,数据类型需要统一,建议使用 Hadoop 数据类型,即分布式对象。实践中证实使用writable系列的类型比java系列的类型简单. 不要尝试同时使用二种系列的类型, 中间容易出现ClassCastException.
0)在resolver对输入数据(类型、个数)加以判断
1)首先分析数据从原始数据到最后输出所需的步骤
2)init方法根据每个步骤的数据的输入不同,加上相关的判断与输出类型
3)init方法注意每一步的输出类型
4)定义静态类实现AggregationBuffer聚合流接口,在此定义临时存放集合的变量,该变量是临时存储聚合。
5)reset方法需要手工调用,在getNewAggergationBuffer方法中声明实现AggregationBuffer的静态类变量,并调用reset方法
6)iterate方法将原始数据转为临时聚合流数据,注意将原始数据赋值到AggregationBuffer聚合流变量
7)terminatePartial方法,将返回部分聚集结果,一个封装了聚集计算当前状态的对象
8)merge方法,将terminatePartial方法生成的部分聚集与另一部分聚合值合并
9)terminate方法,将返回最后聚集的结果集
HIVE内置Generic UDAF(collect_set)源码分析
/**
* GenericUDAFCollectSet
*/
@Description(name = "collect_set", value = "_FUNC_(x) - Returns a set of objects with duplicate elements eliminated")
public class GenericUDAFCollectSet extends AbstractGenericUDAFResolver {
static final Log LOG = LogFactory.getLog(GenericUDAFCollectSet.class.getName());
public GenericUDAFCollectSet() {
}
@Override
public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters)
throws SemanticException {
//判别参数个数
if (parameters.length != 1) {
throw new UDFArgumentTypeException(parameters.length - 1,
"Exactly one argument is expected.");
}
//判别是否是基本类型,可以重写成支持复合类型
if (parameters[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
throw new UDFArgumentTypeException(0,
"Only primitive type arguments are accepted but "
+ parameters[0].getTypeName() + " was passed as parameter 1.");
}
//指定调用的Evaluator,用来接收消息和指定UDAF如何调用
return new GenericUDAFMkSetEvaluator();
}
public static class GenericUDAFMkSetEvaluator extends GenericUDAFEvaluator {
// For PARTIAL1 and COMPLETE: ObjectInspectors for original data
private PrimitiveObjectInspector inputOI;
// For PARTIAL2 and FINAL: ObjectInspectors for partial aggregations (list
// of objs)
private StandardListObjectInspector loi;
private StandardListObjectInspector internalMergeOI;
public ObjectInspector init(Mode m, ObjectInspector[] parameters)
throws HiveException {
super.init(m, parameters);
// init output object inspectors
// The output of a partial aggregation is a list
/**
* collect_set函数每个阶段分析
* 1.PARTIAL1阶段,原始数据到部分聚合,在collect_set中,则是将原始数据放入set中,所以,
* 输入数据类型是PrimitiveObjectInspector,输出类型是StandardListObjectInspector
* 2.在其他情况,有两种情形:(1)两个set之间的数据合并,也就是不满足if条件情况下
*(2)直接从原始数据到set,这种情况的出现是为了兼容从原始数据直接到set,也就是说map后
* 直接到输出,没有reduce过程,也就是COMPLETE阶段
*/
if (m == Mode.PARTIAL1) {
inputOI = (PrimitiveObjectInspector) parameters[0];
return ObjectInspectorFactory
.getStandardListObjectInspector((PrimitiveObjectInspector) ObjectInspectorUtils
.getStandardObjectInspector(inputOI));
} else {
//COMPLETE 阶段
if (!(parameters[0] instanceof StandardListObjectInspector)) {
//no map aggregation.
inputOI = (PrimitiveObjectInspector) ObjectInspectorUtils
.getStandardObjectInspector(parameters[0]);
return (StandardListObjectInspector) ObjectInspectorFactory
.getStandardListObjectInspector(inputOI);
} else { //PARTIAL2,FINAL阶段,两个阶段都是list与list合并,调用一致
internalMergeOI = (StandardListObjectInspector) parameters[0];
inputOI = (PrimitiveObjectInspector) internalMergeOI.getListElementObjectInspector();
loi = (StandardListObjectInspector) ObjectInspectorUtils.getStandardObjectInspector(internalMergeOI);
return loi;
}
}
}
static class MkArrayAggregationBuffer implements AggregationBuffer {
Set<Object> container;
}
@Override
public void reset(AggregationBuffer agg) throws HiveException {
((MkArrayAggregationBuffer) agg).container = new HashSet<Object>();
}
@Override
public AggregationBuffer getNewAggregationBuffer() throws HiveException {
MkArrayAggregationBuffer ret = new MkArrayAggregationBuffer();
reset(ret);
return ret;
}
//mapside,将原始值转换添加到集合中
@Override
public void iterate(AggregationBuffer agg, Object[] parameters)
throws HiveException {
assert (parameters.length == 1);
Object p = parameters[0];
if (p != null) {
MkArrayAggregationBuffer myagg = (MkArrayAggregationBuffer) agg;
putIntoSet(p, myagg);
}
}
//mapside,临时聚集
@Override
public Object terminatePartial(AggregationBuffer agg) throws HiveException {
MkArrayAggregationBuffer myagg = (MkArrayAggregationBuffer) agg;
ArrayList<Object> ret = new ArrayList<Object>(myagg.container.size());
ret.addAll(myagg.container);
return ret;
}
//terminatePartial的临时聚集跟另一个聚集合并
@Override
public void merge(AggregationBuffer agg, Object partial)
throws HiveException {
MkArrayAggregationBuffer myagg = (MkArrayAggregationBuffer) agg;
ArrayList<Object> partialResult = (ArrayList<Object>) internalMergeOI.getList(partial);
for(Object i : partialResult) {
putIntoSet(i, myagg);
}
}
//合并最终结果到结果集返回
@Override
public Object terminate(AggregationBuffer agg) throws HiveException {
MkArrayAggregationBuffer myagg = (MkArrayAggregationBuffer) agg;
ArrayList<Object> ret = new ArrayList<Object>(myagg.container.size());
ret.addAll(myagg.container);
return ret;
}
private void putIntoSet(Object p, MkArrayAggregationBuffer myagg) {
if (myagg.container.contains(p))
return;
Object pCopy = ObjectInspectorUtils.copyToStandardObject(p,
this.inputOI);
myagg.container.add(pCopy);
}
}
}