package DecesionTree
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext
import org.apache.spark.ml.feature.StringIndexer
import java.math._
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.storage.StorageLevel
import org.apache.spark.sql.DataFrame
/**
* 基于ID3算法生成决策树---统计学习方法
*/
object ID3Tree {
def main(args: Array[String]): Unit ={
val conf=newSparkConf().setMaster("local").setAppName("ML")
val sc=newSparkContext(conf)
val sqlcontext=newSQLContext(sc)//train数据
val sampleData=Array(Array("1","青年","否","否","一般","否"),Array("2","青年","否","否","好","否"),Array("3","青年","是","否","好","是"),Array("4","青年","是","是","一般","是"),Array("5","青年","否","否","一般","否"),Array("6","中年","否","否","一般","否"),Array("7","中年","否","否","好","否"),Array("8","中年","是","是","好","是"),Array("9","中年","否","是","非常好","是"),Array("10","中年","否","是","非常好","是"),Array("11","老年","否","是","非常好","是"),Array("12","老年","否","是","好","是"),Array("13","老年","是","否","好","是"),Array("14","老年","是","否","非常好","是"),Array("15","老年","否","否","一般","否"))import sqlcontext.implicits._
var DF=sc.parallelize(sampleData).map { x =>
val age=x(1)
val work=x(2)
val house=x(3)
val credit=x(4)
val label=x(5)(age,work,house,credit,label)}.toDF("age","isWork","isHouse","credit","label")//决策树节点,内部节点(String),其中String表示当前内部节点
val internalNode=ArrayBuffer[String]()//决策树节点,叶子节点(str1,str2),其中str1表示判定条件,str2表示叶子节点标记
val leafNode=ArrayBuffer[(String,String)]()//每个内部节点所对应的子节点数
val countArr=ArrayBuffer[Int]()
var flag=truewhile(flag){
val totalRecord=DF.rdd.count().toInt
val labels=DF.select("label").rdd.map(Row=>Row.getString(Row.fieldIndex("label"))).map { x =>(x,1)}.reduceByKey(_+_).collect()
val featurePoint=Getfeature(DF,totalRecord,labels)(0)._1
//df表示上一个特征点对应的所有label
val df=DF.select(featurePoint).distinct().rdd.map { Row => Row.getString(Row.fieldIndex(featurePoint))}.collect()
var count=0//arr表示该内部节点中非叶子节点的子节点
var arr=""for(lb<-df){//根据最优特征,划分数据集
val str=s"$featurePoint ="+ s"'$lb'"
val newDF=DF.where(str).select("label").distinct().rdd.map { Row => Row.getString(Row.fieldIndex("label"))}.collect()
val D1=newDF.length
if(D1==1){
leafNode.append((lb,newDF(0)))
count +=1}else{
arr=lb
}}
internalNode.append(featurePoint)//判断决策树是否训练完成,若当前内部节点所对应的叶子节点的个数为2,则表示训练结束if(count==2) flag=falseelse{
val str2=s"$featurePoint ="+ s"'$arr'"
var sk=DF.where(str2).toDF().rdd.map { Row =>
val ID="0"
val age=Row.getString(Row.fieldIndex("age"))
val isWork=Row.getString(Row.fieldIndex("isWork"))
val isHouse=Row.getString(Row.fieldIndex("isHouse"))
val credit=Row.getString(Row.fieldIndex("credit"))
val label=Row.getString(Row.fieldIndex("label"))Array(ID,age,isWork,isHouse,credit,label)}.collect()//此处应该刷新一下DF
DF=sc.parallelize(sk).map { x =>
val age=x(1)
val work=x(2)
val house=x(3)
val credit=x(4)
val label=x(5)(age,work,house,credit,label)}.toDF("age","isWork","isHouse","credit","label")}
countArr.append(count)}println("所有的内部节点")
internalNode.foreach { x =>println(x)}println("所有的叶子节点")
leafNode.foreach(println(_))println("每个内部节点对应的叶子节点个数")
countArr.foreach { x =>println(x)}}/**
* 根据ID3算法,求取最优特征
*/
def Getfeature(DF:DataFrame,totalRecord:Int,labels:Array[(String,Int)]):ArrayBuffer[(String, Double)]={
val features=DF.columns
//计算数据集D的熵
var Hd=0.0for(lab<-labels){
val labelcount=lab._2.toDouble
val pi=labelcount/totalRecord
Hd+=-1.0*((pi)*Math.log(pi)/Math.log(2))}//计算特征A对数据集的经验条件熵
val Hda=ArrayBuffer[Double]()for(feature<-features){
var Hdik=0.0if(!"label".equals(feature)){//DI表示特征A对应的信息
val DI=DF.groupBy(feature).count()//lab表示特征A所有可能得取值
val Lab=ArrayBuffer[String]()
val Di=ArrayBuffer[Int]()
DI.collect().map { Row =>
Lab +=Row.getString(Row.fieldIndex(feature))
Di +=Row.getLong(Row.fieldIndex("count")).toInt
}//获取Dik信息
val Dik=ArrayBuffer[(Int,Int)]()for(lab<-Lab){
var i=0
val str=s"$feature = "+ s"'$lab'"
val newDF=DF.where(str).groupBy("label").count().persist(StorageLevel.MEMORY_ONLY_SER)
val df=newDF.rdd.map { Row => Row.getLong(Row.fieldIndex("count")).toInt}.collect()if(newDF.count().toInt ==2) Dik.append((df(0),df(1)))else Dik.append((df(0),0))}//计算每个label的条件熵for(i<-Di){
val newDik=Dik.take(1)
Dik.remove(0,1)for(j<-newDik){if(j._2 ==0){
val pi=j._1.toDouble/i
Hdik += i.toDouble/totalRecord*(-1.0)*(pi)*Math.log(pi)/Math.log(2)}else{
val pi1=j._1.toDouble/i
val pi2=j._2.toDouble/i
Hdik += i.toDouble/totalRecord*(-1.0)*(pi1)*Math.log(pi1)/Math.log(2)+ i.toDouble/totalRecord*(-1.0)*(pi2)*Math.log(pi2)/Math.log(2)}}}
Hda.append(Hdik)}}//Gda表示信息增益,选取信息增益最大值作为最优特征。
val Gda=ArrayBuffer[(String,Double)]()for(i<-0 until Hda.length){
val hda=Hda(i)
Gda.append((features(i),(Hd-hda)))}
Gda.sortBy(x=>x._2).reverse.take(1)}}-----------------------------
Result:
所有的内部节点
isHouse
isWork
所有的叶子节点
(是,是)(否,否)(是,是)
每个内部节点对应的叶子节点个数
12----------------------------
Decision Tree:
if("是".equals(isHouse))"是"elseif("是".equals(isWork))"是"else"否"