001: /*
002: * This program is free software; you can redistribute it and/or modify
003: * it under the terms of the GNU General Public License as published by
004: * the Free Software Foundation; either version 2 of the License, or
005: * (at your option) any later version.
006: *
007: * This program is distributed in the hope that it will be useful,
008: * but WITHOUT ANY WARRANTY; without even the implied warranty of
009: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
010: * GNU General Public License for more details.
011: *
012: * You should have received a copy of the GNU General Public License
013: * along with this program; if not, write to the Free Software
014: * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
015: */
016:
017: /*
018: * NBTreeNoSplit.java
019: * Copyright (C) 2004 University of Waikato, Hamilton, New Zealand
020: *
021: */
022:
023: package weka.classifiers.trees.j48;
024:
025: import weka.classifiers.Classifier;
026: import weka.classifiers.Evaluation;
027: import weka.classifiers.bayes.NaiveBayesUpdateable;
028: import weka.core.Instance;
029: import weka.core.Instances;
030: import weka.filters.Filter;
031: import weka.filters.supervised.attribute.Discretize;
032:
033: import java.util.Random;
034:
035: /**
036: * Class implementing a "no-split"-split (leaf node) for naive bayes
037: * trees.
038: *
039: * @author Mark Hall (mhall@cs.waikato.ac.nz)
040: * @version $Revision: 1.3 $
041: */
042: public final class NBTreeNoSplit extends ClassifierSplitModel {
043:
044: /** for serialization */
045: private static final long serialVersionUID = 7824804381545259618L;
046:
047: /** the naive bayes classifier */
048: private NaiveBayesUpdateable m_nb;
049:
050: /** the discretizer used */
051: private Discretize m_disc;
052:
053: /** errors on the training data at this node */
054: private double m_errors;
055:
056: public NBTreeNoSplit() {
057: m_numSubsets = 1;
058: }
059:
060: /**
061: * Build the no-split node
062: *
063: * @param instances an <code>Instances</code> value
064: * @exception Exception if an error occurs
065: */
066: public final void buildClassifier(Instances instances)
067: throws Exception {
068: m_nb = new NaiveBayesUpdateable();
069: m_disc = new Discretize();
070: m_disc.setInputFormat(instances);
071: Instances temp = Filter.useFilter(instances, m_disc);
072: m_nb.buildClassifier(temp);
073: if (temp.numInstances() >= 5) {
074: m_errors = crossValidate(m_nb, temp, new Random(1));
075: }
076: m_numSubsets = 1;
077: }
078:
079: /**
080: * Return the errors made by the naive bayes model at this node
081: *
082: * @return the number of errors made
083: */
084: public double getErrors() {
085: return m_errors;
086: }
087:
088: /**
089: * Return the discretizer used at this node
090: *
091: * @return a <code>Discretize</code> value
092: */
093: public Discretize getDiscretizer() {
094: return m_disc;
095: }
096:
097: /**
098: * Get the naive bayes model at this node
099: *
100: * @return a <code>NaiveBayesUpdateable</code> value
101: */
102: public NaiveBayesUpdateable getNaiveBayesModel() {
103: return m_nb;
104: }
105:
106: /**
107: * Always returns 0 because only there is only one subset.
108: */
109: public final int whichSubset(Instance instance) {
110:
111: return 0;
112: }
113:
114: /**
115: * Always returns null because there is only one subset.
116: */
117: public final double[] weights(Instance instance) {
118:
119: return null;
120: }
121:
122: /**
123: * Does nothing because no condition has to be satisfied.
124: */
125: public final String leftSide(Instances instances) {
126:
127: return "";
128: }
129:
130: /**
131: * Does nothing because no condition has to be satisfied.
132: */
133: public final String rightSide(int index, Instances instances) {
134:
135: return "";
136: }
137:
138: /**
139: * Returns a string containing java source code equivalent to the test
140: * made at this node. The instance being tested is called "i".
141: *
142: * @param index index of the nominal value tested
143: * @param data the data containing instance structure info
144: * @return a value of type 'String'
145: */
146: public final String sourceExpression(int index, Instances data) {
147:
148: return "true"; // or should this be false??
149: }
150:
151: /**
152: * Return the probability for a class value
153: *
154: * @param classIndex the index of the class value
155: * @param instance the instance to generate a probability for
156: * @param theSubset the subset to consider
157: * @return a probability
158: * @exception Exception if an error occurs
159: */
160: public double classProb(int classIndex, Instance instance,
161: int theSubset) throws Exception {
162: m_disc.input(instance);
163: Instance temp = m_disc.output();
164: return m_nb.distributionForInstance(temp)[classIndex];
165: }
166:
167: /**
168: * Return a textual description of the node
169: *
170: * @return a <code>String</code> value
171: */
172: public String toString() {
173: return m_nb.toString();
174: }
175:
176: /**
177: * Utility method for fast 5-fold cross validation of a naive bayes
178: * model
179: *
180: * @param fullModel a <code>NaiveBayesUpdateable</code> value
181: * @param trainingSet an <code>Instances</code> value
182: * @param r a <code>Random</code> value
183: * @return a <code>double</code> value
184: * @exception Exception if an error occurs
185: */
186: public static double crossValidate(NaiveBayesUpdateable fullModel,
187: Instances trainingSet, Random r) throws Exception {
188: // make some copies for fast evaluation of 5-fold xval
189: Classifier[] copies = Classifier.makeCopies(fullModel, 5);
190: Evaluation eval = new Evaluation(trainingSet);
191: // make some splits
192: for (int j = 0; j < 5; j++) {
193: Instances test = trainingSet.testCV(5, j);
194: // unlearn these test instances
195: for (int k = 0; k < test.numInstances(); k++) {
196: test.instance(k).setWeight(-test.instance(k).weight());
197: ((NaiveBayesUpdateable) copies[j])
198: .updateClassifier(test.instance(k));
199: // reset the weight back to its original value
200: test.instance(k).setWeight(-test.instance(k).weight());
201: }
202: eval.evaluateModel(copies[j], test);
203: }
204: return eval.incorrect();
205: }
206: }
|