0001: /*
0002: * This program is free software; you can redistribute it and/or modify
0003: * it under the terms of the GNU General Public License as published by
0004: * the Free Software Foundation; either version 2 of the License, or
0005: * (at your option) any later version.
0006: *
0007: * This program is distributed in the hope that it will be useful,
0008: * but WITHOUT ANY WARRANTY; without even the implied warranty of
0009: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
0010: * GNU General Public License for more details.
0011: *
0012: * You should have received a copy of the GNU General Public License
0013: * along with this program; if not, write to the Free Software
0014: * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
0015: */
0016:
0017: /*
0018: * Evaluation.java
0019: * Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
0020: *
0021: */
0022:
0023: package weka.classifiers;
0024:
0025: import weka.classifiers.evaluation.NominalPrediction;
0026: import weka.classifiers.evaluation.ThresholdCurve;
0027: import weka.classifiers.xml.XMLClassifier;
0028: import weka.core.Drawable;
0029: import weka.core.FastVector;
0030: import weka.core.Instance;
0031: import weka.core.Instances;
0032: import weka.core.Option;
0033: import weka.core.OptionHandler;
0034: import weka.core.Range;
0035: import weka.core.Summarizable;
0036: import weka.core.Utils;
0037: import weka.core.converters.ConverterUtils.DataSink;
0038: import weka.core.converters.ConverterUtils.DataSource;
0039: import weka.core.xml.KOML;
0040: import weka.core.xml.XMLOptions;
0041: import weka.core.xml.XMLSerialization;
0042: import weka.estimators.Estimator;
0043: import weka.estimators.KernelEstimator;
0044:
0045: import java.io.BufferedInputStream;
0046: import java.io.BufferedOutputStream;
0047: import java.io.BufferedReader;
0048: import java.io.FileInputStream;
0049: import java.io.FileOutputStream;
0050: import java.io.FileReader;
0051: import java.io.InputStream;
0052: import java.io.ObjectInputStream;
0053: import java.io.ObjectOutputStream;
0054: import java.io.OutputStream;
0055: import java.io.Reader;
0056: import java.util.Enumeration;
0057: import java.util.Random;
0058: import java.util.zip.GZIPInputStream;
0059: import java.util.zip.GZIPOutputStream;
0060:
0061: /**
0062: * Class for evaluating machine learning models. <p/>
0063: *
0064: * ------------------------------------------------------------------- <p/>
0065: *
0066: * General options when evaluating a learning scheme from the command-line: <p/>
0067: *
0068: * -t filename <br/>
0069: * Name of the file with the training data. (required) <p/>
0070: *
0071: * -T filename <br/>
0072: * Name of the file with the test data. If missing a cross-validation
0073: * is performed. <p/>
0074: *
0075: * -c index <br/>
0076: * Index of the class attribute (1, 2, ...; default: last). <p/>
0077: *
0078: * -x number <br/>
0079: * The number of folds for the cross-validation (default: 10). <p/>
0080: *
0081: * -no-cv <br/>
0082: * No cross validation. If no test file is provided, no evaluation
0083: * is done. <p/>
0084: *
0085: * -split-percentage percentage <br/>
0086: * Sets the percentage for the train/test set split, e.g., 66. <p/>
0087: *
0088: * -preserve-order <br/>
0089: * Preserves the order in the percentage split instead of randomizing
0090: * the data first with the seed value ('-s'). <p/>
0091: *
0092: * -s seed <br/>
0093: * Random number seed for the cross-validation and percentage split
0094: * (default: 1). <p/>
0095: *
0096: * -m filename <br/>
0097: * The name of a file containing a cost matrix. <p/>
0098: *
0099: * -l filename <br/>
0100: * Loads classifier from the given file. In case the filename ends with ".xml"
0101: * the options are loaded from XML. <p/>
0102: *
0103: * -d filename <br/>
0104: * Saves classifier built from the training data into the given file. In case
0105: * the filename ends with ".xml" the options are saved XML, not the model. <p/>
0106: *
0107: * -v <br/>
0108: * Outputs no statistics for the training data. <p/>
0109: *
0110: * -o <br/>
0111: * Outputs statistics only, not the classifier. <p/>
0112: *
0113: * -i <br/>
0114: * Outputs information-retrieval statistics per class. <p/>
0115: *
0116: * -k <br/>
0117: * Outputs information-theoretic statistics. <p/>
0118: *
0119: * -p range <br/>
0120: * Outputs predictions for test instances (or the train instances if no test
0121: * instances provided), along with the attributes in the specified range
0122: * (and nothing else). Use '-p 0' if no attributes are desired. <p/>
0123: *
0124: * -distribution <br/>
0125: * Outputs the distribution instead of only the prediction
0126: * in conjunction with the '-p' option (only nominal classes). <p/>
0127: *
0128: * -r <br/>
0129: * Outputs cumulative margin distribution (and nothing else). <p/>
0130: *
0131: * -g <br/>
0132: * Only for classifiers that implement "Graphable." Outputs
0133: * the graph representation of the classifier (and nothing
0134: * else). <p/>
0135: *
0136: * -xml filename | xml-string <br/>
0137: * Retrieves the options from the XML-data instead of the command line. <p/>
0138: *
0139: * -threshold-file file <br/>
0140: * The file to save the threshold data to.
0141: * The format is determined by the extensions, e.g., '.arff' for ARFF
0142: * format or '.csv' for CSV. <p/>
0143: *
0144: * -threshold-label label <br/>
0145: * The class label to determine the threshold data for
0146: * (default is the first label) <p/>
0147: *
0148: * ------------------------------------------------------------------- <p/>
0149: *
0150: * Example usage as the main of a classifier (called FunkyClassifier):
0151: * <code> <pre>
0152: * public static void main(String [] args) {
0153: * runClassifier(new FunkyClassifier(), args);
0154: * }
0155: * </pre> </code>
0156: * <p/>
0157: *
0158: * ------------------------------------------------------------------ <p/>
0159: *
0160: * Example usage from within an application:
0161: * <code> <pre>
0162: * Instances trainInstances = ... instances got from somewhere
0163: * Instances testInstances = ... instances got from somewhere
0164: * Classifier scheme = ... scheme got from somewhere
0165: *
0166: * Evaluation evaluation = new Evaluation(trainInstances);
0167: * evaluation.evaluateModel(scheme, testInstances);
0168: * System.out.println(evaluation.toSummaryString());
0169: * </pre> </code>
0170: *
0171: *
0172: * @author Eibe Frank (eibe@cs.waikato.ac.nz)
0173: * @author Len Trigg (trigg@cs.waikato.ac.nz)
0174: * @version $Revision: 1.77 $
0175: */
0176: public class Evaluation implements Summarizable {
0177:
0178: /** The number of classes. */
0179: protected int m_NumClasses;
0180:
0181: /** The number of folds for a cross-validation. */
0182: protected int m_NumFolds;
0183:
0184: /** The weight of all incorrectly classified instances. */
0185: protected double m_Incorrect;
0186:
0187: /** The weight of all correctly classified instances. */
0188: protected double m_Correct;
0189:
0190: /** The weight of all unclassified instances. */
0191: protected double m_Unclassified;
0192:
0193: /*** The weight of all instances that had no class assigned to them. */
0194: protected double m_MissingClass;
0195:
0196: /** The weight of all instances that had a class assigned to them. */
0197: protected double m_WithClass;
0198:
0199: /** Array for storing the confusion matrix. */
0200: protected double[][] m_ConfusionMatrix;
0201:
0202: /** The names of the classes. */
0203: protected String[] m_ClassNames;
0204:
0205: /** Is the class nominal or numeric? */
0206: protected boolean m_ClassIsNominal;
0207:
0208: /** The prior probabilities of the classes */
0209: protected double[] m_ClassPriors;
0210:
0211: /** The sum of counts for priors */
0212: protected double m_ClassPriorsSum;
0213:
0214: /** The cost matrix (if given). */
0215: protected CostMatrix m_CostMatrix;
0216:
0217: /** The total cost of predictions (includes instance weights) */
0218: protected double m_TotalCost;
0219:
0220: /** Sum of errors. */
0221: protected double m_SumErr;
0222:
0223: /** Sum of absolute errors. */
0224: protected double m_SumAbsErr;
0225:
0226: /** Sum of squared errors. */
0227: protected double m_SumSqrErr;
0228:
0229: /** Sum of class values. */
0230: protected double m_SumClass;
0231:
0232: /** Sum of squared class values. */
0233: protected double m_SumSqrClass;
0234:
0235: /*** Sum of predicted values. */
0236: protected double m_SumPredicted;
0237:
0238: /** Sum of squared predicted values. */
0239: protected double m_SumSqrPredicted;
0240:
0241: /** Sum of predicted * class values. */
0242: protected double m_SumClassPredicted;
0243:
0244: /** Sum of absolute errors of the prior */
0245: protected double m_SumPriorAbsErr;
0246:
0247: /** Sum of absolute errors of the prior */
0248: protected double m_SumPriorSqrErr;
0249:
0250: /** Total Kononenko & Bratko Information */
0251: protected double m_SumKBInfo;
0252:
0253: /*** Resolution of the margin histogram */
0254: protected static int k_MarginResolution = 500;
0255:
0256: /** Cumulative margin distribution */
0257: protected double m_MarginCounts[];
0258:
0259: /** Number of non-missing class training instances seen */
0260: protected int m_NumTrainClassVals;
0261:
0262: /** Array containing all numeric training class values seen */
0263: protected double[] m_TrainClassVals;
0264:
0265: /** Array containing all numeric training class weights */
0266: protected double[] m_TrainClassWeights;
0267:
0268: /** Numeric class error estimator for prior */
0269: protected Estimator m_PriorErrorEstimator;
0270:
0271: /** Numeric class error estimator for scheme */
0272: protected Estimator m_ErrorEstimator;
0273:
0274: /**
0275: * The minimum probablility accepted from an estimator to avoid
0276: * taking log(0) in Sf calculations.
0277: */
0278: protected static final double MIN_SF_PROB = Double.MIN_VALUE;
0279:
0280: /** Total entropy of prior predictions */
0281: protected double m_SumPriorEntropy;
0282:
0283: /** Total entropy of scheme predictions */
0284: protected double m_SumSchemeEntropy;
0285:
0286: /** The list of predictions that have been generated (for computing AUC) */
0287: private FastVector m_Predictions;
0288:
0289: /** enables/disables the use of priors, e.g., if no training set is
0290: * present in case of de-serialized schemes */
0291: protected boolean m_NoPriors = false;
0292:
0293: /**
0294: * Initializes all the counters for the evaluation.
0295: * Use <code>useNoPriors()</code> if the dataset is the test set and you
0296: * can't initialize with the priors from the training set via
0297: * <code>setPriors(Instances)</code>.
0298: *
0299: * @param data set of training instances, to get some header
0300: * information and prior class distribution information
0301: * @throws Exception if the class is not defined
0302: * @see #useNoPriors()
0303: * @see #setPriors(Instances)
0304: */
0305: public Evaluation(Instances data) throws Exception {
0306:
0307: this (data, null);
0308: }
0309:
0310: /**
0311: * Initializes all the counters for the evaluation and also takes a
0312: * cost matrix as parameter.
0313: * Use <code>useNoPriors()</code> if the dataset is the test set and you
0314: * can't initialize with the priors from the training set via
0315: * <code>setPriors(Instances)</code>.
0316: *
0317: * @param data set of training instances, to get some header
0318: * information and prior class distribution information
0319: * @param costMatrix the cost matrix---if null, default costs will be used
0320: * @throws Exception if cost matrix is not compatible with
0321: * data, the class is not defined or the class is numeric
0322: * @see #useNoPriors()
0323: * @see #setPriors(Instances)
0324: */
0325: public Evaluation(Instances data, CostMatrix costMatrix)
0326: throws Exception {
0327:
0328: m_NumClasses = data.numClasses();
0329: m_NumFolds = 1;
0330: m_ClassIsNominal = data.classAttribute().isNominal();
0331:
0332: if (m_ClassIsNominal) {
0333: m_ConfusionMatrix = new double[m_NumClasses][m_NumClasses];
0334: m_ClassNames = new String[m_NumClasses];
0335: for (int i = 0; i < m_NumClasses; i++) {
0336: m_ClassNames[i] = data.classAttribute().value(i);
0337: }
0338: }
0339: m_CostMatrix = costMatrix;
0340: if (m_CostMatrix != null) {
0341: if (!m_ClassIsNominal) {
0342: throw new Exception(
0343: "Class has to be nominal if cost matrix "
0344: + "given!");
0345: }
0346: if (m_CostMatrix.size() != m_NumClasses) {
0347: throw new Exception(
0348: "Cost matrix not compatible with data!");
0349: }
0350: }
0351: m_ClassPriors = new double[m_NumClasses];
0352: setPriors(data);
0353: m_MarginCounts = new double[k_MarginResolution + 1];
0354: }
0355:
0356: /**
0357: * Returns the area under ROC for those predictions that have been collected
0358: * in the evaluateClassifier(Classifier, Instances) method. Returns
0359: * Instance.missingValue() if the area is not available.
0360: *
0361: * @param classIndex the index of the class to consider as "positive"
0362: * @return the area under the ROC curve or not a number
0363: */
0364: public double areaUnderROC(int classIndex) {
0365:
0366: // Check if any predictions have been collected
0367: if (m_Predictions == null) {
0368: return Instance.missingValue();
0369: } else {
0370: ThresholdCurve tc = new ThresholdCurve();
0371: Instances result = tc.getCurve(m_Predictions, classIndex);
0372: return ThresholdCurve.getROCArea(result);
0373: }
0374: }
0375:
0376: /**
0377: * Returns a copy of the confusion matrix.
0378: *
0379: * @return a copy of the confusion matrix as a two-dimensional array
0380: */
0381: public double[][] confusionMatrix() {
0382:
0383: double[][] newMatrix = new double[m_ConfusionMatrix.length][0];
0384:
0385: for (int i = 0; i < m_ConfusionMatrix.length; i++) {
0386: newMatrix[i] = new double[m_ConfusionMatrix[i].length];
0387: System.arraycopy(m_ConfusionMatrix[i], 0, newMatrix[i], 0,
0388: m_ConfusionMatrix[i].length);
0389: }
0390: return newMatrix;
0391: }
0392:
0393: /**
0394: * Performs a (stratified if class is nominal) cross-validation
0395: * for a classifier on a set of instances. Now performs
0396: * a deep copy of the classifier before each call to
0397: * buildClassifier() (just in case the classifier is not
0398: * initialized properly).
0399: *
0400: * @param classifier the classifier with any options set.
0401: * @param data the data on which the cross-validation is to be
0402: * performed
0403: * @param numFolds the number of folds for the cross-validation
0404: * @param random random number generator for randomization
0405: * @throws Exception if a classifier could not be generated
0406: * successfully or the class is not defined
0407: */
0408: public void crossValidateModel(Classifier classifier,
0409: Instances data, int numFolds, Random random)
0410: throws Exception {
0411:
0412: // Make a copy of the data we can reorder
0413: data = new Instances(data);
0414: data.randomize(random);
0415: if (data.classAttribute().isNominal()) {
0416: data.stratify(numFolds);
0417: }
0418: // Do the folds
0419: for (int i = 0; i < numFolds; i++) {
0420: Instances train = data.trainCV(numFolds, i, random);
0421: setPriors(train);
0422: Classifier copiedClassifier = Classifier
0423: .makeCopy(classifier);
0424: copiedClassifier.buildClassifier(train);
0425: Instances test = data.testCV(numFolds, i);
0426: evaluateModel(copiedClassifier, test);
0427: }
0428: m_NumFolds = numFolds;
0429: }
0430:
0431: /**
0432: * Performs a (stratified if class is nominal) cross-validation
0433: * for a classifier on a set of instances.
0434: *
0435: * @param classifierString a string naming the class of the classifier
0436: * @param data the data on which the cross-validation is to be
0437: * performed
0438: * @param numFolds the number of folds for the cross-validation
0439: * @param options the options to the classifier. Any options
0440: * @param random the random number generator for randomizing the data
0441: * accepted by the classifier will be removed from this array.
0442: * @throws Exception if a classifier could not be generated
0443: * successfully or the class is not defined
0444: */
0445: public void crossValidateModel(String classifierString,
0446: Instances data, int numFolds, String[] options,
0447: Random random) throws Exception {
0448:
0449: crossValidateModel(Classifier
0450: .forName(classifierString, options), data, numFolds,
0451: random);
0452: }
0453:
0454: /**
0455: * Evaluates a classifier with the options given in an array of
0456: * strings. <p/>
0457: *
0458: * Valid options are: <p/>
0459: *
0460: * -t filename <br/>
0461: * Name of the file with the training data. (required) <p/>
0462: *
0463: * -T filename <br/>
0464: * Name of the file with the test data. If missing a cross-validation
0465: * is performed. <p/>
0466: *
0467: * -c index <br/>
0468: * Index of the class attribute (1, 2, ...; default: last). <p/>
0469: *
0470: * -x number <br/>
0471: * The number of folds for the cross-validation (default: 10). <p/>
0472: *
0473: * -no-cv <br/>
0474: * No cross validation. If no test file is provided, no evaluation
0475: * is done. <p/>
0476: *
0477: * -split-percentage percentage <br/>
0478: * Sets the percentage for the train/test set split, e.g., 66. <p/>
0479: *
0480: * -preserve-order <br/>
0481: * Preserves the order in the percentage split instead of randomizing
0482: * the data first with the seed value ('-s'). <p/>
0483: *
0484: * -s seed <br/>
0485: * Random number seed for the cross-validation and percentage split
0486: * (default: 1). <p/>
0487: *
0488: * -m filename <br/>
0489: * The name of a file containing a cost matrix. <p/>
0490: *
0491: * -l filename <br/>
0492: * Loads classifier from the given file. In case the filename ends with
0493: * ".xml" the options are loaded from XML. <p/>
0494: *
0495: * -d filename <br/>
0496: * Saves classifier built from the training data into the given file. In case
0497: * the filename ends with ".xml" the options are saved XML, not the model. <p/>
0498: *
0499: * -v <br/>
0500: * Outputs no statistics for the training data. <p/>
0501: *
0502: * -o <br/>
0503: * Outputs statistics only, not the classifier. <p/>
0504: *
0505: * -i <br/>
0506: * Outputs detailed information-retrieval statistics per class. <p/>
0507: *
0508: * -k <br/>
0509: * Outputs information-theoretic statistics. <p/>
0510: *
0511: * -p range <br/>
0512: * Outputs predictions for test instances (or the train instances if no test
0513: * instances provided), along with the attributes in the specified range (and
0514: * nothing else). Use '-p 0' if no attributes are desired. <p/>
0515: *
0516: * -distribution <br/>
0517: * Outputs the distribution instead of only the prediction
0518: * in conjunction with the '-p' option (only nominal classes). <p/>
0519: *
0520: * -r <br/>
0521: * Outputs cumulative margin distribution (and nothing else). <p/>
0522: *
0523: * -g <br/>
0524: * Only for classifiers that implement "Graphable." Outputs
0525: * the graph representation of the classifier (and nothing
0526: * else). <p/>
0527: *
0528: * -xml filename | xml-string <br/>
0529: * Retrieves the options from the XML-data instead of the command line. <p/>
0530: *
0531: * -threshold-file file <br/>
0532: * The file to save the threshold data to.
0533: * The format is determined by the extensions, e.g., '.arff' for ARFF
0534: * format or '.csv' for CSV. <p/>
0535: *
0536: * -threshold-label label <br/>
0537: * The class label to determine the threshold data for
0538: * (default is the first label) <p/>
0539: *
0540: * @param classifierString class of machine learning classifier as a string
0541: * @param options the array of string containing the options
0542: * @throws Exception if model could not be evaluated successfully
0543: * @return a string describing the results
0544: */
0545: public static String evaluateModel(String classifierString,
0546: String[] options) throws Exception {
0547:
0548: Classifier classifier;
0549:
0550: // Create classifier
0551: try {
0552: classifier = (Classifier) Class.forName(classifierString)
0553: .newInstance();
0554: } catch (Exception e) {
0555: throw new Exception("Can't find class with name "
0556: + classifierString + '.');
0557: }
0558: return evaluateModel(classifier, options);
0559: }
0560:
0561: /**
0562: * A test method for this class. Just extracts the first command line
0563: * argument as a classifier class name and calls evaluateModel.
0564: * @param args an array of command line arguments, the first of which
0565: * must be the class name of a classifier.
0566: */
0567: public static void main(String[] args) {
0568:
0569: try {
0570: if (args.length == 0) {
0571: throw new Exception(
0572: "The first argument must be the class name"
0573: + " of a classifier");
0574: }
0575: String classifier = args[0];
0576: args[0] = "";
0577: System.out.println(evaluateModel(classifier, args));
0578: } catch (Exception ex) {
0579: ex.printStackTrace();
0580: System.err.println(ex.getMessage());
0581: }
0582: }
0583:
0584: /**
0585: * Evaluates a classifier with the options given in an array of
0586: * strings. <p/>
0587: *
0588: * Valid options are: <p/>
0589: *
0590: * -t name of training file <br/>
0591: * Name of the file with the training data. (required) <p/>
0592: *
0593: * -T name of test file <br/>
0594: * Name of the file with the test data. If missing a cross-validation
0595: * is performed. <p/>
0596: *
0597: * -c class index <br/>
0598: * Index of the class attribute (1, 2, ...; default: last). <p/>
0599: *
0600: * -x number of folds <br/>
0601: * The number of folds for the cross-validation (default: 10). <p/>
0602: *
0603: * -no-cv <br/>
0604: * No cross validation. If no test file is provided, no evaluation
0605: * is done. <p/>
0606: *
0607: * -split-percentage percentage <br/>
0608: * Sets the percentage for the train/test set split, e.g., 66. <p/>
0609: *
0610: * -preserve-order <br/>
0611: * Preserves the order in the percentage split instead of randomizing
0612: * the data first with the seed value ('-s'). <p/>
0613: *
0614: * -s seed <br/>
0615: * Random number seed for the cross-validation and percentage split
0616: * (default: 1). <p/>
0617: *
0618: * -m file with cost matrix <br/>
0619: * The name of a file containing a cost matrix. <p/>
0620: *
0621: * -l filename <br/>
0622: * Loads classifier from the given file. In case the filename ends with
0623: * ".xml" the options are loaded from XML. <p/>
0624: *
0625: * -d filename <br/>
0626: * Saves classifier built from the training data into the given file. In case
0627: * the filename ends with ".xml" the options are saved XML, not the model. <p/>
0628: *
0629: * -v <br/>
0630: * Outputs no statistics for the training data. <p/>
0631: *
0632: * -o <br/>
0633: * Outputs statistics only, not the classifier. <p/>
0634: *
0635: * -i <br/>
0636: * Outputs detailed information-retrieval statistics per class. <p/>
0637: *
0638: * -k <br/>
0639: * Outputs information-theoretic statistics. <p/>
0640: *
0641: * -p range <br/>
0642: * Outputs predictions for test instances (or the train instances if no test
0643: * instances provided), along with the attributes in the specified range
0644: * (and nothing else). Use '-p 0' if no attributes are desired. <p/>
0645: *
0646: * -distribution <br/>
0647: * Outputs the distribution instead of only the prediction
0648: * in conjunction with the '-p' option (only nominal classes). <p/>
0649: *
0650: * -r <br/>
0651: * Outputs cumulative margin distribution (and nothing else). <p/>
0652: *
0653: * -g <br/>
0654: * Only for classifiers that implement "Graphable." Outputs
0655: * the graph representation of the classifier (and nothing
0656: * else). <p/>
0657: *
0658: * -xml filename | xml-string <br/>
0659: * Retrieves the options from the XML-data instead of the command line. <p/>
0660: *
0661: * @param classifier machine learning classifier
0662: * @param options the array of string containing the options
0663: * @throws Exception if model could not be evaluated successfully
0664: * @return a string describing the results
0665: */
0666: public static String evaluateModel(Classifier classifier,
0667: String[] options) throws Exception {
0668:
0669: Instances train = null, tempTrain, test = null, template = null;
0670: int seed = 1, folds = 10, classIndex = -1;
0671: boolean noCrossValidation = false;
0672: String trainFileName, testFileName, sourceClass, classIndexString, seedString, foldsString, objectInputFileName, objectOutputFileName, attributeRangeString;
0673: boolean noOutput = false, printClassifications = false, trainStatistics = true, printMargins = false, printComplexityStatistics = false, printGraph = false, classStatistics = false, printSource = false;
0674: StringBuffer text = new StringBuffer();
0675: DataSource trainSource = null, testSource = null;
0676: ObjectInputStream objectInputStream = null;
0677: BufferedInputStream xmlInputStream = null;
0678: CostMatrix costMatrix = null;
0679: StringBuffer schemeOptionsText = null;
0680: Range attributesToOutput = null;
0681: long trainTimeStart = 0, trainTimeElapsed = 0, testTimeStart = 0, testTimeElapsed = 0;
0682: String xml = "";
0683: String[] optionsTmp = null;
0684: Classifier classifierBackup;
0685: Classifier classifierClassifications = null;
0686: boolean printDistribution = false;
0687: int actualClassIndex = -1; // 0-based class index
0688: String splitPercentageString = "";
0689: int splitPercentage = -1;
0690: boolean preserveOrder = false;
0691: boolean trainSetPresent = false;
0692: boolean testSetPresent = false;
0693: String thresholdFile;
0694: String thresholdLabel;
0695:
0696: // help requested?
0697: if (Utils.getFlag("h", options)
0698: || Utils.getFlag("help", options)) {
0699: throw new Exception("\nHelp requested."
0700: + makeOptionString(classifier));
0701: }
0702:
0703: try {
0704: // do we get the input from XML instead of normal parameters?
0705: xml = Utils.getOption("xml", options);
0706: if (!xml.equals(""))
0707: options = new XMLOptions(xml).toArray();
0708:
0709: // is the input model only the XML-Options, i.e. w/o built model?
0710: optionsTmp = new String[options.length];
0711: for (int i = 0; i < options.length; i++)
0712: optionsTmp[i] = options[i];
0713:
0714: if (Utils.getOption('l', optionsTmp).toLowerCase()
0715: .endsWith(".xml")) {
0716: // load options from serialized data ('-l' is automatically erased!)
0717: XMLClassifier xmlserial = new XMLClassifier();
0718: Classifier cl = (Classifier) xmlserial.read(Utils
0719: .getOption('l', options));
0720: // merge options
0721: optionsTmp = new String[options.length
0722: + cl.getOptions().length];
0723: System.arraycopy(cl.getOptions(), 0, optionsTmp, 0, cl
0724: .getOptions().length);
0725: System.arraycopy(options, 0, optionsTmp, cl
0726: .getOptions().length, options.length);
0727: options = optionsTmp;
0728: }
0729:
0730: noCrossValidation = Utils.getFlag("no-cv", options);
0731: // Get basic options (options the same for all schemes)
0732: classIndexString = Utils.getOption('c', options);
0733: if (classIndexString.length() != 0) {
0734: if (classIndexString.equals("first"))
0735: classIndex = 1;
0736: else if (classIndexString.equals("last"))
0737: classIndex = -1;
0738: else
0739: classIndex = Integer.parseInt(classIndexString);
0740: }
0741: trainFileName = Utils.getOption('t', options);
0742: objectInputFileName = Utils.getOption('l', options);
0743: objectOutputFileName = Utils.getOption('d', options);
0744: testFileName = Utils.getOption('T', options);
0745: foldsString = Utils.getOption('x', options);
0746: if (foldsString.length() != 0) {
0747: folds = Integer.parseInt(foldsString);
0748: }
0749: seedString = Utils.getOption('s', options);
0750: if (seedString.length() != 0) {
0751: seed = Integer.parseInt(seedString);
0752: }
0753: if (trainFileName.length() == 0) {
0754: if (objectInputFileName.length() == 0) {
0755: throw new Exception(
0756: "No training file and no object "
0757: + "input file given.");
0758: }
0759: if (testFileName.length() == 0) {
0760: throw new Exception("No training file and no test "
0761: + "file given.");
0762: }
0763: } else if ((objectInputFileName.length() != 0)
0764: && ((!(classifier instanceof UpdateableClassifier)) || (testFileName
0765: .length() == 0))) {
0766: throw new Exception(
0767: "Classifier not incremental, or no "
0768: + "test file provided: can't "
0769: + "use both train and model file.");
0770: }
0771: try {
0772: if (trainFileName.length() != 0) {
0773: trainSetPresent = true;
0774: trainSource = new DataSource(trainFileName);
0775: }
0776: if (testFileName.length() != 0) {
0777: testSetPresent = true;
0778: testSource = new DataSource(testFileName);
0779: }
0780: if (objectInputFileName.length() != 0) {
0781: InputStream is = new FileInputStream(
0782: objectInputFileName);
0783: if (objectInputFileName.endsWith(".gz")) {
0784: is = new GZIPInputStream(is);
0785: }
0786: // load from KOML?
0787: if (!(objectInputFileName.endsWith(".koml") && KOML
0788: .isPresent())) {
0789: objectInputStream = new ObjectInputStream(is);
0790: xmlInputStream = null;
0791: } else {
0792: objectInputStream = null;
0793: xmlInputStream = new BufferedInputStream(is);
0794: }
0795: }
0796: } catch (Exception e) {
0797: throw new Exception("Can't open file " + e.getMessage()
0798: + '.');
0799: }
0800: if (testSetPresent) {
0801: template = test = testSource.getStructure();
0802: if (classIndex != -1) {
0803: test.setClassIndex(classIndex - 1);
0804: } else {
0805: if ((test.classIndex() == -1)
0806: || (classIndexString.length() != 0))
0807: test.setClassIndex(test.numAttributes() - 1);
0808: }
0809: actualClassIndex = test.classIndex();
0810: } else {
0811: // percentage split
0812: splitPercentageString = Utils.getOption(
0813: "split-percentage", options);
0814: if (splitPercentageString.length() != 0) {
0815: if (foldsString.length() != 0)
0816: throw new Exception(
0817: "Percentage split cannot be used in conjunction with "
0818: + "cross-validation ('-x').");
0819: splitPercentage = Integer
0820: .parseInt(splitPercentageString);
0821: if ((splitPercentage <= 0)
0822: || (splitPercentage >= 100))
0823: throw new Exception(
0824: "Percentage split value needs be >0 and <100.");
0825: } else {
0826: splitPercentage = -1;
0827: }
0828: preserveOrder = Utils
0829: .getFlag("preserve-order", options);
0830: if (preserveOrder) {
0831: if (splitPercentage == -1)
0832: throw new Exception(
0833: "Percentage split ('-percentage-split') is missing.");
0834: }
0835: // create new train/test sources
0836: if (splitPercentage > 0) {
0837: testSetPresent = true;
0838: Instances tmpInst = trainSource
0839: .getDataSet(actualClassIndex);
0840: if (!preserveOrder)
0841: tmpInst.randomize(new Random(seed));
0842: int trainSize = tmpInst.numInstances()
0843: * splitPercentage / 100;
0844: int testSize = tmpInst.numInstances() - trainSize;
0845: Instances trainInst = new Instances(tmpInst, 0,
0846: trainSize);
0847: Instances testInst = new Instances(tmpInst,
0848: trainSize, testSize);
0849: trainSource = new DataSource(trainInst);
0850: testSource = new DataSource(testInst);
0851: template = test = testSource.getStructure();
0852: if (classIndex != -1) {
0853: test.setClassIndex(classIndex - 1);
0854: } else {
0855: if ((test.classIndex() == -1)
0856: || (classIndexString.length() != 0))
0857: test
0858: .setClassIndex(test.numAttributes() - 1);
0859: }
0860: actualClassIndex = test.classIndex();
0861: }
0862: }
0863: if (trainSetPresent) {
0864: template = train = trainSource.getStructure();
0865: if (classIndex != -1) {
0866: train.setClassIndex(classIndex - 1);
0867: } else {
0868: if ((train.classIndex() == -1)
0869: || (classIndexString.length() != 0))
0870: train.setClassIndex(train.numAttributes() - 1);
0871: }
0872: actualClassIndex = train.classIndex();
0873: if ((testSetPresent) && !test.equalHeaders(train)) {
0874: throw new IllegalArgumentException(
0875: "Train and test file not compatible!");
0876: }
0877: }
0878: if (template == null) {
0879: throw new Exception(
0880: "No actual dataset provided to use as template");
0881: }
0882: costMatrix = handleCostOption(
0883: Utils.getOption('m', options), template
0884: .numClasses());
0885:
0886: classStatistics = Utils.getFlag('i', options);
0887: noOutput = Utils.getFlag('o', options);
0888: trainStatistics = !Utils.getFlag('v', options);
0889: printComplexityStatistics = Utils.getFlag('k', options);
0890: printMargins = Utils.getFlag('r', options);
0891: printGraph = Utils.getFlag('g', options);
0892: sourceClass = Utils.getOption('z', options);
0893: printSource = (sourceClass.length() != 0);
0894: printDistribution = Utils.getFlag("distribution", options);
0895: thresholdFile = Utils.getOption("threshold-file", options);
0896: thresholdLabel = Utils
0897: .getOption("threshold-label", options);
0898:
0899: // Check -p option
0900: try {
0901: attributeRangeString = Utils.getOption('p', options);
0902: } catch (Exception e) {
0903: throw new Exception(
0904: e.getMessage()
0905: + "\nNOTE: the -p option has changed. "
0906: + "It now expects a parameter specifying a range of attributes "
0907: + "to list with the predictions. Use '-p 0' for none.");
0908: }
0909: if (attributeRangeString.length() != 0) {
0910: printClassifications = true;
0911: if (!attributeRangeString.equals("0"))
0912: attributesToOutput = new Range(attributeRangeString);
0913: }
0914:
0915: if (!printClassifications && printDistribution)
0916: throw new Exception(
0917: "Cannot print distribution without '-p' option!");
0918:
0919: // if no training file given, we don't have any priors
0920: if ((!trainSetPresent) && (printComplexityStatistics))
0921: throw new Exception(
0922: "Cannot print complexity statistics ('-k') without training file ('-t')!");
0923:
0924: // If a model file is given, we can't process
0925: // scheme-specific options
0926: if (objectInputFileName.length() != 0) {
0927: Utils.checkForRemainingOptions(options);
0928: } else {
0929:
0930: // Set options for classifier
0931: if (classifier instanceof OptionHandler) {
0932: for (int i = 0; i < options.length; i++) {
0933: if (options[i].length() != 0) {
0934: if (schemeOptionsText == null) {
0935: schemeOptionsText = new StringBuffer();
0936: }
0937: if (options[i].indexOf(' ') != -1) {
0938: schemeOptionsText.append('"'
0939: + options[i] + "\" ");
0940: } else {
0941: schemeOptionsText.append(options[i]
0942: + " ");
0943: }
0944: }
0945: }
0946: ((OptionHandler) classifier).setOptions(options);
0947: }
0948: }
0949: Utils.checkForRemainingOptions(options);
0950: } catch (Exception e) {
0951: throw new Exception("\nWeka exception: " + e.getMessage()
0952: + makeOptionString(classifier));
0953: }
0954:
0955: // Setup up evaluation objects
0956: Evaluation trainingEvaluation = new Evaluation(new Instances(
0957: template, 0), costMatrix);
0958: Evaluation testingEvaluation = new Evaluation(new Instances(
0959: template, 0), costMatrix);
0960:
0961: // disable use of priors if no training file given
0962: if (!trainSetPresent)
0963: testingEvaluation.useNoPriors();
0964:
0965: if (objectInputFileName.length() != 0) {
0966: // Load classifier from file
0967: if (objectInputStream != null) {
0968: classifier = (Classifier) objectInputStream
0969: .readObject();
0970: objectInputStream.close();
0971: } else {
0972: // whether KOML is available has already been checked (objectInputStream would null otherwise)!
0973: classifier = (Classifier) KOML.read(xmlInputStream);
0974: xmlInputStream.close();
0975: }
0976: }
0977:
0978: // backup of fully setup classifier for cross-validation
0979: classifierBackup = Classifier.makeCopy(classifier);
0980:
0981: // Build the classifier if no object file provided
0982: if ((classifier instanceof UpdateableClassifier)
0983: && (testSetPresent) && (costMatrix == null)
0984: && (trainSetPresent)) {
0985:
0986: // Build classifier incrementally
0987: trainingEvaluation.setPriors(train);
0988: testingEvaluation.setPriors(train);
0989: trainTimeStart = System.currentTimeMillis();
0990: if (objectInputFileName.length() == 0) {
0991: classifier.buildClassifier(train);
0992: }
0993: Instance trainInst;
0994: while (trainSource.hasMoreElements(train)) {
0995: trainInst = trainSource.nextElement(train);
0996: trainingEvaluation.updatePriors(trainInst);
0997: testingEvaluation.updatePriors(trainInst);
0998: ((UpdateableClassifier) classifier)
0999: .updateClassifier(trainInst);
1000: }
1001: trainTimeElapsed = System.currentTimeMillis()
1002: - trainTimeStart;
1003: } else if (objectInputFileName.length() == 0) {
1004: // Build classifier in one go
1005: tempTrain = trainSource.getDataSet(actualClassIndex);
1006: trainingEvaluation.setPriors(tempTrain);
1007: testingEvaluation.setPriors(tempTrain);
1008: trainTimeStart = System.currentTimeMillis();
1009: classifier.buildClassifier(tempTrain);
1010: trainTimeElapsed = System.currentTimeMillis()
1011: - trainTimeStart;
1012: }
1013:
1014: // backup of fully trained classifier for printing the classifications
1015: if (printClassifications)
1016: classifierClassifications = Classifier.makeCopy(classifier);
1017:
1018: // Save the classifier if an object output file is provided
1019: if (objectOutputFileName.length() != 0) {
1020: OutputStream os = new FileOutputStream(objectOutputFileName);
1021: // binary
1022: if (!(objectOutputFileName.endsWith(".xml") || (objectOutputFileName
1023: .endsWith(".koml") && KOML.isPresent()))) {
1024: if (objectOutputFileName.endsWith(".gz")) {
1025: os = new GZIPOutputStream(os);
1026: }
1027: ObjectOutputStream objectOutputStream = new ObjectOutputStream(
1028: os);
1029: objectOutputStream.writeObject(classifier);
1030: objectOutputStream.flush();
1031: objectOutputStream.close();
1032: }
1033: // KOML/XML
1034: else {
1035: BufferedOutputStream xmlOutputStream = new BufferedOutputStream(
1036: os);
1037: if (objectOutputFileName.endsWith(".xml")) {
1038: XMLSerialization xmlSerial = new XMLClassifier();
1039: xmlSerial.write(xmlOutputStream, classifier);
1040: } else
1041: // whether KOML is present has already been checked
1042: // if not present -> ".koml" is interpreted as binary - see above
1043: if (objectOutputFileName.endsWith(".koml")) {
1044: KOML.write(xmlOutputStream, classifier);
1045: }
1046: xmlOutputStream.close();
1047: }
1048: }
1049:
1050: // If classifier is drawable output string describing graph
1051: if ((classifier instanceof Drawable) && (printGraph)) {
1052: return ((Drawable) classifier).graph();
1053: }
1054:
1055: // Output the classifier as equivalent source
1056: if ((classifier instanceof Sourcable) && (printSource)) {
1057: return wekaStaticWrapper((Sourcable) classifier,
1058: sourceClass);
1059: }
1060:
1061: // Output model
1062: if (!(noOutput || printMargins)) {
1063: if (classifier instanceof OptionHandler) {
1064: if (schemeOptionsText != null) {
1065: text.append("\nOptions: " + schemeOptionsText);
1066: text.append("\n");
1067: }
1068: }
1069: text.append("\n" + classifier.toString() + "\n");
1070: }
1071:
1072: if (!printMargins && (costMatrix != null)) {
1073: text.append("\n=== Evaluation Cost Matrix ===\n\n");
1074: text.append(costMatrix.toString());
1075: }
1076:
1077: // Output test instance predictions only
1078: if (printClassifications) {
1079: DataSource source = testSource;
1080: // no test set -> use train set
1081: if (source == null)
1082: source = trainSource;
1083: return printClassifications(classifierClassifications,
1084: new Instances(template, 0), source,
1085: actualClassIndex + 1, attributesToOutput,
1086: printDistribution);
1087: }
1088:
1089: // Compute error estimate from training data
1090: if ((trainStatistics) && (trainSetPresent)) {
1091:
1092: if ((classifier instanceof UpdateableClassifier)
1093: && (testSetPresent) && (costMatrix == null)) {
1094:
1095: // Classifier was trained incrementally, so we have to
1096: // reset the source.
1097: trainSource.reset();
1098:
1099: // Incremental testing
1100: train = trainSource.getStructure(actualClassIndex);
1101: testTimeStart = System.currentTimeMillis();
1102: Instance trainInst;
1103: while (trainSource.hasMoreElements(train)) {
1104: trainInst = trainSource.nextElement(train);
1105: trainingEvaluation.evaluateModelOnce(
1106: (Classifier) classifier, trainInst);
1107: }
1108: testTimeElapsed = System.currentTimeMillis()
1109: - testTimeStart;
1110: } else {
1111: testTimeStart = System.currentTimeMillis();
1112: trainingEvaluation.evaluateModel(classifier,
1113: trainSource.getDataSet(actualClassIndex));
1114: testTimeElapsed = System.currentTimeMillis()
1115: - testTimeStart;
1116: }
1117:
1118: // Print the results of the training evaluation
1119: if (printMargins) {
1120: return trainingEvaluation
1121: .toCumulativeMarginDistributionString();
1122: } else {
1123: text.append("\nTime taken to build model: "
1124: + Utils.doubleToString(
1125: trainTimeElapsed / 1000.0, 2)
1126: + " seconds");
1127:
1128: if (splitPercentage > 0)
1129: text
1130: .append("\nTime taken to test model on training split: ");
1131: else
1132: text
1133: .append("\nTime taken to test model on training data: ");
1134: text.append(Utils.doubleToString(
1135: testTimeElapsed / 1000.0, 2)
1136: + " seconds");
1137:
1138: if (splitPercentage > 0)
1139: text.append(trainingEvaluation.toSummaryString(
1140: "\n\n=== Error on training"
1141: + " split ===\n",
1142: printComplexityStatistics));
1143: else
1144: text.append(trainingEvaluation
1145: .toSummaryString(
1146: "\n\n=== Error on training"
1147: + " data ===\n",
1148: printComplexityStatistics));
1149:
1150: if (template.classAttribute().isNominal()) {
1151: if (classStatistics) {
1152: text.append("\n\n"
1153: + trainingEvaluation
1154: .toClassDetailsString());
1155: }
1156: if (!noCrossValidation)
1157: text.append("\n\n"
1158: + trainingEvaluation.toMatrixString());
1159: }
1160:
1161: }
1162: }
1163:
1164: // Compute proper error estimates
1165: if (testSource != null) {
1166: // Testing is on the supplied test data
1167: Instance testInst;
1168: while (testSource.hasMoreElements(test)) {
1169: testInst = testSource.nextElement(test);
1170: testingEvaluation.evaluateModelOnceAndRecordPrediction(
1171: (Classifier) classifier, testInst);
1172: }
1173:
1174: if (splitPercentage > 0)
1175: text.append("\n\n"
1176: + testingEvaluation.toSummaryString(
1177: "=== Error on test split ===\n",
1178: printComplexityStatistics));
1179: else
1180: text.append("\n\n"
1181: + testingEvaluation.toSummaryString(
1182: "=== Error on test data ===\n",
1183: printComplexityStatistics));
1184:
1185: } else if (trainSource != null) {
1186: if (!noCrossValidation) {
1187: // Testing is via cross-validation on training data
1188: Random random = new Random(seed);
1189: // use untrained (!) classifier for cross-validation
1190: classifier = Classifier.makeCopy(classifierBackup);
1191: testingEvaluation.crossValidateModel(classifier,
1192: trainSource.getDataSet(actualClassIndex),
1193: folds, random);
1194: if (template.classAttribute().isNumeric()) {
1195: text.append("\n\n\n"
1196: + testingEvaluation.toSummaryString(
1197: "=== Cross-validation ===\n",
1198: printComplexityStatistics));
1199: } else {
1200: text.append("\n\n\n"
1201: + testingEvaluation.toSummaryString(
1202: "=== Stratified "
1203: + "cross-validation ===\n",
1204: printComplexityStatistics));
1205: }
1206: }
1207: }
1208: if (template.classAttribute().isNominal()) {
1209: if (classStatistics) {
1210: text.append("\n\n"
1211: + testingEvaluation.toClassDetailsString());
1212: }
1213: if (!noCrossValidation)
1214: text
1215: .append("\n\n"
1216: + testingEvaluation.toMatrixString());
1217: }
1218:
1219: if ((thresholdFile.length() != 0)
1220: && template.classAttribute().isNominal()) {
1221: int labelIndex = 0;
1222: if (thresholdLabel.length() != 0)
1223: labelIndex = template.classAttribute().indexOfValue(
1224: thresholdLabel);
1225: if (labelIndex == -1)
1226: throw new IllegalArgumentException("Class label '"
1227: + thresholdLabel + "' is unknown!");
1228: ThresholdCurve tc = new ThresholdCurve();
1229: Instances result = tc.getCurve(testingEvaluation
1230: .predictions(), labelIndex);
1231: DataSink.write(thresholdFile, result);
1232: }
1233:
1234: return text.toString();
1235: }
1236:
1237: /**
1238: * Attempts to load a cost matrix.
1239: *
1240: * @param costFileName the filename of the cost matrix
1241: * @param numClasses the number of classes that should be in the cost matrix
1242: * (only used if the cost file is in old format).
1243: * @return a <code>CostMatrix</code> value, or null if costFileName is empty
1244: * @throws Exception if an error occurs.
1245: */
1246: protected static CostMatrix handleCostOption(String costFileName,
1247: int numClasses) throws Exception {
1248:
1249: if ((costFileName != null) && (costFileName.length() != 0)) {
1250: System.out
1251: .println("NOTE: The behaviour of the -m option has changed between WEKA 3.0"
1252: + " and WEKA 3.1. -m now carries out cost-sensitive *evaluation*"
1253: + " only. For cost-sensitive *prediction*, use one of the"
1254: + " cost-sensitive metaschemes such as"
1255: + " weka.classifiers.meta.CostSensitiveClassifier or"
1256: + " weka.classifiers.meta.MetaCost");
1257:
1258: Reader costReader = null;
1259: try {
1260: costReader = new BufferedReader(new FileReader(
1261: costFileName));
1262: } catch (Exception e) {
1263: throw new Exception("Can't open file " + e.getMessage()
1264: + '.');
1265: }
1266: try {
1267: // First try as a proper cost matrix format
1268: return new CostMatrix(costReader);
1269: } catch (Exception ex) {
1270: try {
1271: // Now try as the poxy old format :-)
1272: //System.err.println("Attempting to read old format cost file");
1273: try {
1274: costReader.close(); // Close the old one
1275: costReader = new BufferedReader(new FileReader(
1276: costFileName));
1277: } catch (Exception e) {
1278: throw new Exception("Can't open file "
1279: + e.getMessage() + '.');
1280: }
1281: CostMatrix costMatrix = new CostMatrix(numClasses);
1282: //System.err.println("Created default cost matrix");
1283: costMatrix.readOldFormat(costReader);
1284: return costMatrix;
1285: //System.err.println("Read old format");
1286: } catch (Exception e2) {
1287: // re-throw the original exception
1288: //System.err.println("Re-throwing original exception");
1289: throw ex;
1290: }
1291: }
1292: } else {
1293: return null;
1294: }
1295: }
1296:
1297: /**
1298: * Evaluates the classifier on a given set of instances. Note that
1299: * the data must have exactly the same format (e.g. order of
1300: * attributes) as the data used to train the classifier! Otherwise
1301: * the results will generally be meaningless.
1302: *
1303: * @param classifier machine learning classifier
1304: * @param data set of test instances for evaluation
1305: * @return the predictions
1306: * @throws Exception if model could not be evaluated
1307: * successfully
1308: */
1309: public double[] evaluateModel(Classifier classifier, Instances data)
1310: throws Exception {
1311:
1312: double predictions[] = new double[data.numInstances()];
1313:
1314: // Need to be able to collect predictions if appropriate (for AUC)
1315:
1316: for (int i = 0; i < data.numInstances(); i++) {
1317: predictions[i] = evaluateModelOnceAndRecordPrediction(
1318: (Classifier) classifier, data.instance(i));
1319: }
1320:
1321: return predictions;
1322: }
1323:
1324: /**
1325: * Evaluates the classifier on a single instance and records the
1326: * prediction (if the class is nominal).
1327: *
1328: * @param classifier machine learning classifier
1329: * @param instance the test instance to be classified
1330: * @return the prediction made by the clasifier
1331: * @throws Exception if model could not be evaluated
1332: * successfully or the data contains string attributes
1333: */
1334: public double evaluateModelOnceAndRecordPrediction(
1335: Classifier classifier, Instance instance) throws Exception {
1336:
1337: Instance classMissing = (Instance) instance.copy();
1338: double pred = 0;
1339: classMissing.setDataset(instance.dataset());
1340: classMissing.setClassMissing();
1341: if (m_ClassIsNominal) {
1342: if (m_Predictions == null) {
1343: m_Predictions = new FastVector();
1344: }
1345: double[] dist = classifier
1346: .distributionForInstance(classMissing);
1347: pred = Utils.maxIndex(dist);
1348: if (dist[(int) pred] <= 0) {
1349: pred = Instance.missingValue();
1350: }
1351: updateStatsForClassifier(dist, instance);
1352: m_Predictions.addElement(new NominalPrediction(instance
1353: .classValue(), dist, instance.weight()));
1354: } else {
1355: pred = classifier.classifyInstance(classMissing);
1356: updateStatsForPredictor(pred, instance);
1357: }
1358: return pred;
1359: }
1360:
1361: /**
1362: * Evaluates the classifier on a single instance.
1363: *
1364: * @param classifier machine learning classifier
1365: * @param instance the test instance to be classified
1366: * @return the prediction made by the clasifier
1367: * @throws Exception if model could not be evaluated
1368: * successfully or the data contains string attributes
1369: */
1370: public double evaluateModelOnce(Classifier classifier,
1371: Instance instance) throws Exception {
1372:
1373: Instance classMissing = (Instance) instance.copy();
1374: double pred = 0;
1375: classMissing.setDataset(instance.dataset());
1376: classMissing.setClassMissing();
1377: if (m_ClassIsNominal) {
1378: double[] dist = classifier
1379: .distributionForInstance(classMissing);
1380: pred = Utils.maxIndex(dist);
1381: if (dist[(int) pred] <= 0) {
1382: pred = Instance.missingValue();
1383: }
1384: updateStatsForClassifier(dist, instance);
1385: } else {
1386: pred = classifier.classifyInstance(classMissing);
1387: updateStatsForPredictor(pred, instance);
1388: }
1389: return pred;
1390: }
1391:
1392: /**
1393: * Evaluates the supplied distribution on a single instance.
1394: *
1395: * @param dist the supplied distribution
1396: * @param instance the test instance to be classified
1397: * @return the prediction
1398: * @throws Exception if model could not be evaluated
1399: * successfully
1400: */
1401: public double evaluateModelOnce(double[] dist, Instance instance)
1402: throws Exception {
1403: double pred;
1404: if (m_ClassIsNominal) {
1405: pred = Utils.maxIndex(dist);
1406: if (dist[(int) pred] <= 0) {
1407: pred = Instance.missingValue();
1408: }
1409: updateStatsForClassifier(dist, instance);
1410: } else {
1411: pred = dist[0];
1412: updateStatsForPredictor(pred, instance);
1413: }
1414: return pred;
1415: }
1416:
1417: /**
1418: * Evaluates the supplied distribution on a single instance.
1419: *
1420: * @param dist the supplied distribution
1421: * @param instance the test instance to be classified
1422: * @return the prediction
1423: * @throws Exception if model could not be evaluated
1424: * successfully
1425: */
1426: public double evaluateModelOnceAndRecordPrediction(double[] dist,
1427: Instance instance) throws Exception {
1428: double pred;
1429: if (m_ClassIsNominal) {
1430: if (m_Predictions == null) {
1431: m_Predictions = new FastVector();
1432: }
1433: pred = Utils.maxIndex(dist);
1434: if (dist[(int) pred] <= 0) {
1435: pred = Instance.missingValue();
1436: }
1437: updateStatsForClassifier(dist, instance);
1438: m_Predictions.addElement(new NominalPrediction(instance
1439: .classValue(), dist, instance.weight()));
1440: } else {
1441: pred = dist[0];
1442: updateStatsForPredictor(pred, instance);
1443: }
1444: return pred;
1445: }
1446:
1447: /**
1448: * Evaluates the supplied prediction on a single instance.
1449: *
1450: * @param prediction the supplied prediction
1451: * @param instance the test instance to be classified
1452: * @throws Exception if model could not be evaluated
1453: * successfully
1454: */
1455: public void evaluateModelOnce(double prediction, Instance instance)
1456: throws Exception {
1457:
1458: if (m_ClassIsNominal) {
1459: updateStatsForClassifier(makeDistribution(prediction),
1460: instance);
1461: } else {
1462: updateStatsForPredictor(prediction, instance);
1463: }
1464: }
1465:
1466: /**
1467: * Returns the predictions that have been collected.
1468: *
1469: * @return a reference to the FastVector containing the predictions
1470: * that have been collected. This should be null if no predictions
1471: * have been collected (e.g. if the class is numeric).
1472: */
1473: public FastVector predictions() {
1474:
1475: return m_Predictions;
1476: }
1477:
1478: /**
1479: * Wraps a static classifier in enough source to test using the weka
1480: * class libraries.
1481: *
1482: * @param classifier a Sourcable Classifier
1483: * @param className the name to give to the source code class
1484: * @return the source for a static classifier that can be tested with
1485: * weka libraries.
1486: * @throws Exception if code-generation fails
1487: */
1488: protected static String wekaStaticWrapper(Sourcable classifier,
1489: String className) throws Exception {
1490:
1491: //String className = "StaticClassifier";
1492: String staticClassifier = classifier.toSource(className);
1493: return "package weka.classifiers;\n\n"
1494: + "import weka.core.Attribute;\n"
1495: + "import weka.core.Instance;\n"
1496: + "import weka.core.Instances;\n"
1497: + "import weka.classifiers.Classifier;\n\n"
1498: + "public class WekaWrapper extends Classifier {\n\n"
1499: + " public void buildClassifier(Instances i) throws Exception {\n"
1500: + " }\n\n"
1501: + " public double classifyInstance(Instance i) throws Exception {\n\n"
1502: + " Object [] s = new Object [i.numAttributes()];\n"
1503: + " for (int j = 0; j < s.length; j++) {\n"
1504: + " if (!i.isMissing(j)) {\n"
1505: + " if (i.attribute(j).type() == Attribute.NOMINAL) {\n"
1506: + " s[j] = i.attribute(j).value((int) i.value(j));\n"
1507: + " } else if (i.attribute(j).type() == Attribute.NUMERIC) {\n"
1508: + " s[j] = new Double(i.value(j));\n"
1509: + " }\n" + " }\n" + " }\n"
1510: + " return " + className + ".classify(s);\n"
1511: + " }\n\n" + "}\n\n" + staticClassifier; // The static classifer class
1512: }
1513:
1514: /**
1515: * Gets the number of test instances that had a known class value
1516: * (actually the sum of the weights of test instances with known
1517: * class value).
1518: *
1519: * @return the number of test instances with known class
1520: */
1521: public final double numInstances() {
1522:
1523: return m_WithClass;
1524: }
1525:
1526: /**
1527: * Gets the number of instances incorrectly classified (that is, for
1528: * which an incorrect prediction was made). (Actually the sum of the weights
1529: * of these instances)
1530: *
1531: * @return the number of incorrectly classified instances
1532: */
1533: public final double incorrect() {
1534:
1535: return m_Incorrect;
1536: }
1537:
1538: /**
1539: * Gets the percentage of instances incorrectly classified (that is, for
1540: * which an incorrect prediction was made).
1541: *
1542: * @return the percent of incorrectly classified instances
1543: * (between 0 and 100)
1544: */
1545: public final double pctIncorrect() {
1546:
1547: return 100 * m_Incorrect / m_WithClass;
1548: }
1549:
1550: /**
1551: * Gets the total cost, that is, the cost of each prediction times the
1552: * weight of the instance, summed over all instances.
1553: *
1554: * @return the total cost
1555: */
1556: public final double totalCost() {
1557:
1558: return m_TotalCost;
1559: }
1560:
1561: /**
1562: * Gets the average cost, that is, total cost of misclassifications
1563: * (incorrect plus unclassified) over the total number of instances.
1564: *
1565: * @return the average cost.
1566: */
1567: public final double avgCost() {
1568:
1569: return m_TotalCost / m_WithClass;
1570: }
1571:
1572: /**
1573: * Gets the number of instances correctly classified (that is, for
1574: * which a correct prediction was made). (Actually the sum of the weights
1575: * of these instances)
1576: *
1577: * @return the number of correctly classified instances
1578: */
1579: public final double correct() {
1580:
1581: return m_Correct;
1582: }
1583:
1584: /**
1585: * Gets the percentage of instances correctly classified (that is, for
1586: * which a correct prediction was made).
1587: *
1588: * @return the percent of correctly classified instances (between 0 and 100)
1589: */
1590: public final double pctCorrect() {
1591:
1592: return 100 * m_Correct / m_WithClass;
1593: }
1594:
1595: /**
1596: * Gets the number of instances not classified (that is, for
1597: * which no prediction was made by the classifier). (Actually the sum
1598: * of the weights of these instances)
1599: *
1600: * @return the number of unclassified instances
1601: */
1602: public final double unclassified() {
1603:
1604: return m_Unclassified;
1605: }
1606:
1607: /**
1608: * Gets the percentage of instances not classified (that is, for
1609: * which no prediction was made by the classifier).
1610: *
1611: * @return the percent of unclassified instances (between 0 and 100)
1612: */
1613: public final double pctUnclassified() {
1614:
1615: return 100 * m_Unclassified / m_WithClass;
1616: }
1617:
1618: /**
1619: * Returns the estimated error rate or the root mean squared error
1620: * (if the class is numeric). If a cost matrix was given this
1621: * error rate gives the average cost.
1622: *
1623: * @return the estimated error rate (between 0 and 1, or between 0 and
1624: * maximum cost)
1625: */
1626: public final double errorRate() {
1627:
1628: if (!m_ClassIsNominal) {
1629: return Math.sqrt(m_SumSqrErr / m_WithClass);
1630: }
1631: if (m_CostMatrix == null) {
1632: return m_Incorrect / m_WithClass;
1633: } else {
1634: return avgCost();
1635: }
1636: }
1637:
1638: /**
1639: * Returns value of kappa statistic if class is nominal.
1640: *
1641: * @return the value of the kappa statistic
1642: */
1643: public final double kappa() {
1644:
1645: double[] sumRows = new double[m_ConfusionMatrix.length];
1646: double[] sumColumns = new double[m_ConfusionMatrix.length];
1647: double sumOfWeights = 0;
1648: for (int i = 0; i < m_ConfusionMatrix.length; i++) {
1649: for (int j = 0; j < m_ConfusionMatrix.length; j++) {
1650: sumRows[i] += m_ConfusionMatrix[i][j];
1651: sumColumns[j] += m_ConfusionMatrix[i][j];
1652: sumOfWeights += m_ConfusionMatrix[i][j];
1653: }
1654: }
1655: double correct = 0, chanceAgreement = 0;
1656: for (int i = 0; i < m_ConfusionMatrix.length; i++) {
1657: chanceAgreement += (sumRows[i] * sumColumns[i]);
1658: correct += m_ConfusionMatrix[i][i];
1659: }
1660: chanceAgreement /= (sumOfWeights * sumOfWeights);
1661: correct /= sumOfWeights;
1662:
1663: if (chanceAgreement < 1) {
1664: return (correct - chanceAgreement) / (1 - chanceAgreement);
1665: } else {
1666: return 1;
1667: }
1668: }
1669:
1670: /**
1671: * Returns the correlation coefficient if the class is numeric.
1672: *
1673: * @return the correlation coefficient
1674: * @throws Exception if class is not numeric
1675: */
1676: public final double correlationCoefficient() throws Exception {
1677:
1678: if (m_ClassIsNominal) {
1679: throw new Exception(
1680: "Can't compute correlation coefficient: "
1681: + "class is nominal!");
1682: }
1683:
1684: double correlation = 0;
1685: double varActual = m_SumSqrClass - m_SumClass * m_SumClass
1686: / m_WithClass;
1687: double varPredicted = m_SumSqrPredicted - m_SumPredicted
1688: * m_SumPredicted / m_WithClass;
1689: double varProd = m_SumClassPredicted - m_SumClass
1690: * m_SumPredicted / m_WithClass;
1691:
1692: if (varActual * varPredicted <= 0) {
1693: correlation = 0.0;
1694: } else {
1695: correlation = varProd / Math.sqrt(varActual * varPredicted);
1696: }
1697:
1698: return correlation;
1699: }
1700:
1701: /**
1702: * Returns the mean absolute error. Refers to the error of the
1703: * predicted values for numeric classes, and the error of the
1704: * predicted probability distribution for nominal classes.
1705: *
1706: * @return the mean absolute error
1707: */
1708: public final double meanAbsoluteError() {
1709:
1710: return m_SumAbsErr / m_WithClass;
1711: }
1712:
1713: /**
1714: * Returns the mean absolute error of the prior.
1715: *
1716: * @return the mean absolute error
1717: */
1718: public final double meanPriorAbsoluteError() {
1719:
1720: if (m_NoPriors)
1721: return Double.NaN;
1722:
1723: return m_SumPriorAbsErr / m_WithClass;
1724: }
1725:
1726: /**
1727: * Returns the relative absolute error.
1728: *
1729: * @return the relative absolute error
1730: * @throws Exception if it can't be computed
1731: */
1732: public final double relativeAbsoluteError() throws Exception {
1733:
1734: if (m_NoPriors)
1735: return Double.NaN;
1736:
1737: return 100 * meanAbsoluteError() / meanPriorAbsoluteError();
1738: }
1739:
1740: /**
1741: * Returns the root mean squared error.
1742: *
1743: * @return the root mean squared error
1744: */
1745: public final double rootMeanSquaredError() {
1746:
1747: return Math.sqrt(m_SumSqrErr / m_WithClass);
1748: }
1749:
1750: /**
1751: * Returns the root mean prior squared error.
1752: *
1753: * @return the root mean prior squared error
1754: */
1755: public final double rootMeanPriorSquaredError() {
1756:
1757: if (m_NoPriors)
1758: return Double.NaN;
1759:
1760: return Math.sqrt(m_SumPriorSqrErr / m_WithClass);
1761: }
1762:
1763: /**
1764: * Returns the root relative squared error if the class is numeric.
1765: *
1766: * @return the root relative squared error
1767: */
1768: public final double rootRelativeSquaredError() {
1769:
1770: if (m_NoPriors)
1771: return Double.NaN;
1772:
1773: return 100.0 * rootMeanSquaredError()
1774: / rootMeanPriorSquaredError();
1775: }
1776:
1777: /**
1778: * Calculate the entropy of the prior distribution
1779: *
1780: * @return the entropy of the prior distribution
1781: * @throws Exception if the class is not nominal
1782: */
1783: public final double priorEntropy() throws Exception {
1784:
1785: if (!m_ClassIsNominal) {
1786: throw new Exception(
1787: "Can't compute entropy of class prior: "
1788: + "class numeric!");
1789: }
1790:
1791: if (m_NoPriors)
1792: return Double.NaN;
1793:
1794: double entropy = 0;
1795: for (int i = 0; i < m_NumClasses; i++) {
1796: entropy -= m_ClassPriors[i] / m_ClassPriorsSum
1797: * Utils.log2(m_ClassPriors[i] / m_ClassPriorsSum);
1798: }
1799: return entropy;
1800: }
1801:
1802: /**
1803: * Return the total Kononenko & Bratko Information score in bits
1804: *
1805: * @return the K&B information score
1806: * @throws Exception if the class is not nominal
1807: */
1808: public final double KBInformation() throws Exception {
1809:
1810: if (!m_ClassIsNominal) {
1811: throw new Exception("Can't compute K&B Info score: "
1812: + "class numeric!");
1813: }
1814:
1815: if (m_NoPriors)
1816: return Double.NaN;
1817:
1818: return m_SumKBInfo;
1819: }
1820:
1821: /**
1822: * Return the Kononenko & Bratko Information score in bits per
1823: * instance.
1824: *
1825: * @return the K&B information score
1826: * @throws Exception if the class is not nominal
1827: */
1828: public final double KBMeanInformation() throws Exception {
1829:
1830: if (!m_ClassIsNominal) {
1831: throw new Exception("Can't compute K&B Info score: "
1832: + "class numeric!");
1833: }
1834:
1835: if (m_NoPriors)
1836: return Double.NaN;
1837:
1838: return m_SumKBInfo / m_WithClass;
1839: }
1840:
1841: /**
1842: * Return the Kononenko & Bratko Relative Information score
1843: *
1844: * @return the K&B relative information score
1845: * @throws Exception if the class is not nominal
1846: */
1847: public final double KBRelativeInformation() throws Exception {
1848:
1849: if (!m_ClassIsNominal) {
1850: throw new Exception("Can't compute K&B Info score: "
1851: + "class numeric!");
1852: }
1853:
1854: if (m_NoPriors)
1855: return Double.NaN;
1856:
1857: return 100.0 * KBInformation() / priorEntropy();
1858: }
1859:
1860: /**
1861: * Returns the total entropy for the null model
1862: *
1863: * @return the total null model entropy
1864: */
1865: public final double SFPriorEntropy() {
1866:
1867: if (m_NoPriors)
1868: return Double.NaN;
1869:
1870: return m_SumPriorEntropy;
1871: }
1872:
1873: /**
1874: * Returns the entropy per instance for the null model
1875: *
1876: * @return the null model entropy per instance
1877: */
1878: public final double SFMeanPriorEntropy() {
1879:
1880: if (m_NoPriors)
1881: return Double.NaN;
1882:
1883: return m_SumPriorEntropy / m_WithClass;
1884: }
1885:
1886: /**
1887: * Returns the total entropy for the scheme
1888: *
1889: * @return the total scheme entropy
1890: */
1891: public final double SFSchemeEntropy() {
1892:
1893: if (m_NoPriors)
1894: return Double.NaN;
1895:
1896: return m_SumSchemeEntropy;
1897: }
1898:
1899: /**
1900: * Returns the entropy per instance for the scheme
1901: *
1902: * @return the scheme entropy per instance
1903: */
1904: public final double SFMeanSchemeEntropy() {
1905:
1906: if (m_NoPriors)
1907: return Double.NaN;
1908:
1909: return m_SumSchemeEntropy / m_WithClass;
1910: }
1911:
1912: /**
1913: * Returns the total SF, which is the null model entropy minus
1914: * the scheme entropy.
1915: *
1916: * @return the total SF
1917: */
1918: public final double SFEntropyGain() {
1919:
1920: if (m_NoPriors)
1921: return Double.NaN;
1922:
1923: return m_SumPriorEntropy - m_SumSchemeEntropy;
1924: }
1925:
1926: /**
1927: * Returns the SF per instance, which is the null model entropy
1928: * minus the scheme entropy, per instance.
1929: *
1930: * @return the SF per instance
1931: */
1932: public final double SFMeanEntropyGain() {
1933:
1934: if (m_NoPriors)
1935: return Double.NaN;
1936:
1937: return (m_SumPriorEntropy - m_SumSchemeEntropy) / m_WithClass;
1938: }
1939:
1940: /**
1941: * Output the cumulative margin distribution as a string suitable
1942: * for input for gnuplot or similar package.
1943: *
1944: * @return the cumulative margin distribution
1945: * @throws Exception if the class attribute is nominal
1946: */
1947: public String toCumulativeMarginDistributionString()
1948: throws Exception {
1949:
1950: if (!m_ClassIsNominal) {
1951: throw new Exception(
1952: "Class must be nominal for margin distributions");
1953: }
1954: String result = "";
1955: double cumulativeCount = 0;
1956: double margin;
1957: for (int i = 0; i <= k_MarginResolution; i++) {
1958: if (m_MarginCounts[i] != 0) {
1959: cumulativeCount += m_MarginCounts[i];
1960: margin = (double) i * 2.0 / k_MarginResolution - 1.0;
1961: result = result
1962: + Utils.doubleToString(margin, 7, 3)
1963: + ' '
1964: + Utils.doubleToString(cumulativeCount * 100
1965: / m_WithClass, 7, 3) + '\n';
1966: } else if (i == 0) {
1967: result = Utils.doubleToString(-1.0, 7, 3) + ' '
1968: + Utils.doubleToString(0, 7, 3) + '\n';
1969: }
1970: }
1971: return result;
1972: }
1973:
1974: /**
1975: * Calls toSummaryString() with no title and no complexity stats
1976: *
1977: * @return a summary description of the classifier evaluation
1978: */
1979: public String toSummaryString() {
1980:
1981: return toSummaryString("", false);
1982: }
1983:
1984: /**
1985: * Calls toSummaryString() with a default title.
1986: *
1987: * @param printComplexityStatistics if true, complexity statistics are
1988: * returned as well
1989: * @return the summary string
1990: */
1991: public String toSummaryString(boolean printComplexityStatistics) {
1992:
1993: return toSummaryString("=== Summary ===\n",
1994: printComplexityStatistics);
1995: }
1996:
1997: /**
1998: * Outputs the performance statistics in summary form. Lists
1999: * number (and percentage) of instances classified correctly,
2000: * incorrectly and unclassified. Outputs the total number of
2001: * instances classified, and the number of instances (if any)
2002: * that had no class value provided.
2003: *
2004: * @param title the title for the statistics
2005: * @param printComplexityStatistics if true, complexity statistics are
2006: * returned as well
2007: * @return the summary as a String
2008: */
2009: public String toSummaryString(String title,
2010: boolean printComplexityStatistics) {
2011:
2012: StringBuffer text = new StringBuffer();
2013:
2014: if (printComplexityStatistics && m_NoPriors) {
2015: printComplexityStatistics = false;
2016: System.err
2017: .println("Priors disabled, cannot print complexity statistics!");
2018: }
2019:
2020: text.append(title + "\n");
2021: try {
2022: if (m_WithClass > 0) {
2023: if (m_ClassIsNominal) {
2024:
2025: text.append("Correctly Classified Instances ");
2026: text.append(Utils.doubleToString(correct(), 12, 4)
2027: + " "
2028: + Utils.doubleToString(pctCorrect(), 12, 4)
2029: + " %\n");
2030: text.append("Incorrectly Classified Instances ");
2031: text.append(Utils
2032: .doubleToString(incorrect(), 12, 4)
2033: + " "
2034: + Utils.doubleToString(pctIncorrect(), 12,
2035: 4) + " %\n");
2036: text.append("Kappa statistic ");
2037: text.append(Utils.doubleToString(kappa(), 12, 4)
2038: + "\n");
2039:
2040: if (m_CostMatrix != null) {
2041: text
2042: .append("Total Cost ");
2043: text.append(Utils.doubleToString(totalCost(),
2044: 12, 4)
2045: + "\n");
2046: text
2047: .append("Average Cost ");
2048: text.append(Utils.doubleToString(avgCost(), 12,
2049: 4)
2050: + "\n");
2051: }
2052: if (printComplexityStatistics) {
2053: text
2054: .append("K&B Relative Info Score ");
2055: text.append(Utils.doubleToString(
2056: KBRelativeInformation(), 12, 4)
2057: + " %\n");
2058: text
2059: .append("K&B Information Score ");
2060: text.append(Utils.doubleToString(
2061: KBInformation(), 12, 4)
2062: + " bits");
2063: text.append(Utils.doubleToString(
2064: KBMeanInformation(), 12, 4)
2065: + " bits/instance\n");
2066: }
2067: } else {
2068: text.append("Correlation coefficient ");
2069: text.append(Utils.doubleToString(
2070: correlationCoefficient(), 12, 4)
2071: + "\n");
2072: }
2073: if (printComplexityStatistics) {
2074: text.append("Class complexity | order 0 ");
2075: text.append(Utils.doubleToString(SFPriorEntropy(),
2076: 12, 4)
2077: + " bits");
2078: text.append(Utils.doubleToString(
2079: SFMeanPriorEntropy(), 12, 4)
2080: + " bits/instance\n");
2081: text.append("Class complexity | scheme ");
2082: text.append(Utils.doubleToString(SFSchemeEntropy(),
2083: 12, 4)
2084: + " bits");
2085: text.append(Utils.doubleToString(
2086: SFMeanSchemeEntropy(), 12, 4)
2087: + " bits/instance\n");
2088: text.append("Complexity improvement (Sf) ");
2089: text.append(Utils.doubleToString(SFEntropyGain(),
2090: 12, 4)
2091: + " bits");
2092: text.append(Utils.doubleToString(
2093: SFMeanEntropyGain(), 12, 4)
2094: + " bits/instance\n");
2095: }
2096:
2097: text.append("Mean absolute error ");
2098: text.append(Utils.doubleToString(meanAbsoluteError(),
2099: 12, 4)
2100: + "\n");
2101: text.append("Root mean squared error ");
2102: text.append(Utils.doubleToString(
2103: rootMeanSquaredError(), 12, 4)
2104: + "\n");
2105: if (!m_NoPriors) {
2106: text.append("Relative absolute error ");
2107: text.append(Utils.doubleToString(
2108: relativeAbsoluteError(), 12, 4)
2109: + " %\n");
2110: text.append("Root relative squared error ");
2111: text.append(Utils.doubleToString(
2112: rootRelativeSquaredError(), 12, 4)
2113: + " %\n");
2114: }
2115: }
2116: if (Utils.gr(unclassified(), 0)) {
2117: text.append("UnClassified Instances ");
2118: text.append(Utils.doubleToString(unclassified(), 12, 4)
2119: + " "
2120: + Utils
2121: .doubleToString(pctUnclassified(), 12,
2122: 4) + " %\n");
2123: }
2124: text.append("Total Number of Instances ");
2125: text
2126: .append(Utils.doubleToString(m_WithClass, 12, 4)
2127: + "\n");
2128: if (m_MissingClass > 0) {
2129: text
2130: .append("Ignored Class Unknown Instances ");
2131: text.append(Utils.doubleToString(m_MissingClass, 12, 4)
2132: + "\n");
2133: }
2134: } catch (Exception ex) {
2135: // Should never occur since the class is known to be nominal
2136: // here
2137: System.err
2138: .println("Arggh - Must be a bug in Evaluation class");
2139: }
2140:
2141: return text.toString();
2142: }
2143:
2144: /**
2145: * Calls toMatrixString() with a default title.
2146: *
2147: * @return the confusion matrix as a string
2148: * @throws Exception if the class is numeric
2149: */
2150: public String toMatrixString() throws Exception {
2151:
2152: return toMatrixString("=== Confusion Matrix ===\n");
2153: }
2154:
2155: /**
2156: * Outputs the performance statistics as a classification confusion
2157: * matrix. For each class value, shows the distribution of
2158: * predicted class values.
2159: *
2160: * @param title the title for the confusion matrix
2161: * @return the confusion matrix as a String
2162: * @throws Exception if the class is numeric
2163: */
2164: public String toMatrixString(String title) throws Exception {
2165:
2166: StringBuffer text = new StringBuffer();
2167: char[] IDChars = { 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i',
2168: 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't',
2169: 'u', 'v', 'w', 'x', 'y', 'z' };
2170: int IDWidth;
2171: boolean fractional = false;
2172:
2173: if (!m_ClassIsNominal) {
2174: throw new Exception(
2175: "Evaluation: No confusion matrix possible!");
2176: }
2177:
2178: // Find the maximum value in the matrix
2179: // and check for fractional display requirement
2180: double maxval = 0;
2181: for (int i = 0; i < m_NumClasses; i++) {
2182: for (int j = 0; j < m_NumClasses; j++) {
2183: double current = m_ConfusionMatrix[i][j];
2184: if (current < 0) {
2185: current *= -10;
2186: }
2187: if (current > maxval) {
2188: maxval = current;
2189: }
2190: double fract = current - Math.rint(current);
2191: if (!fractional
2192: && ((Math.log(fract) / Math.log(10)) >= -2)) {
2193: fractional = true;
2194: }
2195: }
2196: }
2197:
2198: IDWidth = 1 + Math
2199: .max(
2200: (int) (Math.log(maxval) / Math.log(10) + (fractional ? 3
2201: : 0)),
2202: (int) (Math.log(m_NumClasses) / Math
2203: .log(IDChars.length)));
2204: text.append(title).append("\n");
2205: for (int i = 0; i < m_NumClasses; i++) {
2206: if (fractional) {
2207: text.append(" ").append(
2208: num2ShortID(i, IDChars, IDWidth - 3)).append(
2209: " ");
2210: } else {
2211: text.append(" ").append(
2212: num2ShortID(i, IDChars, IDWidth));
2213: }
2214: }
2215: text.append(" <-- classified as\n");
2216: for (int i = 0; i < m_NumClasses; i++) {
2217: for (int j = 0; j < m_NumClasses; j++) {
2218: text.append(" ").append(
2219: Utils.doubleToString(m_ConfusionMatrix[i][j],
2220: IDWidth, (fractional ? 2 : 0)));
2221: }
2222: text.append(" | ").append(num2ShortID(i, IDChars, IDWidth))
2223: .append(" = ").append(m_ClassNames[i]).append("\n");
2224: }
2225: return text.toString();
2226: }
2227:
2228: /**
2229: * Generates a breakdown of the accuracy for each class (with default title),
2230: * incorporating various information-retrieval statistics, such as
2231: * true/false positive rate, precision/recall/F-Measure. Should be
2232: * useful for ROC curves, recall/precision curves.
2233: *
2234: * @return the statistics presented as a string
2235: * @throws Exception if class is not nominal
2236: */
2237: public String toClassDetailsString() throws Exception {
2238:
2239: return toClassDetailsString("=== Detailed Accuracy By Class ===\n");
2240: }
2241:
2242: /**
2243: * Generates a breakdown of the accuracy for each class,
2244: * incorporating various information-retrieval statistics, such as
2245: * true/false positive rate, precision/recall/F-Measure. Should be
2246: * useful for ROC curves, recall/precision curves.
2247: *
2248: * @param title the title to prepend the stats string with
2249: * @return the statistics presented as a string
2250: * @throws Exception if class is not nominal
2251: */
2252: public String toClassDetailsString(String title) throws Exception {
2253:
2254: if (!m_ClassIsNominal) {
2255: throw new Exception(
2256: "Evaluation: No confusion matrix possible!");
2257: }
2258: StringBuffer text = new StringBuffer(title
2259: + "\nTP Rate FP Rate" + " Precision Recall"
2260: + " F-Measure ROC Area Class\n");
2261: for (int i = 0; i < m_NumClasses; i++) {
2262: text
2263: .append(
2264: Utils.doubleToString(truePositiveRate(i),
2265: 7, 3)).append(" ");
2266: text.append(
2267: Utils.doubleToString(falsePositiveRate(i), 7, 3))
2268: .append(" ");
2269: text.append(Utils.doubleToString(precision(i), 7, 3))
2270: .append(" ");
2271: text.append(Utils.doubleToString(recall(i), 7, 3)).append(
2272: " ");
2273: text.append(Utils.doubleToString(fMeasure(i), 7, 3))
2274: .append(" ");
2275: double rocVal = areaUnderROC(i);
2276: if (Instance.isMissingValue(rocVal)) {
2277: text.append(" ? ").append(" ");
2278: } else {
2279: text.append(Utils.doubleToString(rocVal, 7, 3)).append(
2280: " ");
2281: }
2282: text.append(m_ClassNames[i]).append('\n');
2283: }
2284: return text.toString();
2285: }
2286:
2287: /**
2288: * Calculate the number of true positives with respect to a particular class.
2289: * This is defined as<p/>
2290: * <pre>
2291: * correctly classified positives
2292: * </pre>
2293: *
2294: * @param classIndex the index of the class to consider as "positive"
2295: * @return the true positive rate
2296: */
2297: public double numTruePositives(int classIndex) {
2298:
2299: double correct = 0;
2300: for (int j = 0; j < m_NumClasses; j++) {
2301: if (j == classIndex) {
2302: correct += m_ConfusionMatrix[classIndex][j];
2303: }
2304: }
2305: return correct;
2306: }
2307:
2308: /**
2309: * Calculate the true positive rate with respect to a particular class.
2310: * This is defined as<p/>
2311: * <pre>
2312: * correctly classified positives
2313: * ------------------------------
2314: * total positives
2315: * </pre>
2316: *
2317: * @param classIndex the index of the class to consider as "positive"
2318: * @return the true positive rate
2319: */
2320: public double truePositiveRate(int classIndex) {
2321:
2322: double correct = 0, total = 0;
2323: for (int j = 0; j < m_NumClasses; j++) {
2324: if (j == classIndex) {
2325: correct += m_ConfusionMatrix[classIndex][j];
2326: }
2327: total += m_ConfusionMatrix[classIndex][j];
2328: }
2329: if (total == 0) {
2330: return 0;
2331: }
2332: return correct / total;
2333: }
2334:
2335: /**
2336: * Calculate the number of true negatives with respect to a particular class.
2337: * This is defined as<p/>
2338: * <pre>
2339: * correctly classified negatives
2340: * </pre>
2341: *
2342: * @param classIndex the index of the class to consider as "positive"
2343: * @return the true positive rate
2344: */
2345: public double numTrueNegatives(int classIndex) {
2346:
2347: double correct = 0;
2348: for (int i = 0; i < m_NumClasses; i++) {
2349: if (i != classIndex) {
2350: for (int j = 0; j < m_NumClasses; j++) {
2351: if (j != classIndex) {
2352: correct += m_ConfusionMatrix[i][j];
2353: }
2354: }
2355: }
2356: }
2357: return correct;
2358: }
2359:
2360: /**
2361: * Calculate the true negative rate with respect to a particular class.
2362: * This is defined as<p/>
2363: * <pre>
2364: * correctly classified negatives
2365: * ------------------------------
2366: * total negatives
2367: * </pre>
2368: *
2369: * @param classIndex the index of the class to consider as "positive"
2370: * @return the true positive rate
2371: */
2372: public double trueNegativeRate(int classIndex) {
2373:
2374: double correct = 0, total = 0;
2375: for (int i = 0; i < m_NumClasses; i++) {
2376: if (i != classIndex) {
2377: for (int j = 0; j < m_NumClasses; j++) {
2378: if (j != classIndex) {
2379: correct += m_ConfusionMatrix[i][j];
2380: }
2381: total += m_ConfusionMatrix[i][j];
2382: }
2383: }
2384: }
2385: if (total == 0) {
2386: return 0;
2387: }
2388: return correct / total;
2389: }
2390:
2391: /**
2392: * Calculate number of false positives with respect to a particular class.
2393: * This is defined as<p/>
2394: * <pre>
2395: * incorrectly classified negatives
2396: * </pre>
2397: *
2398: * @param classIndex the index of the class to consider as "positive"
2399: * @return the false positive rate
2400: */
2401: public double numFalsePositives(int classIndex) {
2402:
2403: double incorrect = 0;
2404: for (int i = 0; i < m_NumClasses; i++) {
2405: if (i != classIndex) {
2406: for (int j = 0; j < m_NumClasses; j++) {
2407: if (j == classIndex) {
2408: incorrect += m_ConfusionMatrix[i][j];
2409: }
2410: }
2411: }
2412: }
2413: return incorrect;
2414: }
2415:
2416: /**
2417: * Calculate the false positive rate with respect to a particular class.
2418: * This is defined as<p/>
2419: * <pre>
2420: * incorrectly classified negatives
2421: * --------------------------------
2422: * total negatives
2423: * </pre>
2424: *
2425: * @param classIndex the index of the class to consider as "positive"
2426: * @return the false positive rate
2427: */
2428: public double falsePositiveRate(int classIndex) {
2429:
2430: double incorrect = 0, total = 0;
2431: for (int i = 0; i < m_NumClasses; i++) {
2432: if (i != classIndex) {
2433: for (int j = 0; j < m_NumClasses; j++) {
2434: if (j == classIndex) {
2435: incorrect += m_ConfusionMatrix[i][j];
2436: }
2437: total += m_ConfusionMatrix[i][j];
2438: }
2439: }
2440: }
2441: if (total == 0) {
2442: return 0;
2443: }
2444: return incorrect / total;
2445: }
2446:
2447: /**
2448: * Calculate number of false negatives with respect to a particular class.
2449: * This is defined as<p/>
2450: * <pre>
2451: * incorrectly classified positives
2452: * </pre>
2453: *
2454: * @param classIndex the index of the class to consider as "positive"
2455: * @return the false positive rate
2456: */
2457: public double numFalseNegatives(int classIndex) {
2458:
2459: double incorrect = 0;
2460: for (int i = 0; i < m_NumClasses; i++) {
2461: if (i == classIndex) {
2462: for (int j = 0; j < m_NumClasses; j++) {
2463: if (j != classIndex) {
2464: incorrect += m_ConfusionMatrix[i][j];
2465: }
2466: }
2467: }
2468: }
2469: return incorrect;
2470: }
2471:
2472: /**
2473: * Calculate the false negative rate with respect to a particular class.
2474: * This is defined as<p/>
2475: * <pre>
2476: * incorrectly classified positives
2477: * --------------------------------
2478: * total positives
2479: * </pre>
2480: *
2481: * @param classIndex the index of the class to consider as "positive"
2482: * @return the false positive rate
2483: */
2484: public double falseNegativeRate(int classIndex) {
2485:
2486: double incorrect = 0, total = 0;
2487: for (int i = 0; i < m_NumClasses; i++) {
2488: if (i == classIndex) {
2489: for (int j = 0; j < m_NumClasses; j++) {
2490: if (j != classIndex) {
2491: incorrect += m_ConfusionMatrix[i][j];
2492: }
2493: total += m_ConfusionMatrix[i][j];
2494: }
2495: }
2496: }
2497: if (total == 0) {
2498: return 0;
2499: }
2500: return incorrect / total;
2501: }
2502:
2503: /**
2504: * Calculate the recall with respect to a particular class.
2505: * This is defined as<p/>
2506: * <pre>
2507: * correctly classified positives
2508: * ------------------------------
2509: * total positives
2510: * </pre><p/>
2511: * (Which is also the same as the truePositiveRate.)
2512: *
2513: * @param classIndex the index of the class to consider as "positive"
2514: * @return the recall
2515: */
2516: public double recall(int classIndex) {
2517:
2518: return truePositiveRate(classIndex);
2519: }
2520:
2521: /**
2522: * Calculate the precision with respect to a particular class.
2523: * This is defined as<p/>
2524: * <pre>
2525: * correctly classified positives
2526: * ------------------------------
2527: * total predicted as positive
2528: * </pre>
2529: *
2530: * @param classIndex the index of the class to consider as "positive"
2531: * @return the precision
2532: */
2533: public double precision(int classIndex) {
2534:
2535: double correct = 0, total = 0;
2536: for (int i = 0; i < m_NumClasses; i++) {
2537: if (i == classIndex) {
2538: correct += m_ConfusionMatrix[i][classIndex];
2539: }
2540: total += m_ConfusionMatrix[i][classIndex];
2541: }
2542: if (total == 0) {
2543: return 0;
2544: }
2545: return correct / total;
2546: }
2547:
2548: /**
2549: * Calculate the F-Measure with respect to a particular class.
2550: * This is defined as<p/>
2551: * <pre>
2552: * 2 * recall * precision
2553: * ----------------------
2554: * recall + precision
2555: * </pre>
2556: *
2557: * @param classIndex the index of the class to consider as "positive"
2558: * @return the F-Measure
2559: */
2560: public double fMeasure(int classIndex) {
2561:
2562: double precision = precision(classIndex);
2563: double recall = recall(classIndex);
2564: if ((precision + recall) == 0) {
2565: return 0;
2566: }
2567: return 2 * precision * recall / (precision + recall);
2568: }
2569:
2570: /**
2571: * Sets the class prior probabilities
2572: *
2573: * @param train the training instances used to determine
2574: * the prior probabilities
2575: * @throws Exception if the class attribute of the instances is not
2576: * set
2577: */
2578: public void setPriors(Instances train) throws Exception {
2579: m_NoPriors = false;
2580:
2581: if (!m_ClassIsNominal) {
2582:
2583: m_NumTrainClassVals = 0;
2584: m_TrainClassVals = null;
2585: m_TrainClassWeights = null;
2586: m_PriorErrorEstimator = null;
2587: m_ErrorEstimator = null;
2588:
2589: for (int i = 0; i < train.numInstances(); i++) {
2590: Instance currentInst = train.instance(i);
2591: if (!currentInst.classIsMissing()) {
2592: addNumericTrainClass(currentInst.classValue(),
2593: currentInst.weight());
2594: }
2595: }
2596:
2597: } else {
2598: for (int i = 0; i < m_NumClasses; i++) {
2599: m_ClassPriors[i] = 1;
2600: }
2601: m_ClassPriorsSum = m_NumClasses;
2602: for (int i = 0; i < train.numInstances(); i++) {
2603: if (!train.instance(i).classIsMissing()) {
2604: m_ClassPriors[(int) train.instance(i).classValue()] += train
2605: .instance(i).weight();
2606: m_ClassPriorsSum += train.instance(i).weight();
2607: }
2608: }
2609: }
2610: }
2611:
2612: /**
2613: * Get the current weighted class counts
2614: *
2615: * @return the weighted class counts
2616: */
2617: public double[] getClassPriors() {
2618: return m_ClassPriors;
2619: }
2620:
2621: /**
2622: * Updates the class prior probabilities (when incrementally
2623: * training)
2624: *
2625: * @param instance the new training instance seen
2626: * @throws Exception if the class of the instance is not
2627: * set
2628: */
2629: public void updatePriors(Instance instance) throws Exception {
2630: if (!instance.classIsMissing()) {
2631: if (!m_ClassIsNominal) {
2632: if (!instance.classIsMissing()) {
2633: addNumericTrainClass(instance.classValue(),
2634: instance.weight());
2635: }
2636: } else {
2637: m_ClassPriors[(int) instance.classValue()] += instance
2638: .weight();
2639: m_ClassPriorsSum += instance.weight();
2640: }
2641: }
2642: }
2643:
2644: /**
2645: * disables the use of priors, e.g., in case of de-serialized schemes
2646: * that have no access to the original training set, but are evaluated
2647: * on a set set.
2648: */
2649: public void useNoPriors() {
2650: m_NoPriors = true;
2651: }
2652:
2653: /**
2654: * Tests whether the current evaluation object is equal to another
2655: * evaluation object
2656: *
2657: * @param obj the object to compare against
2658: * @return true if the two objects are equal
2659: */
2660: public boolean equals(Object obj) {
2661:
2662: if ((obj == null) || !(obj.getClass().equals(this .getClass()))) {
2663: return false;
2664: }
2665: Evaluation cmp = (Evaluation) obj;
2666: if (m_ClassIsNominal != cmp.m_ClassIsNominal)
2667: return false;
2668: if (m_NumClasses != cmp.m_NumClasses)
2669: return false;
2670:
2671: if (m_Incorrect != cmp.m_Incorrect)
2672: return false;
2673: if (m_Correct != cmp.m_Correct)
2674: return false;
2675: if (m_Unclassified != cmp.m_Unclassified)
2676: return false;
2677: if (m_MissingClass != cmp.m_MissingClass)
2678: return false;
2679: if (m_WithClass != cmp.m_WithClass)
2680: return false;
2681:
2682: if (m_SumErr != cmp.m_SumErr)
2683: return false;
2684: if (m_SumAbsErr != cmp.m_SumAbsErr)
2685: return false;
2686: if (m_SumSqrErr != cmp.m_SumSqrErr)
2687: return false;
2688: if (m_SumClass != cmp.m_SumClass)
2689: return false;
2690: if (m_SumSqrClass != cmp.m_SumSqrClass)
2691: return false;
2692: if (m_SumPredicted != cmp.m_SumPredicted)
2693: return false;
2694: if (m_SumSqrPredicted != cmp.m_SumSqrPredicted)
2695: return false;
2696: if (m_SumClassPredicted != cmp.m_SumClassPredicted)
2697: return false;
2698:
2699: if (m_ClassIsNominal) {
2700: for (int i = 0; i < m_NumClasses; i++) {
2701: for (int j = 0; j < m_NumClasses; j++) {
2702: if (m_ConfusionMatrix[i][j] != cmp.m_ConfusionMatrix[i][j]) {
2703: return false;
2704: }
2705: }
2706: }
2707: }
2708:
2709: return true;
2710: }
2711:
2712: /**
2713: * Prints the predictions for the given dataset into a String variable.
2714: *
2715: * @param classifier the classifier to use
2716: * @param train the training data
2717: * @param testSource the test set
2718: * @param classIndex the class index (1-based), if -1 ot does not
2719: * override the class index is stored in the data
2720: * file (by using the last attribute)
2721: * @param attributesToOutput the indices of the attributes to output
2722: * @return the generated predictions for the attribute range
2723: * @throws Exception if test file cannot be opened
2724: */
2725: protected static String printClassifications(Classifier classifier,
2726: Instances train, DataSource testSource, int classIndex,
2727: Range attributesToOutput) throws Exception {
2728:
2729: return printClassifications(classifier, train, testSource,
2730: classIndex, attributesToOutput, false);
2731: }
2732:
2733: /**
2734: * Prints the predictions for the given dataset into a String variable.
2735: *
2736: * @param classifier the classifier to use
2737: * @param train the training data
2738: * @param testSource the test set
2739: * @param classIndex the class index (1-based), if -1 ot does not
2740: * override the class index is stored in the data
2741: * file (by using the last attribute)
2742: * @param attributesToOutput the indices of the attributes to output
2743: * @param printDistribution prints the complete distribution for nominal
2744: * classes, not just the predicted value
2745: * @return the generated predictions for the attribute range
2746: * @throws Exception if test file cannot be opened
2747: */
2748: protected static String printClassifications(Classifier classifier,
2749: Instances train, DataSource testSource, int classIndex,
2750: Range attributesToOutput, boolean printDistribution)
2751: throws Exception {
2752:
2753: StringBuffer text = new StringBuffer();
2754: if (testSource != null) {
2755: Instances test = testSource.getStructure();
2756: if (classIndex != -1) {
2757: test.setClassIndex(classIndex - 1);
2758: } else {
2759: if (test.classIndex() == -1)
2760: test.setClassIndex(test.numAttributes() - 1);
2761: }
2762:
2763: // print header
2764: if (test.classAttribute().isNominal())
2765: if (printDistribution)
2766: text
2767: .append(" inst# actual predicted error distribution");
2768: else
2769: text
2770: .append(" inst# actual predicted error prediction");
2771: else
2772: text.append(" inst# actual predicted error");
2773: if (attributesToOutput != null) {
2774: attributesToOutput.setUpper(test.numAttributes() - 1);
2775: text.append(" (");
2776: boolean first = true;
2777: for (int i = 0; i < test.numAttributes(); i++) {
2778: if (i == test.classIndex())
2779: continue;
2780:
2781: if (attributesToOutput.isInRange(i)) {
2782: if (!first)
2783: text.append(",");
2784: text.append(test.attribute(i).name());
2785: first = false;
2786: }
2787: }
2788: text.append(")");
2789: }
2790: text.append("\n");
2791:
2792: // print predictions
2793: int i = 0;
2794: testSource.reset();
2795: while (testSource.hasMoreElements(test)) {
2796: Instance inst = testSource.nextElement(test);
2797: text.append(predictionText(classifier, inst, i,
2798: attributesToOutput, printDistribution));
2799: i++;
2800: }
2801: }
2802: return text.toString();
2803: }
2804:
2805: /**
2806: * returns the prediction made by the classifier as a string
2807: *
2808: * @param classifier the classifier to use
2809: * @param inst the instance to generate text from
2810: * @param instNum the index in the dataset
2811: * @param attributesToOutput the indices of the attributes to output
2812: * @param printDistribution prints the complete distribution for nominal
2813: * classes, not just the predicted value
2814: * @return the generated text
2815: * @throws Exception if something goes wrong
2816: * @see #printClassifications(Classifier, Instances, String, int, Range, boolean)
2817: */
2818: protected static String predictionText(Classifier classifier,
2819: Instance inst, int instNum, Range attributesToOutput,
2820: boolean printDistribution) throws Exception {
2821:
2822: StringBuffer result = new StringBuffer();
2823: int width = 10;
2824: int prec = 3;
2825:
2826: Instance withMissing = (Instance) inst.copy();
2827: withMissing.setDataset(inst.dataset());
2828: double predValue = ((Classifier) classifier)
2829: .classifyInstance(withMissing);
2830:
2831: // index
2832: result.append(Utils.padLeft("" + (instNum + 1), 6));
2833:
2834: if (inst.dataset().classAttribute().isNumeric()) {
2835: // actual
2836: if (inst.classIsMissing())
2837: result.append(" " + Utils.padLeft("?", width));
2838: else
2839: result.append(" "
2840: + Utils.doubleToString(inst.classValue(),
2841: width, prec));
2842: // predicted
2843: if (Instance.isMissingValue(predValue))
2844: result.append(" " + Utils.padLeft("?", width));
2845: else
2846: result.append(" "
2847: + Utils.doubleToString(predValue, width, prec));
2848: // error
2849: if (Instance.isMissingValue(predValue)
2850: || inst.classIsMissing())
2851: result.append(" " + Utils.padLeft("?", width));
2852: else
2853: result.append(" "
2854: + Utils.doubleToString(predValue
2855: - inst.classValue(), width, prec));
2856: } else {
2857: // actual
2858: result.append(" "
2859: + Utils.padLeft(((int) inst.classValue() + 1) + ":"
2860: + inst.toString(inst.classIndex()), width));
2861: // predicted
2862: if (Instance.isMissingValue(predValue))
2863: result.append(" " + Utils.padLeft("?", width));
2864: else
2865: result
2866: .append(" "
2867: + Utils
2868: .padLeft(
2869: ((int) predValue + 1)
2870: + ":"
2871: + inst
2872: .dataset()
2873: .classAttribute()
2874: .value(
2875: (int) predValue),
2876: width));
2877: // error?
2878: if ((int) predValue + 1 != (int) inst.classValue() + 1)
2879: result.append(" " + " + ");
2880: else
2881: result.append(" " + " ");
2882: // prediction/distribution
2883: if (printDistribution) {
2884: if (Instance.isMissingValue(predValue)) {
2885: result.append(" " + "?");
2886: } else {
2887: result.append(" ");
2888: double[] dist = classifier
2889: .distributionForInstance(withMissing);
2890: for (int n = 0; n < dist.length; n++) {
2891: if (n > 0)
2892: result.append(",");
2893: if (n == (int) predValue)
2894: result.append("*");
2895: result.append(Utils.doubleToString(dist[n],
2896: prec));
2897: }
2898: }
2899: } else {
2900: if (Instance.isMissingValue(predValue))
2901: result.append(" " + "?");
2902: else
2903: result
2904: .append(" "
2905: + Utils
2906: .doubleToString(
2907: classifier
2908: .distributionForInstance(withMissing)[(int) predValue],
2909: prec));
2910: }
2911: }
2912:
2913: // attributes
2914: result
2915: .append(" "
2916: + attributeValuesString(withMissing,
2917: attributesToOutput) + "\n");
2918:
2919: return result.toString();
2920: }
2921:
2922: /**
2923: * Builds a string listing the attribute values in a specified range of indices,
2924: * separated by commas and enclosed in brackets.
2925: *
2926: * @param instance the instance to print the values from
2927: * @param attRange the range of the attributes to list
2928: * @return a string listing values of the attributes in the range
2929: */
2930: protected static String attributeValuesString(Instance instance,
2931: Range attRange) {
2932: StringBuffer text = new StringBuffer();
2933: if (attRange != null) {
2934: boolean firstOutput = true;
2935: attRange.setUpper(instance.numAttributes() - 1);
2936: for (int i = 0; i < instance.numAttributes(); i++)
2937: if (attRange.isInRange(i) && i != instance.classIndex()) {
2938: if (firstOutput)
2939: text.append("(");
2940: else
2941: text.append(",");
2942: text.append(instance.toString(i));
2943: firstOutput = false;
2944: }
2945: if (!firstOutput)
2946: text.append(")");
2947: }
2948: return text.toString();
2949: }
2950:
2951: /**
2952: * Make up the help string giving all the command line options
2953: *
2954: * @param classifier the classifier to include options for
2955: * @return a string detailing the valid command line options
2956: */
2957: protected static String makeOptionString(Classifier classifier) {
2958:
2959: StringBuffer optionsText = new StringBuffer("");
2960:
2961: // General options
2962: optionsText.append("\n\nGeneral options:\n\n");
2963: optionsText.append("-t <name of training file>\n");
2964: optionsText.append("\tSets training file.\n");
2965: optionsText.append("-T <name of test file>\n");
2966: optionsText
2967: .append("\tSets test file. If missing, a cross-validation will be performed\n");
2968: optionsText.append("\ton the training data.\n");
2969: optionsText.append("-c <class index>\n");
2970: optionsText
2971: .append("\tSets index of class attribute (default: last).\n");
2972: optionsText.append("-x <number of folds>\n");
2973: optionsText
2974: .append("\tSets number of folds for cross-validation (default: 10).\n");
2975: optionsText.append("-no-cv\n");
2976: optionsText.append("\tDo not perform any cross validation.\n");
2977: optionsText.append("-split-percentage <percentage>\n");
2978: optionsText
2979: .append("\tSets the percentage for the train/test set split, e.g., 66.\n");
2980: optionsText.append("-preserve-order\n");
2981: optionsText
2982: .append("\tPreserves the order in the percentage split.\n");
2983: optionsText.append("-s <random number seed>\n");
2984: optionsText
2985: .append("\tSets random number seed for cross-validation or percentage split\n");
2986: optionsText.append("\t(default: 1).\n");
2987: optionsText.append("-m <name of file with cost matrix>\n");
2988: optionsText.append("\tSets file with cost matrix.\n");
2989: optionsText.append("-l <name of input file>\n");
2990: optionsText
2991: .append("\tSets model input file. In case the filename ends with '.xml',\n");
2992: optionsText
2993: .append("\tthe options are loaded from the XML file.\n");
2994: optionsText.append("-d <name of output file>\n");
2995: optionsText
2996: .append("\tSets model output file. In case the filename ends with '.xml',\n");
2997: optionsText
2998: .append("\tonly the options are saved to the XML file, not the model.\n");
2999: optionsText.append("-v\n");
3000: optionsText
3001: .append("\tOutputs no statistics for training data.\n");
3002: optionsText.append("-o\n");
3003: optionsText
3004: .append("\tOutputs statistics only, not the classifier.\n");
3005: optionsText.append("-i\n");
3006: optionsText.append("\tOutputs detailed information-retrieval");
3007: optionsText.append(" statistics for each class.\n");
3008: optionsText.append("-k\n");
3009: optionsText
3010: .append("\tOutputs information-theoretic statistics.\n");
3011: optionsText.append("-p <attribute range>\n");
3012: optionsText
3013: .append("\tOnly outputs predictions for test instances (or the train\n"
3014: + "\tinstances if no test instances provided), along with attributes\n"
3015: + "\t(0 for none).\n");
3016: optionsText.append("-distribution\n");
3017: optionsText
3018: .append("\tOutputs the distribution instead of only the prediction\n");
3019: optionsText
3020: .append("\tin conjunction with the '-p' option (only nominal classes).\n");
3021: optionsText.append("-r\n");
3022: optionsText
3023: .append("\tOnly outputs cumulative margin distribution.\n");
3024: if (classifier instanceof Sourcable) {
3025: optionsText.append("-z <class name>\n");
3026: optionsText
3027: .append("\tOnly outputs the source representation"
3028: + " of the classifier,\n\tgiving it the supplied"
3029: + " name.\n");
3030: }
3031: if (classifier instanceof Drawable) {
3032: optionsText.append("-g\n");
3033: optionsText
3034: .append("\tOnly outputs the graph representation"
3035: + " of the classifier.\n");
3036: }
3037: optionsText.append("-xml filename | xml-string\n");
3038: optionsText
3039: .append("\tRetrieves the options from the XML-data instead of the "
3040: + "command line.\n");
3041: optionsText.append("-threshold-file <file>\n");
3042: optionsText
3043: .append("\tThe file to save the threshold data to.\n"
3044: + "\tThe format is determined by the extensions, e.g., '.arff' for ARFF \n"
3045: + "\tformat or '.csv' for CSV.\n");
3046: optionsText.append("-threshold-label <label>\n");
3047: optionsText
3048: .append("\tThe class label to determine the threshold data for\n"
3049: + "\t(default is the first label)\n");
3050:
3051: // Get scheme-specific options
3052: if (classifier instanceof OptionHandler) {
3053: optionsText.append("\nOptions specific to "
3054: + classifier.getClass().getName() + ":\n\n");
3055: Enumeration enu = ((OptionHandler) classifier)
3056: .listOptions();
3057: while (enu.hasMoreElements()) {
3058: Option option = (Option) enu.nextElement();
3059: optionsText.append(option.synopsis() + '\n');
3060: optionsText.append(option.description() + "\n");
3061: }
3062: }
3063: return optionsText.toString();
3064: }
3065:
3066: /**
3067: * Method for generating indices for the confusion matrix.
3068: *
3069: * @param num integer to format
3070: * @param IDChars the characters to use
3071: * @param IDWidth the width of the entry
3072: * @return the formatted integer as a string
3073: */
3074: protected String num2ShortID(int num, char[] IDChars, int IDWidth) {
3075:
3076: char ID[] = new char[IDWidth];
3077: int i;
3078:
3079: for (i = IDWidth - 1; i >= 0; i--) {
3080: ID[i] = IDChars[num % IDChars.length];
3081: num = num / IDChars.length - 1;
3082: if (num < 0) {
3083: break;
3084: }
3085: }
3086: for (i--; i >= 0; i--) {
3087: ID[i] = ' ';
3088: }
3089:
3090: return new String(ID);
3091: }
3092:
3093: /**
3094: * Convert a single prediction into a probability distribution
3095: * with all zero probabilities except the predicted value which
3096: * has probability 1.0;
3097: *
3098: * @param predictedClass the index of the predicted class
3099: * @return the probability distribution
3100: */
3101: protected double[] makeDistribution(double predictedClass) {
3102:
3103: double[] result = new double[m_NumClasses];
3104: if (Instance.isMissingValue(predictedClass)) {
3105: return result;
3106: }
3107: if (m_ClassIsNominal) {
3108: result[(int) predictedClass] = 1.0;
3109: } else {
3110: result[0] = predictedClass;
3111: }
3112: return result;
3113: }
3114:
3115: /**
3116: * Updates all the statistics about a classifiers performance for
3117: * the current test instance.
3118: *
3119: * @param predictedDistribution the probabilities assigned to
3120: * each class
3121: * @param instance the instance to be classified
3122: * @throws Exception if the class of the instance is not
3123: * set
3124: */
3125: protected void updateStatsForClassifier(
3126: double[] predictedDistribution, Instance instance)
3127: throws Exception {
3128:
3129: int actualClass = (int) instance.classValue();
3130:
3131: if (!instance.classIsMissing()) {
3132: updateMargins(predictedDistribution, actualClass, instance
3133: .weight());
3134:
3135: // Determine the predicted class (doesn't detect multiple
3136: // classifications)
3137: int predictedClass = -1;
3138: double bestProb = 0.0;
3139: for (int i = 0; i < m_NumClasses; i++) {
3140: if (predictedDistribution[i] > bestProb) {
3141: predictedClass = i;
3142: bestProb = predictedDistribution[i];
3143: }
3144: }
3145:
3146: m_WithClass += instance.weight();
3147:
3148: // Determine misclassification cost
3149: if (m_CostMatrix != null) {
3150: if (predictedClass < 0) {
3151: // For missing predictions, we assume the worst possible cost.
3152: // This is pretty harsh.
3153: // Perhaps we could take the negative of the cost of a correct
3154: // prediction (-m_CostMatrix.getElement(actualClass,actualClass)),
3155: // although often this will be zero
3156: m_TotalCost += instance.weight()
3157: * m_CostMatrix.getMaxCost(actualClass,
3158: instance);
3159: } else {
3160: m_TotalCost += instance.weight()
3161: * m_CostMatrix.getElement(actualClass,
3162: predictedClass, instance);
3163: }
3164: }
3165:
3166: // Update counts when no class was predicted
3167: if (predictedClass < 0) {
3168: m_Unclassified += instance.weight();
3169: return;
3170: }
3171:
3172: double predictedProb = Math.max(MIN_SF_PROB,
3173: predictedDistribution[actualClass]);
3174: double priorProb = Math.max(MIN_SF_PROB,
3175: m_ClassPriors[actualClass] / m_ClassPriorsSum);
3176: if (predictedProb >= priorProb) {
3177: m_SumKBInfo += (Utils.log2(predictedProb) - Utils
3178: .log2(priorProb))
3179: * instance.weight();
3180: } else {
3181: m_SumKBInfo -= (Utils.log2(1.0 - predictedProb) - Utils
3182: .log2(1.0 - priorProb))
3183: * instance.weight();
3184: }
3185:
3186: m_SumSchemeEntropy -= Utils.log2(predictedProb)
3187: * instance.weight();
3188: m_SumPriorEntropy -= Utils.log2(priorProb)
3189: * instance.weight();
3190:
3191: updateNumericScores(predictedDistribution,
3192: makeDistribution(instance.classValue()), instance
3193: .weight());
3194:
3195: // Update other stats
3196: m_ConfusionMatrix[actualClass][predictedClass] += instance
3197: .weight();
3198: if (predictedClass != actualClass) {
3199: m_Incorrect += instance.weight();
3200: } else {
3201: m_Correct += instance.weight();
3202: }
3203: } else {
3204: m_MissingClass += instance.weight();
3205: }
3206: }
3207:
3208: /**
3209: * Updates all the statistics about a predictors performance for
3210: * the current test instance.
3211: *
3212: * @param predictedValue the numeric value the classifier predicts
3213: * @param instance the instance to be classified
3214: * @throws Exception if the class of the instance is not
3215: * set
3216: */
3217: protected void updateStatsForPredictor(double predictedValue,
3218: Instance instance) throws Exception {
3219:
3220: if (!instance.classIsMissing()) {
3221:
3222: // Update stats
3223: m_WithClass += instance.weight();
3224: if (Instance.isMissingValue(predictedValue)) {
3225: m_Unclassified += instance.weight();
3226: return;
3227: }
3228: m_SumClass += instance.weight() * instance.classValue();
3229: m_SumSqrClass += instance.weight() * instance.classValue()
3230: * instance.classValue();
3231: m_SumClassPredicted += instance.weight()
3232: * instance.classValue() * predictedValue;
3233: m_SumPredicted += instance.weight() * predictedValue;
3234: m_SumSqrPredicted += instance.weight() * predictedValue
3235: * predictedValue;
3236:
3237: if (m_ErrorEstimator == null) {
3238: setNumericPriorsFromBuffer();
3239: }
3240: double predictedProb = Math.max(m_ErrorEstimator
3241: .getProbability(predictedValue
3242: - instance.classValue()), MIN_SF_PROB);
3243: double priorProb = Math
3244: .max(m_PriorErrorEstimator.getProbability(instance
3245: .classValue()), MIN_SF_PROB);
3246:
3247: m_SumSchemeEntropy -= Utils.log2(predictedProb)
3248: * instance.weight();
3249: m_SumPriorEntropy -= Utils.log2(priorProb)
3250: * instance.weight();
3251: m_ErrorEstimator.addValue(predictedValue
3252: - instance.classValue(), instance.weight());
3253:
3254: updateNumericScores(makeDistribution(predictedValue),
3255: makeDistribution(instance.classValue()), instance
3256: .weight());
3257:
3258: } else
3259: m_MissingClass += instance.weight();
3260: }
3261:
3262: /**
3263: * Update the cumulative record of classification margins
3264: *
3265: * @param predictedDistribution the probability distribution predicted for
3266: * the current instance
3267: * @param actualClass the index of the actual instance class
3268: * @param weight the weight assigned to the instance
3269: */
3270: protected void updateMargins(double[] predictedDistribution,
3271: int actualClass, double weight) {
3272:
3273: double probActual = predictedDistribution[actualClass];
3274: double probNext = 0;
3275:
3276: for (int i = 0; i < m_NumClasses; i++)
3277: if ((i != actualClass)
3278: && (predictedDistribution[i] > probNext))
3279: probNext = predictedDistribution[i];
3280:
3281: double margin = probActual - probNext;
3282: int bin = (int) ((margin + 1.0) / 2.0 * k_MarginResolution);
3283: m_MarginCounts[bin] += weight;
3284: }
3285:
3286: /**
3287: * Update the numeric accuracy measures. For numeric classes, the
3288: * accuracy is between the actual and predicted class values. For
3289: * nominal classes, the accuracy is between the actual and
3290: * predicted class probabilities.
3291: *
3292: * @param predicted the predicted values
3293: * @param actual the actual value
3294: * @param weight the weight associated with this prediction
3295: */
3296: protected void updateNumericScores(double[] predicted,
3297: double[] actual, double weight) {
3298:
3299: double diff;
3300: double sumErr = 0, sumAbsErr = 0, sumSqrErr = 0;
3301: double sumPriorAbsErr = 0, sumPriorSqrErr = 0;
3302: for (int i = 0; i < m_NumClasses; i++) {
3303: diff = predicted[i] - actual[i];
3304: sumErr += diff;
3305: sumAbsErr += Math.abs(diff);
3306: sumSqrErr += diff * diff;
3307: diff = (m_ClassPriors[i] / m_ClassPriorsSum) - actual[i];
3308: sumPriorAbsErr += Math.abs(diff);
3309: sumPriorSqrErr += diff * diff;
3310: }
3311: m_SumErr += weight * sumErr / m_NumClasses;
3312: m_SumAbsErr += weight * sumAbsErr / m_NumClasses;
3313: m_SumSqrErr += weight * sumSqrErr / m_NumClasses;
3314: m_SumPriorAbsErr += weight * sumPriorAbsErr / m_NumClasses;
3315: m_SumPriorSqrErr += weight * sumPriorSqrErr / m_NumClasses;
3316: }
3317:
3318: /**
3319: * Adds a numeric (non-missing) training class value and weight to
3320: * the buffer of stored values.
3321: *
3322: * @param classValue the class value
3323: * @param weight the instance weight
3324: */
3325: protected void addNumericTrainClass(double classValue, double weight) {
3326:
3327: if (m_TrainClassVals == null) {
3328: m_TrainClassVals = new double[100];
3329: m_TrainClassWeights = new double[100];
3330: }
3331: if (m_NumTrainClassVals == m_TrainClassVals.length) {
3332: double[] temp = new double[m_TrainClassVals.length * 2];
3333: System.arraycopy(m_TrainClassVals, 0, temp, 0,
3334: m_TrainClassVals.length);
3335: m_TrainClassVals = temp;
3336:
3337: temp = new double[m_TrainClassWeights.length * 2];
3338: System.arraycopy(m_TrainClassWeights, 0, temp, 0,
3339: m_TrainClassWeights.length);
3340: m_TrainClassWeights = temp;
3341: }
3342: m_TrainClassVals[m_NumTrainClassVals] = classValue;
3343: m_TrainClassWeights[m_NumTrainClassVals] = weight;
3344: m_NumTrainClassVals++;
3345: }
3346:
3347: /**
3348: * Sets up the priors for numeric class attributes from the
3349: * training class values that have been seen so far.
3350: */
3351: protected void setNumericPriorsFromBuffer() {
3352:
3353: double numPrecision = 0.01; // Default value
3354: if (m_NumTrainClassVals > 1) {
3355: double[] temp = new double[m_NumTrainClassVals];
3356: System.arraycopy(m_TrainClassVals, 0, temp, 0,
3357: m_NumTrainClassVals);
3358: int[] index = Utils.sort(temp);
3359: double lastVal = temp[index[0]];
3360: double deltaSum = 0;
3361: int distinct = 0;
3362: for (int i = 1; i < temp.length; i++) {
3363: double current = temp[index[i]];
3364: if (current != lastVal) {
3365: deltaSum += current - lastVal;
3366: lastVal = current;
3367: distinct++;
3368: }
3369: }
3370: if (distinct > 0) {
3371: numPrecision = deltaSum / distinct;
3372: }
3373: }
3374: m_PriorErrorEstimator = new KernelEstimator(numPrecision);
3375: m_ErrorEstimator = new KernelEstimator(numPrecision);
3376: m_ClassPriors[0] = m_ClassPriorsSum = 0;
3377: for (int i = 0; i < m_NumTrainClassVals; i++) {
3378: m_ClassPriors[0] += m_TrainClassVals[i]
3379: * m_TrainClassWeights[i];
3380: m_ClassPriorsSum += m_TrainClassWeights[i];
3381: m_PriorErrorEstimator.addValue(m_TrainClassVals[i],
3382: m_TrainClassWeights[i]);
3383: }
3384: }
3385: }
|