001: /*
002: * This program is free software; you can redistribute it and/or modify
003: * it under the terms of the GNU General Public License as published by
004: * the Free Software Foundation; either version 2 of the License, or
005: * (at your option) any later version.
006: *
007: * This program is distributed in the hope that it will be useful,
008: * but WITHOUT ANY WARRANTY; without even the implied warranty of
009: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
010: * GNU General Public License for more details.
011: *
012: * You should have received a copy of the GNU General Public License
013: * along with this program; if not, write to the Free Software
014: * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
015: */
016:
017: /*
018: * NBTreeModelSelection.java
019: * Copyright (C) 2004 University of Waikato, Hamilton, New Zealand
020: *
021: */
022:
023: package weka.classifiers.trees.j48;
024:
025: import weka.core.Attribute;
026: import weka.core.Instances;
027: import weka.core.Utils;
028:
029: import java.util.Enumeration;
030:
031: /**
032: * Class for selecting a NB tree split.
033: *
034: * @author Mark Hall (mhall@cs.waikato.ac.nz)
035: * @version $Revision: 1.4 $
036: */
037: public class NBTreeModelSelection extends ModelSelection {
038:
039: /** for serialization */
040: private static final long serialVersionUID = 990097748931976704L;
041:
042: /** Minimum number of objects in interval. */
043: private int m_minNoObj;
044:
045: /** All the training data */
046: private Instances m_allData; //
047:
048: /**
049: * Initializes the split selection method with the given parameters.
050: *
051: * @param minNoObj minimum number of instances that have to occur in at least two
052: * subsets induced by split
053: * @param allData FULL training dataset (necessary for
054: * selection of split points).
055: */
056: public NBTreeModelSelection(int minNoObj, Instances allData) {
057: m_minNoObj = minNoObj;
058: m_allData = allData;
059: }
060:
061: /**
062: * Sets reference to training data to null.
063: */
064: public void cleanup() {
065:
066: m_allData = null;
067: }
068:
069: /**
070: * Selects NBTree-type split for the given dataset.
071: */
072: public final ClassifierSplitModel selectModel(Instances data) {
073:
074: double globalErrors = 0;
075:
076: double minResult;
077: double currentResult;
078: NBTreeSplit[] currentModel;
079: NBTreeSplit bestModel = null;
080: NBTreeNoSplit noSplitModel = null;
081: int validModels = 0;
082: boolean multiVal = true;
083: Distribution checkDistribution;
084: Attribute attribute;
085: double sumOfWeights;
086: int i;
087:
088: try {
089: // build the global model at this node
090: noSplitModel = new NBTreeNoSplit();
091: noSplitModel.buildClassifier(data);
092: if (data.numInstances() < 5) {
093: return noSplitModel;
094: }
095:
096: // evaluate it
097: globalErrors = noSplitModel.getErrors();
098: if (globalErrors == 0) {
099: return noSplitModel;
100: }
101:
102: // Check if all Instances belong to one class or if not
103: // enough Instances to split.
104: checkDistribution = new Distribution(data);
105: if (Utils.sm(checkDistribution.total(), m_minNoObj)
106: || Utils.eq(checkDistribution.total(),
107: checkDistribution
108: .perClass(checkDistribution
109: .maxClass()))) {
110: return noSplitModel;
111: }
112:
113: // Check if all attributes are nominal and have a
114: // lot of values.
115: if (m_allData != null) {
116: Enumeration enu = data.enumerateAttributes();
117: while (enu.hasMoreElements()) {
118: attribute = (Attribute) enu.nextElement();
119: if ((attribute.isNumeric())
120: || (Utils.sm(
121: (double) attribute.numValues(),
122: (0.3 * (double) m_allData
123: .numInstances())))) {
124: multiVal = false;
125: break;
126: }
127: }
128: }
129:
130: currentModel = new NBTreeSplit[data.numAttributes()];
131: sumOfWeights = data.sumOfWeights();
132:
133: // For each attribute.
134: for (i = 0; i < data.numAttributes(); i++) {
135:
136: // Apart from class attribute.
137: if (i != (data).classIndex()) {
138:
139: // Get models for current attribute.
140: currentModel[i] = new NBTreeSplit(i, m_minNoObj,
141: sumOfWeights);
142: currentModel[i].setGlobalModel(noSplitModel);
143: currentModel[i].buildClassifier(data);
144:
145: // Check if useful split for current attribute
146: // exists and check for enumerated attributes with
147: // a lot of values.
148: if (currentModel[i].checkModel()) {
149: validModels++;
150: }
151: } else {
152: currentModel[i] = null;
153: }
154: }
155:
156: // Check if any useful split was found.
157: if (validModels == 0) {
158: return noSplitModel;
159: }
160:
161: // Find "best" attribute to split on.
162: minResult = globalErrors;
163: for (i = 0; i < data.numAttributes(); i++) {
164: if ((i != (data).classIndex())
165: && (currentModel[i].checkModel())) {
166: /* System.err.println("Errors for "+data.attribute(i).name()+" "+
167: currentModel[i].getErrors()); */
168: if (currentModel[i].getErrors() < minResult) {
169: bestModel = currentModel[i];
170: minResult = currentModel[i].getErrors();
171: }
172: }
173: }
174: // System.exit(1);
175: // Check if useful split was found.
176:
177: if (((globalErrors - minResult) / globalErrors) < 0.05) {
178: return noSplitModel;
179: }
180:
181: /* if (bestModel == null) {
182: System.err.println("This shouldn't happen! glob : "+globalErrors+
183: " minRes : "+minResult);
184: System.exit(1);
185: } */
186: // Set the global model for the best split
187: // bestModel.setGlobalModel(noSplitModel);
188: return bestModel;
189: } catch (Exception e) {
190: e.printStackTrace();
191: }
192: return null;
193: }
194:
195: /**
196: * Selects NBTree-type split for the given dataset.
197: */
198: public final ClassifierSplitModel selectModel(Instances train,
199: Instances test) {
200:
201: return selectModel(train);
202: }
203: }
|