0001: /*
0002: * This program is free software; you can redistribute it and/or modify
0003: * it under the terms of the GNU General Public License as published by
0004: * the Free Software Foundation; either version 2 of the License, or
0005: * (at your option) any later version.
0006: *
0007: * This program is distributed in the hope that it will be useful,
0008: * but WITHOUT ANY WARRANTY; without even the implied warranty of
0009: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
0010: * GNU General Public License for more details.
0011: *
0012: * You should have received a copy of the GNU General Public License
0013: * along with this program; if not, write to the Free Software
0014: * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
0015: */
0016:
0017: /*
0018: * BVDecomposeSegCVSub.java
0019: * Copyright (C) 2003 Paul Conilione
0020: *
0021: * Based on the class: BVDecompose.java by Len Trigg (1999)
0022: */
0023:
0024: /*
0025: * DEDICATION
0026: *
0027: * Paul Conilione would like to express his deep gratitude and appreciation
0028: * to his Chinese Buddhist Taoist Master Sifu Chow Yuk Nen for the abilities
0029: * and insight that he has been taught, which have allowed him to program in
0030: * a clear and efficient manner.
0031: *
0032: * Master Sifu Chow Yuk Nen's Teachings are unique and precious. They are
0033: * applicable to any field of human endeavour. Through his unique and powerful
0034: * ability to skilfully apply Chinese Buddhist Teachings, people have achieved
0035: * success in; Computing, chemical engineering, business, accounting, philosophy
0036: * and more.
0037: *
0038: */
0039:
0040: package weka.classifiers;
0041:
0042: import weka.core.Attribute;
0043: import weka.core.Instance;
0044: import weka.core.Instances;
0045: import weka.core.Option;
0046: import weka.core.OptionHandler;
0047: import weka.core.TechnicalInformation;
0048: import weka.core.TechnicalInformation.Type;
0049: import weka.core.TechnicalInformation.Field;
0050: import weka.core.TechnicalInformationHandler;
0051: import weka.core.Utils;
0052:
0053: import java.io.BufferedReader;
0054: import java.io.FileReader;
0055: import java.io.Reader;
0056: import java.util.Enumeration;
0057: import java.util.Random;
0058: import java.util.Vector;
0059:
0060: /**
0061: <!-- globalinfo-start -->
0062: * This class performs Bias-Variance decomposion on any classifier using the sub-sampled cross-validation procedure as specified in (1).<br/>
0063: * The Kohavi and Wolpert definition of bias and variance is specified in (2).<br/>
0064: * The Webb definition of bias and variance is specified in (3).<br/>
0065: * <br/>
0066: * Geoffrey I. Webb, Paul Conilione (2002). Estimating bias and variance from data. School of Computer Science and Software Engineering, Victoria, Australia.<br/>
0067: * <br/>
0068: * Ron Kohavi, David H. Wolpert: Bias Plus Variance Decomposition for Zero-One Loss Functions. In: Machine Learning: Proceedings of the Thirteenth International Conference, 275-283, 1996.<br/>
0069: * <br/>
0070: * Geoffrey I. Webb (2000). MultiBoosting: A Technique for Combining Boosting and Wagging. Machine Learning. 40(2):159-196.
0071: * <p/>
0072: <!-- globalinfo-end -->
0073: *
0074: <!-- technical-bibtex-start -->
0075: * BibTeX:
0076: * <pre>
0077: * @misc{Webb2002,
0078: * address = {School of Computer Science and Software Engineering, Victoria, Australia},
0079: * author = {Geoffrey I. Webb and Paul Conilione},
0080: * institution = {Monash University},
0081: * title = {Estimating bias and variance from data},
0082: * year = {2002},
0083: * PDF = {http://www.csse.monash.edu.au/~webb/Files/WebbConilione04.pdf}
0084: * }
0085: *
0086: * @inproceedings{Kohavi1996,
0087: * author = {Ron Kohavi and David H. Wolpert},
0088: * booktitle = {Machine Learning: Proceedings of the Thirteenth International Conference},
0089: * editor = {Lorenza Saitta},
0090: * pages = {275-283},
0091: * publisher = {Morgan Kaufmann},
0092: * title = {Bias Plus Variance Decomposition for Zero-One Loss Functions},
0093: * year = {1996},
0094: * PS = {http://robotics.stanford.edu/~ronnyk/biasVar.ps}
0095: * }
0096: *
0097: * @article{Webb2000,
0098: * author = {Geoffrey I. Webb},
0099: * journal = {Machine Learning},
0100: * number = {2},
0101: * pages = {159-196},
0102: * title = {MultiBoosting: A Technique for Combining Boosting and Wagging},
0103: * volume = {40},
0104: * year = {2000}
0105: * }
0106: * </pre>
0107: * <p/>
0108: <!-- technical-bibtex-end -->
0109: *
0110: <!-- options-start -->
0111: * Valid options are: <p/>
0112: *
0113: * <pre> -c <class index>
0114: * The index of the class attribute.
0115: * (default last)</pre>
0116: *
0117: * <pre> -D
0118: * Turn on debugging output.</pre>
0119: *
0120: * <pre> -l <num>
0121: * The number of times each instance is classified.
0122: * (default 10)</pre>
0123: *
0124: * <pre> -p <proportion of objects in common>
0125: * The average proportion of instances common between any two training sets</pre>
0126: *
0127: * <pre> -s <seed>
0128: * The random number seed used.</pre>
0129: *
0130: * <pre> -t <name of arff file>
0131: * The name of the arff file used for the decomposition.</pre>
0132: *
0133: * <pre> -T <number of instances in training set>
0134: * The number of instances in the training set.</pre>
0135: *
0136: * <pre> -W <classifier class name>
0137: * Full class name of the learner used in the decomposition.
0138: * eg: weka.classifiers.bayes.NaiveBayes</pre>
0139: *
0140: * <pre>
0141: * Options specific to learner weka.classifiers.rules.ZeroR:
0142: * </pre>
0143: *
0144: * <pre> -D
0145: * If set, classifier is run in debug mode and
0146: * may output additional info to the console</pre>
0147: *
0148: <!-- options-end -->
0149: *
0150: * Options after -- are passed to the designated sub-learner. <p>
0151: *
0152: * @author Paul Conilione (paulc4321@yahoo.com.au)
0153: * @version $Revision: 1.5 $
0154: */
0155: public class BVDecomposeSegCVSub implements OptionHandler,
0156: TechnicalInformationHandler {
0157:
0158: /** Debugging mode, gives extra output if true. */
0159: protected boolean m_Debug;
0160:
0161: /** An instantiated base classifier used for getting and testing options. */
0162: protected Classifier m_Classifier = new weka.classifiers.rules.ZeroR();
0163:
0164: /** The options to be passed to the base classifier. */
0165: protected String[] m_ClassifierOptions;
0166:
0167: /** The number of times an instance is classified*/
0168: protected int m_ClassifyIterations;
0169:
0170: /** The name of the data file used for the decomposition */
0171: protected String m_DataFileName;
0172:
0173: /** The index of the class attribute */
0174: protected int m_ClassIndex = -1;
0175:
0176: /** The random number seed */
0177: protected int m_Seed = 1;
0178:
0179: /** The calculated Kohavi & Wolpert bias (squared) */
0180: protected double m_KWBias;
0181:
0182: /** The calculated Kohavi & Wolpert variance */
0183: protected double m_KWVariance;
0184:
0185: /** The calculated Kohavi & Wolpert sigma */
0186: protected double m_KWSigma;
0187:
0188: /** The calculated Webb bias */
0189: protected double m_WBias;
0190:
0191: /** The calculated Webb variance */
0192: protected double m_WVariance;
0193:
0194: /** The error rate */
0195: protected double m_Error;
0196:
0197: /** The training set size */
0198: protected int m_TrainSize;
0199:
0200: /** Proportion of instances common between any two training sets. */
0201: protected double m_P;
0202:
0203: /**
0204: * Returns a string describing this object
0205: * @return a description of the classifier suitable for
0206: * displaying in the explorer/experimenter gui
0207: */
0208: public String globalInfo() {
0209: return "This class performs Bias-Variance decomposion on any classifier using the "
0210: + "sub-sampled cross-validation procedure as specified in (1).\n"
0211: + "The Kohavi and Wolpert definition of bias and variance is specified in (2).\n"
0212: + "The Webb definition of bias and variance is specified in (3).\n\n"
0213: + getTechnicalInformation().toString();
0214: }
0215:
0216: /**
0217: * Returns an instance of a TechnicalInformation object, containing
0218: * detailed information about the technical background of this class,
0219: * e.g., paper reference or book this class is based on.
0220: *
0221: * @return the technical information about this class
0222: */
0223: public TechnicalInformation getTechnicalInformation() {
0224: TechnicalInformation result;
0225: TechnicalInformation additional;
0226:
0227: result = new TechnicalInformation(Type.MISC);
0228: result.setValue(Field.AUTHOR,
0229: "Geoffrey I. Webb and Paul Conilione");
0230: result.setValue(Field.YEAR, "2002");
0231: result.setValue(Field.TITLE,
0232: "Estimating bias and variance from data");
0233: result.setValue(Field.INSTITUTION, "Monash University");
0234: result
0235: .setValue(Field.ADDRESS,
0236: "School of Computer Science and Software Engineering, Victoria, Australia");
0237: result
0238: .setValue(Field.PDF,
0239: "http://www.csse.monash.edu.au/~webb/Files/WebbConilione04.pdf");
0240:
0241: additional = result.add(Type.INPROCEEDINGS);
0242: additional.setValue(Field.AUTHOR,
0243: "Ron Kohavi and David H. Wolpert");
0244: additional.setValue(Field.YEAR, "1996");
0245: additional
0246: .setValue(Field.TITLE,
0247: "Bias Plus Variance Decomposition for Zero-One Loss Functions");
0248: additional
0249: .setValue(Field.BOOKTITLE,
0250: "Machine Learning: Proceedings of the Thirteenth International Conference");
0251: additional.setValue(Field.PUBLISHER, "Morgan Kaufmann");
0252: additional.setValue(Field.EDITOR, "Lorenza Saitta");
0253: additional.setValue(Field.PAGES, "275-283");
0254: additional.setValue(Field.PS,
0255: "http://robotics.stanford.edu/~ronnyk/biasVar.ps");
0256:
0257: additional = result.add(Type.ARTICLE);
0258: additional.setValue(Field.AUTHOR, "Geoffrey I. Webb");
0259: additional.setValue(Field.YEAR, "2000");
0260: additional
0261: .setValue(Field.TITLE,
0262: "MultiBoosting: A Technique for Combining Boosting and Wagging");
0263: additional.setValue(Field.JOURNAL, "Machine Learning");
0264: additional.setValue(Field.VOLUME, "40");
0265: additional.setValue(Field.NUMBER, "2");
0266: additional.setValue(Field.PAGES, "159-196");
0267:
0268: return result;
0269: }
0270:
0271: /**
0272: * Returns an enumeration describing the available options.
0273: *
0274: * @return an enumeration of all the available options.
0275: */
0276: public Enumeration listOptions() {
0277:
0278: Vector newVector = new Vector(8);
0279:
0280: newVector.addElement(new Option(
0281: "\tThe index of the class attribute.\n"
0282: + "\t(default last)", "c", 1,
0283: "-c <class index>"));
0284: newVector.addElement(new Option("\tTurn on debugging output.",
0285: "D", 0, "-D"));
0286: newVector.addElement(new Option(
0287: "\tThe number of times each instance is classified.\n"
0288: + "\t(default 10)", "l", 1, "-l <num>"));
0289: newVector
0290: .addElement(new Option(
0291: "\tThe average proportion of instances common between any two training sets",
0292: "p", 1, "-p <proportion of objects in common>"));
0293: newVector.addElement(new Option(
0294: "\tThe random number seed used.", "s", 1, "-s <seed>"));
0295: newVector
0296: .addElement(new Option(
0297: "\tThe name of the arff file used for the decomposition.",
0298: "t", 1, "-t <name of arff file>"));
0299: newVector.addElement(new Option(
0300: "\tThe number of instances in the training set.", "T",
0301: 1, "-T <number of instances in training set>"));
0302: newVector.addElement(new Option(
0303: "\tFull class name of the learner used in the decomposition.\n"
0304: + "\teg: weka.classifiers.bayes.NaiveBayes",
0305: "W", 1, "-W <classifier class name>"));
0306:
0307: if ((m_Classifier != null)
0308: && (m_Classifier instanceof OptionHandler)) {
0309: newVector.addElement(new Option("", "", 0,
0310: "\nOptions specific to learner "
0311: + m_Classifier.getClass().getName() + ":"));
0312: Enumeration enu = ((OptionHandler) m_Classifier)
0313: .listOptions();
0314: while (enu.hasMoreElements()) {
0315: newVector.addElement(enu.nextElement());
0316: }
0317: }
0318: return newVector.elements();
0319: }
0320:
0321: /**
0322: * Sets the OptionHandler's options using the given list. All options
0323: * will be set (or reset) during this call (i.e. incremental setting
0324: * of options is not possible). <p/>
0325: *
0326: <!-- options-start -->
0327: * Valid options are: <p/>
0328: *
0329: * <pre> -c <class index>
0330: * The index of the class attribute.
0331: * (default last)</pre>
0332: *
0333: * <pre> -D
0334: * Turn on debugging output.</pre>
0335: *
0336: * <pre> -l <num>
0337: * The number of times each instance is classified.
0338: * (default 10)</pre>
0339: *
0340: * <pre> -p <proportion of objects in common>
0341: * The average proportion of instances common between any two training sets</pre>
0342: *
0343: * <pre> -s <seed>
0344: * The random number seed used.</pre>
0345: *
0346: * <pre> -t <name of arff file>
0347: * The name of the arff file used for the decomposition.</pre>
0348: *
0349: * <pre> -T <number of instances in training set>
0350: * The number of instances in the training set.</pre>
0351: *
0352: * <pre> -W <classifier class name>
0353: * Full class name of the learner used in the decomposition.
0354: * eg: weka.classifiers.bayes.NaiveBayes</pre>
0355: *
0356: * <pre>
0357: * Options specific to learner weka.classifiers.rules.ZeroR:
0358: * </pre>
0359: *
0360: * <pre> -D
0361: * If set, classifier is run in debug mode and
0362: * may output additional info to the console</pre>
0363: *
0364: <!-- options-end -->
0365: *
0366: * @param options the list of options as an array of strings
0367: * @throws Exception if an option is not supported
0368: */
0369: public void setOptions(String[] options) throws Exception {
0370: setDebug(Utils.getFlag('D', options));
0371:
0372: String classIndex = Utils.getOption('c', options);
0373: if (classIndex.length() != 0) {
0374: if (classIndex.toLowerCase().equals("last")) {
0375: setClassIndex(0);
0376: } else if (classIndex.toLowerCase().equals("first")) {
0377: setClassIndex(1);
0378: } else {
0379: setClassIndex(Integer.parseInt(classIndex));
0380: }
0381: } else {
0382: setClassIndex(0);
0383: }
0384:
0385: String classifyIterations = Utils.getOption('l', options);
0386: if (classifyIterations.length() != 0) {
0387: setClassifyIterations(Integer.parseInt(classifyIterations));
0388: } else {
0389: setClassifyIterations(10);
0390: }
0391:
0392: String prob = Utils.getOption('p', options);
0393: if (prob.length() != 0) {
0394: setP(Double.parseDouble(prob));
0395: } else {
0396: setP(-1);
0397: }
0398: //throw new Exception("A proportion must be specified" + " with a -p option.");
0399:
0400: String seedString = Utils.getOption('s', options);
0401: if (seedString.length() != 0) {
0402: setSeed(Integer.parseInt(seedString));
0403: } else {
0404: setSeed(1);
0405: }
0406:
0407: String dataFile = Utils.getOption('t', options);
0408: if (dataFile.length() != 0) {
0409: setDataFileName(dataFile);
0410: } else {
0411: throw new Exception("An arff file must be specified"
0412: + " with the -t option.");
0413: }
0414:
0415: String trainSize = Utils.getOption('T', options);
0416: if (trainSize.length() != 0) {
0417: setTrainSize(Integer.parseInt(trainSize));
0418: } else {
0419: setTrainSize(-1);
0420: }
0421: //throw new Exception("A training set size must be specified" + " with a -T option.");
0422:
0423: String classifierName = Utils.getOption('W', options);
0424: if (classifierName.length() != 0) {
0425: setClassifier(Classifier.forName(classifierName, Utils
0426: .partitionOptions(options)));
0427: } else {
0428: throw new Exception(
0429: "A learner must be specified with the -W option.");
0430: }
0431: }
0432:
0433: /**
0434: * Gets the current settings of the CheckClassifier.
0435: *
0436: * @return an array of strings suitable for passing to setOptions
0437: */
0438: public String[] getOptions() {
0439:
0440: String[] classifierOptions = new String[0];
0441: if ((m_Classifier != null)
0442: && (m_Classifier instanceof OptionHandler)) {
0443: classifierOptions = ((OptionHandler) m_Classifier)
0444: .getOptions();
0445: }
0446: String[] options = new String[classifierOptions.length + 14];
0447: int current = 0;
0448: if (getDebug()) {
0449: options[current++] = "-D";
0450: }
0451: options[current++] = "-c";
0452: options[current++] = "" + getClassIndex();
0453: options[current++] = "-l";
0454: options[current++] = "" + getClassifyIterations();
0455: options[current++] = "-p";
0456: options[current++] = "" + getP();
0457: options[current++] = "-s";
0458: options[current++] = "" + getSeed();
0459: if (getDataFileName() != null) {
0460: options[current++] = "-t";
0461: options[current++] = "" + getDataFileName();
0462: }
0463: options[current++] = "-T";
0464: options[current++] = "" + getTrainSize();
0465: if (getClassifier() != null) {
0466: options[current++] = "-W";
0467: options[current++] = getClassifier().getClass().getName();
0468: }
0469:
0470: options[current++] = "--";
0471: System.arraycopy(classifierOptions, 0, options, current,
0472: classifierOptions.length);
0473: current += classifierOptions.length;
0474: while (current < options.length) {
0475: options[current++] = "";
0476: }
0477: return options;
0478: }
0479:
0480: /**
0481: * Set the classifiers being analysed
0482: *
0483: * @param newClassifier the Classifier to use.
0484: */
0485: public void setClassifier(Classifier newClassifier) {
0486:
0487: m_Classifier = newClassifier;
0488: }
0489:
0490: /**
0491: * Gets the name of the classifier being analysed
0492: *
0493: * @return the classifier being analysed.
0494: */
0495: public Classifier getClassifier() {
0496:
0497: return m_Classifier;
0498: }
0499:
0500: /**
0501: * Sets debugging mode
0502: *
0503: * @param debug true if debug output should be printed
0504: */
0505: public void setDebug(boolean debug) {
0506:
0507: m_Debug = debug;
0508: }
0509:
0510: /**
0511: * Gets whether debugging is turned on
0512: *
0513: * @return true if debugging output is on
0514: */
0515: public boolean getDebug() {
0516:
0517: return m_Debug;
0518: }
0519:
0520: /**
0521: * Sets the random number seed
0522: *
0523: * @param seed the random number seed
0524: */
0525: public void setSeed(int seed) {
0526:
0527: m_Seed = seed;
0528: }
0529:
0530: /**
0531: * Gets the random number seed
0532: *
0533: * @return the random number seed
0534: */
0535: public int getSeed() {
0536:
0537: return m_Seed;
0538: }
0539:
0540: /**
0541: * Sets the number of times an instance is classified
0542: *
0543: * @param classifyIterations number of times an instance is classified
0544: */
0545: public void setClassifyIterations(int classifyIterations) {
0546:
0547: m_ClassifyIterations = classifyIterations;
0548: }
0549:
0550: /**
0551: * Gets the number of times an instance is classified
0552: *
0553: * @return the maximum number of times an instance is classified
0554: */
0555: public int getClassifyIterations() {
0556:
0557: return m_ClassifyIterations;
0558: }
0559:
0560: /**
0561: * Sets the name of the dataset file.
0562: *
0563: * @param dataFileName name of dataset file.
0564: */
0565: public void setDataFileName(String dataFileName) {
0566:
0567: m_DataFileName = dataFileName;
0568: }
0569:
0570: /**
0571: * Get the name of the data file used for the decomposition
0572: *
0573: * @return the name of the data file
0574: */
0575: public String getDataFileName() {
0576:
0577: return m_DataFileName;
0578: }
0579:
0580: /**
0581: * Get the index (starting from 1) of the attribute used as the class.
0582: *
0583: * @return the index of the class attribute
0584: */
0585: public int getClassIndex() {
0586:
0587: return m_ClassIndex + 1;
0588: }
0589:
0590: /**
0591: * Sets index of attribute to discretize on
0592: *
0593: * @param classIndex the index (starting from 1) of the class attribute
0594: */
0595: public void setClassIndex(int classIndex) {
0596:
0597: m_ClassIndex = classIndex - 1;
0598: }
0599:
0600: /**
0601: * Get the calculated bias squared according to the Kohavi and Wolpert definition
0602: *
0603: * @return the bias squared
0604: */
0605: public double getKWBias() {
0606:
0607: return m_KWBias;
0608: }
0609:
0610: /**
0611: * Get the calculated bias according to the Webb definition
0612: *
0613: * @return the bias
0614: *
0615: */
0616: public double getWBias() {
0617:
0618: return m_WBias;
0619: }
0620:
0621: /**
0622: * Get the calculated variance according to the Kohavi and Wolpert definition
0623: *
0624: * @return the variance
0625: */
0626: public double getKWVariance() {
0627:
0628: return m_KWVariance;
0629: }
0630:
0631: /**
0632: * Get the calculated variance according to the Webb definition
0633: *
0634: * @return the variance according to Webb
0635: *
0636: */
0637: public double getWVariance() {
0638:
0639: return m_WVariance;
0640: }
0641:
0642: /**
0643: * Get the calculated sigma according to the Kohavi and Wolpert definition
0644: *
0645: * @return the sigma
0646: *
0647: */
0648: public double getKWSigma() {
0649:
0650: return m_KWSigma;
0651: }
0652:
0653: /**
0654: * Set the training size.
0655: *
0656: * @param size the size of the training set
0657: *
0658: */
0659: public void setTrainSize(int size) {
0660:
0661: m_TrainSize = size;
0662: }
0663:
0664: /**
0665: * Get the training size
0666: *
0667: * @return the size of the training set
0668: *
0669: */
0670: public int getTrainSize() {
0671:
0672: return m_TrainSize;
0673: }
0674:
0675: /**
0676: * Set the proportion of instances that are common between two training sets
0677: * used to train a classifier.
0678: *
0679: * @param proportion the proportion of instances that are common between training
0680: * sets.
0681: *
0682: */
0683: public void setP(double proportion) {
0684:
0685: m_P = proportion;
0686: }
0687:
0688: /**
0689: * Get the proportion of instances that are common between two training sets.
0690: *
0691: * @return the proportion
0692: *
0693: */
0694: public double getP() {
0695:
0696: return m_P;
0697: }
0698:
0699: /**
0700: * Get the calculated error rate
0701: *
0702: * @return the error rate
0703: */
0704: public double getError() {
0705:
0706: return m_Error;
0707: }
0708:
0709: /**
0710: * Carry out the bias-variance decomposition using the sub-sampled cross-validation method.
0711: *
0712: * @throws Exception if the decomposition couldn't be carried out
0713: */
0714: public void decompose() throws Exception {
0715:
0716: Reader dataReader;
0717: Instances data;
0718:
0719: int tps; // training pool size, size of segment E.
0720: int k; // number of folds in segment E.
0721: int q; // number of segments of size tps.
0722:
0723: dataReader = new BufferedReader(new FileReader(m_DataFileName)); //open file
0724: data = new Instances(dataReader); // encapsulate in wrapper class called weka.Instances()
0725:
0726: if (m_ClassIndex < 0) {
0727: data.setClassIndex(data.numAttributes() - 1);
0728: } else {
0729: data.setClassIndex(m_ClassIndex);
0730: }
0731:
0732: if (data.classAttribute().type() != Attribute.NOMINAL) {
0733: throw new Exception("Class attribute must be nominal");
0734: }
0735: int numClasses = data.numClasses();
0736:
0737: data.deleteWithMissingClass();
0738: if (data.checkForStringAttributes()) {
0739: throw new Exception("Can't handle string attributes!");
0740: }
0741:
0742: // Dataset size must be greater than 2
0743: if (data.numInstances() <= 2) {
0744: throw new Exception("Dataset size must be greater than 2.");
0745: }
0746:
0747: if (m_TrainSize == -1) { // default value
0748: m_TrainSize = (int) Math
0749: .floor((double) data.numInstances() / 2.0);
0750: } else if (m_TrainSize < 0
0751: || m_TrainSize >= data.numInstances() - 1) { // Check if 0 < training Size < D - 1
0752: throw new Exception("Training set size of " + m_TrainSize
0753: + " is invalid.");
0754: }
0755:
0756: if (m_P == -1) { // default value
0757: m_P = (double) m_TrainSize
0758: / ((double) data.numInstances() - 1);
0759: } else if (m_P < (m_TrainSize / ((double) data.numInstances() - 1))
0760: || m_P >= 1.0) { //Check if p is in range: m/(|D|-1) <= p < 1.0
0761: throw new Exception(
0762: "Proportion is not in range: "
0763: + (m_TrainSize / ((double) data
0764: .numInstances() - 1))
0765: + " <= p < 1.0 ");
0766: }
0767:
0768: //roundup tps from double to integer
0769: tps = (int) Math
0770: .ceil(((double) m_TrainSize / (double) m_P) + 1);
0771: k = (int) Math.ceil(tps / (tps - (double) m_TrainSize));
0772:
0773: // number of folds cannot be more than the number of instances in the training pool
0774: if (k > tps) {
0775: throw new Exception(
0776: "The required number of folds is too many."
0777: + "Change p or the size of the training set.");
0778: }
0779:
0780: // calculate the number of segments, round down.
0781: q = (int) Math.floor((double) data.numInstances()
0782: / (double) tps);
0783:
0784: //create confusion matrix, columns = number of instances in data set, as all will be used, by rows = number of classes.
0785: double[][] instanceProbs = new double[data.numInstances()][numClasses];
0786: int[][] foldIndex = new int[k][2];
0787: Vector segmentList = new Vector(q + 1);
0788:
0789: //Set random seed
0790: Random random = new Random(m_Seed);
0791:
0792: data.randomize(random);
0793:
0794: //create index arrays for different segments
0795:
0796: int currentDataIndex = 0;
0797:
0798: for (int count = 1; count <= (q + 1); count++) {
0799: if (count > q) {
0800: int[] segmentIndex = new int[(data.numInstances() - (q * tps))];
0801: for (int index = 0; index < segmentIndex.length; index++, currentDataIndex++) {
0802:
0803: segmentIndex[index] = currentDataIndex;
0804: }
0805: segmentList.add(segmentIndex);
0806: } else {
0807: int[] segmentIndex = new int[tps];
0808:
0809: for (int index = 0; index < segmentIndex.length; index++, currentDataIndex++) {
0810: segmentIndex[index] = currentDataIndex;
0811: }
0812: segmentList.add(segmentIndex);
0813: }
0814: }
0815:
0816: int remainder = tps % k; // remainder is used to determine when to shrink the fold size by 1.
0817:
0818: //foldSize = ROUNDUP( tps / k ) (round up, eg 3 -> 3, 3.3->4)
0819: int foldSize = (int) Math.ceil((double) tps / (double) k); //roundup fold size double to integer
0820: int index = 0;
0821: int currentIndex;
0822:
0823: for (int count = 0; count < k; count++) {
0824: if (remainder != 0 && count == remainder) {
0825: foldSize -= 1;
0826: }
0827: foldIndex[count][0] = index;
0828: foldIndex[count][1] = foldSize;
0829: index += foldSize;
0830: }
0831:
0832: for (int l = 0; l < m_ClassifyIterations; l++) {
0833:
0834: for (int i = 1; i <= q; i++) {
0835:
0836: int[] currentSegment = (int[]) segmentList.get(i - 1);
0837:
0838: randomize(currentSegment, random);
0839:
0840: //CROSS FOLD VALIDATION for current Segment
0841: for (int j = 1; j <= k; j++) {
0842:
0843: Instances TP = null;
0844: for (int foldNum = 1; foldNum <= k; foldNum++) {
0845: if (foldNum != j) {
0846:
0847: int startFoldIndex = foldIndex[foldNum - 1][0]; //start index
0848: foldSize = foldIndex[foldNum - 1][1];
0849: int endFoldIndex = startFoldIndex
0850: + foldSize - 1;
0851:
0852: for (int currentFoldIndex = startFoldIndex; currentFoldIndex <= endFoldIndex; currentFoldIndex++) {
0853:
0854: if (TP == null) {
0855: TP = new Instances(
0856: data,
0857: currentSegment[currentFoldIndex],
0858: 1);
0859: } else {
0860: TP
0861: .add(data
0862: .instance(currentSegment[currentFoldIndex]));
0863: }
0864: }
0865: }
0866: }
0867:
0868: TP.randomize(random);
0869:
0870: if (getTrainSize() > TP.numInstances()) {
0871: throw new Exception(
0872: "The training set size of "
0873: + getTrainSize()
0874: + ", is greater than the training pool "
0875: + TP.numInstances());
0876: }
0877:
0878: Instances train = new Instances(TP, 0, m_TrainSize);
0879:
0880: Classifier current = Classifier
0881: .makeCopy(m_Classifier);
0882: current.buildClassifier(train); // create a clssifier using the instances in train.
0883:
0884: int currentTestIndex = foldIndex[j - 1][0]; //start index
0885: int testFoldSize = foldIndex[j - 1][1]; //size
0886: int endTestIndex = currentTestIndex + testFoldSize
0887: - 1;
0888:
0889: while (currentTestIndex <= endTestIndex) {
0890:
0891: Instance testInst = data
0892: .instance(currentSegment[currentTestIndex]);
0893: int pred = (int) current
0894: .classifyInstance(testInst);
0895:
0896: if (pred != testInst.classValue()) {
0897: m_Error++; // add 1 to mis-classifications.
0898: }
0899: instanceProbs[currentSegment[currentTestIndex]][pred]++;
0900: currentTestIndex++;
0901: }
0902:
0903: if (i == 1 && j == 1) {
0904: int[] segmentElast = (int[]) segmentList
0905: .lastElement();
0906: for (currentIndex = 0; currentIndex < segmentElast.length; currentIndex++) {
0907: Instance testInst = data
0908: .instance(segmentElast[currentIndex]);
0909: int pred = (int) current
0910: .classifyInstance(testInst);
0911: if (pred != testInst.classValue()) {
0912: m_Error++; // add 1 to mis-classifications.
0913: }
0914:
0915: instanceProbs[segmentElast[currentIndex]][pred]++;
0916: }
0917: }
0918: }
0919: }
0920: }
0921:
0922: m_Error /= (double) (m_ClassifyIterations * data.numInstances());
0923:
0924: m_KWBias = 0.0;
0925: m_KWVariance = 0.0;
0926: m_KWSigma = 0.0;
0927:
0928: m_WBias = 0.0;
0929: m_WVariance = 0.0;
0930:
0931: for (int i = 0; i < data.numInstances(); i++) {
0932:
0933: Instance current = data.instance(i);
0934:
0935: double[] predProbs = instanceProbs[i];
0936: double pActual, pPred;
0937: double bsum = 0, vsum = 0, ssum = 0;
0938: double wBSum = 0, wVSum = 0;
0939:
0940: Vector centralTendencies = findCentralTendencies(predProbs);
0941:
0942: if (centralTendencies == null) {
0943: throw new Exception("Central tendency was null.");
0944: }
0945:
0946: for (int j = 0; j < numClasses; j++) {
0947: pActual = (current.classValue() == j) ? 1 : 0;
0948: pPred = predProbs[j] / m_ClassifyIterations;
0949: bsum += (pActual - pPred) * (pActual - pPred) - pPred
0950: * (1 - pPred) / (m_ClassifyIterations - 1);
0951: vsum += pPred * pPred;
0952: ssum += pActual * pActual;
0953: }
0954:
0955: m_KWBias += bsum;
0956: m_KWVariance += (1 - vsum);
0957: m_KWSigma += (1 - ssum);
0958:
0959: for (int count = 0; count < centralTendencies.size(); count++) {
0960:
0961: int wB = 0, wV = 0;
0962: int centralTendency = ((Integer) centralTendencies
0963: .get(count)).intValue();
0964:
0965: // For a single instance xi, find the bias and variance.
0966: for (int j = 0; j < numClasses; j++) {
0967:
0968: //Webb definition
0969: if (j != (int) current.classValue()
0970: && j == centralTendency) {
0971: wB += predProbs[j];
0972: }
0973: if (j != (int) current.classValue()
0974: && j != centralTendency) {
0975: wV += predProbs[j];
0976: }
0977:
0978: }
0979: wBSum += (double) wB;
0980: wVSum += (double) wV;
0981: }
0982:
0983: // calculate bais by dividing bSum by the number of central tendencies and
0984: // total number of instances. (effectively finding the average and dividing
0985: // by the number of instances to get the nominalised probability).
0986:
0987: m_WBias += (wBSum / ((double) (centralTendencies.size() * m_ClassifyIterations)));
0988: // calculate variance by dividing vSum by the total number of interations
0989: m_WVariance += (wVSum / ((double) (centralTendencies.size() * m_ClassifyIterations)));
0990:
0991: }
0992:
0993: m_KWBias /= (2.0 * (double) data.numInstances());
0994: m_KWVariance /= (2.0 * (double) data.numInstances());
0995: m_KWSigma /= (2.0 * (double) data.numInstances());
0996:
0997: // bias = bias / number of data instances
0998: m_WBias /= (double) data.numInstances();
0999: // variance = variance / number of data instances.
1000: m_WVariance /= (double) data.numInstances();
1001:
1002: if (m_Debug) {
1003: System.err.println("Decomposition finished");
1004: }
1005:
1006: }
1007:
1008: /** Finds the central tendency, given the classifications for an instance.
1009: *
1010: * Where the central tendency is defined as the class that was most commonly
1011: * selected for a given instance.<p>
1012: *
1013: * For example, instance 'x' may be classified out of 3 classes y = {1, 2, 3},
1014: * so if x is classified 10 times, and is classified as follows, '1' = 2 times, '2' = 5 times
1015: * and '3' = 3 times. Then the central tendency is '2'. <p>
1016: *
1017: * However, it is important to note that this method returns a list of all classes
1018: * that have the highest number of classifications.
1019: *
1020: * In cases where there are several classes with the largest number of classifications, then
1021: * all of these classes are returned. For example if 'x' is classified '1' = 4 times,
1022: * '2' = 4 times and '3' = 2 times. Then '1' and '2' are returned.<p>
1023: *
1024: * @param predProbs the array of classifications for a single instance.
1025: *
1026: * @return a Vector containing Integer objects which store the class(s) which
1027: * are the central tendency.
1028: */
1029: public Vector findCentralTendencies(double[] predProbs) {
1030:
1031: int centralTValue = 0;
1032: int currentValue = 0;
1033: //array to store the list of classes the have the greatest number of classifictions.
1034: Vector centralTClasses;
1035:
1036: centralTClasses = new Vector(); //create an array with size of the number of classes.
1037:
1038: // Go through array, finding the central tendency.
1039: for (int i = 0; i < predProbs.length; i++) {
1040: currentValue = (int) predProbs[i];
1041: // if current value is greater than the central tendency value then
1042: // clear vector and add new class to vector array.
1043: if (currentValue > centralTValue) {
1044: centralTClasses.clear();
1045: centralTClasses.addElement(new Integer(i));
1046: centralTValue = currentValue;
1047: } else if (currentValue != 0
1048: && currentValue == centralTValue) {
1049: centralTClasses.addElement(new Integer(i));
1050: }
1051: }
1052: //return all classes that have the greatest number of classifications.
1053: if (centralTValue != 0) {
1054: return centralTClasses;
1055: } else {
1056: return null;
1057: }
1058:
1059: }
1060:
1061: /**
1062: * Returns description of the bias-variance decomposition results.
1063: *
1064: * @return the bias-variance decomposition results as a string
1065: */
1066: public String toString() {
1067:
1068: String result = "\nBias-Variance Decomposition Segmentation, Cross Validation\n"
1069: + "with subsampling.\n";
1070:
1071: if (getClassifier() == null) {
1072: return "Invalid setup";
1073: }
1074:
1075: result += "\nClassifier : "
1076: + getClassifier().getClass().getName();
1077: if (getClassifier() instanceof OptionHandler) {
1078: result += Utils.joinOptions(((OptionHandler) m_Classifier)
1079: .getOptions());
1080: }
1081: result += "\nData File : " + getDataFileName();
1082: result += "\nClass Index : ";
1083: if (getClassIndex() == 0) {
1084: result += "last";
1085: } else {
1086: result += getClassIndex();
1087: }
1088: result += "\nIterations : " + getClassifyIterations();
1089: result += "\np : " + getP();
1090: result += "\nTraining Size : " + getTrainSize();
1091: result += "\nSeed : " + getSeed();
1092:
1093: result += "\n\nDefinition : " + "Kohavi and Wolpert";
1094: result += "\nError :"
1095: + Utils.doubleToString(getError(), 4);
1096: result += "\nBias^2 :"
1097: + Utils.doubleToString(getKWBias(), 4);
1098: result += "\nVariance :"
1099: + Utils.doubleToString(getKWVariance(), 4);
1100: result += "\nSigma^2 :"
1101: + Utils.doubleToString(getKWSigma(), 4);
1102:
1103: result += "\n\nDefinition : " + "Webb";
1104: result += "\nError :"
1105: + Utils.doubleToString(getError(), 4);
1106: result += "\nBias :"
1107: + Utils.doubleToString(getWBias(), 4);
1108: result += "\nVariance :"
1109: + Utils.doubleToString(getWVariance(), 4);
1110:
1111: return result;
1112: }
1113:
1114: /**
1115: * Test method for this class
1116: *
1117: * @param args the command line arguments
1118: */
1119: public static void main(String[] args) {
1120:
1121: try {
1122: BVDecomposeSegCVSub bvd = new BVDecomposeSegCVSub();
1123:
1124: try {
1125: bvd.setOptions(args);
1126: Utils.checkForRemainingOptions(args);
1127: } catch (Exception ex) {
1128: String result = ex.getMessage()
1129: + "\nBVDecompose Options:\n\n";
1130: Enumeration enu = bvd.listOptions();
1131: while (enu.hasMoreElements()) {
1132: Option option = (Option) enu.nextElement();
1133: result += option.synopsis() + "\n"
1134: + option.description() + "\n";
1135: }
1136: throw new Exception(result);
1137: }
1138:
1139: bvd.decompose();
1140:
1141: System.out.println(bvd.toString());
1142:
1143: } catch (Exception ex) {
1144: System.err.println(ex.getMessage());
1145: }
1146:
1147: }
1148:
1149: /**
1150: * Accepts an array of ints and randomises the values in the array, using the
1151: * random seed.
1152: *
1153: *@param index is the array of integers
1154: *@param random is the Random seed.
1155: */
1156: public final void randomize(int[] index, Random random) {
1157: for (int j = index.length - 1; j > 0; j--) {
1158: int k = random.nextInt(j + 1);
1159: int temp = index[j];
1160: index[j] = index[k];
1161: index[k] = temp;
1162: }
1163: }
1164: }
|