001: /*
002: * This program is free software; you can redistribute it and/or modify
003: * it under the terms of the GNU General Public License as published by
004: * the Free Software Foundation; either version 2 of the License, or
005: * (at your option) any later version.
006: *
007: * This program is distributed in the hope that it will be useful,
008: * but WITHOUT ANY WARRANTY; without even the implied warranty of
009: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
010: * GNU General Public License for more details.
011: *
012: * You should have received a copy of the GNU General Public License
013: * along with this program; if not, write to the Free Software
014: * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
015: */
016:
017: /*
018: * C45PruneableClassifierTree.java
019: * Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
020: *
021: */
022:
023: package weka.classifiers.trees.j48;
024:
025: import weka.core.Capabilities;
026: import weka.core.Instances;
027: import weka.core.Utils;
028: import weka.core.Capabilities.Capability;
029:
030: /**
031: * Class for handling a tree structure that can
032: * be pruned using C4.5 procedures.
033: *
034: * @author Eibe Frank (eibe@cs.waikato.ac.nz)
035: * @version $Revision: 1.14 $
036: */
037:
038: public class C45PruneableClassifierTree extends ClassifierTree {
039:
040: /** for serialization */
041: static final long serialVersionUID = -4813820170260388194L;
042:
043: /** True if the tree is to be pruned. */
044: boolean m_pruneTheTree = false;
045:
046: /** The confidence factor for pruning. */
047: float m_CF = 0.25f;
048:
049: /** Is subtree raising to be performed? */
050: boolean m_subtreeRaising = true;
051:
052: /** Cleanup after the tree has been built. */
053: boolean m_cleanup = true;
054:
055: /**
056: * Constructor for pruneable tree structure. Stores reference
057: * to associated training data at each node.
058: *
059: * @param toSelectLocModel selection method for local splitting model
060: * @param pruneTree true if the tree is to be pruned
061: * @param cf the confidence factor for pruning
062: * @param raiseTree
063: * @param cleanup
064: * @throws Exception if something goes wrong
065: */
066: public C45PruneableClassifierTree(ModelSelection toSelectLocModel,
067: boolean pruneTree, float cf, boolean raiseTree,
068: boolean cleanup) throws Exception {
069:
070: super (toSelectLocModel);
071:
072: m_pruneTheTree = pruneTree;
073: m_CF = cf;
074: m_subtreeRaising = raiseTree;
075: m_cleanup = cleanup;
076: }
077:
078: /**
079: * Returns default capabilities of the classifier tree.
080: *
081: * @return the capabilities of this classifier tree
082: */
083: public Capabilities getCapabilities() {
084: Capabilities result = super .getCapabilities();
085:
086: // attributes
087: result.enable(Capability.NOMINAL_ATTRIBUTES);
088: result.enable(Capability.NUMERIC_ATTRIBUTES);
089: result.enable(Capability.DATE_ATTRIBUTES);
090: result.enable(Capability.MISSING_VALUES);
091:
092: // class
093: result.enable(Capability.NOMINAL_CLASS);
094: result.enable(Capability.MISSING_CLASS_VALUES);
095:
096: // instances
097: result.setMinimumNumberInstances(0);
098:
099: return result;
100: }
101:
102: /**
103: * Method for building a pruneable classifier tree.
104: *
105: * @param data the data for building the tree
106: * @throws Exception if something goes wrong
107: */
108: public void buildClassifier(Instances data) throws Exception {
109:
110: // can classifier tree handle the data?
111: getCapabilities().testWithFail(data);
112:
113: // remove instances with missing class
114: data = new Instances(data);
115: data.deleteWithMissingClass();
116:
117: buildTree(data, m_subtreeRaising);
118: collapse();
119: if (m_pruneTheTree) {
120: prune();
121: }
122: if (m_cleanup) {
123: cleanup(new Instances(data, 0));
124: }
125: }
126:
127: /**
128: * Collapses a tree to a node if training error doesn't increase.
129: */
130: public final void collapse() {
131:
132: double errorsOfSubtree;
133: double errorsOfTree;
134: int i;
135:
136: if (!m_isLeaf) {
137: errorsOfSubtree = getTrainingErrors();
138: errorsOfTree = localModel().distribution().numIncorrect();
139: if (errorsOfSubtree >= errorsOfTree - 1E-3) {
140:
141: // Free adjacent trees
142: m_sons = null;
143: m_isLeaf = true;
144:
145: // Get NoSplit Model for tree.
146: m_localModel = new NoSplit(localModel().distribution());
147: } else
148: for (i = 0; i < m_sons.length; i++)
149: son(i).collapse();
150: }
151: }
152:
153: /**
154: * Prunes a tree using C4.5's pruning procedure.
155: *
156: * @throws Exception if something goes wrong
157: */
158: public void prune() throws Exception {
159:
160: double errorsLargestBranch;
161: double errorsLeaf;
162: double errorsTree;
163: int indexOfLargestBranch;
164: C45PruneableClassifierTree largestBranch;
165: int i;
166:
167: if (!m_isLeaf) {
168:
169: // Prune all subtrees.
170: for (i = 0; i < m_sons.length; i++)
171: son(i).prune();
172:
173: // Compute error for largest branch
174: indexOfLargestBranch = localModel().distribution().maxBag();
175: if (m_subtreeRaising) {
176: errorsLargestBranch = son(indexOfLargestBranch)
177: .getEstimatedErrorsForBranch(
178: (Instances) m_train);
179: } else {
180: errorsLargestBranch = Double.MAX_VALUE;
181: }
182:
183: // Compute error if this Tree would be leaf
184: errorsLeaf = getEstimatedErrorsForDistribution(localModel()
185: .distribution());
186:
187: // Compute error for the whole subtree
188: errorsTree = getEstimatedErrors();
189:
190: // Decide if leaf is best choice.
191: if (Utils.smOrEq(errorsLeaf, errorsTree + 0.1)
192: && Utils.smOrEq(errorsLeaf,
193: errorsLargestBranch + 0.1)) {
194:
195: // Free son Trees
196: m_sons = null;
197: m_isLeaf = true;
198:
199: // Get NoSplit Model for node.
200: m_localModel = new NoSplit(localModel().distribution());
201: return;
202: }
203:
204: // Decide if largest branch is better choice
205: // than whole subtree.
206: if (Utils.smOrEq(errorsLargestBranch, errorsTree + 0.1)) {
207: largestBranch = son(indexOfLargestBranch);
208: m_sons = largestBranch.m_sons;
209: m_localModel = largestBranch.localModel();
210: m_isLeaf = largestBranch.m_isLeaf;
211: newDistribution(m_train);
212: prune();
213: }
214: }
215: }
216:
217: /**
218: * Returns a newly created tree.
219: *
220: * @param data the data to work with
221: * @return the new tree
222: * @throws Exception if something goes wrong
223: */
224: protected ClassifierTree getNewTree(Instances data)
225: throws Exception {
226:
227: C45PruneableClassifierTree newTree = new C45PruneableClassifierTree(
228: m_toSelectModel, m_pruneTheTree, m_CF,
229: m_subtreeRaising, m_cleanup);
230: newTree.buildTree((Instances) data, m_subtreeRaising);
231:
232: return newTree;
233: }
234:
235: /**
236: * Computes estimated errors for tree.
237: *
238: * @return the estimated errors
239: */
240: private double getEstimatedErrors() {
241:
242: double errors = 0;
243: int i;
244:
245: if (m_isLeaf)
246: return getEstimatedErrorsForDistribution(localModel()
247: .distribution());
248: else {
249: for (i = 0; i < m_sons.length; i++)
250: errors = errors + son(i).getEstimatedErrors();
251: return errors;
252: }
253: }
254:
255: /**
256: * Computes estimated errors for one branch.
257: *
258: * @param data the data to work with
259: * @return the estimated errors
260: * @throws Exception if something goes wrong
261: */
262: private double getEstimatedErrorsForBranch(Instances data)
263: throws Exception {
264:
265: Instances[] localInstances;
266: double errors = 0;
267: int i;
268:
269: if (m_isLeaf)
270: return getEstimatedErrorsForDistribution(new Distribution(
271: data));
272: else {
273: Distribution savedDist = localModel().m_distribution;
274: localModel().resetDistribution(data);
275: localInstances = (Instances[]) localModel().split(data);
276: localModel().m_distribution = savedDist;
277: for (i = 0; i < m_sons.length; i++)
278: errors = errors
279: + son(i).getEstimatedErrorsForBranch(
280: localInstances[i]);
281: return errors;
282: }
283: }
284:
285: /**
286: * Computes estimated errors for leaf.
287: *
288: * @param theDistribution the distribution to use
289: * @return the estimated errors
290: */
291: private double getEstimatedErrorsForDistribution(
292: Distribution theDistribution) {
293:
294: if (Utils.eq(theDistribution.total(), 0))
295: return 0;
296: else
297: return theDistribution.numIncorrect()
298: + Stats.addErrs(theDistribution.total(),
299: theDistribution.numIncorrect(), m_CF);
300: }
301:
302: /**
303: * Computes errors of tree on training data.
304: *
305: * @return the training errors
306: */
307: private double getTrainingErrors() {
308:
309: double errors = 0;
310: int i;
311:
312: if (m_isLeaf)
313: return localModel().distribution().numIncorrect();
314: else {
315: for (i = 0; i < m_sons.length; i++)
316: errors = errors + son(i).getTrainingErrors();
317: return errors;
318: }
319: }
320:
321: /**
322: * Method just exists to make program easier to read.
323: *
324: * @return the local split model
325: */
326: private ClassifierSplitModel localModel() {
327:
328: return (ClassifierSplitModel) m_localModel;
329: }
330:
331: /**
332: * Computes new distributions of instances for nodes
333: * in tree.
334: *
335: * @param data the data to compute the distributions for
336: * @throws Exception if something goes wrong
337: */
338: private void newDistribution(Instances data) throws Exception {
339:
340: Instances[] localInstances;
341:
342: localModel().resetDistribution(data);
343: m_train = data;
344: if (!m_isLeaf) {
345: localInstances = (Instances[]) localModel().split(data);
346: for (int i = 0; i < m_sons.length; i++)
347: son(i).newDistribution(localInstances[i]);
348: } else {
349:
350: // Check whether there are some instances at the leaf now!
351: if (!Utils.eq(data.sumOfWeights(), 0)) {
352: m_isEmpty = false;
353: }
354: }
355: }
356:
357: /**
358: * Method just exists to make program easier to read.
359: */
360: private C45PruneableClassifierTree son(int index) {
361:
362: return (C45PruneableClassifierTree) m_sons[index];
363: }
364: }
|