001: /*
002: * This program is free software; you can redistribute it and/or modify
003: * it under the terms of the GNU General Public License as published by
004: * the Free Software Foundation; either version 2 of the License, or
005: * (at your option) any later version.
006: *
007: * This program is distributed in the hope that it will be useful,
008: * but WITHOUT ANY WARRANTY; without even the implied warranty of
009: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
010: * GNU General Public License for more details.
011: *
012: * You should have received a copy of the GNU General Public License
013: * along with this program; if not, write to the Free Software
014: * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
015: */
016:
017: /*
018: * BinC45Split.java
019: * Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
020: *
021: */
022:
023: package weka.classifiers.trees.j48;
024:
025: import weka.core.Instance;
026: import weka.core.Instances;
027: import weka.core.Utils;
028:
029: import java.util.Enumeration;
030:
031: /**
032: * Class implementing a binary C4.5-like split on an attribute.
033: *
034: * @author Eibe Frank (eibe@cs.waikato.ac.nz)
035: * @version $Revision: 1.13 $
036: */
037: public class BinC45Split extends ClassifierSplitModel {
038:
039: /** for serialization */
040: private static final long serialVersionUID = -1278776919563022474L;
041:
042: /** Attribute to split on. */
043: private int m_attIndex;
044:
045: /** Minimum number of objects in a split. */
046: private int m_minNoObj;
047:
048: /** Value of split point. */
049: private double m_splitPoint;
050:
051: /** InfoGain of split. */
052: private double m_infoGain;
053:
054: /** GainRatio of split. */
055: private double m_gainRatio;
056:
057: /** The sum of the weights of the instances. */
058: private double m_sumOfWeights;
059:
060: /** Static reference to splitting criterion. */
061: private static InfoGainSplitCrit m_infoGainCrit = new InfoGainSplitCrit();
062:
063: /** Static reference to splitting criterion. */
064: private static GainRatioSplitCrit m_gainRatioCrit = new GainRatioSplitCrit();
065:
066: /**
067: * Initializes the split model.
068: */
069: public BinC45Split(int attIndex, int minNoObj, double sumOfWeights) {
070:
071: // Get index of attribute to split on.
072: m_attIndex = attIndex;
073:
074: // Set minimum number of objects.
075: m_minNoObj = minNoObj;
076:
077: // Set sum of weights;
078: m_sumOfWeights = sumOfWeights;
079: }
080:
081: /**
082: * Creates a C4.5-type split on the given data.
083: *
084: * @exception Exception if something goes wrong
085: */
086: public void buildClassifier(Instances trainInstances)
087: throws Exception {
088:
089: // Initialize the remaining instance variables.
090: m_numSubsets = 0;
091: m_splitPoint = Double.MAX_VALUE;
092: m_infoGain = 0;
093: m_gainRatio = 0;
094:
095: // Different treatment for enumerated and numeric
096: // attributes.
097: if (trainInstances.attribute(m_attIndex).isNominal()) {
098: handleEnumeratedAttribute(trainInstances);
099: } else {
100: trainInstances.sort(trainInstances.attribute(m_attIndex));
101: handleNumericAttribute(trainInstances);
102: }
103: }
104:
105: /**
106: * Returns index of attribute for which split was generated.
107: */
108: public final int attIndex() {
109:
110: return m_attIndex;
111: }
112:
113: /**
114: * Returns (C4.5-type) gain ratio for the generated split.
115: */
116: public final double gainRatio() {
117: return m_gainRatio;
118: }
119:
120: /**
121: * Gets class probability for instance.
122: *
123: * @exception Exception if something goes wrong
124: */
125: public final double classProb(int classIndex, Instance instance,
126: int theSubset) throws Exception {
127:
128: if (theSubset <= -1) {
129: double[] weights = weights(instance);
130: if (weights == null) {
131: return m_distribution.prob(classIndex);
132: } else {
133: double prob = 0;
134: for (int i = 0; i < weights.length; i++) {
135: prob += weights[i]
136: * m_distribution.prob(classIndex, i);
137: }
138: return prob;
139: }
140: } else {
141: if (Utils.gr(m_distribution.perBag(theSubset), 0)) {
142: return m_distribution.prob(classIndex, theSubset);
143: } else {
144: return m_distribution.prob(classIndex);
145: }
146: }
147: }
148:
149: /**
150: * Creates split on enumerated attribute.
151: *
152: * @exception Exception if something goes wrong
153: */
154: private void handleEnumeratedAttribute(Instances trainInstances)
155: throws Exception {
156:
157: Distribution newDistribution, secondDistribution;
158: int numAttValues;
159: double currIG, currGR;
160: Instance instance;
161: int i;
162:
163: numAttValues = trainInstances.attribute(m_attIndex).numValues();
164: newDistribution = new Distribution(numAttValues, trainInstances
165: .numClasses());
166:
167: // Only Instances with known values are relevant.
168: Enumeration enu = trainInstances.enumerateInstances();
169: while (enu.hasMoreElements()) {
170: instance = (Instance) enu.nextElement();
171: if (!instance.isMissing(m_attIndex))
172: newDistribution.add((int) instance.value(m_attIndex),
173: instance);
174: }
175: m_distribution = newDistribution;
176:
177: // For all values
178: for (i = 0; i < numAttValues; i++) {
179:
180: if (Utils.grOrEq(newDistribution.perBag(i), m_minNoObj)) {
181: secondDistribution = new Distribution(newDistribution,
182: i);
183:
184: // Check if minimum number of Instances in the two
185: // subsets.
186: if (secondDistribution.check(m_minNoObj)) {
187: m_numSubsets = 2;
188: currIG = m_infoGainCrit.splitCritValue(
189: secondDistribution, m_sumOfWeights);
190: currGR = m_gainRatioCrit.splitCritValue(
191: secondDistribution, m_sumOfWeights, currIG);
192: if ((i == 0) || Utils.gr(currGR, m_gainRatio)) {
193: m_gainRatio = currGR;
194: m_infoGain = currIG;
195: m_splitPoint = (double) i;
196: m_distribution = secondDistribution;
197: }
198: }
199: }
200: }
201: }
202:
203: /**
204: * Creates split on numeric attribute.
205: *
206: * @exception Exception if something goes wrong
207: */
208: private void handleNumericAttribute(Instances trainInstances)
209: throws Exception {
210:
211: int firstMiss;
212: int next = 1;
213: int last = 0;
214: int index = 0;
215: int splitIndex = -1;
216: double currentInfoGain;
217: double defaultEnt;
218: double minSplit;
219: Instance instance;
220: int i;
221:
222: // Current attribute is a numeric attribute.
223: m_distribution = new Distribution(2, trainInstances
224: .numClasses());
225:
226: // Only Instances with known values are relevant.
227: Enumeration enu = trainInstances.enumerateInstances();
228: i = 0;
229: while (enu.hasMoreElements()) {
230: instance = (Instance) enu.nextElement();
231: if (instance.isMissing(m_attIndex))
232: break;
233: m_distribution.add(1, instance);
234: i++;
235: }
236: firstMiss = i;
237:
238: // Compute minimum number of Instances required in each
239: // subset.
240: minSplit = 0.1 * (m_distribution.total())
241: / ((double) trainInstances.numClasses());
242: if (Utils.smOrEq(minSplit, m_minNoObj))
243: minSplit = m_minNoObj;
244: else if (Utils.gr(minSplit, 25))
245: minSplit = 25;
246:
247: // Enough Instances with known values?
248: if (Utils.sm((double) firstMiss, 2 * minSplit))
249: return;
250:
251: // Compute values of criteria for all possible split
252: // indices.
253: defaultEnt = m_infoGainCrit.oldEnt(m_distribution);
254: while (next < firstMiss) {
255:
256: if (trainInstances.instance(next - 1).value(m_attIndex) + 1e-5 < trainInstances
257: .instance(next).value(m_attIndex)) {
258:
259: // Move class values for all Instances up to next
260: // possible split point.
261: m_distribution.shiftRange(1, 0, trainInstances, last,
262: next);
263:
264: // Check if enough Instances in each subset and compute
265: // values for criteria.
266: if (Utils.grOrEq(m_distribution.perBag(0), minSplit)
267: && Utils.grOrEq(m_distribution.perBag(1),
268: minSplit)) {
269: currentInfoGain = m_infoGainCrit.splitCritValue(
270: m_distribution, m_sumOfWeights, defaultEnt);
271: if (Utils.gr(currentInfoGain, m_infoGain)) {
272: m_infoGain = currentInfoGain;
273: splitIndex = next - 1;
274: }
275: index++;
276: }
277: last = next;
278: }
279: next++;
280: }
281:
282: // Was there any useful split?
283: if (index == 0)
284: return;
285:
286: // Compute modified information gain for best split.
287: m_infoGain = m_infoGain - (Utils.log2(index) / m_sumOfWeights);
288: if (Utils.smOrEq(m_infoGain, 0))
289: return;
290:
291: // Set instance variables' values to values for
292: // best split.
293: m_numSubsets = 2;
294: m_splitPoint = (trainInstances.instance(splitIndex + 1).value(
295: m_attIndex) + trainInstances.instance(splitIndex)
296: .value(m_attIndex)) / 2;
297:
298: // In case we have a numerical precision problem we need to choose the
299: // smaller value
300: if (m_splitPoint == trainInstances.instance(splitIndex + 1)
301: .value(m_attIndex)) {
302: m_splitPoint = trainInstances.instance(splitIndex).value(
303: m_attIndex);
304: }
305:
306: // Restore distributioN for best split.
307: m_distribution = new Distribution(2, trainInstances
308: .numClasses());
309: m_distribution.addRange(0, trainInstances, 0, splitIndex + 1);
310: m_distribution.addRange(1, trainInstances, splitIndex + 1,
311: firstMiss);
312:
313: // Compute modified gain ratio for best split.
314: m_gainRatio = m_gainRatioCrit.splitCritValue(m_distribution,
315: m_sumOfWeights, m_infoGain);
316: }
317:
318: /**
319: * Returns (C4.5-type) information gain for the generated split.
320: */
321: public final double infoGain() {
322:
323: return m_infoGain;
324: }
325:
326: /**
327: * Prints left side of condition.
328: *
329: * @param data the data to get the attribute name from.
330: * @return the attribute name
331: */
332: public final String leftSide(Instances data) {
333:
334: return data.attribute(m_attIndex).name();
335: }
336:
337: /**
338: * Prints the condition satisfied by instances in a subset.
339: *
340: * @param index of subset and training set.
341: */
342: public final String rightSide(int index, Instances data) {
343:
344: StringBuffer text;
345:
346: text = new StringBuffer();
347: if (data.attribute(m_attIndex).isNominal()) {
348: if (index == 0)
349: text.append(" = "
350: + data.attribute(m_attIndex).value(
351: (int) m_splitPoint));
352: else
353: text.append(" != "
354: + data.attribute(m_attIndex).value(
355: (int) m_splitPoint));
356: } else if (index == 0)
357: text.append(" <= " + m_splitPoint);
358: else
359: text.append(" > " + m_splitPoint);
360:
361: return text.toString();
362: }
363:
364: /**
365: * Returns a string containing java source code equivalent to the test
366: * made at this node. The instance being tested is called "i".
367: *
368: * @param index index of the nominal value tested
369: * @param data the data containing instance structure info
370: * @return a value of type 'String'
371: */
372: public final String sourceExpression(int index, Instances data) {
373:
374: StringBuffer expr = null;
375: if (index < 0) {
376: return "i[" + m_attIndex + "] == null";
377: }
378: if (data.attribute(m_attIndex).isNominal()) {
379: if (index == 0) {
380: expr = new StringBuffer("i[");
381: } else {
382: expr = new StringBuffer("!i[");
383: }
384: expr.append(m_attIndex).append("]");
385: expr.append(".equals(\"").append(
386: data.attribute(m_attIndex)
387: .value((int) m_splitPoint)).append("\")");
388: } else {
389: expr = new StringBuffer("((Double) i[");
390: expr.append(m_attIndex).append("])");
391: if (index == 0) {
392: expr.append(".doubleValue() <= ").append(m_splitPoint);
393: } else {
394: expr.append(".doubleValue() > ").append(m_splitPoint);
395: }
396: }
397: return expr.toString();
398: }
399:
400: /**
401: * Sets split point to greatest value in given data smaller or equal to
402: * old split point.
403: * (C4.5 does this for some strange reason).
404: */
405: public final void setSplitPoint(Instances allInstances) {
406:
407: double newSplitPoint = -Double.MAX_VALUE;
408: double tempValue;
409: Instance instance;
410:
411: if ((!allInstances.attribute(m_attIndex).isNominal())
412: && (m_numSubsets > 1)) {
413: Enumeration enu = allInstances.enumerateInstances();
414: while (enu.hasMoreElements()) {
415: instance = (Instance) enu.nextElement();
416: if (!instance.isMissing(m_attIndex)) {
417: tempValue = instance.value(m_attIndex);
418: if (Utils.gr(tempValue, newSplitPoint)
419: && Utils.smOrEq(tempValue, m_splitPoint))
420: newSplitPoint = tempValue;
421: }
422: }
423: m_splitPoint = newSplitPoint;
424: }
425: }
426:
427: /**
428: * Sets distribution associated with model.
429: */
430: public void resetDistribution(Instances data) throws Exception {
431:
432: Instances insts = new Instances(data, data.numInstances());
433: for (int i = 0; i < data.numInstances(); i++) {
434: if (whichSubset(data.instance(i)) > -1) {
435: insts.add(data.instance(i));
436: }
437: }
438: Distribution newD = new Distribution(insts, this );
439: newD.addInstWithUnknown(data, m_attIndex);
440: m_distribution = newD;
441: }
442:
443: /**
444: * Returns weights if instance is assigned to more than one subset.
445: * Returns null if instance is only assigned to one subset.
446: */
447: public final double[] weights(Instance instance) {
448:
449: double[] weights;
450: int i;
451:
452: if (instance.isMissing(m_attIndex)) {
453: weights = new double[m_numSubsets];
454: for (i = 0; i < m_numSubsets; i++)
455: weights[i] = m_distribution.perBag(i)
456: / m_distribution.total();
457: return weights;
458: } else {
459: return null;
460: }
461: }
462:
463: /**
464: * Returns index of subset instance is assigned to.
465: * Returns -1 if instance is assigned to more than one subset.
466: *
467: * @exception Exception if something goes wrong
468: */
469:
470: public final int whichSubset(Instance instance) throws Exception {
471:
472: if (instance.isMissing(m_attIndex))
473: return -1;
474: else {
475: if (instance.attribute(m_attIndex).isNominal()) {
476: if ((int) m_splitPoint == (int) instance
477: .value(m_attIndex))
478: return 0;
479: else
480: return 1;
481: } else if (Utils.smOrEq(instance.value(m_attIndex),
482: m_splitPoint))
483: return 0;
484: else
485: return 1;
486: }
487: }
488: }
|