001: /*
002: * This program is free software; you can redistribute it and/or modify
003: * it under the terms of the GNU General Public License as published by
004: * the Free Software Foundation; either version 2 of the License, or
005: * (at your option) any later version.
006: *
007: * This program is distributed in the hope that it will be useful,
008: * but WITHOUT ANY WARRANTY; without even the implied warranty of
009: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
010: * GNU General Public License for more details.
011: *
012: * You should have received a copy of the GNU General Public License
013: * along with this program; if not, write to the Free Software
014: * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
015: */
016:
017: /*
018: * BVDecompose.java
019: * Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
020: *
021: */
022:
023: package weka.classifiers;
024:
025: import weka.core.Attribute;
026: import weka.core.Instance;
027: import weka.core.Instances;
028: import weka.core.Option;
029: import weka.core.OptionHandler;
030: import weka.core.TechnicalInformation;
031: import weka.core.TechnicalInformation.Type;
032: import weka.core.TechnicalInformation.Field;
033: import weka.core.TechnicalInformationHandler;
034: import weka.core.Utils;
035:
036: import java.io.BufferedReader;
037: import java.io.FileReader;
038: import java.io.Reader;
039: import java.util.Enumeration;
040: import java.util.Random;
041: import java.util.Vector;
042:
043: /**
044: <!-- globalinfo-start -->
045: * Class for performing a Bias-Variance decomposition on any classifier using the method specified in:<br/>
046: * <br/>
047: * Ron Kohavi, David H. Wolpert: Bias Plus Variance Decomposition for Zero-One Loss Functions. In: Machine Learning: Proceedings of the Thirteenth International Conference, 275-283, 1996.
048: * <p/>
049: <!-- globalinfo-end -->
050: *
051: <!-- technical-bibtex-start -->
052: * BibTeX:
053: * <pre>
054: * @inproceedings{Kohavi1996,
055: * author = {Ron Kohavi and David H. Wolpert},
056: * booktitle = {Machine Learning: Proceedings of the Thirteenth International Conference},
057: * editor = {Lorenza Saitta},
058: * pages = {275-283},
059: * publisher = {Morgan Kaufmann},
060: * title = {Bias Plus Variance Decomposition for Zero-One Loss Functions},
061: * year = {1996},
062: * PS = {http://robotics.stanford.edu/~ronnyk/biasVar.ps}
063: * }
064: * </pre>
065: * <p/>
066: <!-- technical-bibtex-end -->
067: *
068: <!-- options-start -->
069: * Valid options are: <p/>
070: *
071: * <pre> -c <class index>
072: * The index of the class attribute.
073: * (default last)</pre>
074: *
075: * <pre> -t <name of arff file>
076: * The name of the arff file used for the decomposition.</pre>
077: *
078: * <pre> -T <training pool size>
079: * The number of instances placed in the training pool.
080: * The remainder will be used for testing. (default 100)</pre>
081: *
082: * <pre> -s <seed>
083: * The random number seed used.</pre>
084: *
085: * <pre> -x <num>
086: * The number of training repetitions used.
087: * (default 50)</pre>
088: *
089: * <pre> -D
090: * Turn on debugging output.</pre>
091: *
092: * <pre> -W <classifier class name>
093: * Full class name of the learner used in the decomposition.
094: * eg: weka.classifiers.bayes.NaiveBayes</pre>
095: *
096: * <pre>
097: * Options specific to learner weka.classifiers.rules.ZeroR:
098: * </pre>
099: *
100: * <pre> -D
101: * If set, classifier is run in debug mode and
102: * may output additional info to the console</pre>
103: *
104: <!-- options-end -->
105: *
106: * Options after -- are passed to the designated sub-learner. <p>
107: *
108: * @author Len Trigg (trigg@cs.waikato.ac.nz)
109: * @version $Revision: 1.13 $
110: */
111: public class BVDecompose implements OptionHandler,
112: TechnicalInformationHandler {
113:
114: /** Debugging mode, gives extra output if true */
115: protected boolean m_Debug;
116:
117: /** An instantiated base classifier used for getting and testing options. */
118: protected Classifier m_Classifier = new weka.classifiers.rules.ZeroR();
119:
120: /** The options to be passed to the base classifier. */
121: protected String[] m_ClassifierOptions;
122:
123: /** The number of train iterations */
124: protected int m_TrainIterations = 50;
125:
126: /** The name of the data file used for the decomposition */
127: protected String m_DataFileName;
128:
129: /** The index of the class attribute */
130: protected int m_ClassIndex = -1;
131:
132: /** The random number seed */
133: protected int m_Seed = 1;
134:
135: /** The calculated bias (squared) */
136: protected double m_Bias;
137:
138: /** The calculated variance */
139: protected double m_Variance;
140:
141: /** The calculated sigma (squared) */
142: protected double m_Sigma;
143:
144: /** The error rate */
145: protected double m_Error;
146:
147: /** The number of instances used in the training pool */
148: protected int m_TrainPoolSize = 100;
149:
150: /**
151: * Returns a string describing this object
152: * @return a description of the classifier suitable for
153: * displaying in the explorer/experimenter gui
154: */
155: public String globalInfo() {
156:
157: return "Class for performing a Bias-Variance decomposition on any classifier "
158: + "using the method specified in:\n\n"
159: + getTechnicalInformation().toString();
160: }
161:
162: /**
163: * Returns an instance of a TechnicalInformation object, containing
164: * detailed information about the technical background of this class,
165: * e.g., paper reference or book this class is based on.
166: *
167: * @return the technical information about this class
168: */
169: public TechnicalInformation getTechnicalInformation() {
170: TechnicalInformation result;
171:
172: result = new TechnicalInformation(Type.INPROCEEDINGS);
173: result
174: .setValue(Field.AUTHOR,
175: "Ron Kohavi and David H. Wolpert");
176: result.setValue(Field.YEAR, "1996");
177: result
178: .setValue(Field.TITLE,
179: "Bias Plus Variance Decomposition for Zero-One Loss Functions");
180: result
181: .setValue(Field.BOOKTITLE,
182: "Machine Learning: Proceedings of the Thirteenth International Conference");
183: result.setValue(Field.PUBLISHER, "Morgan Kaufmann");
184: result.setValue(Field.EDITOR, "Lorenza Saitta");
185: result.setValue(Field.PAGES, "275-283");
186: result.setValue(Field.PS,
187: "http://robotics.stanford.edu/~ronnyk/biasVar.ps");
188:
189: return result;
190: }
191:
192: /**
193: * Returns an enumeration describing the available options.
194: *
195: * @return an enumeration of all the available options.
196: */
197: public Enumeration listOptions() {
198:
199: Vector newVector = new Vector(7);
200:
201: newVector.addElement(new Option(
202: "\tThe index of the class attribute.\n"
203: + "\t(default last)", "c", 1,
204: "-c <class index>"));
205: newVector
206: .addElement(new Option(
207: "\tThe name of the arff file used for the decomposition.",
208: "t", 1, "-t <name of arff file>"));
209: newVector
210: .addElement(new Option(
211: "\tThe number of instances placed in the training pool.\n"
212: + "\tThe remainder will be used for testing. (default 100)",
213: "T", 1, "-T <training pool size>"));
214: newVector.addElement(new Option(
215: "\tThe random number seed used.", "s", 1, "-s <seed>"));
216: newVector.addElement(new Option(
217: "\tThe number of training repetitions used.\n"
218: + "\t(default 50)", "x", 1, "-x <num>"));
219: newVector.addElement(new Option("\tTurn on debugging output.",
220: "D", 0, "-D"));
221: newVector.addElement(new Option(
222: "\tFull class name of the learner used in the decomposition.\n"
223: + "\teg: weka.classifiers.bayes.NaiveBayes",
224: "W", 1, "-W <classifier class name>"));
225:
226: if ((m_Classifier != null)
227: && (m_Classifier instanceof OptionHandler)) {
228: newVector.addElement(new Option("", "", 0,
229: "\nOptions specific to learner "
230: + m_Classifier.getClass().getName() + ":"));
231: Enumeration enu = ((OptionHandler) m_Classifier)
232: .listOptions();
233: while (enu.hasMoreElements()) {
234: newVector.addElement(enu.nextElement());
235: }
236: }
237: return newVector.elements();
238: }
239:
240: /**
241: * Parses a given list of options. <p/>
242: *
243: <!-- options-start -->
244: * Valid options are: <p/>
245: *
246: * <pre> -c <class index>
247: * The index of the class attribute.
248: * (default last)</pre>
249: *
250: * <pre> -t <name of arff file>
251: * The name of the arff file used for the decomposition.</pre>
252: *
253: * <pre> -T <training pool size>
254: * The number of instances placed in the training pool.
255: * The remainder will be used for testing. (default 100)</pre>
256: *
257: * <pre> -s <seed>
258: * The random number seed used.</pre>
259: *
260: * <pre> -x <num>
261: * The number of training repetitions used.
262: * (default 50)</pre>
263: *
264: * <pre> -D
265: * Turn on debugging output.</pre>
266: *
267: * <pre> -W <classifier class name>
268: * Full class name of the learner used in the decomposition.
269: * eg: weka.classifiers.bayes.NaiveBayes</pre>
270: *
271: * <pre>
272: * Options specific to learner weka.classifiers.rules.ZeroR:
273: * </pre>
274: *
275: * <pre> -D
276: * If set, classifier is run in debug mode and
277: * may output additional info to the console</pre>
278: *
279: <!-- options-end -->
280: *
281: * Options after -- are passed to the designated sub-learner. <p>
282: *
283: * @param options the list of options as an array of strings
284: * @throws Exception if an option is not supported
285: */
286: public void setOptions(String[] options) throws Exception {
287:
288: setDebug(Utils.getFlag('D', options));
289:
290: String classIndex = Utils.getOption('c', options);
291: if (classIndex.length() != 0) {
292: if (classIndex.toLowerCase().equals("last")) {
293: setClassIndex(0);
294: } else if (classIndex.toLowerCase().equals("first")) {
295: setClassIndex(1);
296: } else {
297: setClassIndex(Integer.parseInt(classIndex));
298: }
299: } else {
300: setClassIndex(0);
301: }
302:
303: String trainIterations = Utils.getOption('x', options);
304: if (trainIterations.length() != 0) {
305: setTrainIterations(Integer.parseInt(trainIterations));
306: } else {
307: setTrainIterations(50);
308: }
309:
310: String trainPoolSize = Utils.getOption('T', options);
311: if (trainPoolSize.length() != 0) {
312: setTrainPoolSize(Integer.parseInt(trainPoolSize));
313: } else {
314: setTrainPoolSize(100);
315: }
316:
317: String seedString = Utils.getOption('s', options);
318: if (seedString.length() != 0) {
319: setSeed(Integer.parseInt(seedString));
320: } else {
321: setSeed(1);
322: }
323:
324: String dataFile = Utils.getOption('t', options);
325: if (dataFile.length() == 0) {
326: throw new Exception("An arff file must be specified"
327: + " with the -t option.");
328: }
329: setDataFileName(dataFile);
330:
331: String classifierName = Utils.getOption('W', options);
332: if (classifierName.length() == 0) {
333: throw new Exception(
334: "A learner must be specified with the -W option.");
335: }
336: setClassifier(Classifier.forName(classifierName, Utils
337: .partitionOptions(options)));
338: }
339:
340: /**
341: * Gets the current settings of the CheckClassifier.
342: *
343: * @return an array of strings suitable for passing to setOptions
344: */
345: public String[] getOptions() {
346:
347: String[] classifierOptions = new String[0];
348: if ((m_Classifier != null)
349: && (m_Classifier instanceof OptionHandler)) {
350: classifierOptions = ((OptionHandler) m_Classifier)
351: .getOptions();
352: }
353: String[] options = new String[classifierOptions.length + 14];
354: int current = 0;
355: if (getDebug()) {
356: options[current++] = "-D";
357: }
358: options[current++] = "-c";
359: options[current++] = "" + getClassIndex();
360: options[current++] = "-x";
361: options[current++] = "" + getTrainIterations();
362: options[current++] = "-T";
363: options[current++] = "" + getTrainPoolSize();
364: options[current++] = "-s";
365: options[current++] = "" + getSeed();
366: if (getDataFileName() != null) {
367: options[current++] = "-t";
368: options[current++] = "" + getDataFileName();
369: }
370: if (getClassifier() != null) {
371: options[current++] = "-W";
372: options[current++] = getClassifier().getClass().getName();
373: }
374: options[current++] = "--";
375: System.arraycopy(classifierOptions, 0, options, current,
376: classifierOptions.length);
377: current += classifierOptions.length;
378: while (current < options.length) {
379: options[current++] = "";
380: }
381: return options;
382: }
383:
384: /**
385: * Get the number of instances in the training pool.
386: *
387: * @return number of instances in the training pool.
388: */
389: public int getTrainPoolSize() {
390:
391: return m_TrainPoolSize;
392: }
393:
394: /**
395: * Set the number of instances in the training pool.
396: *
397: * @param numTrain number of instances in the training pool.
398: */
399: public void setTrainPoolSize(int numTrain) {
400:
401: m_TrainPoolSize = numTrain;
402: }
403:
404: /**
405: * Set the classifiers being analysed
406: *
407: * @param newClassifier the Classifier to use.
408: */
409: public void setClassifier(Classifier newClassifier) {
410:
411: m_Classifier = newClassifier;
412: }
413:
414: /**
415: * Gets the name of the classifier being analysed
416: *
417: * @return the classifier being analysed.
418: */
419: public Classifier getClassifier() {
420:
421: return m_Classifier;
422: }
423:
424: /**
425: * Sets debugging mode
426: *
427: * @param debug true if debug output should be printed
428: */
429: public void setDebug(boolean debug) {
430:
431: m_Debug = debug;
432: }
433:
434: /**
435: * Gets whether debugging is turned on
436: *
437: * @return true if debugging output is on
438: */
439: public boolean getDebug() {
440:
441: return m_Debug;
442: }
443:
444: /**
445: * Sets the random number seed
446: *
447: * @param seed the random number seed
448: */
449: public void setSeed(int seed) {
450:
451: m_Seed = seed;
452: }
453:
454: /**
455: * Gets the random number seed
456: *
457: * @return the random number seed
458: */
459: public int getSeed() {
460:
461: return m_Seed;
462: }
463:
464: /**
465: * Sets the maximum number of boost iterations
466: *
467: * @param trainIterations the number of boost iterations
468: */
469: public void setTrainIterations(int trainIterations) {
470:
471: m_TrainIterations = trainIterations;
472: }
473:
474: /**
475: * Gets the maximum number of boost iterations
476: *
477: * @return the maximum number of boost iterations
478: */
479: public int getTrainIterations() {
480:
481: return m_TrainIterations;
482: }
483:
484: /**
485: * Sets the name of the data file used for the decomposition
486: *
487: * @param dataFileName the data file to use
488: */
489: public void setDataFileName(String dataFileName) {
490:
491: m_DataFileName = dataFileName;
492: }
493:
494: /**
495: * Get the name of the data file used for the decomposition
496: *
497: * @return the name of the data file
498: */
499: public String getDataFileName() {
500:
501: return m_DataFileName;
502: }
503:
504: /**
505: * Get the index (starting from 1) of the attribute used as the class.
506: *
507: * @return the index of the class attribute
508: */
509: public int getClassIndex() {
510:
511: return m_ClassIndex + 1;
512: }
513:
514: /**
515: * Sets index of attribute to discretize on
516: *
517: * @param classIndex the index (starting from 1) of the class attribute
518: */
519: public void setClassIndex(int classIndex) {
520:
521: m_ClassIndex = classIndex - 1;
522: }
523:
524: /**
525: * Get the calculated bias squared
526: *
527: * @return the bias squared
528: */
529: public double getBias() {
530:
531: return m_Bias;
532: }
533:
534: /**
535: * Get the calculated variance
536: *
537: * @return the variance
538: */
539: public double getVariance() {
540:
541: return m_Variance;
542: }
543:
544: /**
545: * Get the calculated sigma squared
546: *
547: * @return the sigma squared
548: */
549: public double getSigma() {
550:
551: return m_Sigma;
552: }
553:
554: /**
555: * Get the calculated error rate
556: *
557: * @return the error rate
558: */
559: public double getError() {
560:
561: return m_Error;
562: }
563:
564: /**
565: * Carry out the bias-variance decomposition
566: *
567: * @throws Exception if the decomposition couldn't be carried out
568: */
569: public void decompose() throws Exception {
570:
571: Reader dataReader = new BufferedReader(new FileReader(
572: m_DataFileName));
573: Instances data = new Instances(dataReader);
574:
575: if (m_ClassIndex < 0) {
576: data.setClassIndex(data.numAttributes() - 1);
577: } else {
578: data.setClassIndex(m_ClassIndex);
579: }
580: if (data.classAttribute().type() != Attribute.NOMINAL) {
581: throw new Exception("Class attribute must be nominal");
582: }
583: int numClasses = data.numClasses();
584:
585: data.deleteWithMissingClass();
586: if (data.checkForStringAttributes()) {
587: throw new Exception("Can't handle string attributes!");
588: }
589:
590: if (data.numInstances() < 2 * m_TrainPoolSize) {
591: throw new Exception("The dataset must contain at least "
592: + (2 * m_TrainPoolSize) + " instances");
593: }
594: Random random = new Random(m_Seed);
595: data.randomize(random);
596: Instances trainPool = new Instances(data, 0, m_TrainPoolSize);
597: Instances test = new Instances(data, m_TrainPoolSize, data
598: .numInstances()
599: - m_TrainPoolSize);
600: int numTest = test.numInstances();
601: double[][] instanceProbs = new double[numTest][numClasses];
602:
603: m_Error = 0;
604: for (int i = 0; i < m_TrainIterations; i++) {
605: if (m_Debug) {
606: System.err.println("Iteration " + (i + 1));
607: }
608: trainPool.randomize(random);
609: Instances train = new Instances(trainPool, 0,
610: m_TrainPoolSize / 2);
611:
612: Classifier current = Classifier.makeCopy(m_Classifier);
613: current.buildClassifier(train);
614:
615: //// Evaluate the classifier on test, updating BVD stats
616: for (int j = 0; j < numTest; j++) {
617: int pred = (int) current.classifyInstance(test
618: .instance(j));
619: if (pred != test.instance(j).classValue()) {
620: m_Error++;
621: }
622: instanceProbs[j][pred]++;
623: }
624: }
625: m_Error /= (m_TrainIterations * numTest);
626:
627: // Average the BV over each instance in test.
628: m_Bias = 0;
629: m_Variance = 0;
630: m_Sigma = 0;
631: for (int i = 0; i < numTest; i++) {
632: Instance current = test.instance(i);
633: double[] predProbs = instanceProbs[i];
634: double pActual, pPred;
635: double bsum = 0, vsum = 0, ssum = 0;
636: for (int j = 0; j < numClasses; j++) {
637: pActual = (current.classValue() == j) ? 1 : 0; // Or via 1NN from test data?
638: pPred = predProbs[j] / m_TrainIterations;
639: bsum += (pActual - pPred) * (pActual - pPred) - pPred
640: * (1 - pPred) / (m_TrainIterations - 1);
641: vsum += pPred * pPred;
642: ssum += pActual * pActual;
643: }
644: m_Bias += bsum;
645: m_Variance += (1 - vsum);
646: m_Sigma += (1 - ssum);
647: }
648: m_Bias /= (2 * numTest);
649: m_Variance /= (2 * numTest);
650: m_Sigma /= (2 * numTest);
651:
652: if (m_Debug) {
653: System.err.println("Decomposition finished");
654: }
655: }
656:
657: /**
658: * Returns description of the bias-variance decomposition results.
659: *
660: * @return the bias-variance decomposition results as a string
661: */
662: public String toString() {
663:
664: String result = "\nBias-Variance Decomposition\n";
665:
666: if (getClassifier() == null) {
667: return "Invalid setup";
668: }
669:
670: result += "\nClassifier : "
671: + getClassifier().getClass().getName();
672: if (getClassifier() instanceof OptionHandler) {
673: result += Utils.joinOptions(((OptionHandler) m_Classifier)
674: .getOptions());
675: }
676: result += "\nData File : " + getDataFileName();
677: result += "\nClass Index : ";
678: if (getClassIndex() == 0) {
679: result += "last";
680: } else {
681: result += getClassIndex();
682: }
683: result += "\nTraining Pool: " + getTrainPoolSize();
684: result += "\nIterations : " + getTrainIterations();
685: result += "\nSeed : " + getSeed();
686: result += "\nError : "
687: + Utils.doubleToString(getError(), 6, 4);
688: result += "\nSigma^2 : "
689: + Utils.doubleToString(getSigma(), 6, 4);
690: result += "\nBias^2 : "
691: + Utils.doubleToString(getBias(), 6, 4);
692: result += "\nVariance : "
693: + Utils.doubleToString(getVariance(), 6, 4);
694:
695: return result + "\n";
696: }
697:
698: /**
699: * Test method for this class
700: *
701: * @param args the command line arguments
702: */
703: public static void main(String[] args) {
704:
705: try {
706: BVDecompose bvd = new BVDecompose();
707:
708: try {
709: bvd.setOptions(args);
710: Utils.checkForRemainingOptions(args);
711: } catch (Exception ex) {
712: String result = ex.getMessage()
713: + "\nBVDecompose Options:\n\n";
714: Enumeration enu = bvd.listOptions();
715: while (enu.hasMoreElements()) {
716: Option option = (Option) enu.nextElement();
717: result += option.synopsis() + "\n"
718: + option.description() + "\n";
719: }
720: throw new Exception(result);
721: }
722:
723: bvd.decompose();
724: System.out.println(bvd.toString());
725: } catch (Exception ex) {
726: System.err.println(ex.getMessage());
727: }
728: }
729: }
|