|
| 1 | +package DataMining_Viterbi; |
| 2 | + |
| 3 | +import java.io.BufferedReader; |
| 4 | +import java.io.File; |
| 5 | +import java.io.FileReader; |
| 6 | +import java.io.IOException; |
| 7 | +import java.util.ArrayList; |
| 8 | +import java.util.HashMap; |
| 9 | +import java.util.Map; |
| 10 | + |
| 11 | +/** |
| 12 | + * 维特比算法工具类 |
| 13 | + * |
| 14 | + * @author lyq |
| 15 | + * |
| 16 | + */ |
| 17 | +public class ViterbiTool { |
| 18 | + // 状态转移概率矩阵文件地址 |
| 19 | + private String stmFilePath; |
| 20 | + // 混淆矩阵文件地址 |
| 21 | + private String confusionFilePath; |
| 22 | + // 初始状态概率 |
| 23 | + private double[] initStatePro; |
| 24 | + // 观察到的状态序列 |
| 25 | + public String[] observeStates; |
| 26 | + // 状态转移矩阵值 |
| 27 | + private double[][] stMatrix; |
| 28 | + // 混淆矩阵值 |
| 29 | + private double[][] confusionMatrix; |
| 30 | + // 各个条件下的潜在特征概率值 |
| 31 | + private double[][] potentialValues; |
| 32 | + // 潜在特征 |
| 33 | + private ArrayList<String> potentialAttrs; |
| 34 | + // 属性值列坐标映射图 |
| 35 | + private HashMap<String, Integer> name2Index; |
| 36 | + // 列坐标属性值映射图 |
| 37 | + private HashMap<Integer, String> index2name; |
| 38 | + |
| 39 | + public ViterbiTool(String stmFilePath, String confusionFilePath, |
| 40 | + double[] initStatePro, String[] observeStates) { |
| 41 | + this.stmFilePath = stmFilePath; |
| 42 | + this.confusionFilePath = confusionFilePath; |
| 43 | + this.initStatePro = initStatePro; |
| 44 | + this.observeStates = observeStates; |
| 45 | + |
| 46 | + initOperation(); |
| 47 | + } |
| 48 | + |
| 49 | + /** |
| 50 | + * 初始化数据操作 |
| 51 | + */ |
| 52 | + private void initOperation() { |
| 53 | + double[] temp; |
| 54 | + int index; |
| 55 | + ArrayList<String[]> smtDatas; |
| 56 | + ArrayList<String[]> cfDatas; |
| 57 | + |
| 58 | + smtDatas = readDataFile(stmFilePath); |
| 59 | + cfDatas = readDataFile(confusionFilePath); |
| 60 | + |
| 61 | + index = 0; |
| 62 | + this.stMatrix = new double[smtDatas.size()][]; |
| 63 | + for (String[] array : smtDatas) { |
| 64 | + temp = new double[array.length]; |
| 65 | + for (int i = 0; i < array.length; i++) { |
| 66 | + try { |
| 67 | + temp[i] = Double.parseDouble(array[i]); |
| 68 | + } catch (NumberFormatException e) { |
| 69 | + temp[i] = -1; |
| 70 | + } |
| 71 | + } |
| 72 | + |
| 73 | + // 将转换后的值赋给数组中 |
| 74 | + this.stMatrix[index] = temp; |
| 75 | + index++; |
| 76 | + } |
| 77 | + |
| 78 | + index = 0; |
| 79 | + this.confusionMatrix = new double[cfDatas.size()][]; |
| 80 | + for (String[] array : cfDatas) { |
| 81 | + temp = new double[array.length]; |
| 82 | + for (int i = 0; i < array.length; i++) { |
| 83 | + try { |
| 84 | + temp[i] = Double.parseDouble(array[i]); |
| 85 | + } catch (NumberFormatException e) { |
| 86 | + temp[i] = -1; |
| 87 | + } |
| 88 | + } |
| 89 | + |
| 90 | + // 将转换后的值赋给数组中 |
| 91 | + this.confusionMatrix[index] = temp; |
| 92 | + index++; |
| 93 | + } |
| 94 | + |
| 95 | + this.potentialAttrs = new ArrayList<>(); |
| 96 | + // 添加潜在特征属性 |
| 97 | + for (String s : smtDatas.get(0)) { |
| 98 | + this.potentialAttrs.add(s); |
| 99 | + } |
| 100 | + // 去除首列无效列 |
| 101 | + potentialAttrs.remove(0); |
| 102 | + |
| 103 | + this.name2Index = new HashMap<>(); |
| 104 | + this.index2name = new HashMap<>(); |
| 105 | + |
| 106 | + // 添加名称下标映射关系 |
| 107 | + for (int i = 1; i < smtDatas.get(0).length; i++) { |
| 108 | + this.name2Index.put(smtDatas.get(0)[i], i); |
| 109 | + // 添加下标到名称的映射 |
| 110 | + this.index2name.put(i, smtDatas.get(0)[i]); |
| 111 | + } |
| 112 | + |
| 113 | + for (int i = 1; i < cfDatas.get(0).length; i++) { |
| 114 | + this.name2Index.put(cfDatas.get(0)[i], i); |
| 115 | + } |
| 116 | + } |
| 117 | + |
| 118 | + /** |
| 119 | + * 从文件中读取数据 |
| 120 | + */ |
| 121 | + private ArrayList<String[]> readDataFile(String filePath) { |
| 122 | + File file = new File(filePath); |
| 123 | + ArrayList<String[]> dataArray = new ArrayList<String[]>(); |
| 124 | + |
| 125 | + try { |
| 126 | + BufferedReader in = new BufferedReader(new FileReader(file)); |
| 127 | + String str; |
| 128 | + String[] tempArray; |
| 129 | + while ((str = in.readLine()) != null) { |
| 130 | + tempArray = str.split(" "); |
| 131 | + dataArray.add(tempArray); |
| 132 | + } |
| 133 | + in.close(); |
| 134 | + } catch (IOException e) { |
| 135 | + e.getStackTrace(); |
| 136 | + } |
| 137 | + |
| 138 | + return dataArray; |
| 139 | + } |
| 140 | + |
| 141 | + /** |
| 142 | + * 根据观察特征计算隐藏的特征概率矩阵 |
| 143 | + */ |
| 144 | + private void calPotencialProMatrix() { |
| 145 | + String curObserveState; |
| 146 | + // 观察特征和潜在特征的下标 |
| 147 | + int osIndex; |
| 148 | + int psIndex; |
| 149 | + double temp; |
| 150 | + double maxPro; |
| 151 | + // 混淆矩阵概率值,就是相关影响的因素概率 |
| 152 | + double confusionPro; |
| 153 | + |
| 154 | + this.potentialValues = new double[observeStates.length][potentialAttrs |
| 155 | + .size() + 1]; |
| 156 | + for (int i = 0; i < this.observeStates.length; i++) { |
| 157 | + curObserveState = this.observeStates[i]; |
| 158 | + osIndex = this.name2Index.get(curObserveState); |
| 159 | + maxPro = -1; |
| 160 | + |
| 161 | + // 因为是第一个观察特征,没有前面的影响,根据初始状态计算 |
| 162 | + if (i == 0) { |
| 163 | + for (String attr : this.potentialAttrs) { |
| 164 | + psIndex = this.name2Index.get(attr); |
| 165 | + confusionPro = this.confusionMatrix[psIndex][osIndex]; |
| 166 | + |
| 167 | + temp = this.initStatePro[psIndex - 1] * confusionPro; |
| 168 | + this.potentialValues[BaseNames.DAY1][psIndex] = temp; |
| 169 | + } |
| 170 | + } else { |
| 171 | + // 后面的潜在特征受前一个特征的影响,以及当前的混淆因素影响 |
| 172 | + for (String toDayAttr : this.potentialAttrs) { |
| 173 | + psIndex = this.name2Index.get(toDayAttr); |
| 174 | + confusionPro = this.confusionMatrix[psIndex][osIndex]; |
| 175 | + |
| 176 | + int index; |
| 177 | + maxPro = -1; |
| 178 | + // 通过昨天的概率计算今天此特征的最大概率 |
| 179 | + for (String yAttr : this.potentialAttrs) { |
| 180 | + index = this.name2Index.get(yAttr); |
| 181 | + temp = this.potentialValues[i - 1][index] |
| 182 | + * this.stMatrix[index][psIndex]; |
| 183 | + |
| 184 | + // 计算得到今天此潜在特征的最大概率 |
| 185 | + if (temp > maxPro) { |
| 186 | + maxPro = temp; |
| 187 | + } |
| 188 | + } |
| 189 | + |
| 190 | + this.potentialValues[i][psIndex] = maxPro * confusionPro; |
| 191 | + } |
| 192 | + } |
| 193 | + } |
| 194 | + } |
| 195 | + |
| 196 | + /** |
| 197 | + * 根据同时期最大概率值输出潜在特征值 |
| 198 | + */ |
| 199 | + private void outputResultAttr() { |
| 200 | + double maxPro; |
| 201 | + int maxIndex; |
| 202 | + ArrayList<String> psValues; |
| 203 | + |
| 204 | + psValues = new ArrayList<>(); |
| 205 | + for (int i = 0; i < this.potentialValues.length; i++) { |
| 206 | + maxPro = -1; |
| 207 | + maxIndex = 0; |
| 208 | + |
| 209 | + for (int j = 0; j < potentialValues[i].length; j++) { |
| 210 | + if (this.potentialValues[i][j] > maxPro) { |
| 211 | + maxPro = potentialValues[i][j]; |
| 212 | + maxIndex = j; |
| 213 | + } |
| 214 | + } |
| 215 | + |
| 216 | + // 取出最大概率下标对应的潜在特征 |
| 217 | + psValues.add(this.index2name.get(maxIndex)); |
| 218 | + } |
| 219 | + |
| 220 | + System.out.println("观察特征为:"); |
| 221 | + for (String s : this.observeStates) { |
| 222 | + System.out.print(s + ", "); |
| 223 | + } |
| 224 | + System.out.println(); |
| 225 | + |
| 226 | + System.out.println("潜在特征为:"); |
| 227 | + for (String s : psValues) { |
| 228 | + System.out.print(s + ", "); |
| 229 | + } |
| 230 | + System.out.println(); |
| 231 | + } |
| 232 | + |
| 233 | + /** |
| 234 | + * 根据观察属性,得到潜在属性信息 |
| 235 | + */ |
| 236 | + public void calHMMObserve() { |
| 237 | + calPotencialProMatrix(); |
| 238 | + outputResultAttr(); |
| 239 | + } |
| 240 | +} |
0 commit comments