001 /**
002 *
003 */
004 package org.wdssii.decisiontree;
005
006 import java.util.ArrayList;
007 import java.util.List;
008
009 /**
010 * C45 learning algorithm to create an axial decision tree. J. R. Quinlan.
011 * Improved use of continuous attributes in c4.5. Journal of Artificial
012 * Intelligence Research, 4:77-90, 1996
013 *
014 * Usage:
015 * <pre>
016 float[][] data = new float[numTraining][numAttr];
017 int[] categories = new int[numTraining];
018 // populate arrays
019 ...
020 QuinlanC45AxialDecisionTreeCreator classifier = new QuinlanC45AxialDecisionTreeCreator(0.1); // pruning fraction
021 DecisionTree tree = classifier.learn(data, categories);
022 * </pre>
023 *
024 * @author lakshman
025 *
026 */
027 public class QuinlanC45AxialDecisionTreeCreator implements DecisionTreeCreator {
028 /** How many members can be in a population before it is split? */
029 private int populationToConsiderSplitting = 10;
030
031 /**
032 * fraction of the training data set to keep aside so that the learned tree
033 * is not overfit. A value of 0.1f may be a pretty good choice. The pruning
034 * points will be the last few instances of the training data set. Pass in a
035 * randomized sample if this simply won't do.
036 */
037 private float pruningFraction = 0.1f;
038
039 /** how deep can this tree go? The deeper the tree the less general it is. */
040 private int maxDepth = 10;
041
042 /** how many classes are there? */
043 private int numCategories = 0;
044
045 /** By default, InformationGain is used. */
046 private FitnessFunction fitness = new GainRatioFitnessFunction();
047
048 @SuppressWarnings("serial")
049 public static class TreeCreationException extends RuntimeException {
050 TreeCreationException(String cause) {
051 super(cause);
052 }
053 }
054
055 public QuinlanC45AxialDecisionTreeCreator(float pruningFraction) {
056 this.pruningFraction = pruningFraction;
057 }
058
059 public QuinlanC45AxialDecisionTreeCreator() {
060 }
061
062 /**
063 * @param inputData
064 * an array where each row corresponds to a single instance (to
065 * be classified) and the columns hold the attributes of that
066 * instance
067 * @param targetClass
068 * an array where each row corresponds to a single instance,
069 * specifically the actual classification of that instance. The
070 * class needs to be a number 0,1,2,...,N-1 where N is the number
071 * of classes. Some of these classes may have no examples.
072 * @return decisiontree
073 */
074 public AxialDecisionTree learn(float[][] inputData, int[] targetClass)
075 throws TreeCreationException, IllegalArgumentException {
076 if (inputData.length == 0 || inputData[0].length == 0
077 || targetClass.length != inputData.length
078 || pruningFraction < 0.0f) {
079 throw new IllegalArgumentException();
080 }
081 int numTraining = Math.round(inputData.length * (1 - pruningFraction));
082 int numTesting = inputData.length - numTraining;
083
084 int[] toConsider = new int[numTraining];
085 for (int i = 0; i < numTraining; ++i) {
086 toConsider[i] = i;
087 if (targetClass[i] >= numCategories) {
088 numCategories = targetClass[i] + 1;
089 }
090 }
091
092 AxialTreeNode node = buildTree(inputData, targetClass, toConsider, 0);
093 if (node == null) {
094 throw new TreeCreationException(
095 "Can not classify decision tree as there are too few unique inputs");
096 }
097
098 if (numTesting > 0) {
099 toConsider = new int[numTesting];
100 for (int i = 0; i < numTesting; ++i) {
101 toConsider[i] = numTraining + i; // the last few
102 }
103 node = pruneTree(node, inputData, targetClass, toConsider);
104 }
105
106 return new AxialDecisionTree(node, -1, inputData[0].length, numCategories);
107 }
108
109 /**
110 * removes nodes that do not perform well on the validation dataset.
111 */
112 private AxialTreeNode pruneTree(AxialTreeNode node, float[][] inputData,
113 int[] targetClass, int[] toConsider) {
114 // Prune the left/right branches
115 List<Integer> leftPoints = new ArrayList<Integer>();
116 List<Integer> rightPoints = new ArrayList<Integer>();
117 for (int i = 0, n = toConsider.length; i < n; ++i) {
118 int row = toConsider[i];
119 if (node.isHandledByLeftBranch(inputData[row])) {
120 leftPoints.add(i);
121 } else if (node.isHandledByRightBranch(inputData[row])) {
122 rightPoints.add(i);
123 }
124 }
125 if (leftPoints.size() > 0) {
126 int[] toConsiderLeft = new int[leftPoints.size()];
127 for (int i = 0, n = toConsiderLeft.length; i < n; ++i) {
128 toConsiderLeft[i] = leftPoints.get(i).intValue();
129 }
130 node.setLeft(pruneTree(node.getLeft(), inputData, targetClass,
131 toConsiderLeft));
132 }
133 if (rightPoints.size() > 0) {
134 int[] toConsiderRight = new int[rightPoints.size()];
135 for (int i = 0, n = toConsiderRight.length; i < n; ++i) {
136 toConsiderRight[i] = rightPoints.get(i).intValue();
137 }
138 node.setRight(pruneTree(node.getRight(), inputData, targetClass,
139 toConsiderRight));
140 }
141
142 node.normalize();
143
144 // Decide whether to replace this node by just a stump
145 // The replacement will happen if it will increase the number of correct
146 int numCorrect = 0;
147 int numCorrect_asStump = 0;
148 int defaultCategory = node.getDefaultCategory();
149 for (int i = 0, n = toConsider.length; i < n; ++i) {
150 int row = toConsider[i];
151 int trueCategory = targetClass[row];
152 int estCategory = node.classify(inputData[row]);
153 if (trueCategory == estCategory) {
154 ++numCorrect;
155 }
156 if (trueCategory == defaultCategory) {
157 ++numCorrect_asStump;
158 }
159 }
160 if (numCorrect_asStump > numCorrect) {
161 System.out.println("Pruning " + node);
162 return new AxialTreeNode(defaultCategory);
163 }
164
165 float fractionCorrect = (float) numCorrect / toConsider.length;
166 if (fractionCorrect < 0.4f) {
167 System.out.println("Pruning " + node);
168 return new AxialTreeNode(defaultCategory);
169 }
170
171 return node;
172 }
173
174 /**
175 * Helper method that creates a sub-tree and returns a node
176 *
177 * @return
178 */
179 private AxialTreeNode buildTree(float[][] inputData, int[] targetClass,
180 int[] toConsider, int depth) {
181 if (toConsider.length < populationToConsiderSplitting
182 || depth == maxDepth) {
183 // find most likely category and return a node that supplies it
184 // always
185 int mostLikelyCategory = getMostLikelyCategory(inputData,
186 targetClass, toConsider);
187 return new AxialTreeNode(mostLikelyCategory);
188 }
189
190 // Is everything of same category?
191 boolean allSame = true;
192 int startCategory = targetClass[toConsider[0]];
193 for (int i = 0, n = toConsider.length; i < n; ++i) {
194 if (targetClass[toConsider[i]] != startCategory) {
195 allSame = false;
196 break;
197 }
198 }
199 if (allSame) {
200 return new AxialTreeNode(startCategory);
201 }
202
203 // compute best attribute by finding the one that has highest
204 // information gain
205 int numAttributes = inputData[0].length;
206 FitnessFunction.Split[] splits = new FitnessFunction.Split[numAttributes];
207 for (int i = 0; i < numAttributes; ++i) {
208 splits[i] = fitness.computeSplitAndGain(inputData, targetClass,
209 numCategories, toConsider, i);
210 }
211 int bestAttribute = 0;
212 for (int i = 1; i < numAttributes; ++i) {
213 if (splits[i].score > splits[bestAttribute].score) {
214 bestAttribute = i;
215 }
216 }
217
218 // create node to split on bestAttribute
219 float thresh = splits[bestAttribute].thresh;
220 int[][] toConsiderSplit = split(inputData, toConsider, bestAttribute,
221 thresh);
222
223 // if there are no examples on one side of the branch, simply return the
224 // other side
225 if (toConsiderSplit[0].length == 0) {
226 return buildTree(inputData, targetClass, toConsiderSplit[1],
227 depth + 1);
228 } else if (toConsiderSplit[1].length == 0) {
229 return buildTree(inputData, targetClass, toConsiderSplit[0],
230 depth + 1);
231 }
232
233 int leftCategory = getMostLikelyCategory(inputData, targetClass,
234 toConsiderSplit[0]);
235 int rightCategory = getMostLikelyCategory(inputData, targetClass,
236 toConsiderSplit[1]);
237 AxialTreeNode leftNode = buildTree(inputData, targetClass,
238 toConsiderSplit[0], depth + 1);
239 AxialTreeNode rightNode = buildTree(inputData, targetClass,
240 toConsiderSplit[1], depth + 1);
241 int defaultCategory = (toConsiderSplit[0].length > toConsiderSplit[1].length) ? leftCategory
242 : rightCategory;
243 AxialTreeNode branch = new AxialTreeNode(bestAttribute, thresh,
244 leftNode, rightNode, defaultCategory);
245 branch.normalize();
246 return branch;
247 }
248
249 private int[][] split(float[][] inputData, int[] toConsider,
250 int bestAttribute, float thresh) {
251 int numLeft = 0;
252 for (int i = 0, n = toConsider.length; i < n; ++i) {
253 if (inputData[toConsider[i]][bestAttribute] < thresh) {
254 ++numLeft;
255 }
256 }
257 int numRight = toConsider.length - numLeft;
258 int[][] result = new int[2][];
259 result[0] = new int[numLeft];
260 result[1] = new int[numRight];
261 int leftIndex = 0;
262 int rightIndex = 0;
263 for (int i = 0, n = toConsider.length; i < n; ++i) {
264 if (inputData[toConsider[i]][bestAttribute] < thresh) {
265 result[0][leftIndex] = toConsider[i];
266 ++leftIndex;
267 } else {
268 result[1][rightIndex] = toConsider[i];
269 ++rightIndex;
270 }
271 }
272 return result;
273 }
274
275 private int getMostLikelyCategory(float[][] inputData, int[] targetClass,
276 int[] toConsider) {
277 if (toConsider.length == 0) {
278 throw new IllegalStateException(
279 "should not have empty toConsider array here");
280 }
281 int[] populationByCategory = new int[numCategories];
282 for (int i = 0, n = toConsider.length; i < n; ++i) {
283 int category = targetClass[toConsider[i]];
284 ++populationByCategory[category];
285 }
286 int bestCategory = 0;
287 for (int i = 1; i < numCategories; ++i) {
288 if (populationByCategory[i] > populationByCategory[bestCategory]) {
289 bestCategory = i;
290 }
291 }
292 return bestCategory;
293 }
294
295 /** Corresponds to previously learnt data set */
296 public int getNumCategories() {
297 return numCategories;
298 }
299
300
301 }