Source Code Cross Referenced for Evaluation.java in  » Science » weka » weka » classifiers » Java Source Code / Java DocumentationJava Source Code and Java Documentation

Java Source Code / Java Documentation
1. 6.0 JDK Core
2. 6.0 JDK Modules
3. 6.0 JDK Modules com.sun
4. 6.0 JDK Modules com.sun.java
5. 6.0 JDK Modules sun
6. 6.0 JDK Platform
7. Ajax
8. Apache Harmony Java SE
9. Aspect oriented
10. Authentication Authorization
11. Blogger System
12. Build
13. Byte Code
14. Cache
15. Chart
16. Chat
17. Code Analyzer
18. Collaboration
19. Content Management System
20. Database Client
21. Database DBMS
22. Database JDBC Connection Pool
23. Database ORM
24. Development
25. EJB Server geronimo
26. EJB Server GlassFish
27. EJB Server JBoss 4.2.1
28. EJB Server resin 3.1.5
29. ERP CRM Financial
30. ESB
31. Forum
32. GIS
33. Graphic Library
34. Groupware
35. HTML Parser
36. IDE
37. IDE Eclipse
38. IDE Netbeans
39. Installer
40. Internationalization Localization
41. Inversion of Control
42. Issue Tracking
43. J2EE
44. JBoss
45. JMS
46. JMX
47. Library
48. Mail Clients
49. Net
50. Parser
51. PDF
52. Portal
53. Profiler
54. Project Management
55. Report
56. RSS RDF
57. Rule Engine
58. Science
59. Scripting
60. Search Engine
61. Security
62. Sevlet Container
63. Source Control
64. Swing Library
65. Template Engine
66. Test Coverage
67. Testing
68. UML
69. Web Crawler
70. Web Framework
71. Web Mail
72. Web Server
73. Web Services
74. Web Services apache cxf 2.0.1
75. Web Services AXIS2
76. Wiki Engine
77. Workflow Engines
78. XML
79. XML UI
Java
Java Tutorial
Java Open Source
Jar File Download
Java Articles
Java Products
Java by API
Photoshop Tutorials
Maya Tutorials
Flash Tutorials
3ds-Max Tutorials
Illustrator Tutorials
GIMP Tutorials
C# / C Sharp
C# / CSharp Tutorial
C# / CSharp Open Source
ASP.Net
ASP.NET Tutorial
JavaScript DHTML
JavaScript Tutorial
JavaScript Reference
HTML / CSS
HTML CSS Reference
C / ANSI-C
C Tutorial
C++
C++ Tutorial
Ruby
PHP
Python
Python Tutorial
Python Open Source
SQL Server / T-SQL
SQL Server / T-SQL Tutorial
Oracle PL / SQL
Oracle PL/SQL Tutorial
PostgreSQL
SQL / MySQL
MySQL Tutorial
VB.Net
VB.Net Tutorial
Flash / Flex / ActionScript
VBA / Excel / Access / Word
XML
XML Tutorial
Microsoft Office PowerPoint 2007 Tutorial
Microsoft Office Excel 2007 Tutorial
Microsoft Office Word 2007 Tutorial
Java Source Code / Java Documentation » Science » weka » weka.classifiers 
Source Cross Referenced  Class Diagram Java Document (Java Doc) 


0001:        /*
0002:         *    This program is free software; you can redistribute it and/or modify
0003:         *    it under the terms of the GNU General Public License as published by
0004:         *    the Free Software Foundation; either version 2 of the License, or
0005:         *    (at your option) any later version.
0006:         *
0007:         *    This program is distributed in the hope that it will be useful,
0008:         *    but WITHOUT ANY WARRANTY; without even the implied warranty of
0009:         *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
0010:         *    GNU General Public License for more details.
0011:         *
0012:         *    You should have received a copy of the GNU General Public License
0013:         *    along with this program; if not, write to the Free Software
0014:         *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
0015:         */
0016:
0017:        /*
0018:         *    Evaluation.java
0019:         *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
0020:         *
0021:         */
0022:
0023:        package weka.classifiers;
0024:
0025:        import weka.classifiers.evaluation.NominalPrediction;
0026:        import weka.classifiers.evaluation.ThresholdCurve;
0027:        import weka.classifiers.xml.XMLClassifier;
0028:        import weka.core.Drawable;
0029:        import weka.core.FastVector;
0030:        import weka.core.Instance;
0031:        import weka.core.Instances;
0032:        import weka.core.Option;
0033:        import weka.core.OptionHandler;
0034:        import weka.core.Range;
0035:        import weka.core.Summarizable;
0036:        import weka.core.Utils;
0037:        import weka.core.converters.ConverterUtils.DataSink;
0038:        import weka.core.converters.ConverterUtils.DataSource;
0039:        import weka.core.xml.KOML;
0040:        import weka.core.xml.XMLOptions;
0041:        import weka.core.xml.XMLSerialization;
0042:        import weka.estimators.Estimator;
0043:        import weka.estimators.KernelEstimator;
0044:
0045:        import java.io.BufferedInputStream;
0046:        import java.io.BufferedOutputStream;
0047:        import java.io.BufferedReader;
0048:        import java.io.FileInputStream;
0049:        import java.io.FileOutputStream;
0050:        import java.io.FileReader;
0051:        import java.io.InputStream;
0052:        import java.io.ObjectInputStream;
0053:        import java.io.ObjectOutputStream;
0054:        import java.io.OutputStream;
0055:        import java.io.Reader;
0056:        import java.util.Enumeration;
0057:        import java.util.Random;
0058:        import java.util.zip.GZIPInputStream;
0059:        import java.util.zip.GZIPOutputStream;
0060:
0061:        /**
0062:         * Class for evaluating machine learning models. <p/>
0063:         *
0064:         * ------------------------------------------------------------------- <p/>
0065:         *
0066:         * General options when evaluating a learning scheme from the command-line: <p/>
0067:         *
0068:         * -t filename <br/>
0069:         * Name of the file with the training data. (required) <p/>
0070:         *
0071:         * -T filename <br/>
0072:         * Name of the file with the test data. If missing a cross-validation 
0073:         * is performed. <p/>
0074:         *
0075:         * -c index <br/>
0076:         * Index of the class attribute (1, 2, ...; default: last). <p/>
0077:         *
0078:         * -x number <br/>
0079:         * The number of folds for the cross-validation (default: 10). <p/>
0080:         *
0081:         * -no-cv <br/>
0082:         * No cross validation.  If no test file is provided, no evaluation
0083:         * is done. <p/>
0084:         * 
0085:         * -split-percentage percentage <br/>
0086:         * Sets the percentage for the train/test set split, e.g., 66. <p/>
0087:         * 
0088:         * -preserve-order <br/>
0089:         * Preserves the order in the percentage split instead of randomizing
0090:         * the data first with the seed value ('-s'). <p/>
0091:         *
0092:         * -s seed <br/>
0093:         * Random number seed for the cross-validation and percentage split
0094:         * (default: 1). <p/>
0095:         *
0096:         * -m filename <br/>
0097:         * The name of a file containing a cost matrix. <p/>
0098:         *
0099:         * -l filename <br/>
0100:         * Loads classifier from the given file. In case the filename ends with ".xml" 
0101:         * the options are loaded from XML. <p/>
0102:         *
0103:         * -d filename <br/>
0104:         * Saves classifier built from the training data into the given file. In case 
0105:         * the filename ends with ".xml" the options are saved XML, not the model. <p/>
0106:         *
0107:         * -v <br/>
0108:         * Outputs no statistics for the training data. <p/>
0109:         *
0110:         * -o <br/>
0111:         * Outputs statistics only, not the classifier. <p/>
0112:         * 
0113:         * -i <br/>
0114:         * Outputs information-retrieval statistics per class. <p/>
0115:         *
0116:         * -k <br/>
0117:         * Outputs information-theoretic statistics. <p/>
0118:         *
0119:         * -p range <br/>
0120:         * Outputs predictions for test instances (or the train instances if no test
0121:         * instances provided), along with the attributes in the specified range 
0122:         * (and nothing else). Use '-p 0' if no attributes are desired. <p/>
0123:         * 
0124:         * -distribution <br/>
0125:         * Outputs the distribution instead of only the prediction
0126:         * in conjunction with the '-p' option (only nominal classes). <p/>
0127:         *
0128:         * -r <br/>
0129:         * Outputs cumulative margin distribution (and nothing else). <p/>
0130:         *
0131:         * -g <br/> 
0132:         * Only for classifiers that implement "Graphable." Outputs
0133:         * the graph representation of the classifier (and nothing
0134:         * else). <p/>
0135:         * 
0136:         * -xml filename | xml-string <br/>
0137:         * Retrieves the options from the XML-data instead of the command line. <p/>
0138:         * 
0139:         * -threshold-file file <br/>
0140:         * The file to save the threshold data to.
0141:         * The format is determined by the extensions, e.g., '.arff' for ARFF
0142:         * format or '.csv' for CSV. <p/>
0143:         *         
0144:         * -threshold-label label <br/>
0145:         * The class label to determine the threshold data for
0146:         * (default is the first label) <p/>
0147:         *         
0148:         * ------------------------------------------------------------------- <p/>
0149:         *
0150:         * Example usage as the main of a classifier (called FunkyClassifier):
0151:         * <code> <pre>
0152:         * public static void main(String [] args) {
0153:         *   runClassifier(new FunkyClassifier(), args);
0154:         * }
0155:         * </pre> </code> 
0156:         * <p/>
0157:         *
0158:         * ------------------------------------------------------------------ <p/>
0159:         *
0160:         * Example usage from within an application:
0161:         * <code> <pre>
0162:         * Instances trainInstances = ... instances got from somewhere
0163:         * Instances testInstances = ... instances got from somewhere
0164:         * Classifier scheme = ... scheme got from somewhere
0165:         *
0166:         * Evaluation evaluation = new Evaluation(trainInstances);
0167:         * evaluation.evaluateModel(scheme, testInstances);
0168:         * System.out.println(evaluation.toSummaryString());
0169:         * </pre> </code> 
0170:         *
0171:         *
0172:         * @author   Eibe Frank (eibe@cs.waikato.ac.nz)
0173:         * @author   Len Trigg (trigg@cs.waikato.ac.nz)
0174:         * @version  $Revision: 1.77 $
0175:         */
0176:        public class Evaluation implements  Summarizable {
0177:
0178:            /** The number of classes. */
0179:            protected int m_NumClasses;
0180:
0181:            /** The number of folds for a cross-validation. */
0182:            protected int m_NumFolds;
0183:
0184:            /** The weight of all incorrectly classified instances. */
0185:            protected double m_Incorrect;
0186:
0187:            /** The weight of all correctly classified instances. */
0188:            protected double m_Correct;
0189:
0190:            /** The weight of all unclassified instances. */
0191:            protected double m_Unclassified;
0192:
0193:            /*** The weight of all instances that had no class assigned to them. */
0194:            protected double m_MissingClass;
0195:
0196:            /** The weight of all instances that had a class assigned to them. */
0197:            protected double m_WithClass;
0198:
0199:            /** Array for storing the confusion matrix. */
0200:            protected double[][] m_ConfusionMatrix;
0201:
0202:            /** The names of the classes. */
0203:            protected String[] m_ClassNames;
0204:
0205:            /** Is the class nominal or numeric? */
0206:            protected boolean m_ClassIsNominal;
0207:
0208:            /** The prior probabilities of the classes */
0209:            protected double[] m_ClassPriors;
0210:
0211:            /** The sum of counts for priors */
0212:            protected double m_ClassPriorsSum;
0213:
0214:            /** The cost matrix (if given). */
0215:            protected CostMatrix m_CostMatrix;
0216:
0217:            /** The total cost of predictions (includes instance weights) */
0218:            protected double m_TotalCost;
0219:
0220:            /** Sum of errors. */
0221:            protected double m_SumErr;
0222:
0223:            /** Sum of absolute errors. */
0224:            protected double m_SumAbsErr;
0225:
0226:            /** Sum of squared errors. */
0227:            protected double m_SumSqrErr;
0228:
0229:            /** Sum of class values. */
0230:            protected double m_SumClass;
0231:
0232:            /** Sum of squared class values. */
0233:            protected double m_SumSqrClass;
0234:
0235:            /*** Sum of predicted values. */
0236:            protected double m_SumPredicted;
0237:
0238:            /** Sum of squared predicted values. */
0239:            protected double m_SumSqrPredicted;
0240:
0241:            /** Sum of predicted * class values. */
0242:            protected double m_SumClassPredicted;
0243:
0244:            /** Sum of absolute errors of the prior */
0245:            protected double m_SumPriorAbsErr;
0246:
0247:            /** Sum of absolute errors of the prior */
0248:            protected double m_SumPriorSqrErr;
0249:
0250:            /** Total Kononenko & Bratko Information */
0251:            protected double m_SumKBInfo;
0252:
0253:            /*** Resolution of the margin histogram */
0254:            protected static int k_MarginResolution = 500;
0255:
0256:            /** Cumulative margin distribution */
0257:            protected double m_MarginCounts[];
0258:
0259:            /** Number of non-missing class training instances seen */
0260:            protected int m_NumTrainClassVals;
0261:
0262:            /** Array containing all numeric training class values seen */
0263:            protected double[] m_TrainClassVals;
0264:
0265:            /** Array containing all numeric training class weights */
0266:            protected double[] m_TrainClassWeights;
0267:
0268:            /** Numeric class error estimator for prior */
0269:            protected Estimator m_PriorErrorEstimator;
0270:
0271:            /** Numeric class error estimator for scheme */
0272:            protected Estimator m_ErrorEstimator;
0273:
0274:            /**
0275:             * The minimum probablility accepted from an estimator to avoid
0276:             * taking log(0) in Sf calculations.
0277:             */
0278:            protected static final double MIN_SF_PROB = Double.MIN_VALUE;
0279:
0280:            /** Total entropy of prior predictions */
0281:            protected double m_SumPriorEntropy;
0282:
0283:            /** Total entropy of scheme predictions */
0284:            protected double m_SumSchemeEntropy;
0285:
0286:            /** The list of predictions that have been generated (for computing AUC) */
0287:            private FastVector m_Predictions;
0288:
0289:            /** enables/disables the use of priors, e.g., if no training set is
0290:             * present in case of de-serialized schemes */
0291:            protected boolean m_NoPriors = false;
0292:
0293:            /**
0294:             * Initializes all the counters for the evaluation. 
0295:             * Use <code>useNoPriors()</code> if the dataset is the test set and you
0296:             * can't initialize with the priors from the training set via 
0297:             * <code>setPriors(Instances)</code>.
0298:             *
0299:             * @param data 	set of training instances, to get some header 
0300:             * 			information and prior class distribution information
0301:             * @throws Exception 	if the class is not defined
0302:             * @see 		#useNoPriors()
0303:             * @see 		#setPriors(Instances)
0304:             */
0305:            public Evaluation(Instances data) throws Exception {
0306:
0307:                this (data, null);
0308:            }
0309:
0310:            /**
0311:             * Initializes all the counters for the evaluation and also takes a
0312:             * cost matrix as parameter.
0313:             * Use <code>useNoPriors()</code> if the dataset is the test set and you
0314:             * can't initialize with the priors from the training set via 
0315:             * <code>setPriors(Instances)</code>.
0316:             *
0317:             * @param data 	set of training instances, to get some header 
0318:             * 			information and prior class distribution information
0319:             * @param costMatrix 	the cost matrix---if null, default costs will be used
0320:             * @throws Exception 	if cost matrix is not compatible with 
0321:             * 			data, the class is not defined or the class is numeric
0322:             * @see 		#useNoPriors()
0323:             * @see 		#setPriors(Instances)
0324:             */
0325:            public Evaluation(Instances data, CostMatrix costMatrix)
0326:                    throws Exception {
0327:
0328:                m_NumClasses = data.numClasses();
0329:                m_NumFolds = 1;
0330:                m_ClassIsNominal = data.classAttribute().isNominal();
0331:
0332:                if (m_ClassIsNominal) {
0333:                    m_ConfusionMatrix = new double[m_NumClasses][m_NumClasses];
0334:                    m_ClassNames = new String[m_NumClasses];
0335:                    for (int i = 0; i < m_NumClasses; i++) {
0336:                        m_ClassNames[i] = data.classAttribute().value(i);
0337:                    }
0338:                }
0339:                m_CostMatrix = costMatrix;
0340:                if (m_CostMatrix != null) {
0341:                    if (!m_ClassIsNominal) {
0342:                        throw new Exception(
0343:                                "Class has to be nominal if cost matrix "
0344:                                        + "given!");
0345:                    }
0346:                    if (m_CostMatrix.size() != m_NumClasses) {
0347:                        throw new Exception(
0348:                                "Cost matrix not compatible with data!");
0349:                    }
0350:                }
0351:                m_ClassPriors = new double[m_NumClasses];
0352:                setPriors(data);
0353:                m_MarginCounts = new double[k_MarginResolution + 1];
0354:            }
0355:
0356:            /**
0357:             * Returns the area under ROC for those predictions that have been collected
0358:             * in the evaluateClassifier(Classifier, Instances) method. Returns 
0359:             * Instance.missingValue() if the area is not available.
0360:             *
0361:             * @param classIndex the index of the class to consider as "positive"
0362:             * @return the area under the ROC curve or not a number
0363:             */
0364:            public double areaUnderROC(int classIndex) {
0365:
0366:                // Check if any predictions have been collected
0367:                if (m_Predictions == null) {
0368:                    return Instance.missingValue();
0369:                } else {
0370:                    ThresholdCurve tc = new ThresholdCurve();
0371:                    Instances result = tc.getCurve(m_Predictions, classIndex);
0372:                    return ThresholdCurve.getROCArea(result);
0373:                }
0374:            }
0375:
0376:            /**
0377:             * Returns a copy of the confusion matrix.
0378:             *
0379:             * @return a copy of the confusion matrix as a two-dimensional array
0380:             */
0381:            public double[][] confusionMatrix() {
0382:
0383:                double[][] newMatrix = new double[m_ConfusionMatrix.length][0];
0384:
0385:                for (int i = 0; i < m_ConfusionMatrix.length; i++) {
0386:                    newMatrix[i] = new double[m_ConfusionMatrix[i].length];
0387:                    System.arraycopy(m_ConfusionMatrix[i], 0, newMatrix[i], 0,
0388:                            m_ConfusionMatrix[i].length);
0389:                }
0390:                return newMatrix;
0391:            }
0392:
0393:            /**
0394:             * Performs a (stratified if class is nominal) cross-validation 
0395:             * for a classifier on a set of instances. Now performs
0396:             * a deep copy of the classifier before each call to 
0397:             * buildClassifier() (just in case the classifier is not
0398:             * initialized properly).
0399:             *
0400:             * @param classifier the classifier with any options set.
0401:             * @param data the data on which the cross-validation is to be 
0402:             * performed 
0403:             * @param numFolds the number of folds for the cross-validation
0404:             * @param random random number generator for randomization 
0405:             * @throws Exception if a classifier could not be generated 
0406:             * successfully or the class is not defined
0407:             */
0408:            public void crossValidateModel(Classifier classifier,
0409:                    Instances data, int numFolds, Random random)
0410:                    throws Exception {
0411:
0412:                // Make a copy of the data we can reorder
0413:                data = new Instances(data);
0414:                data.randomize(random);
0415:                if (data.classAttribute().isNominal()) {
0416:                    data.stratify(numFolds);
0417:                }
0418:                // Do the folds
0419:                for (int i = 0; i < numFolds; i++) {
0420:                    Instances train = data.trainCV(numFolds, i, random);
0421:                    setPriors(train);
0422:                    Classifier copiedClassifier = Classifier
0423:                            .makeCopy(classifier);
0424:                    copiedClassifier.buildClassifier(train);
0425:                    Instances test = data.testCV(numFolds, i);
0426:                    evaluateModel(copiedClassifier, test);
0427:                }
0428:                m_NumFolds = numFolds;
0429:            }
0430:
0431:            /**
0432:             * Performs a (stratified if class is nominal) cross-validation 
0433:             * for a classifier on a set of instances.
0434:             *
0435:             * @param classifierString a string naming the class of the classifier
0436:             * @param data the data on which the cross-validation is to be 
0437:             * performed 
0438:             * @param numFolds the number of folds for the cross-validation
0439:             * @param options the options to the classifier. Any options
0440:             * @param random the random number generator for randomizing the data
0441:             * accepted by the classifier will be removed from this array.
0442:             * @throws Exception if a classifier could not be generated 
0443:             * successfully or the class is not defined
0444:             */
0445:            public void crossValidateModel(String classifierString,
0446:                    Instances data, int numFolds, String[] options,
0447:                    Random random) throws Exception {
0448:
0449:                crossValidateModel(Classifier
0450:                        .forName(classifierString, options), data, numFolds,
0451:                        random);
0452:            }
0453:
0454:            /**
0455:             * Evaluates a classifier with the options given in an array of
0456:             * strings. <p/>
0457:             *
0458:             * Valid options are: <p/>
0459:             *
0460:             * -t filename <br/>
0461:             * Name of the file with the training data. (required) <p/>
0462:             *
0463:             * -T filename <br/>
0464:             * Name of the file with the test data. If missing a cross-validation 
0465:             * is performed. <p/>
0466:             *
0467:             * -c index <br/>
0468:             * Index of the class attribute (1, 2, ...; default: last). <p/>
0469:             *
0470:             * -x number <br/>
0471:             * The number of folds for the cross-validation (default: 10). <p/>
0472:             *
0473:             * -no-cv <br/>
0474:             * No cross validation.  If no test file is provided, no evaluation
0475:             * is done. <p/>
0476:             * 
0477:             * -split-percentage percentage <br/>
0478:             * Sets the percentage for the train/test set split, e.g., 66. <p/>
0479:             * 
0480:             * -preserve-order <br/>
0481:             * Preserves the order in the percentage split instead of randomizing
0482:             * the data first with the seed value ('-s'). <p/>
0483:             *
0484:             * -s seed <br/>
0485:             * Random number seed for the cross-validation and percentage split
0486:             * (default: 1). <p/>
0487:             *
0488:             * -m filename <br/>
0489:             * The name of a file containing a cost matrix. <p/>
0490:             *
0491:             * -l filename <br/>
0492:             * Loads classifier from the given file. In case the filename ends with
0493:             * ".xml" the options are loaded from XML. <p/>
0494:             *
0495:             * -d filename <br/>
0496:             * Saves classifier built from the training data into the given file. In case 
0497:             * the filename ends with ".xml" the options are saved XML, not the model. <p/>
0498:             *
0499:             * -v <br/>
0500:             * Outputs no statistics for the training data. <p/>
0501:             *
0502:             * -o <br/>
0503:             * Outputs statistics only, not the classifier. <p/>
0504:             * 
0505:             * -i <br/>
0506:             * Outputs detailed information-retrieval statistics per class. <p/>
0507:             *
0508:             * -k <br/>
0509:             * Outputs information-theoretic statistics. <p/>
0510:             *
0511:             * -p range <br/>
0512:             * Outputs predictions for test instances (or the train instances if no test
0513:             * instances provided), along with the attributes in the specified range (and 
0514:             *  nothing else). Use '-p 0' if no attributes are desired. <p/>
0515:             *
0516:             * -distribution <br/>
0517:             * Outputs the distribution instead of only the prediction
0518:             * in conjunction with the '-p' option (only nominal classes). <p/>
0519:             *
0520:             * -r <br/>
0521:             * Outputs cumulative margin distribution (and nothing else). <p/>
0522:             *
0523:             * -g <br/> 
0524:             * Only for classifiers that implement "Graphable." Outputs
0525:             * the graph representation of the classifier (and nothing
0526:             * else). <p/>
0527:             *
0528:             * -xml filename | xml-string <br/>
0529:             * Retrieves the options from the XML-data instead of the command line. <p/>
0530:             * 
0531:             * -threshold-file file <br/>
0532:             * The file to save the threshold data to.
0533:             * The format is determined by the extensions, e.g., '.arff' for ARFF
0534:             * format or '.csv' for CSV. <p/>
0535:             *         
0536:             * -threshold-label label <br/>
0537:             * The class label to determine the threshold data for
0538:             * (default is the first label) <p/>
0539:             *
0540:             * @param classifierString class of machine learning classifier as a string
0541:             * @param options the array of string containing the options
0542:             * @throws Exception if model could not be evaluated successfully
0543:             * @return a string describing the results 
0544:             */
0545:            public static String evaluateModel(String classifierString,
0546:                    String[] options) throws Exception {
0547:
0548:                Classifier classifier;
0549:
0550:                // Create classifier
0551:                try {
0552:                    classifier = (Classifier) Class.forName(classifierString)
0553:                            .newInstance();
0554:                } catch (Exception e) {
0555:                    throw new Exception("Can't find class with name "
0556:                            + classifierString + '.');
0557:                }
0558:                return evaluateModel(classifier, options);
0559:            }
0560:
0561:            /**
0562:             * A test method for this class. Just extracts the first command line
0563:             * argument as a classifier class name and calls evaluateModel.
0564:             * @param args an array of command line arguments, the first of which
0565:             * must be the class name of a classifier.
0566:             */
0567:            public static void main(String[] args) {
0568:
0569:                try {
0570:                    if (args.length == 0) {
0571:                        throw new Exception(
0572:                                "The first argument must be the class name"
0573:                                        + " of a classifier");
0574:                    }
0575:                    String classifier = args[0];
0576:                    args[0] = "";
0577:                    System.out.println(evaluateModel(classifier, args));
0578:                } catch (Exception ex) {
0579:                    ex.printStackTrace();
0580:                    System.err.println(ex.getMessage());
0581:                }
0582:            }
0583:
0584:            /**
0585:             * Evaluates a classifier with the options given in an array of
0586:             * strings. <p/>
0587:             *
0588:             * Valid options are: <p/>
0589:             *
0590:             * -t name of training file <br/>
0591:             * Name of the file with the training data. (required) <p/>
0592:             *
0593:             * -T name of test file <br/>
0594:             * Name of the file with the test data. If missing a cross-validation 
0595:             * is performed. <p/>
0596:             *
0597:             * -c class index <br/>
0598:             * Index of the class attribute (1, 2, ...; default: last). <p/>
0599:             *
0600:             * -x number of folds <br/>
0601:             * The number of folds for the cross-validation (default: 10). <p/>
0602:             *
0603:             * -no-cv <br/>
0604:             * No cross validation.  If no test file is provided, no evaluation
0605:             * is done. <p/>
0606:             * 
0607:             * -split-percentage percentage <br/>
0608:             * Sets the percentage for the train/test set split, e.g., 66. <p/>
0609:             * 
0610:             * -preserve-order <br/>
0611:             * Preserves the order in the percentage split instead of randomizing
0612:             * the data first with the seed value ('-s'). <p/>
0613:             *
0614:             * -s seed <br/>
0615:             * Random number seed for the cross-validation and percentage split
0616:             * (default: 1). <p/>
0617:             *
0618:             * -m file with cost matrix <br/>
0619:             * The name of a file containing a cost matrix. <p/>
0620:             *
0621:             * -l filename <br/>
0622:             * Loads classifier from the given file. In case the filename ends with
0623:             * ".xml" the options are loaded from XML. <p/>
0624:             *
0625:             * -d filename <br/>
0626:             * Saves classifier built from the training data into the given file. In case 
0627:             * the filename ends with ".xml" the options are saved XML, not the model. <p/>
0628:             *
0629:             * -v <br/>
0630:             * Outputs no statistics for the training data. <p/>
0631:             *
0632:             * -o <br/>
0633:             * Outputs statistics only, not the classifier. <p/>
0634:             * 
0635:             * -i <br/>
0636:             * Outputs detailed information-retrieval statistics per class. <p/>
0637:             *
0638:             * -k <br/>
0639:             * Outputs information-theoretic statistics. <p/>
0640:             *
0641:             * -p range <br/>
0642:             * Outputs predictions for test instances (or the train instances if no test
0643:             * instances provided), along with the attributes in the specified range 
0644:             * (and nothing else). Use '-p 0' if no attributes are desired. <p/>
0645:             *
0646:             * -distribution <br/>
0647:             * Outputs the distribution instead of only the prediction
0648:             * in conjunction with the '-p' option (only nominal classes). <p/>
0649:             *
0650:             * -r <br/>
0651:             * Outputs cumulative margin distribution (and nothing else). <p/>
0652:             *
0653:             * -g <br/> 
0654:             * Only for classifiers that implement "Graphable." Outputs
0655:             * the graph representation of the classifier (and nothing
0656:             * else). <p/>
0657:             *
0658:             * -xml filename | xml-string <br/>
0659:             * Retrieves the options from the XML-data instead of the command line. <p/>
0660:             *
0661:             * @param classifier machine learning classifier
0662:             * @param options the array of string containing the options
0663:             * @throws Exception if model could not be evaluated successfully
0664:             * @return a string describing the results 
0665:             */
0666:            public static String evaluateModel(Classifier classifier,
0667:                    String[] options) throws Exception {
0668:
0669:                Instances train = null, tempTrain, test = null, template = null;
0670:                int seed = 1, folds = 10, classIndex = -1;
0671:                boolean noCrossValidation = false;
0672:                String trainFileName, testFileName, sourceClass, classIndexString, seedString, foldsString, objectInputFileName, objectOutputFileName, attributeRangeString;
0673:                boolean noOutput = false, printClassifications = false, trainStatistics = true, printMargins = false, printComplexityStatistics = false, printGraph = false, classStatistics = false, printSource = false;
0674:                StringBuffer text = new StringBuffer();
0675:                DataSource trainSource = null, testSource = null;
0676:                ObjectInputStream objectInputStream = null;
0677:                BufferedInputStream xmlInputStream = null;
0678:                CostMatrix costMatrix = null;
0679:                StringBuffer schemeOptionsText = null;
0680:                Range attributesToOutput = null;
0681:                long trainTimeStart = 0, trainTimeElapsed = 0, testTimeStart = 0, testTimeElapsed = 0;
0682:                String xml = "";
0683:                String[] optionsTmp = null;
0684:                Classifier classifierBackup;
0685:                Classifier classifierClassifications = null;
0686:                boolean printDistribution = false;
0687:                int actualClassIndex = -1; // 0-based class index
0688:                String splitPercentageString = "";
0689:                int splitPercentage = -1;
0690:                boolean preserveOrder = false;
0691:                boolean trainSetPresent = false;
0692:                boolean testSetPresent = false;
0693:                String thresholdFile;
0694:                String thresholdLabel;
0695:
0696:                // help requested?
0697:                if (Utils.getFlag("h", options)
0698:                        || Utils.getFlag("help", options)) {
0699:                    throw new Exception("\nHelp requested."
0700:                            + makeOptionString(classifier));
0701:                }
0702:
0703:                try {
0704:                    // do we get the input from XML instead of normal parameters?
0705:                    xml = Utils.getOption("xml", options);
0706:                    if (!xml.equals(""))
0707:                        options = new XMLOptions(xml).toArray();
0708:
0709:                    // is the input model only the XML-Options, i.e. w/o built model?
0710:                    optionsTmp = new String[options.length];
0711:                    for (int i = 0; i < options.length; i++)
0712:                        optionsTmp[i] = options[i];
0713:
0714:                    if (Utils.getOption('l', optionsTmp).toLowerCase()
0715:                            .endsWith(".xml")) {
0716:                        // load options from serialized data ('-l' is automatically erased!)
0717:                        XMLClassifier xmlserial = new XMLClassifier();
0718:                        Classifier cl = (Classifier) xmlserial.read(Utils
0719:                                .getOption('l', options));
0720:                        // merge options
0721:                        optionsTmp = new String[options.length
0722:                                + cl.getOptions().length];
0723:                        System.arraycopy(cl.getOptions(), 0, optionsTmp, 0, cl
0724:                                .getOptions().length);
0725:                        System.arraycopy(options, 0, optionsTmp, cl
0726:                                .getOptions().length, options.length);
0727:                        options = optionsTmp;
0728:                    }
0729:
0730:                    noCrossValidation = Utils.getFlag("no-cv", options);
0731:                    // Get basic options (options the same for all schemes)
0732:                    classIndexString = Utils.getOption('c', options);
0733:                    if (classIndexString.length() != 0) {
0734:                        if (classIndexString.equals("first"))
0735:                            classIndex = 1;
0736:                        else if (classIndexString.equals("last"))
0737:                            classIndex = -1;
0738:                        else
0739:                            classIndex = Integer.parseInt(classIndexString);
0740:                    }
0741:                    trainFileName = Utils.getOption('t', options);
0742:                    objectInputFileName = Utils.getOption('l', options);
0743:                    objectOutputFileName = Utils.getOption('d', options);
0744:                    testFileName = Utils.getOption('T', options);
0745:                    foldsString = Utils.getOption('x', options);
0746:                    if (foldsString.length() != 0) {
0747:                        folds = Integer.parseInt(foldsString);
0748:                    }
0749:                    seedString = Utils.getOption('s', options);
0750:                    if (seedString.length() != 0) {
0751:                        seed = Integer.parseInt(seedString);
0752:                    }
0753:                    if (trainFileName.length() == 0) {
0754:                        if (objectInputFileName.length() == 0) {
0755:                            throw new Exception(
0756:                                    "No training file and no object "
0757:                                            + "input file given.");
0758:                        }
0759:                        if (testFileName.length() == 0) {
0760:                            throw new Exception("No training file and no test "
0761:                                    + "file given.");
0762:                        }
0763:                    } else if ((objectInputFileName.length() != 0)
0764:                            && ((!(classifier instanceof  UpdateableClassifier)) || (testFileName
0765:                                    .length() == 0))) {
0766:                        throw new Exception(
0767:                                "Classifier not incremental, or no "
0768:                                        + "test file provided: can't "
0769:                                        + "use both train and model file.");
0770:                    }
0771:                    try {
0772:                        if (trainFileName.length() != 0) {
0773:                            trainSetPresent = true;
0774:                            trainSource = new DataSource(trainFileName);
0775:                        }
0776:                        if (testFileName.length() != 0) {
0777:                            testSetPresent = true;
0778:                            testSource = new DataSource(testFileName);
0779:                        }
0780:                        if (objectInputFileName.length() != 0) {
0781:                            InputStream is = new FileInputStream(
0782:                                    objectInputFileName);
0783:                            if (objectInputFileName.endsWith(".gz")) {
0784:                                is = new GZIPInputStream(is);
0785:                            }
0786:                            // load from KOML?
0787:                            if (!(objectInputFileName.endsWith(".koml") && KOML
0788:                                    .isPresent())) {
0789:                                objectInputStream = new ObjectInputStream(is);
0790:                                xmlInputStream = null;
0791:                            } else {
0792:                                objectInputStream = null;
0793:                                xmlInputStream = new BufferedInputStream(is);
0794:                            }
0795:                        }
0796:                    } catch (Exception e) {
0797:                        throw new Exception("Can't open file " + e.getMessage()
0798:                                + '.');
0799:                    }
0800:                    if (testSetPresent) {
0801:                        template = test = testSource.getStructure();
0802:                        if (classIndex != -1) {
0803:                            test.setClassIndex(classIndex - 1);
0804:                        } else {
0805:                            if ((test.classIndex() == -1)
0806:                                    || (classIndexString.length() != 0))
0807:                                test.setClassIndex(test.numAttributes() - 1);
0808:                        }
0809:                        actualClassIndex = test.classIndex();
0810:                    } else {
0811:                        // percentage split
0812:                        splitPercentageString = Utils.getOption(
0813:                                "split-percentage", options);
0814:                        if (splitPercentageString.length() != 0) {
0815:                            if (foldsString.length() != 0)
0816:                                throw new Exception(
0817:                                        "Percentage split cannot be used in conjunction with "
0818:                                                + "cross-validation ('-x').");
0819:                            splitPercentage = Integer
0820:                                    .parseInt(splitPercentageString);
0821:                            if ((splitPercentage <= 0)
0822:                                    || (splitPercentage >= 100))
0823:                                throw new Exception(
0824:                                        "Percentage split value needs be >0 and <100.");
0825:                        } else {
0826:                            splitPercentage = -1;
0827:                        }
0828:                        preserveOrder = Utils
0829:                                .getFlag("preserve-order", options);
0830:                        if (preserveOrder) {
0831:                            if (splitPercentage == -1)
0832:                                throw new Exception(
0833:                                        "Percentage split ('-percentage-split') is missing.");
0834:                        }
0835:                        // create new train/test sources
0836:                        if (splitPercentage > 0) {
0837:                            testSetPresent = true;
0838:                            Instances tmpInst = trainSource
0839:                                    .getDataSet(actualClassIndex);
0840:                            if (!preserveOrder)
0841:                                tmpInst.randomize(new Random(seed));
0842:                            int trainSize = tmpInst.numInstances()
0843:                                    * splitPercentage / 100;
0844:                            int testSize = tmpInst.numInstances() - trainSize;
0845:                            Instances trainInst = new Instances(tmpInst, 0,
0846:                                    trainSize);
0847:                            Instances testInst = new Instances(tmpInst,
0848:                                    trainSize, testSize);
0849:                            trainSource = new DataSource(trainInst);
0850:                            testSource = new DataSource(testInst);
0851:                            template = test = testSource.getStructure();
0852:                            if (classIndex != -1) {
0853:                                test.setClassIndex(classIndex - 1);
0854:                            } else {
0855:                                if ((test.classIndex() == -1)
0856:                                        || (classIndexString.length() != 0))
0857:                                    test
0858:                                            .setClassIndex(test.numAttributes() - 1);
0859:                            }
0860:                            actualClassIndex = test.classIndex();
0861:                        }
0862:                    }
0863:                    if (trainSetPresent) {
0864:                        template = train = trainSource.getStructure();
0865:                        if (classIndex != -1) {
0866:                            train.setClassIndex(classIndex - 1);
0867:                        } else {
0868:                            if ((train.classIndex() == -1)
0869:                                    || (classIndexString.length() != 0))
0870:                                train.setClassIndex(train.numAttributes() - 1);
0871:                        }
0872:                        actualClassIndex = train.classIndex();
0873:                        if ((testSetPresent) && !test.equalHeaders(train)) {
0874:                            throw new IllegalArgumentException(
0875:                                    "Train and test file not compatible!");
0876:                        }
0877:                    }
0878:                    if (template == null) {
0879:                        throw new Exception(
0880:                                "No actual dataset provided to use as template");
0881:                    }
0882:                    costMatrix = handleCostOption(
0883:                            Utils.getOption('m', options), template
0884:                                    .numClasses());
0885:
0886:                    classStatistics = Utils.getFlag('i', options);
0887:                    noOutput = Utils.getFlag('o', options);
0888:                    trainStatistics = !Utils.getFlag('v', options);
0889:                    printComplexityStatistics = Utils.getFlag('k', options);
0890:                    printMargins = Utils.getFlag('r', options);
0891:                    printGraph = Utils.getFlag('g', options);
0892:                    sourceClass = Utils.getOption('z', options);
0893:                    printSource = (sourceClass.length() != 0);
0894:                    printDistribution = Utils.getFlag("distribution", options);
0895:                    thresholdFile = Utils.getOption("threshold-file", options);
0896:                    thresholdLabel = Utils
0897:                            .getOption("threshold-label", options);
0898:
0899:                    // Check -p option
0900:                    try {
0901:                        attributeRangeString = Utils.getOption('p', options);
0902:                    } catch (Exception e) {
0903:                        throw new Exception(
0904:                                e.getMessage()
0905:                                        + "\nNOTE: the -p option has changed. "
0906:                                        + "It now expects a parameter specifying a range of attributes "
0907:                                        + "to list with the predictions. Use '-p 0' for none.");
0908:                    }
0909:                    if (attributeRangeString.length() != 0) {
0910:                        printClassifications = true;
0911:                        if (!attributeRangeString.equals("0"))
0912:                            attributesToOutput = new Range(attributeRangeString);
0913:                    }
0914:
0915:                    if (!printClassifications && printDistribution)
0916:                        throw new Exception(
0917:                                "Cannot print distribution without '-p' option!");
0918:
0919:                    // if no training file given, we don't have any priors
0920:                    if ((!trainSetPresent) && (printComplexityStatistics))
0921:                        throw new Exception(
0922:                                "Cannot print complexity statistics ('-k') without training file ('-t')!");
0923:
0924:                    // If a model file is given, we can't process 
0925:                    // scheme-specific options
0926:                    if (objectInputFileName.length() != 0) {
0927:                        Utils.checkForRemainingOptions(options);
0928:                    } else {
0929:
0930:                        // Set options for classifier
0931:                        if (classifier instanceof  OptionHandler) {
0932:                            for (int i = 0; i < options.length; i++) {
0933:                                if (options[i].length() != 0) {
0934:                                    if (schemeOptionsText == null) {
0935:                                        schemeOptionsText = new StringBuffer();
0936:                                    }
0937:                                    if (options[i].indexOf(' ') != -1) {
0938:                                        schemeOptionsText.append('"'
0939:                                                + options[i] + "\" ");
0940:                                    } else {
0941:                                        schemeOptionsText.append(options[i]
0942:                                                + " ");
0943:                                    }
0944:                                }
0945:                            }
0946:                            ((OptionHandler) classifier).setOptions(options);
0947:                        }
0948:                    }
0949:                    Utils.checkForRemainingOptions(options);
0950:                } catch (Exception e) {
0951:                    throw new Exception("\nWeka exception: " + e.getMessage()
0952:                            + makeOptionString(classifier));
0953:                }
0954:
0955:                // Setup up evaluation objects
0956:                Evaluation trainingEvaluation = new Evaluation(new Instances(
0957:                        template, 0), costMatrix);
0958:                Evaluation testingEvaluation = new Evaluation(new Instances(
0959:                        template, 0), costMatrix);
0960:
0961:                // disable use of priors if no training file given
0962:                if (!trainSetPresent)
0963:                    testingEvaluation.useNoPriors();
0964:
0965:                if (objectInputFileName.length() != 0) {
0966:                    // Load classifier from file
0967:                    if (objectInputStream != null) {
0968:                        classifier = (Classifier) objectInputStream
0969:                                .readObject();
0970:                        objectInputStream.close();
0971:                    } else {
0972:                        // whether KOML is available has already been checked (objectInputStream would null otherwise)!
0973:                        classifier = (Classifier) KOML.read(xmlInputStream);
0974:                        xmlInputStream.close();
0975:                    }
0976:                }
0977:
0978:                // backup of fully setup classifier for cross-validation
0979:                classifierBackup = Classifier.makeCopy(classifier);
0980:
0981:                // Build the classifier if no object file provided
0982:                if ((classifier instanceof  UpdateableClassifier)
0983:                        && (testSetPresent) && (costMatrix == null)
0984:                        && (trainSetPresent)) {
0985:
0986:                    // Build classifier incrementally
0987:                    trainingEvaluation.setPriors(train);
0988:                    testingEvaluation.setPriors(train);
0989:                    trainTimeStart = System.currentTimeMillis();
0990:                    if (objectInputFileName.length() == 0) {
0991:                        classifier.buildClassifier(train);
0992:                    }
0993:                    Instance trainInst;
0994:                    while (trainSource.hasMoreElements(train)) {
0995:                        trainInst = trainSource.nextElement(train);
0996:                        trainingEvaluation.updatePriors(trainInst);
0997:                        testingEvaluation.updatePriors(trainInst);
0998:                        ((UpdateableClassifier) classifier)
0999:                                .updateClassifier(trainInst);
1000:                    }
1001:                    trainTimeElapsed = System.currentTimeMillis()
1002:                            - trainTimeStart;
1003:                } else if (objectInputFileName.length() == 0) {
1004:                    // Build classifier in one go
1005:                    tempTrain = trainSource.getDataSet(actualClassIndex);
1006:                    trainingEvaluation.setPriors(tempTrain);
1007:                    testingEvaluation.setPriors(tempTrain);
1008:                    trainTimeStart = System.currentTimeMillis();
1009:                    classifier.buildClassifier(tempTrain);
1010:                    trainTimeElapsed = System.currentTimeMillis()
1011:                            - trainTimeStart;
1012:                }
1013:
1014:                // backup of fully trained classifier for printing the classifications
1015:                if (printClassifications)
1016:                    classifierClassifications = Classifier.makeCopy(classifier);
1017:
1018:                // Save the classifier if an object output file is provided
1019:                if (objectOutputFileName.length() != 0) {
1020:                    OutputStream os = new FileOutputStream(objectOutputFileName);
1021:                    // binary
1022:                    if (!(objectOutputFileName.endsWith(".xml") || (objectOutputFileName
1023:                            .endsWith(".koml") && KOML.isPresent()))) {
1024:                        if (objectOutputFileName.endsWith(".gz")) {
1025:                            os = new GZIPOutputStream(os);
1026:                        }
1027:                        ObjectOutputStream objectOutputStream = new ObjectOutputStream(
1028:                                os);
1029:                        objectOutputStream.writeObject(classifier);
1030:                        objectOutputStream.flush();
1031:                        objectOutputStream.close();
1032:                    }
1033:                    // KOML/XML
1034:                    else {
1035:                        BufferedOutputStream xmlOutputStream = new BufferedOutputStream(
1036:                                os);
1037:                        if (objectOutputFileName.endsWith(".xml")) {
1038:                            XMLSerialization xmlSerial = new XMLClassifier();
1039:                            xmlSerial.write(xmlOutputStream, classifier);
1040:                        } else
1041:                        // whether KOML is present has already been checked
1042:                        // if not present -> ".koml" is interpreted as binary - see above
1043:                        if (objectOutputFileName.endsWith(".koml")) {
1044:                            KOML.write(xmlOutputStream, classifier);
1045:                        }
1046:                        xmlOutputStream.close();
1047:                    }
1048:                }
1049:
1050:                // If classifier is drawable output string describing graph
1051:                if ((classifier instanceof  Drawable) && (printGraph)) {
1052:                    return ((Drawable) classifier).graph();
1053:                }
1054:
1055:                // Output the classifier as equivalent source
1056:                if ((classifier instanceof  Sourcable) && (printSource)) {
1057:                    return wekaStaticWrapper((Sourcable) classifier,
1058:                            sourceClass);
1059:                }
1060:
1061:                // Output model
1062:                if (!(noOutput || printMargins)) {
1063:                    if (classifier instanceof  OptionHandler) {
1064:                        if (schemeOptionsText != null) {
1065:                            text.append("\nOptions: " + schemeOptionsText);
1066:                            text.append("\n");
1067:                        }
1068:                    }
1069:                    text.append("\n" + classifier.toString() + "\n");
1070:                }
1071:
1072:                if (!printMargins && (costMatrix != null)) {
1073:                    text.append("\n=== Evaluation Cost Matrix ===\n\n");
1074:                    text.append(costMatrix.toString());
1075:                }
1076:
1077:                // Output test instance predictions only
1078:                if (printClassifications) {
1079:                    DataSource source = testSource;
1080:                    // no test set -> use train set
1081:                    if (source == null)
1082:                        source = trainSource;
1083:                    return printClassifications(classifierClassifications,
1084:                            new Instances(template, 0), source,
1085:                            actualClassIndex + 1, attributesToOutput,
1086:                            printDistribution);
1087:                }
1088:
1089:                // Compute error estimate from training data
1090:                if ((trainStatistics) && (trainSetPresent)) {
1091:
1092:                    if ((classifier instanceof  UpdateableClassifier)
1093:                            && (testSetPresent) && (costMatrix == null)) {
1094:
1095:                        // Classifier was trained incrementally, so we have to 
1096:                        // reset the source.
1097:                        trainSource.reset();
1098:
1099:                        // Incremental testing
1100:                        train = trainSource.getStructure(actualClassIndex);
1101:                        testTimeStart = System.currentTimeMillis();
1102:                        Instance trainInst;
1103:                        while (trainSource.hasMoreElements(train)) {
1104:                            trainInst = trainSource.nextElement(train);
1105:                            trainingEvaluation.evaluateModelOnce(
1106:                                    (Classifier) classifier, trainInst);
1107:                        }
1108:                        testTimeElapsed = System.currentTimeMillis()
1109:                                - testTimeStart;
1110:                    } else {
1111:                        testTimeStart = System.currentTimeMillis();
1112:                        trainingEvaluation.evaluateModel(classifier,
1113:                                trainSource.getDataSet(actualClassIndex));
1114:                        testTimeElapsed = System.currentTimeMillis()
1115:                                - testTimeStart;
1116:                    }
1117:
1118:                    // Print the results of the training evaluation
1119:                    if (printMargins) {
1120:                        return trainingEvaluation
1121:                                .toCumulativeMarginDistributionString();
1122:                    } else {
1123:                        text.append("\nTime taken to build model: "
1124:                                + Utils.doubleToString(
1125:                                        trainTimeElapsed / 1000.0, 2)
1126:                                + " seconds");
1127:
1128:                        if (splitPercentage > 0)
1129:                            text
1130:                                    .append("\nTime taken to test model on training split: ");
1131:                        else
1132:                            text
1133:                                    .append("\nTime taken to test model on training data: ");
1134:                        text.append(Utils.doubleToString(
1135:                                testTimeElapsed / 1000.0, 2)
1136:                                + " seconds");
1137:
1138:                        if (splitPercentage > 0)
1139:                            text.append(trainingEvaluation.toSummaryString(
1140:                                    "\n\n=== Error on training"
1141:                                            + " split ===\n",
1142:                                    printComplexityStatistics));
1143:                        else
1144:                            text.append(trainingEvaluation
1145:                                    .toSummaryString(
1146:                                            "\n\n=== Error on training"
1147:                                                    + " data ===\n",
1148:                                            printComplexityStatistics));
1149:
1150:                        if (template.classAttribute().isNominal()) {
1151:                            if (classStatistics) {
1152:                                text.append("\n\n"
1153:                                        + trainingEvaluation
1154:                                                .toClassDetailsString());
1155:                            }
1156:                            if (!noCrossValidation)
1157:                                text.append("\n\n"
1158:                                        + trainingEvaluation.toMatrixString());
1159:                        }
1160:
1161:                    }
1162:                }
1163:
1164:                // Compute proper error estimates
1165:                if (testSource != null) {
1166:                    // Testing is on the supplied test data
1167:                    Instance testInst;
1168:                    while (testSource.hasMoreElements(test)) {
1169:                        testInst = testSource.nextElement(test);
1170:                        testingEvaluation.evaluateModelOnceAndRecordPrediction(
1171:                                (Classifier) classifier, testInst);
1172:                    }
1173:
1174:                    if (splitPercentage > 0)
1175:                        text.append("\n\n"
1176:                                + testingEvaluation.toSummaryString(
1177:                                        "=== Error on test split ===\n",
1178:                                        printComplexityStatistics));
1179:                    else
1180:                        text.append("\n\n"
1181:                                + testingEvaluation.toSummaryString(
1182:                                        "=== Error on test data ===\n",
1183:                                        printComplexityStatistics));
1184:
1185:                } else if (trainSource != null) {
1186:                    if (!noCrossValidation) {
1187:                        // Testing is via cross-validation on training data
1188:                        Random random = new Random(seed);
1189:                        // use untrained (!) classifier for cross-validation
1190:                        classifier = Classifier.makeCopy(classifierBackup);
1191:                        testingEvaluation.crossValidateModel(classifier,
1192:                                trainSource.getDataSet(actualClassIndex),
1193:                                folds, random);
1194:                        if (template.classAttribute().isNumeric()) {
1195:                            text.append("\n\n\n"
1196:                                    + testingEvaluation.toSummaryString(
1197:                                            "=== Cross-validation ===\n",
1198:                                            printComplexityStatistics));
1199:                        } else {
1200:                            text.append("\n\n\n"
1201:                                    + testingEvaluation.toSummaryString(
1202:                                            "=== Stratified "
1203:                                                    + "cross-validation ===\n",
1204:                                            printComplexityStatistics));
1205:                        }
1206:                    }
1207:                }
1208:                if (template.classAttribute().isNominal()) {
1209:                    if (classStatistics) {
1210:                        text.append("\n\n"
1211:                                + testingEvaluation.toClassDetailsString());
1212:                    }
1213:                    if (!noCrossValidation)
1214:                        text
1215:                                .append("\n\n"
1216:                                        + testingEvaluation.toMatrixString());
1217:                }
1218:
1219:                if ((thresholdFile.length() != 0)
1220:                        && template.classAttribute().isNominal()) {
1221:                    int labelIndex = 0;
1222:                    if (thresholdLabel.length() != 0)
1223:                        labelIndex = template.classAttribute().indexOfValue(
1224:                                thresholdLabel);
1225:                    if (labelIndex == -1)
1226:                        throw new IllegalArgumentException("Class label '"
1227:                                + thresholdLabel + "' is unknown!");
1228:                    ThresholdCurve tc = new ThresholdCurve();
1229:                    Instances result = tc.getCurve(testingEvaluation
1230:                            .predictions(), labelIndex);
1231:                    DataSink.write(thresholdFile, result);
1232:                }
1233:
1234:                return text.toString();
1235:            }
1236:
1237:            /**
1238:             * Attempts to load a cost matrix.
1239:             *
1240:             * @param costFileName the filename of the cost matrix
1241:             * @param numClasses the number of classes that should be in the cost matrix
1242:             * (only used if the cost file is in old format).
1243:             * @return a <code>CostMatrix</code> value, or null if costFileName is empty
1244:             * @throws Exception if an error occurs.
1245:             */
1246:            protected static CostMatrix handleCostOption(String costFileName,
1247:                    int numClasses) throws Exception {
1248:
1249:                if ((costFileName != null) && (costFileName.length() != 0)) {
1250:                    System.out
1251:                            .println("NOTE: The behaviour of the -m option has changed between WEKA 3.0"
1252:                                    + " and WEKA 3.1. -m now carries out cost-sensitive *evaluation*"
1253:                                    + " only. For cost-sensitive *prediction*, use one of the"
1254:                                    + " cost-sensitive metaschemes such as"
1255:                                    + " weka.classifiers.meta.CostSensitiveClassifier or"
1256:                                    + " weka.classifiers.meta.MetaCost");
1257:
1258:                    Reader costReader = null;
1259:                    try {
1260:                        costReader = new BufferedReader(new FileReader(
1261:                                costFileName));
1262:                    } catch (Exception e) {
1263:                        throw new Exception("Can't open file " + e.getMessage()
1264:                                + '.');
1265:                    }
1266:                    try {
1267:                        // First try as a proper cost matrix format
1268:                        return new CostMatrix(costReader);
1269:                    } catch (Exception ex) {
1270:                        try {
1271:                            // Now try as the poxy old format :-)
1272:                            //System.err.println("Attempting to read old format cost file");
1273:                            try {
1274:                                costReader.close(); // Close the old one
1275:                                costReader = new BufferedReader(new FileReader(
1276:                                        costFileName));
1277:                            } catch (Exception e) {
1278:                                throw new Exception("Can't open file "
1279:                                        + e.getMessage() + '.');
1280:                            }
1281:                            CostMatrix costMatrix = new CostMatrix(numClasses);
1282:                            //System.err.println("Created default cost matrix");
1283:                            costMatrix.readOldFormat(costReader);
1284:                            return costMatrix;
1285:                            //System.err.println("Read old format");
1286:                        } catch (Exception e2) {
1287:                            // re-throw the original exception
1288:                            //System.err.println("Re-throwing original exception");
1289:                            throw ex;
1290:                        }
1291:                    }
1292:                } else {
1293:                    return null;
1294:                }
1295:            }
1296:
1297:            /**
1298:             * Evaluates the classifier on a given set of instances. Note that
1299:             * the data must have exactly the same format (e.g. order of
1300:             * attributes) as the data used to train the classifier! Otherwise
1301:             * the results will generally be meaningless.
1302:             *
1303:             * @param classifier machine learning classifier
1304:             * @param data set of test instances for evaluation
1305:             * @return the predictions
1306:             * @throws Exception if model could not be evaluated 
1307:             * successfully 
1308:             */
1309:            public double[] evaluateModel(Classifier classifier, Instances data)
1310:                    throws Exception {
1311:
1312:                double predictions[] = new double[data.numInstances()];
1313:
1314:                // Need to be able to collect predictions if appropriate (for AUC)
1315:
1316:                for (int i = 0; i < data.numInstances(); i++) {
1317:                    predictions[i] = evaluateModelOnceAndRecordPrediction(
1318:                            (Classifier) classifier, data.instance(i));
1319:                }
1320:
1321:                return predictions;
1322:            }
1323:
1324:            /**
1325:             * Evaluates the classifier on a single instance and records the
1326:             * prediction (if the class is nominal).
1327:             *
1328:             * @param classifier machine learning classifier
1329:             * @param instance the test instance to be classified
1330:             * @return the prediction made by the clasifier
1331:             * @throws Exception if model could not be evaluated 
1332:             * successfully or the data contains string attributes
1333:             */
1334:            public double evaluateModelOnceAndRecordPrediction(
1335:                    Classifier classifier, Instance instance) throws Exception {
1336:
1337:                Instance classMissing = (Instance) instance.copy();
1338:                double pred = 0;
1339:                classMissing.setDataset(instance.dataset());
1340:                classMissing.setClassMissing();
1341:                if (m_ClassIsNominal) {
1342:                    if (m_Predictions == null) {
1343:                        m_Predictions = new FastVector();
1344:                    }
1345:                    double[] dist = classifier
1346:                            .distributionForInstance(classMissing);
1347:                    pred = Utils.maxIndex(dist);
1348:                    if (dist[(int) pred] <= 0) {
1349:                        pred = Instance.missingValue();
1350:                    }
1351:                    updateStatsForClassifier(dist, instance);
1352:                    m_Predictions.addElement(new NominalPrediction(instance
1353:                            .classValue(), dist, instance.weight()));
1354:                } else {
1355:                    pred = classifier.classifyInstance(classMissing);
1356:                    updateStatsForPredictor(pred, instance);
1357:                }
1358:                return pred;
1359:            }
1360:
1361:            /**
1362:             * Evaluates the classifier on a single instance.
1363:             *
1364:             * @param classifier machine learning classifier
1365:             * @param instance the test instance to be classified
1366:             * @return the prediction made by the clasifier
1367:             * @throws Exception if model could not be evaluated 
1368:             * successfully or the data contains string attributes
1369:             */
1370:            public double evaluateModelOnce(Classifier classifier,
1371:                    Instance instance) throws Exception {
1372:
1373:                Instance classMissing = (Instance) instance.copy();
1374:                double pred = 0;
1375:                classMissing.setDataset(instance.dataset());
1376:                classMissing.setClassMissing();
1377:                if (m_ClassIsNominal) {
1378:                    double[] dist = classifier
1379:                            .distributionForInstance(classMissing);
1380:                    pred = Utils.maxIndex(dist);
1381:                    if (dist[(int) pred] <= 0) {
1382:                        pred = Instance.missingValue();
1383:                    }
1384:                    updateStatsForClassifier(dist, instance);
1385:                } else {
1386:                    pred = classifier.classifyInstance(classMissing);
1387:                    updateStatsForPredictor(pred, instance);
1388:                }
1389:                return pred;
1390:            }
1391:
1392:            /**
1393:             * Evaluates the supplied distribution on a single instance.
1394:             *
1395:             * @param dist the supplied distribution
1396:             * @param instance the test instance to be classified
1397:             * @return the prediction
1398:             * @throws Exception if model could not be evaluated 
1399:             * successfully
1400:             */
1401:            public double evaluateModelOnce(double[] dist, Instance instance)
1402:                    throws Exception {
1403:                double pred;
1404:                if (m_ClassIsNominal) {
1405:                    pred = Utils.maxIndex(dist);
1406:                    if (dist[(int) pred] <= 0) {
1407:                        pred = Instance.missingValue();
1408:                    }
1409:                    updateStatsForClassifier(dist, instance);
1410:                } else {
1411:                    pred = dist[0];
1412:                    updateStatsForPredictor(pred, instance);
1413:                }
1414:                return pred;
1415:            }
1416:
1417:            /**
1418:             * Evaluates the supplied distribution on a single instance.
1419:             *
1420:             * @param dist the supplied distribution
1421:             * @param instance the test instance to be classified
1422:             * @return the prediction
1423:             * @throws Exception if model could not be evaluated 
1424:             * successfully
1425:             */
1426:            public double evaluateModelOnceAndRecordPrediction(double[] dist,
1427:                    Instance instance) throws Exception {
1428:                double pred;
1429:                if (m_ClassIsNominal) {
1430:                    if (m_Predictions == null) {
1431:                        m_Predictions = new FastVector();
1432:                    }
1433:                    pred = Utils.maxIndex(dist);
1434:                    if (dist[(int) pred] <= 0) {
1435:                        pred = Instance.missingValue();
1436:                    }
1437:                    updateStatsForClassifier(dist, instance);
1438:                    m_Predictions.addElement(new NominalPrediction(instance
1439:                            .classValue(), dist, instance.weight()));
1440:                } else {
1441:                    pred = dist[0];
1442:                    updateStatsForPredictor(pred, instance);
1443:                }
1444:                return pred;
1445:            }
1446:
1447:            /**
1448:             * Evaluates the supplied prediction on a single instance.
1449:             *
1450:             * @param prediction the supplied prediction
1451:             * @param instance the test instance to be classified
1452:             * @throws Exception if model could not be evaluated 
1453:             * successfully
1454:             */
1455:            public void evaluateModelOnce(double prediction, Instance instance)
1456:                    throws Exception {
1457:
1458:                if (m_ClassIsNominal) {
1459:                    updateStatsForClassifier(makeDistribution(prediction),
1460:                            instance);
1461:                } else {
1462:                    updateStatsForPredictor(prediction, instance);
1463:                }
1464:            }
1465:
1466:            /**
1467:             * Returns the predictions that have been collected.
1468:             *
1469:             * @return a reference to the FastVector containing the predictions
1470:             * that have been collected. This should be null if no predictions
1471:             * have been collected (e.g. if the class is numeric).
1472:             */
1473:            public FastVector predictions() {
1474:
1475:                return m_Predictions;
1476:            }
1477:
1478:            /**
1479:             * Wraps a static classifier in enough source to test using the weka
1480:             * class libraries.
1481:             *
1482:             * @param classifier a Sourcable Classifier
1483:             * @param className the name to give to the source code class
1484:             * @return the source for a static classifier that can be tested with
1485:             * weka libraries.
1486:             * @throws Exception if code-generation fails
1487:             */
1488:            protected static String wekaStaticWrapper(Sourcable classifier,
1489:                    String className) throws Exception {
1490:
1491:                //String className = "StaticClassifier";
1492:                String staticClassifier = classifier.toSource(className);
1493:                return "package weka.classifiers;\n\n"
1494:                        + "import weka.core.Attribute;\n"
1495:                        + "import weka.core.Instance;\n"
1496:                        + "import weka.core.Instances;\n"
1497:                        + "import weka.classifiers.Classifier;\n\n"
1498:                        + "public class WekaWrapper extends Classifier {\n\n"
1499:                        + "  public void buildClassifier(Instances i) throws Exception {\n"
1500:                        + "  }\n\n"
1501:                        + "  public double classifyInstance(Instance i) throws Exception {\n\n"
1502:                        + "    Object [] s = new Object [i.numAttributes()];\n"
1503:                        + "    for (int j = 0; j < s.length; j++) {\n"
1504:                        + "      if (!i.isMissing(j)) {\n"
1505:                        + "        if (i.attribute(j).type() == Attribute.NOMINAL) {\n"
1506:                        + "          s[j] = i.attribute(j).value((int) i.value(j));\n"
1507:                        + "        } else if (i.attribute(j).type() == Attribute.NUMERIC) {\n"
1508:                        + "          s[j] = new Double(i.value(j));\n"
1509:                        + "        }\n" + "      }\n" + "    }\n"
1510:                        + "    return " + className + ".classify(s);\n"
1511:                        + "  }\n\n" + "}\n\n" + staticClassifier; // The static classifer class
1512:            }
1513:
1514:            /**
1515:             * Gets the number of test instances that had a known class value
1516:             * (actually the sum of the weights of test instances with known 
1517:             * class value).
1518:             *
1519:             * @return the number of test instances with known class
1520:             */
1521:            public final double numInstances() {
1522:
1523:                return m_WithClass;
1524:            }
1525:
1526:            /**
1527:             * Gets the number of instances incorrectly classified (that is, for
1528:             * which an incorrect prediction was made). (Actually the sum of the weights
1529:             * of these instances)
1530:             *
1531:             * @return the number of incorrectly classified instances 
1532:             */
1533:            public final double incorrect() {
1534:
1535:                return m_Incorrect;
1536:            }
1537:
1538:            /**
1539:             * Gets the percentage of instances incorrectly classified (that is, for
1540:             * which an incorrect prediction was made).
1541:             *
1542:             * @return the percent of incorrectly classified instances 
1543:             * (between 0 and 100)
1544:             */
1545:            public final double pctIncorrect() {
1546:
1547:                return 100 * m_Incorrect / m_WithClass;
1548:            }
1549:
1550:            /**
1551:             * Gets the total cost, that is, the cost of each prediction times the
1552:             * weight of the instance, summed over all instances.
1553:             *
1554:             * @return the total cost
1555:             */
1556:            public final double totalCost() {
1557:
1558:                return m_TotalCost;
1559:            }
1560:
1561:            /**
1562:             * Gets the average cost, that is, total cost of misclassifications
1563:             * (incorrect plus unclassified) over the total number of instances.
1564:             *
1565:             * @return the average cost.  
1566:             */
1567:            public final double avgCost() {
1568:
1569:                return m_TotalCost / m_WithClass;
1570:            }
1571:
1572:            /**
1573:             * Gets the number of instances correctly classified (that is, for
1574:             * which a correct prediction was made). (Actually the sum of the weights
1575:             * of these instances)
1576:             *
1577:             * @return the number of correctly classified instances
1578:             */
1579:            public final double correct() {
1580:
1581:                return m_Correct;
1582:            }
1583:
1584:            /**
1585:             * Gets the percentage of instances correctly classified (that is, for
1586:             * which a correct prediction was made).
1587:             *
1588:             * @return the percent of correctly classified instances (between 0 and 100)
1589:             */
1590:            public final double pctCorrect() {
1591:
1592:                return 100 * m_Correct / m_WithClass;
1593:            }
1594:
1595:            /**
1596:             * Gets the number of instances not classified (that is, for
1597:             * which no prediction was made by the classifier). (Actually the sum
1598:             * of the weights of these instances)
1599:             *
1600:             * @return the number of unclassified instances
1601:             */
1602:            public final double unclassified() {
1603:
1604:                return m_Unclassified;
1605:            }
1606:
1607:            /**
1608:             * Gets the percentage of instances not classified (that is, for
1609:             * which no prediction was made by the classifier).
1610:             *
1611:             * @return the percent of unclassified instances (between 0 and 100)
1612:             */
1613:            public final double pctUnclassified() {
1614:
1615:                return 100 * m_Unclassified / m_WithClass;
1616:            }
1617:
1618:            /**
1619:             * Returns the estimated error rate or the root mean squared error
1620:             * (if the class is numeric). If a cost matrix was given this
1621:             * error rate gives the average cost.
1622:             *
1623:             * @return the estimated error rate (between 0 and 1, or between 0 and 
1624:             * maximum cost)
1625:             */
1626:            public final double errorRate() {
1627:
1628:                if (!m_ClassIsNominal) {
1629:                    return Math.sqrt(m_SumSqrErr / m_WithClass);
1630:                }
1631:                if (m_CostMatrix == null) {
1632:                    return m_Incorrect / m_WithClass;
1633:                } else {
1634:                    return avgCost();
1635:                }
1636:            }
1637:
1638:            /**
1639:             * Returns value of kappa statistic if class is nominal.
1640:             *
1641:             * @return the value of the kappa statistic
1642:             */
1643:            public final double kappa() {
1644:
1645:                double[] sumRows = new double[m_ConfusionMatrix.length];
1646:                double[] sumColumns = new double[m_ConfusionMatrix.length];
1647:                double sumOfWeights = 0;
1648:                for (int i = 0; i < m_ConfusionMatrix.length; i++) {
1649:                    for (int j = 0; j < m_ConfusionMatrix.length; j++) {
1650:                        sumRows[i] += m_ConfusionMatrix[i][j];
1651:                        sumColumns[j] += m_ConfusionMatrix[i][j];
1652:                        sumOfWeights += m_ConfusionMatrix[i][j];
1653:                    }
1654:                }
1655:                double correct = 0, chanceAgreement = 0;
1656:                for (int i = 0; i < m_ConfusionMatrix.length; i++) {
1657:                    chanceAgreement += (sumRows[i] * sumColumns[i]);
1658:                    correct += m_ConfusionMatrix[i][i];
1659:                }
1660:                chanceAgreement /= (sumOfWeights * sumOfWeights);
1661:                correct /= sumOfWeights;
1662:
1663:                if (chanceAgreement < 1) {
1664:                    return (correct - chanceAgreement) / (1 - chanceAgreement);
1665:                } else {
1666:                    return 1;
1667:                }
1668:            }
1669:
1670:            /**
1671:             * Returns the correlation coefficient if the class is numeric.
1672:             *
1673:             * @return the correlation coefficient
1674:             * @throws Exception if class is not numeric
1675:             */
1676:            public final double correlationCoefficient() throws Exception {
1677:
1678:                if (m_ClassIsNominal) {
1679:                    throw new Exception(
1680:                            "Can't compute correlation coefficient: "
1681:                                    + "class is nominal!");
1682:                }
1683:
1684:                double correlation = 0;
1685:                double varActual = m_SumSqrClass - m_SumClass * m_SumClass
1686:                        / m_WithClass;
1687:                double varPredicted = m_SumSqrPredicted - m_SumPredicted
1688:                        * m_SumPredicted / m_WithClass;
1689:                double varProd = m_SumClassPredicted - m_SumClass
1690:                        * m_SumPredicted / m_WithClass;
1691:
1692:                if (varActual * varPredicted <= 0) {
1693:                    correlation = 0.0;
1694:                } else {
1695:                    correlation = varProd / Math.sqrt(varActual * varPredicted);
1696:                }
1697:
1698:                return correlation;
1699:            }
1700:
1701:            /**
1702:             * Returns the mean absolute error. Refers to the error of the
1703:             * predicted values for numeric classes, and the error of the 
1704:             * predicted probability distribution for nominal classes.
1705:             *
1706:             * @return the mean absolute error 
1707:             */
1708:            public final double meanAbsoluteError() {
1709:
1710:                return m_SumAbsErr / m_WithClass;
1711:            }
1712:
1713:            /**
1714:             * Returns the mean absolute error of the prior.
1715:             *
1716:             * @return the mean absolute error 
1717:             */
1718:            public final double meanPriorAbsoluteError() {
1719:
1720:                if (m_NoPriors)
1721:                    return Double.NaN;
1722:
1723:                return m_SumPriorAbsErr / m_WithClass;
1724:            }
1725:
1726:            /**
1727:             * Returns the relative absolute error.
1728:             *
1729:             * @return the relative absolute error 
1730:             * @throws Exception if it can't be computed
1731:             */
1732:            public final double relativeAbsoluteError() throws Exception {
1733:
1734:                if (m_NoPriors)
1735:                    return Double.NaN;
1736:
1737:                return 100 * meanAbsoluteError() / meanPriorAbsoluteError();
1738:            }
1739:
1740:            /**
1741:             * Returns the root mean squared error.
1742:             *
1743:             * @return the root mean squared error 
1744:             */
1745:            public final double rootMeanSquaredError() {
1746:
1747:                return Math.sqrt(m_SumSqrErr / m_WithClass);
1748:            }
1749:
1750:            /**
1751:             * Returns the root mean prior squared error.
1752:             *
1753:             * @return the root mean prior squared error 
1754:             */
1755:            public final double rootMeanPriorSquaredError() {
1756:
1757:                if (m_NoPriors)
1758:                    return Double.NaN;
1759:
1760:                return Math.sqrt(m_SumPriorSqrErr / m_WithClass);
1761:            }
1762:
1763:            /**
1764:             * Returns the root relative squared error if the class is numeric.
1765:             *
1766:             * @return the root relative squared error 
1767:             */
1768:            public final double rootRelativeSquaredError() {
1769:
1770:                if (m_NoPriors)
1771:                    return Double.NaN;
1772:
1773:                return 100.0 * rootMeanSquaredError()
1774:                        / rootMeanPriorSquaredError();
1775:            }
1776:
1777:            /**
1778:             * Calculate the entropy of the prior distribution
1779:             *
1780:             * @return the entropy of the prior distribution
1781:             * @throws Exception if the class is not nominal
1782:             */
1783:            public final double priorEntropy() throws Exception {
1784:
1785:                if (!m_ClassIsNominal) {
1786:                    throw new Exception(
1787:                            "Can't compute entropy of class prior: "
1788:                                    + "class numeric!");
1789:                }
1790:
1791:                if (m_NoPriors)
1792:                    return Double.NaN;
1793:
1794:                double entropy = 0;
1795:                for (int i = 0; i < m_NumClasses; i++) {
1796:                    entropy -= m_ClassPriors[i] / m_ClassPriorsSum
1797:                            * Utils.log2(m_ClassPriors[i] / m_ClassPriorsSum);
1798:                }
1799:                return entropy;
1800:            }
1801:
1802:            /**
1803:             * Return the total Kononenko & Bratko Information score in bits
1804:             *
1805:             * @return the K&B information score
1806:             * @throws Exception if the class is not nominal
1807:             */
1808:            public final double KBInformation() throws Exception {
1809:
1810:                if (!m_ClassIsNominal) {
1811:                    throw new Exception("Can't compute K&B Info score: "
1812:                            + "class numeric!");
1813:                }
1814:
1815:                if (m_NoPriors)
1816:                    return Double.NaN;
1817:
1818:                return m_SumKBInfo;
1819:            }
1820:
1821:            /**
1822:             * Return the Kononenko & Bratko Information score in bits per 
1823:             * instance.
1824:             *
1825:             * @return the K&B information score
1826:             * @throws Exception if the class is not nominal
1827:             */
1828:            public final double KBMeanInformation() throws Exception {
1829:
1830:                if (!m_ClassIsNominal) {
1831:                    throw new Exception("Can't compute K&B Info score: "
1832:                            + "class numeric!");
1833:                }
1834:
1835:                if (m_NoPriors)
1836:                    return Double.NaN;
1837:
1838:                return m_SumKBInfo / m_WithClass;
1839:            }
1840:
1841:            /**
1842:             * Return the Kononenko & Bratko Relative Information score
1843:             *
1844:             * @return the K&B relative information score
1845:             * @throws Exception if the class is not nominal
1846:             */
1847:            public final double KBRelativeInformation() throws Exception {
1848:
1849:                if (!m_ClassIsNominal) {
1850:                    throw new Exception("Can't compute K&B Info score: "
1851:                            + "class numeric!");
1852:                }
1853:
1854:                if (m_NoPriors)
1855:                    return Double.NaN;
1856:
1857:                return 100.0 * KBInformation() / priorEntropy();
1858:            }
1859:
1860:            /**
1861:             * Returns the total entropy for the null model
1862:             * 
1863:             * @return the total null model entropy
1864:             */
1865:            public final double SFPriorEntropy() {
1866:
1867:                if (m_NoPriors)
1868:                    return Double.NaN;
1869:
1870:                return m_SumPriorEntropy;
1871:            }
1872:
1873:            /**
1874:             * Returns the entropy per instance for the null model
1875:             * 
1876:             * @return the null model entropy per instance
1877:             */
1878:            public final double SFMeanPriorEntropy() {
1879:
1880:                if (m_NoPriors)
1881:                    return Double.NaN;
1882:
1883:                return m_SumPriorEntropy / m_WithClass;
1884:            }
1885:
1886:            /**
1887:             * Returns the total entropy for the scheme
1888:             * 
1889:             * @return the total scheme entropy
1890:             */
1891:            public final double SFSchemeEntropy() {
1892:
1893:                if (m_NoPriors)
1894:                    return Double.NaN;
1895:
1896:                return m_SumSchemeEntropy;
1897:            }
1898:
1899:            /**
1900:             * Returns the entropy per instance for the scheme
1901:             * 
1902:             * @return the scheme entropy per instance
1903:             */
1904:            public final double SFMeanSchemeEntropy() {
1905:
1906:                if (m_NoPriors)
1907:                    return Double.NaN;
1908:
1909:                return m_SumSchemeEntropy / m_WithClass;
1910:            }
1911:
1912:            /**
1913:             * Returns the total SF, which is the null model entropy minus
1914:             * the scheme entropy.
1915:             * 
1916:             * @return the total SF
1917:             */
1918:            public final double SFEntropyGain() {
1919:
1920:                if (m_NoPriors)
1921:                    return Double.NaN;
1922:
1923:                return m_SumPriorEntropy - m_SumSchemeEntropy;
1924:            }
1925:
1926:            /**
1927:             * Returns the SF per instance, which is the null model entropy
1928:             * minus the scheme entropy, per instance.
1929:             * 
1930:             * @return the SF per instance
1931:             */
1932:            public final double SFMeanEntropyGain() {
1933:
1934:                if (m_NoPriors)
1935:                    return Double.NaN;
1936:
1937:                return (m_SumPriorEntropy - m_SumSchemeEntropy) / m_WithClass;
1938:            }
1939:
1940:            /**
1941:             * Output the cumulative margin distribution as a string suitable
1942:             * for input for gnuplot or similar package.
1943:             *
1944:             * @return the cumulative margin distribution
1945:             * @throws Exception if the class attribute is nominal
1946:             */
1947:            public String toCumulativeMarginDistributionString()
1948:                    throws Exception {
1949:
1950:                if (!m_ClassIsNominal) {
1951:                    throw new Exception(
1952:                            "Class must be nominal for margin distributions");
1953:                }
1954:                String result = "";
1955:                double cumulativeCount = 0;
1956:                double margin;
1957:                for (int i = 0; i <= k_MarginResolution; i++) {
1958:                    if (m_MarginCounts[i] != 0) {
1959:                        cumulativeCount += m_MarginCounts[i];
1960:                        margin = (double) i * 2.0 / k_MarginResolution - 1.0;
1961:                        result = result
1962:                                + Utils.doubleToString(margin, 7, 3)
1963:                                + ' '
1964:                                + Utils.doubleToString(cumulativeCount * 100
1965:                                        / m_WithClass, 7, 3) + '\n';
1966:                    } else if (i == 0) {
1967:                        result = Utils.doubleToString(-1.0, 7, 3) + ' '
1968:                                + Utils.doubleToString(0, 7, 3) + '\n';
1969:                    }
1970:                }
1971:                return result;
1972:            }
1973:
1974:            /**
1975:             * Calls toSummaryString() with no title and no complexity stats
1976:             *
1977:             * @return a summary description of the classifier evaluation
1978:             */
1979:            public String toSummaryString() {
1980:
1981:                return toSummaryString("", false);
1982:            }
1983:
1984:            /**
1985:             * Calls toSummaryString() with a default title.
1986:             *
1987:             * @param printComplexityStatistics if true, complexity statistics are
1988:             * returned as well
1989:             * @return the summary string
1990:             */
1991:            public String toSummaryString(boolean printComplexityStatistics) {
1992:
1993:                return toSummaryString("=== Summary ===\n",
1994:                        printComplexityStatistics);
1995:            }
1996:
1997:            /**
1998:             * Outputs the performance statistics in summary form. Lists 
1999:             * number (and percentage) of instances classified correctly, 
2000:             * incorrectly and unclassified. Outputs the total number of 
2001:             * instances classified, and the number of instances (if any) 
2002:             * that had no class value provided. 
2003:             *
2004:             * @param title the title for the statistics
2005:             * @param printComplexityStatistics if true, complexity statistics are
2006:             * returned as well
2007:             * @return the summary as a String
2008:             */
2009:            public String toSummaryString(String title,
2010:                    boolean printComplexityStatistics) {
2011:
2012:                StringBuffer text = new StringBuffer();
2013:
2014:                if (printComplexityStatistics && m_NoPriors) {
2015:                    printComplexityStatistics = false;
2016:                    System.err
2017:                            .println("Priors disabled, cannot print complexity statistics!");
2018:                }
2019:
2020:                text.append(title + "\n");
2021:                try {
2022:                    if (m_WithClass > 0) {
2023:                        if (m_ClassIsNominal) {
2024:
2025:                            text.append("Correctly Classified Instances     ");
2026:                            text.append(Utils.doubleToString(correct(), 12, 4)
2027:                                    + "     "
2028:                                    + Utils.doubleToString(pctCorrect(), 12, 4)
2029:                                    + " %\n");
2030:                            text.append("Incorrectly Classified Instances   ");
2031:                            text.append(Utils
2032:                                    .doubleToString(incorrect(), 12, 4)
2033:                                    + "     "
2034:                                    + Utils.doubleToString(pctIncorrect(), 12,
2035:                                            4) + " %\n");
2036:                            text.append("Kappa statistic                    ");
2037:                            text.append(Utils.doubleToString(kappa(), 12, 4)
2038:                                    + "\n");
2039:
2040:                            if (m_CostMatrix != null) {
2041:                                text
2042:                                        .append("Total Cost                         ");
2043:                                text.append(Utils.doubleToString(totalCost(),
2044:                                        12, 4)
2045:                                        + "\n");
2046:                                text
2047:                                        .append("Average Cost                       ");
2048:                                text.append(Utils.doubleToString(avgCost(), 12,
2049:                                        4)
2050:                                        + "\n");
2051:                            }
2052:                            if (printComplexityStatistics) {
2053:                                text
2054:                                        .append("K&B Relative Info Score            ");
2055:                                text.append(Utils.doubleToString(
2056:                                        KBRelativeInformation(), 12, 4)
2057:                                        + " %\n");
2058:                                text
2059:                                        .append("K&B Information Score              ");
2060:                                text.append(Utils.doubleToString(
2061:                                        KBInformation(), 12, 4)
2062:                                        + " bits");
2063:                                text.append(Utils.doubleToString(
2064:                                        KBMeanInformation(), 12, 4)
2065:                                        + " bits/instance\n");
2066:                            }
2067:                        } else {
2068:                            text.append("Correlation coefficient            ");
2069:                            text.append(Utils.doubleToString(
2070:                                    correlationCoefficient(), 12, 4)
2071:                                    + "\n");
2072:                        }
2073:                        if (printComplexityStatistics) {
2074:                            text.append("Class complexity | order 0         ");
2075:                            text.append(Utils.doubleToString(SFPriorEntropy(),
2076:                                    12, 4)
2077:                                    + " bits");
2078:                            text.append(Utils.doubleToString(
2079:                                    SFMeanPriorEntropy(), 12, 4)
2080:                                    + " bits/instance\n");
2081:                            text.append("Class complexity | scheme          ");
2082:                            text.append(Utils.doubleToString(SFSchemeEntropy(),
2083:                                    12, 4)
2084:                                    + " bits");
2085:                            text.append(Utils.doubleToString(
2086:                                    SFMeanSchemeEntropy(), 12, 4)
2087:                                    + " bits/instance\n");
2088:                            text.append("Complexity improvement     (Sf)    ");
2089:                            text.append(Utils.doubleToString(SFEntropyGain(),
2090:                                    12, 4)
2091:                                    + " bits");
2092:                            text.append(Utils.doubleToString(
2093:                                    SFMeanEntropyGain(), 12, 4)
2094:                                    + " bits/instance\n");
2095:                        }
2096:
2097:                        text.append("Mean absolute error                ");
2098:                        text.append(Utils.doubleToString(meanAbsoluteError(),
2099:                                12, 4)
2100:                                + "\n");
2101:                        text.append("Root mean squared error            ");
2102:                        text.append(Utils.doubleToString(
2103:                                rootMeanSquaredError(), 12, 4)
2104:                                + "\n");
2105:                        if (!m_NoPriors) {
2106:                            text.append("Relative absolute error            ");
2107:                            text.append(Utils.doubleToString(
2108:                                    relativeAbsoluteError(), 12, 4)
2109:                                    + " %\n");
2110:                            text.append("Root relative squared error        ");
2111:                            text.append(Utils.doubleToString(
2112:                                    rootRelativeSquaredError(), 12, 4)
2113:                                    + " %\n");
2114:                        }
2115:                    }
2116:                    if (Utils.gr(unclassified(), 0)) {
2117:                        text.append("UnClassified Instances             ");
2118:                        text.append(Utils.doubleToString(unclassified(), 12, 4)
2119:                                + "     "
2120:                                + Utils
2121:                                        .doubleToString(pctUnclassified(), 12,
2122:                                                4) + " %\n");
2123:                    }
2124:                    text.append("Total Number of Instances          ");
2125:                    text
2126:                            .append(Utils.doubleToString(m_WithClass, 12, 4)
2127:                                    + "\n");
2128:                    if (m_MissingClass > 0) {
2129:                        text
2130:                                .append("Ignored Class Unknown Instances            ");
2131:                        text.append(Utils.doubleToString(m_MissingClass, 12, 4)
2132:                                + "\n");
2133:                    }
2134:                } catch (Exception ex) {
2135:                    // Should never occur since the class is known to be nominal 
2136:                    // here
2137:                    System.err
2138:                            .println("Arggh - Must be a bug in Evaluation class");
2139:                }
2140:
2141:                return text.toString();
2142:            }
2143:
2144:            /**
2145:             * Calls toMatrixString() with a default title.
2146:             *
2147:             * @return the confusion matrix as a string
2148:             * @throws Exception if the class is numeric
2149:             */
2150:            public String toMatrixString() throws Exception {
2151:
2152:                return toMatrixString("=== Confusion Matrix ===\n");
2153:            }
2154:
2155:            /**
2156:             * Outputs the performance statistics as a classification confusion
2157:             * matrix. For each class value, shows the distribution of 
2158:             * predicted class values.
2159:             *
2160:             * @param title the title for the confusion matrix
2161:             * @return the confusion matrix as a String
2162:             * @throws Exception if the class is numeric
2163:             */
2164:            public String toMatrixString(String title) throws Exception {
2165:
2166:                StringBuffer text = new StringBuffer();
2167:                char[] IDChars = { 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i',
2168:                        'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't',
2169:                        'u', 'v', 'w', 'x', 'y', 'z' };
2170:                int IDWidth;
2171:                boolean fractional = false;
2172:
2173:                if (!m_ClassIsNominal) {
2174:                    throw new Exception(
2175:                            "Evaluation: No confusion matrix possible!");
2176:                }
2177:
2178:                // Find the maximum value in the matrix
2179:                // and check for fractional display requirement 
2180:                double maxval = 0;
2181:                for (int i = 0; i < m_NumClasses; i++) {
2182:                    for (int j = 0; j < m_NumClasses; j++) {
2183:                        double current = m_ConfusionMatrix[i][j];
2184:                        if (current < 0) {
2185:                            current *= -10;
2186:                        }
2187:                        if (current > maxval) {
2188:                            maxval = current;
2189:                        }
2190:                        double fract = current - Math.rint(current);
2191:                        if (!fractional
2192:                                && ((Math.log(fract) / Math.log(10)) >= -2)) {
2193:                            fractional = true;
2194:                        }
2195:                    }
2196:                }
2197:
2198:                IDWidth = 1 + Math
2199:                        .max(
2200:                                (int) (Math.log(maxval) / Math.log(10) + (fractional ? 3
2201:                                        : 0)),
2202:                                (int) (Math.log(m_NumClasses) / Math
2203:                                        .log(IDChars.length)));
2204:                text.append(title).append("\n");
2205:                for (int i = 0; i < m_NumClasses; i++) {
2206:                    if (fractional) {
2207:                        text.append(" ").append(
2208:                                num2ShortID(i, IDChars, IDWidth - 3)).append(
2209:                                "   ");
2210:                    } else {
2211:                        text.append(" ").append(
2212:                                num2ShortID(i, IDChars, IDWidth));
2213:                    }
2214:                }
2215:                text.append("   <-- classified as\n");
2216:                for (int i = 0; i < m_NumClasses; i++) {
2217:                    for (int j = 0; j < m_NumClasses; j++) {
2218:                        text.append(" ").append(
2219:                                Utils.doubleToString(m_ConfusionMatrix[i][j],
2220:                                        IDWidth, (fractional ? 2 : 0)));
2221:                    }
2222:                    text.append(" | ").append(num2ShortID(i, IDChars, IDWidth))
2223:                            .append(" = ").append(m_ClassNames[i]).append("\n");
2224:                }
2225:                return text.toString();
2226:            }
2227:
2228:            /**
2229:             * Generates a breakdown of the accuracy for each class (with default title),
2230:             * incorporating various information-retrieval statistics, such as
2231:             * true/false positive rate, precision/recall/F-Measure.  Should be
2232:             * useful for ROC curves, recall/precision curves.  
2233:             * 
2234:             * @return the statistics presented as a string
2235:             * @throws Exception if class is not nominal
2236:             */
2237:            public String toClassDetailsString() throws Exception {
2238:
2239:                return toClassDetailsString("=== Detailed Accuracy By Class ===\n");
2240:            }
2241:
2242:            /**
2243:             * Generates a breakdown of the accuracy for each class,
2244:             * incorporating various information-retrieval statistics, such as
2245:             * true/false positive rate, precision/recall/F-Measure.  Should be
2246:             * useful for ROC curves, recall/precision curves.  
2247:             * 
2248:             * @param title the title to prepend the stats string with 
2249:             * @return the statistics presented as a string
2250:             * @throws Exception if class is not nominal
2251:             */
2252:            public String toClassDetailsString(String title) throws Exception {
2253:
2254:                if (!m_ClassIsNominal) {
2255:                    throw new Exception(
2256:                            "Evaluation: No confusion matrix possible!");
2257:                }
2258:                StringBuffer text = new StringBuffer(title
2259:                        + "\nTP Rate   FP Rate" + "   Precision   Recall"
2260:                        + "  F-Measure   ROC Area  Class\n");
2261:                for (int i = 0; i < m_NumClasses; i++) {
2262:                    text
2263:                            .append(
2264:                                    Utils.doubleToString(truePositiveRate(i),
2265:                                            7, 3)).append("   ");
2266:                    text.append(
2267:                            Utils.doubleToString(falsePositiveRate(i), 7, 3))
2268:                            .append("    ");
2269:                    text.append(Utils.doubleToString(precision(i), 7, 3))
2270:                            .append("   ");
2271:                    text.append(Utils.doubleToString(recall(i), 7, 3)).append(
2272:                            "   ");
2273:                    text.append(Utils.doubleToString(fMeasure(i), 7, 3))
2274:                            .append("    ");
2275:                    double rocVal = areaUnderROC(i);
2276:                    if (Instance.isMissingValue(rocVal)) {
2277:                        text.append("  ?    ").append("    ");
2278:                    } else {
2279:                        text.append(Utils.doubleToString(rocVal, 7, 3)).append(
2280:                                "    ");
2281:                    }
2282:                    text.append(m_ClassNames[i]).append('\n');
2283:                }
2284:                return text.toString();
2285:            }
2286:
2287:            /**
2288:             * Calculate the number of true positives with respect to a particular class. 
2289:             * This is defined as<p/>
2290:             * <pre>
2291:             * correctly classified positives
2292:             * </pre>
2293:             *
2294:             * @param classIndex the index of the class to consider as "positive"
2295:             * @return the true positive rate
2296:             */
2297:            public double numTruePositives(int classIndex) {
2298:
2299:                double correct = 0;
2300:                for (int j = 0; j < m_NumClasses; j++) {
2301:                    if (j == classIndex) {
2302:                        correct += m_ConfusionMatrix[classIndex][j];
2303:                    }
2304:                }
2305:                return correct;
2306:            }
2307:
2308:            /**
2309:             * Calculate the true positive rate with respect to a particular class. 
2310:             * This is defined as<p/>
2311:             * <pre>
2312:             * correctly classified positives
2313:             * ------------------------------
2314:             *       total positives
2315:             * </pre>
2316:             *
2317:             * @param classIndex the index of the class to consider as "positive"
2318:             * @return the true positive rate
2319:             */
2320:            public double truePositiveRate(int classIndex) {
2321:
2322:                double correct = 0, total = 0;
2323:                for (int j = 0; j < m_NumClasses; j++) {
2324:                    if (j == classIndex) {
2325:                        correct += m_ConfusionMatrix[classIndex][j];
2326:                    }
2327:                    total += m_ConfusionMatrix[classIndex][j];
2328:                }
2329:                if (total == 0) {
2330:                    return 0;
2331:                }
2332:                return correct / total;
2333:            }
2334:
2335:            /**
2336:             * Calculate the number of true negatives with respect to a particular class. 
2337:             * This is defined as<p/>
2338:             * <pre>
2339:             * correctly classified negatives
2340:             * </pre>
2341:             *
2342:             * @param classIndex the index of the class to consider as "positive"
2343:             * @return the true positive rate
2344:             */
2345:            public double numTrueNegatives(int classIndex) {
2346:
2347:                double correct = 0;
2348:                for (int i = 0; i < m_NumClasses; i++) {
2349:                    if (i != classIndex) {
2350:                        for (int j = 0; j < m_NumClasses; j++) {
2351:                            if (j != classIndex) {
2352:                                correct += m_ConfusionMatrix[i][j];
2353:                            }
2354:                        }
2355:                    }
2356:                }
2357:                return correct;
2358:            }
2359:
2360:            /**
2361:             * Calculate the true negative rate with respect to a particular class. 
2362:             * This is defined as<p/>
2363:             * <pre>
2364:             * correctly classified negatives
2365:             * ------------------------------
2366:             *       total negatives
2367:             * </pre>
2368:             *
2369:             * @param classIndex the index of the class to consider as "positive"
2370:             * @return the true positive rate
2371:             */
2372:            public double trueNegativeRate(int classIndex) {
2373:
2374:                double correct = 0, total = 0;
2375:                for (int i = 0; i < m_NumClasses; i++) {
2376:                    if (i != classIndex) {
2377:                        for (int j = 0; j < m_NumClasses; j++) {
2378:                            if (j != classIndex) {
2379:                                correct += m_ConfusionMatrix[i][j];
2380:                            }
2381:                            total += m_ConfusionMatrix[i][j];
2382:                        }
2383:                    }
2384:                }
2385:                if (total == 0) {
2386:                    return 0;
2387:                }
2388:                return correct / total;
2389:            }
2390:
2391:            /**
2392:             * Calculate number of false positives with respect to a particular class. 
2393:             * This is defined as<p/>
2394:             * <pre>
2395:             * incorrectly classified negatives
2396:             * </pre>
2397:             *
2398:             * @param classIndex the index of the class to consider as "positive"
2399:             * @return the false positive rate
2400:             */
2401:            public double numFalsePositives(int classIndex) {
2402:
2403:                double incorrect = 0;
2404:                for (int i = 0; i < m_NumClasses; i++) {
2405:                    if (i != classIndex) {
2406:                        for (int j = 0; j < m_NumClasses; j++) {
2407:                            if (j == classIndex) {
2408:                                incorrect += m_ConfusionMatrix[i][j];
2409:                            }
2410:                        }
2411:                    }
2412:                }
2413:                return incorrect;
2414:            }
2415:
2416:            /**
2417:             * Calculate the false positive rate with respect to a particular class. 
2418:             * This is defined as<p/>
2419:             * <pre>
2420:             * incorrectly classified negatives
2421:             * --------------------------------
2422:             *        total negatives
2423:             * </pre>
2424:             *
2425:             * @param classIndex the index of the class to consider as "positive"
2426:             * @return the false positive rate
2427:             */
2428:            public double falsePositiveRate(int classIndex) {
2429:
2430:                double incorrect = 0, total = 0;
2431:                for (int i = 0; i < m_NumClasses; i++) {
2432:                    if (i != classIndex) {
2433:                        for (int j = 0; j < m_NumClasses; j++) {
2434:                            if (j == classIndex) {
2435:                                incorrect += m_ConfusionMatrix[i][j];
2436:                            }
2437:                            total += m_ConfusionMatrix[i][j];
2438:                        }
2439:                    }
2440:                }
2441:                if (total == 0) {
2442:                    return 0;
2443:                }
2444:                return incorrect / total;
2445:            }
2446:
2447:            /**
2448:             * Calculate number of false negatives with respect to a particular class. 
2449:             * This is defined as<p/>
2450:             * <pre>
2451:             * incorrectly classified positives
2452:             * </pre>
2453:             *
2454:             * @param classIndex the index of the class to consider as "positive"
2455:             * @return the false positive rate
2456:             */
2457:            public double numFalseNegatives(int classIndex) {
2458:
2459:                double incorrect = 0;
2460:                for (int i = 0; i < m_NumClasses; i++) {
2461:                    if (i == classIndex) {
2462:                        for (int j = 0; j < m_NumClasses; j++) {
2463:                            if (j != classIndex) {
2464:                                incorrect += m_ConfusionMatrix[i][j];
2465:                            }
2466:                        }
2467:                    }
2468:                }
2469:                return incorrect;
2470:            }
2471:
2472:            /**
2473:             * Calculate the false negative rate with respect to a particular class. 
2474:             * This is defined as<p/>
2475:             * <pre>
2476:             * incorrectly classified positives
2477:             * --------------------------------
2478:             *        total positives
2479:             * </pre>
2480:             *
2481:             * @param classIndex the index of the class to consider as "positive"
2482:             * @return the false positive rate
2483:             */
2484:            public double falseNegativeRate(int classIndex) {
2485:
2486:                double incorrect = 0, total = 0;
2487:                for (int i = 0; i < m_NumClasses; i++) {
2488:                    if (i == classIndex) {
2489:                        for (int j = 0; j < m_NumClasses; j++) {
2490:                            if (j != classIndex) {
2491:                                incorrect += m_ConfusionMatrix[i][j];
2492:                            }
2493:                            total += m_ConfusionMatrix[i][j];
2494:                        }
2495:                    }
2496:                }
2497:                if (total == 0) {
2498:                    return 0;
2499:                }
2500:                return incorrect / total;
2501:            }
2502:
2503:            /**
2504:             * Calculate the recall with respect to a particular class. 
2505:             * This is defined as<p/>
2506:             * <pre>
2507:             * correctly classified positives
2508:             * ------------------------------
2509:             *       total positives
2510:             * </pre><p/>
2511:             * (Which is also the same as the truePositiveRate.)
2512:             *
2513:             * @param classIndex the index of the class to consider as "positive"
2514:             * @return the recall
2515:             */
2516:            public double recall(int classIndex) {
2517:
2518:                return truePositiveRate(classIndex);
2519:            }
2520:
2521:            /**
2522:             * Calculate the precision with respect to a particular class. 
2523:             * This is defined as<p/>
2524:             * <pre>
2525:             * correctly classified positives
2526:             * ------------------------------
2527:             *  total predicted as positive
2528:             * </pre>
2529:             *
2530:             * @param classIndex the index of the class to consider as "positive"
2531:             * @return the precision
2532:             */
2533:            public double precision(int classIndex) {
2534:
2535:                double correct = 0, total = 0;
2536:                for (int i = 0; i < m_NumClasses; i++) {
2537:                    if (i == classIndex) {
2538:                        correct += m_ConfusionMatrix[i][classIndex];
2539:                    }
2540:                    total += m_ConfusionMatrix[i][classIndex];
2541:                }
2542:                if (total == 0) {
2543:                    return 0;
2544:                }
2545:                return correct / total;
2546:            }
2547:
2548:            /**
2549:             * Calculate the F-Measure with respect to a particular class. 
2550:             * This is defined as<p/>
2551:             * <pre>
2552:             * 2 * recall * precision
2553:             * ----------------------
2554:             *   recall + precision
2555:             * </pre>
2556:             *
2557:             * @param classIndex the index of the class to consider as "positive"
2558:             * @return the F-Measure
2559:             */
2560:            public double fMeasure(int classIndex) {
2561:
2562:                double precision = precision(classIndex);
2563:                double recall = recall(classIndex);
2564:                if ((precision + recall) == 0) {
2565:                    return 0;
2566:                }
2567:                return 2 * precision * recall / (precision + recall);
2568:            }
2569:
2570:            /**
2571:             * Sets the class prior probabilities
2572:             *
2573:             * @param train the training instances used to determine
2574:             * the prior probabilities
2575:             * @throws Exception if the class attribute of the instances is not
2576:             * set
2577:             */
2578:            public void setPriors(Instances train) throws Exception {
2579:                m_NoPriors = false;
2580:
2581:                if (!m_ClassIsNominal) {
2582:
2583:                    m_NumTrainClassVals = 0;
2584:                    m_TrainClassVals = null;
2585:                    m_TrainClassWeights = null;
2586:                    m_PriorErrorEstimator = null;
2587:                    m_ErrorEstimator = null;
2588:
2589:                    for (int i = 0; i < train.numInstances(); i++) {
2590:                        Instance currentInst = train.instance(i);
2591:                        if (!currentInst.classIsMissing()) {
2592:                            addNumericTrainClass(currentInst.classValue(),
2593:                                    currentInst.weight());
2594:                        }
2595:                    }
2596:
2597:                } else {
2598:                    for (int i = 0; i < m_NumClasses; i++) {
2599:                        m_ClassPriors[i] = 1;
2600:                    }
2601:                    m_ClassPriorsSum = m_NumClasses;
2602:                    for (int i = 0; i < train.numInstances(); i++) {
2603:                        if (!train.instance(i).classIsMissing()) {
2604:                            m_ClassPriors[(int) train.instance(i).classValue()] += train
2605:                                    .instance(i).weight();
2606:                            m_ClassPriorsSum += train.instance(i).weight();
2607:                        }
2608:                    }
2609:                }
2610:            }
2611:
2612:            /**
2613:             * Get the current weighted class counts
2614:             * 
2615:             * @return the weighted class counts
2616:             */
2617:            public double[] getClassPriors() {
2618:                return m_ClassPriors;
2619:            }
2620:
2621:            /**
2622:             * Updates the class prior probabilities (when incrementally 
2623:             * training)
2624:             *
2625:             * @param instance the new training instance seen
2626:             * @throws Exception if the class of the instance is not
2627:             * set
2628:             */
2629:            public void updatePriors(Instance instance) throws Exception {
2630:                if (!instance.classIsMissing()) {
2631:                    if (!m_ClassIsNominal) {
2632:                        if (!instance.classIsMissing()) {
2633:                            addNumericTrainClass(instance.classValue(),
2634:                                    instance.weight());
2635:                        }
2636:                    } else {
2637:                        m_ClassPriors[(int) instance.classValue()] += instance
2638:                                .weight();
2639:                        m_ClassPriorsSum += instance.weight();
2640:                    }
2641:                }
2642:            }
2643:
2644:            /**
2645:             * disables the use of priors, e.g., in case of de-serialized schemes
2646:             * that have no access to the original training set, but are evaluated
2647:             * on a set set.
2648:             */
2649:            public void useNoPriors() {
2650:                m_NoPriors = true;
2651:            }
2652:
2653:            /**
2654:             * Tests whether the current evaluation object is equal to another
2655:             * evaluation object
2656:             *
2657:             * @param obj the object to compare against
2658:             * @return true if the two objects are equal
2659:             */
2660:            public boolean equals(Object obj) {
2661:
2662:                if ((obj == null) || !(obj.getClass().equals(this .getClass()))) {
2663:                    return false;
2664:                }
2665:                Evaluation cmp = (Evaluation) obj;
2666:                if (m_ClassIsNominal != cmp.m_ClassIsNominal)
2667:                    return false;
2668:                if (m_NumClasses != cmp.m_NumClasses)
2669:                    return false;
2670:
2671:                if (m_Incorrect != cmp.m_Incorrect)
2672:                    return false;
2673:                if (m_Correct != cmp.m_Correct)
2674:                    return false;
2675:                if (m_Unclassified != cmp.m_Unclassified)
2676:                    return false;
2677:                if (m_MissingClass != cmp.m_MissingClass)
2678:                    return false;
2679:                if (m_WithClass != cmp.m_WithClass)
2680:                    return false;
2681:
2682:                if (m_SumErr != cmp.m_SumErr)
2683:                    return false;
2684:                if (m_SumAbsErr != cmp.m_SumAbsErr)
2685:                    return false;
2686:                if (m_SumSqrErr != cmp.m_SumSqrErr)
2687:                    return false;
2688:                if (m_SumClass != cmp.m_SumClass)
2689:                    return false;
2690:                if (m_SumSqrClass != cmp.m_SumSqrClass)
2691:                    return false;
2692:                if (m_SumPredicted != cmp.m_SumPredicted)
2693:                    return false;
2694:                if (m_SumSqrPredicted != cmp.m_SumSqrPredicted)
2695:                    return false;
2696:                if (m_SumClassPredicted != cmp.m_SumClassPredicted)
2697:                    return false;
2698:
2699:                if (m_ClassIsNominal) {
2700:                    for (int i = 0; i < m_NumClasses; i++) {
2701:                        for (int j = 0; j < m_NumClasses; j++) {
2702:                            if (m_ConfusionMatrix[i][j] != cmp.m_ConfusionMatrix[i][j]) {
2703:                                return false;
2704:                            }
2705:                        }
2706:                    }
2707:                }
2708:
2709:                return true;
2710:            }
2711:
2712:            /**
2713:             * Prints the predictions for the given dataset into a String variable.
2714:             * 
2715:             * @param classifier		the classifier to use
2716:             * @param train		the training data
2717:             * @param testSource		the test set
2718:             * @param classIndex		the class index (1-based), if -1 ot does not 
2719:             * 				override the class index is stored in the data 
2720:             * 				file (by using the last attribute)
2721:             * @param attributesToOutput	the indices of the attributes to output
2722:             * @return			the generated predictions for the attribute range
2723:             * @throws Exception 		if test file cannot be opened
2724:             */
2725:            protected static String printClassifications(Classifier classifier,
2726:                    Instances train, DataSource testSource, int classIndex,
2727:                    Range attributesToOutput) throws Exception {
2728:
2729:                return printClassifications(classifier, train, testSource,
2730:                        classIndex, attributesToOutput, false);
2731:            }
2732:
2733:            /**
2734:             * Prints the predictions for the given dataset into a String variable.
2735:             * 
2736:             * @param classifier		the classifier to use
2737:             * @param train		the training data
2738:             * @param testSource		the test set
2739:             * @param classIndex		the class index (1-based), if -1 ot does not 
2740:             * 				override the class index is stored in the data 
2741:             * 				file (by using the last attribute)
2742:             * @param attributesToOutput	the indices of the attributes to output
2743:             * @param printDistribution	prints the complete distribution for nominal 
2744:             * 				classes, not just the predicted value
2745:             * @return			the generated predictions for the attribute range
2746:             * @throws Exception 		if test file cannot be opened
2747:             */
2748:            protected static String printClassifications(Classifier classifier,
2749:                    Instances train, DataSource testSource, int classIndex,
2750:                    Range attributesToOutput, boolean printDistribution)
2751:                    throws Exception {
2752:
2753:                StringBuffer text = new StringBuffer();
2754:                if (testSource != null) {
2755:                    Instances test = testSource.getStructure();
2756:                    if (classIndex != -1) {
2757:                        test.setClassIndex(classIndex - 1);
2758:                    } else {
2759:                        if (test.classIndex() == -1)
2760:                            test.setClassIndex(test.numAttributes() - 1);
2761:                    }
2762:
2763:                    // print header
2764:                    if (test.classAttribute().isNominal())
2765:                        if (printDistribution)
2766:                            text
2767:                                    .append(" inst#     actual  predicted error distribution");
2768:                        else
2769:                            text
2770:                                    .append(" inst#     actual  predicted error prediction");
2771:                    else
2772:                        text.append(" inst#     actual  predicted      error");
2773:                    if (attributesToOutput != null) {
2774:                        attributesToOutput.setUpper(test.numAttributes() - 1);
2775:                        text.append(" (");
2776:                        boolean first = true;
2777:                        for (int i = 0; i < test.numAttributes(); i++) {
2778:                            if (i == test.classIndex())
2779:                                continue;
2780:
2781:                            if (attributesToOutput.isInRange(i)) {
2782:                                if (!first)
2783:                                    text.append(",");
2784:                                text.append(test.attribute(i).name());
2785:                                first = false;
2786:                            }
2787:                        }
2788:                        text.append(")");
2789:                    }
2790:                    text.append("\n");
2791:
2792:                    // print predictions
2793:                    int i = 0;
2794:                    testSource.reset();
2795:                    while (testSource.hasMoreElements(test)) {
2796:                        Instance inst = testSource.nextElement(test);
2797:                        text.append(predictionText(classifier, inst, i,
2798:                                attributesToOutput, printDistribution));
2799:                        i++;
2800:                    }
2801:                }
2802:                return text.toString();
2803:            }
2804:
2805:            /**
2806:             * returns the prediction made by the classifier as a string
2807:             * 
2808:             * @param classifier		the classifier to use
2809:             * @param inst		the instance to generate text from
2810:             * @param instNum		the index in the dataset
2811:             * @param attributesToOutput	the indices of the attributes to output
2812:             * @param printDistribution	prints the complete distribution for nominal 
2813:             * 				classes, not just the predicted value
2814:             * @return			the generated text
2815:             * @throws Exception		if something goes wrong
2816:             * @see			#printClassifications(Classifier, Instances, String, int, Range, boolean)
2817:             */
2818:            protected static String predictionText(Classifier classifier,
2819:                    Instance inst, int instNum, Range attributesToOutput,
2820:                    boolean printDistribution) throws Exception {
2821:
2822:                StringBuffer result = new StringBuffer();
2823:                int width = 10;
2824:                int prec = 3;
2825:
2826:                Instance withMissing = (Instance) inst.copy();
2827:                withMissing.setDataset(inst.dataset());
2828:                double predValue = ((Classifier) classifier)
2829:                        .classifyInstance(withMissing);
2830:
2831:                // index
2832:                result.append(Utils.padLeft("" + (instNum + 1), 6));
2833:
2834:                if (inst.dataset().classAttribute().isNumeric()) {
2835:                    // actual
2836:                    if (inst.classIsMissing())
2837:                        result.append(" " + Utils.padLeft("?", width));
2838:                    else
2839:                        result.append(" "
2840:                                + Utils.doubleToString(inst.classValue(),
2841:                                        width, prec));
2842:                    // predicted
2843:                    if (Instance.isMissingValue(predValue))
2844:                        result.append(" " + Utils.padLeft("?", width));
2845:                    else
2846:                        result.append(" "
2847:                                + Utils.doubleToString(predValue, width, prec));
2848:                    // error
2849:                    if (Instance.isMissingValue(predValue)
2850:                            || inst.classIsMissing())
2851:                        result.append(" " + Utils.padLeft("?", width));
2852:                    else
2853:                        result.append(" "
2854:                                + Utils.doubleToString(predValue
2855:                                        - inst.classValue(), width, prec));
2856:                } else {
2857:                    // actual
2858:                    result.append(" "
2859:                            + Utils.padLeft(((int) inst.classValue() + 1) + ":"
2860:                                    + inst.toString(inst.classIndex()), width));
2861:                    // predicted
2862:                    if (Instance.isMissingValue(predValue))
2863:                        result.append(" " + Utils.padLeft("?", width));
2864:                    else
2865:                        result
2866:                                .append(" "
2867:                                        + Utils
2868:                                                .padLeft(
2869:                                                        ((int) predValue + 1)
2870:                                                                + ":"
2871:                                                                + inst
2872:                                                                        .dataset()
2873:                                                                        .classAttribute()
2874:                                                                        .value(
2875:                                                                                (int) predValue),
2876:                                                        width));
2877:                    // error?
2878:                    if ((int) predValue + 1 != (int) inst.classValue() + 1)
2879:                        result.append(" " + "  +  ");
2880:                    else
2881:                        result.append(" " + "     ");
2882:                    // prediction/distribution
2883:                    if (printDistribution) {
2884:                        if (Instance.isMissingValue(predValue)) {
2885:                            result.append(" " + "?");
2886:                        } else {
2887:                            result.append(" ");
2888:                            double[] dist = classifier
2889:                                    .distributionForInstance(withMissing);
2890:                            for (int n = 0; n < dist.length; n++) {
2891:                                if (n > 0)
2892:                                    result.append(",");
2893:                                if (n == (int) predValue)
2894:                                    result.append("*");
2895:                                result.append(Utils.doubleToString(dist[n],
2896:                                        prec));
2897:                            }
2898:                        }
2899:                    } else {
2900:                        if (Instance.isMissingValue(predValue))
2901:                            result.append(" " + "?");
2902:                        else
2903:                            result
2904:                                    .append(" "
2905:                                            + Utils
2906:                                                    .doubleToString(
2907:                                                            classifier
2908:                                                                    .distributionForInstance(withMissing)[(int) predValue],
2909:                                                            prec));
2910:                    }
2911:                }
2912:
2913:                // attributes
2914:                result
2915:                        .append(" "
2916:                                + attributeValuesString(withMissing,
2917:                                        attributesToOutput) + "\n");
2918:
2919:                return result.toString();
2920:            }
2921:
2922:            /**
2923:             * Builds a string listing the attribute values in a specified range of indices,
2924:             * separated by commas and enclosed in brackets.
2925:             *
2926:             * @param instance the instance to print the values from
2927:             * @param attRange the range of the attributes to list
2928:             * @return a string listing values of the attributes in the range
2929:             */
2930:            protected static String attributeValuesString(Instance instance,
2931:                    Range attRange) {
2932:                StringBuffer text = new StringBuffer();
2933:                if (attRange != null) {
2934:                    boolean firstOutput = true;
2935:                    attRange.setUpper(instance.numAttributes() - 1);
2936:                    for (int i = 0; i < instance.numAttributes(); i++)
2937:                        if (attRange.isInRange(i) && i != instance.classIndex()) {
2938:                            if (firstOutput)
2939:                                text.append("(");
2940:                            else
2941:                                text.append(",");
2942:                            text.append(instance.toString(i));
2943:                            firstOutput = false;
2944:                        }
2945:                    if (!firstOutput)
2946:                        text.append(")");
2947:                }
2948:                return text.toString();
2949:            }
2950:
2951:            /**
2952:             * Make up the help string giving all the command line options
2953:             *
2954:             * @param classifier the classifier to include options for
2955:             * @return a string detailing the valid command line options
2956:             */
2957:            protected static String makeOptionString(Classifier classifier) {
2958:
2959:                StringBuffer optionsText = new StringBuffer("");
2960:
2961:                // General options
2962:                optionsText.append("\n\nGeneral options:\n\n");
2963:                optionsText.append("-t <name of training file>\n");
2964:                optionsText.append("\tSets training file.\n");
2965:                optionsText.append("-T <name of test file>\n");
2966:                optionsText
2967:                        .append("\tSets test file. If missing, a cross-validation will be performed\n");
2968:                optionsText.append("\ton the training data.\n");
2969:                optionsText.append("-c <class index>\n");
2970:                optionsText
2971:                        .append("\tSets index of class attribute (default: last).\n");
2972:                optionsText.append("-x <number of folds>\n");
2973:                optionsText
2974:                        .append("\tSets number of folds for cross-validation (default: 10).\n");
2975:                optionsText.append("-no-cv\n");
2976:                optionsText.append("\tDo not perform any cross validation.\n");
2977:                optionsText.append("-split-percentage <percentage>\n");
2978:                optionsText
2979:                        .append("\tSets the percentage for the train/test set split, e.g., 66.\n");
2980:                optionsText.append("-preserve-order\n");
2981:                optionsText
2982:                        .append("\tPreserves the order in the percentage split.\n");
2983:                optionsText.append("-s <random number seed>\n");
2984:                optionsText
2985:                        .append("\tSets random number seed for cross-validation or percentage split\n");
2986:                optionsText.append("\t(default: 1).\n");
2987:                optionsText.append("-m <name of file with cost matrix>\n");
2988:                optionsText.append("\tSets file with cost matrix.\n");
2989:                optionsText.append("-l <name of input file>\n");
2990:                optionsText
2991:                        .append("\tSets model input file. In case the filename ends with '.xml',\n");
2992:                optionsText
2993:                        .append("\tthe options are loaded from the XML file.\n");
2994:                optionsText.append("-d <name of output file>\n");
2995:                optionsText
2996:                        .append("\tSets model output file. In case the filename ends with '.xml',\n");
2997:                optionsText
2998:                        .append("\tonly the options are saved to the XML file, not the model.\n");
2999:                optionsText.append("-v\n");
3000:                optionsText
3001:                        .append("\tOutputs no statistics for training data.\n");
3002:                optionsText.append("-o\n");
3003:                optionsText
3004:                        .append("\tOutputs statistics only, not the classifier.\n");
3005:                optionsText.append("-i\n");
3006:                optionsText.append("\tOutputs detailed information-retrieval");
3007:                optionsText.append(" statistics for each class.\n");
3008:                optionsText.append("-k\n");
3009:                optionsText
3010:                        .append("\tOutputs information-theoretic statistics.\n");
3011:                optionsText.append("-p <attribute range>\n");
3012:                optionsText
3013:                        .append("\tOnly outputs predictions for test instances (or the train\n"
3014:                                + "\tinstances if no test instances provided), along with attributes\n"
3015:                                + "\t(0 for none).\n");
3016:                optionsText.append("-distribution\n");
3017:                optionsText
3018:                        .append("\tOutputs the distribution instead of only the prediction\n");
3019:                optionsText
3020:                        .append("\tin conjunction with the '-p' option (only nominal classes).\n");
3021:                optionsText.append("-r\n");
3022:                optionsText
3023:                        .append("\tOnly outputs cumulative margin distribution.\n");
3024:                if (classifier instanceof  Sourcable) {
3025:                    optionsText.append("-z <class name>\n");
3026:                    optionsText
3027:                            .append("\tOnly outputs the source representation"
3028:                                    + " of the classifier,\n\tgiving it the supplied"
3029:                                    + " name.\n");
3030:                }
3031:                if (classifier instanceof  Drawable) {
3032:                    optionsText.append("-g\n");
3033:                    optionsText
3034:                            .append("\tOnly outputs the graph representation"
3035:                                    + " of the classifier.\n");
3036:                }
3037:                optionsText.append("-xml filename | xml-string\n");
3038:                optionsText
3039:                        .append("\tRetrieves the options from the XML-data instead of the "
3040:                                + "command line.\n");
3041:                optionsText.append("-threshold-file <file>\n");
3042:                optionsText
3043:                        .append("\tThe file to save the threshold data to.\n"
3044:                                + "\tThe format is determined by the extensions, e.g., '.arff' for ARFF \n"
3045:                                + "\tformat or '.csv' for CSV.\n");
3046:                optionsText.append("-threshold-label <label>\n");
3047:                optionsText
3048:                        .append("\tThe class label to determine the threshold data for\n"
3049:                                + "\t(default is the first label)\n");
3050:
3051:                // Get scheme-specific options
3052:                if (classifier instanceof  OptionHandler) {
3053:                    optionsText.append("\nOptions specific to "
3054:                            + classifier.getClass().getName() + ":\n\n");
3055:                    Enumeration enu = ((OptionHandler) classifier)
3056:                            .listOptions();
3057:                    while (enu.hasMoreElements()) {
3058:                        Option option = (Option) enu.nextElement();
3059:                        optionsText.append(option.synopsis() + '\n');
3060:                        optionsText.append(option.description() + "\n");
3061:                    }
3062:                }
3063:                return optionsText.toString();
3064:            }
3065:
3066:            /**
3067:             * Method for generating indices for the confusion matrix.
3068:             *
3069:             * @param num 	integer to format
3070:             * @param IDChars	the characters to use
3071:             * @param IDWidth	the width of the entry
3072:             * @return 		the formatted integer as a string
3073:             */
3074:            protected String num2ShortID(int num, char[] IDChars, int IDWidth) {
3075:
3076:                char ID[] = new char[IDWidth];
3077:                int i;
3078:
3079:                for (i = IDWidth - 1; i >= 0; i--) {
3080:                    ID[i] = IDChars[num % IDChars.length];
3081:                    num = num / IDChars.length - 1;
3082:                    if (num < 0) {
3083:                        break;
3084:                    }
3085:                }
3086:                for (i--; i >= 0; i--) {
3087:                    ID[i] = ' ';
3088:                }
3089:
3090:                return new String(ID);
3091:            }
3092:
3093:            /**
3094:             * Convert a single prediction into a probability distribution
3095:             * with all zero probabilities except the predicted value which
3096:             * has probability 1.0;
3097:             *
3098:             * @param predictedClass the index of the predicted class
3099:             * @return the probability distribution
3100:             */
3101:            protected double[] makeDistribution(double predictedClass) {
3102:
3103:                double[] result = new double[m_NumClasses];
3104:                if (Instance.isMissingValue(predictedClass)) {
3105:                    return result;
3106:                }
3107:                if (m_ClassIsNominal) {
3108:                    result[(int) predictedClass] = 1.0;
3109:                } else {
3110:                    result[0] = predictedClass;
3111:                }
3112:                return result;
3113:            }
3114:
3115:            /**
3116:             * Updates all the statistics about a classifiers performance for 
3117:             * the current test instance.
3118:             *
3119:             * @param predictedDistribution the probabilities assigned to 
3120:             * each class
3121:             * @param instance the instance to be classified
3122:             * @throws Exception if the class of the instance is not
3123:             * set
3124:             */
3125:            protected void updateStatsForClassifier(
3126:                    double[] predictedDistribution, Instance instance)
3127:                    throws Exception {
3128:
3129:                int actualClass = (int) instance.classValue();
3130:
3131:                if (!instance.classIsMissing()) {
3132:                    updateMargins(predictedDistribution, actualClass, instance
3133:                            .weight());
3134:
3135:                    // Determine the predicted class (doesn't detect multiple 
3136:                    // classifications)
3137:                    int predictedClass = -1;
3138:                    double bestProb = 0.0;
3139:                    for (int i = 0; i < m_NumClasses; i++) {
3140:                        if (predictedDistribution[i] > bestProb) {
3141:                            predictedClass = i;
3142:                            bestProb = predictedDistribution[i];
3143:                        }
3144:                    }
3145:
3146:                    m_WithClass += instance.weight();
3147:
3148:                    // Determine misclassification cost
3149:                    if (m_CostMatrix != null) {
3150:                        if (predictedClass < 0) {
3151:                            // For missing predictions, we assume the worst possible cost.
3152:                            // This is pretty harsh.
3153:                            // Perhaps we could take the negative of the cost of a correct
3154:                            // prediction (-m_CostMatrix.getElement(actualClass,actualClass)),
3155:                            // although often this will be zero
3156:                            m_TotalCost += instance.weight()
3157:                                    * m_CostMatrix.getMaxCost(actualClass,
3158:                                            instance);
3159:                        } else {
3160:                            m_TotalCost += instance.weight()
3161:                                    * m_CostMatrix.getElement(actualClass,
3162:                                            predictedClass, instance);
3163:                        }
3164:                    }
3165:
3166:                    // Update counts when no class was predicted
3167:                    if (predictedClass < 0) {
3168:                        m_Unclassified += instance.weight();
3169:                        return;
3170:                    }
3171:
3172:                    double predictedProb = Math.max(MIN_SF_PROB,
3173:                            predictedDistribution[actualClass]);
3174:                    double priorProb = Math.max(MIN_SF_PROB,
3175:                            m_ClassPriors[actualClass] / m_ClassPriorsSum);
3176:                    if (predictedProb >= priorProb) {
3177:                        m_SumKBInfo += (Utils.log2(predictedProb) - Utils
3178:                                .log2(priorProb))
3179:                                * instance.weight();
3180:                    } else {
3181:                        m_SumKBInfo -= (Utils.log2(1.0 - predictedProb) - Utils
3182:                                .log2(1.0 - priorProb))
3183:                                * instance.weight();
3184:                    }
3185:
3186:                    m_SumSchemeEntropy -= Utils.log2(predictedProb)
3187:                            * instance.weight();
3188:                    m_SumPriorEntropy -= Utils.log2(priorProb)
3189:                            * instance.weight();
3190:
3191:                    updateNumericScores(predictedDistribution,
3192:                            makeDistribution(instance.classValue()), instance
3193:                                    .weight());
3194:
3195:                    // Update other stats
3196:                    m_ConfusionMatrix[actualClass][predictedClass] += instance
3197:                            .weight();
3198:                    if (predictedClass != actualClass) {
3199:                        m_Incorrect += instance.weight();
3200:                    } else {
3201:                        m_Correct += instance.weight();
3202:                    }
3203:                } else {
3204:                    m_MissingClass += instance.weight();
3205:                }
3206:            }
3207:
3208:            /**
3209:             * Updates all the statistics about a predictors performance for 
3210:             * the current test instance.
3211:             *
3212:             * @param predictedValue the numeric value the classifier predicts
3213:             * @param instance the instance to be classified
3214:             * @throws Exception if the class of the instance is not
3215:             * set
3216:             */
3217:            protected void updateStatsForPredictor(double predictedValue,
3218:                    Instance instance) throws Exception {
3219:
3220:                if (!instance.classIsMissing()) {
3221:
3222:                    // Update stats
3223:                    m_WithClass += instance.weight();
3224:                    if (Instance.isMissingValue(predictedValue)) {
3225:                        m_Unclassified += instance.weight();
3226:                        return;
3227:                    }
3228:                    m_SumClass += instance.weight() * instance.classValue();
3229:                    m_SumSqrClass += instance.weight() * instance.classValue()
3230:                            * instance.classValue();
3231:                    m_SumClassPredicted += instance.weight()
3232:                            * instance.classValue() * predictedValue;
3233:                    m_SumPredicted += instance.weight() * predictedValue;
3234:                    m_SumSqrPredicted += instance.weight() * predictedValue
3235:                            * predictedValue;
3236:
3237:                    if (m_ErrorEstimator == null) {
3238:                        setNumericPriorsFromBuffer();
3239:                    }
3240:                    double predictedProb = Math.max(m_ErrorEstimator
3241:                            .getProbability(predictedValue
3242:                                    - instance.classValue()), MIN_SF_PROB);
3243:                    double priorProb = Math
3244:                            .max(m_PriorErrorEstimator.getProbability(instance
3245:                                    .classValue()), MIN_SF_PROB);
3246:
3247:                    m_SumSchemeEntropy -= Utils.log2(predictedProb)
3248:                            * instance.weight();
3249:                    m_SumPriorEntropy -= Utils.log2(priorProb)
3250:                            * instance.weight();
3251:                    m_ErrorEstimator.addValue(predictedValue
3252:                            - instance.classValue(), instance.weight());
3253:
3254:                    updateNumericScores(makeDistribution(predictedValue),
3255:                            makeDistribution(instance.classValue()), instance
3256:                                    .weight());
3257:
3258:                } else
3259:                    m_MissingClass += instance.weight();
3260:            }
3261:
3262:            /**
3263:             * Update the cumulative record of classification margins
3264:             *
3265:             * @param predictedDistribution the probability distribution predicted for
3266:             * the current instance
3267:             * @param actualClass the index of the actual instance class
3268:             * @param weight the weight assigned to the instance
3269:             */
3270:            protected void updateMargins(double[] predictedDistribution,
3271:                    int actualClass, double weight) {
3272:
3273:                double probActual = predictedDistribution[actualClass];
3274:                double probNext = 0;
3275:
3276:                for (int i = 0; i < m_NumClasses; i++)
3277:                    if ((i != actualClass)
3278:                            && (predictedDistribution[i] > probNext))
3279:                        probNext = predictedDistribution[i];
3280:
3281:                double margin = probActual - probNext;
3282:                int bin = (int) ((margin + 1.0) / 2.0 * k_MarginResolution);
3283:                m_MarginCounts[bin] += weight;
3284:            }
3285:
3286:            /**
3287:             * Update the numeric accuracy measures. For numeric classes, the
3288:             * accuracy is between the actual and predicted class values. For 
3289:             * nominal classes, the accuracy is between the actual and 
3290:             * predicted class probabilities.
3291:             *
3292:             * @param predicted the predicted values
3293:             * @param actual the actual value
3294:             * @param weight the weight associated with this prediction
3295:             */
3296:            protected void updateNumericScores(double[] predicted,
3297:                    double[] actual, double weight) {
3298:
3299:                double diff;
3300:                double sumErr = 0, sumAbsErr = 0, sumSqrErr = 0;
3301:                double sumPriorAbsErr = 0, sumPriorSqrErr = 0;
3302:                for (int i = 0; i < m_NumClasses; i++) {
3303:                    diff = predicted[i] - actual[i];
3304:                    sumErr += diff;
3305:                    sumAbsErr += Math.abs(diff);
3306:                    sumSqrErr += diff * diff;
3307:                    diff = (m_ClassPriors[i] / m_ClassPriorsSum) - actual[i];
3308:                    sumPriorAbsErr += Math.abs(diff);
3309:                    sumPriorSqrErr += diff * diff;
3310:                }
3311:                m_SumErr += weight * sumErr / m_NumClasses;
3312:                m_SumAbsErr += weight * sumAbsErr / m_NumClasses;
3313:                m_SumSqrErr += weight * sumSqrErr / m_NumClasses;
3314:                m_SumPriorAbsErr += weight * sumPriorAbsErr / m_NumClasses;
3315:                m_SumPriorSqrErr += weight * sumPriorSqrErr / m_NumClasses;
3316:            }
3317:
3318:            /**
3319:             * Adds a numeric (non-missing) training class value and weight to 
3320:             * the buffer of stored values.
3321:             *
3322:             * @param classValue the class value
3323:             * @param weight the instance weight
3324:             */
3325:            protected void addNumericTrainClass(double classValue, double weight) {
3326:
3327:                if (m_TrainClassVals == null) {
3328:                    m_TrainClassVals = new double[100];
3329:                    m_TrainClassWeights = new double[100];
3330:                }
3331:                if (m_NumTrainClassVals == m_TrainClassVals.length) {
3332:                    double[] temp = new double[m_TrainClassVals.length * 2];
3333:                    System.arraycopy(m_TrainClassVals, 0, temp, 0,
3334:                            m_TrainClassVals.length);
3335:                    m_TrainClassVals = temp;
3336:
3337:                    temp = new double[m_TrainClassWeights.length * 2];
3338:                    System.arraycopy(m_TrainClassWeights, 0, temp, 0,
3339:                            m_TrainClassWeights.length);
3340:                    m_TrainClassWeights = temp;
3341:                }
3342:                m_TrainClassVals[m_NumTrainClassVals] = classValue;
3343:                m_TrainClassWeights[m_NumTrainClassVals] = weight;
3344:                m_NumTrainClassVals++;
3345:            }
3346:
3347:            /**
3348:             * Sets up the priors for numeric class attributes from the 
3349:             * training class values that have been seen so far.
3350:             */
3351:            protected void setNumericPriorsFromBuffer() {
3352:
3353:                double numPrecision = 0.01; // Default value
3354:                if (m_NumTrainClassVals > 1) {
3355:                    double[] temp = new double[m_NumTrainClassVals];
3356:                    System.arraycopy(m_TrainClassVals, 0, temp, 0,
3357:                            m_NumTrainClassVals);
3358:                    int[] index = Utils.sort(temp);
3359:                    double lastVal = temp[index[0]];
3360:                    double deltaSum = 0;
3361:                    int distinct = 0;
3362:                    for (int i = 1; i < temp.length; i++) {
3363:                        double current = temp[index[i]];
3364:                        if (current != lastVal) {
3365:                            deltaSum += current - lastVal;
3366:                            lastVal = current;
3367:                            distinct++;
3368:                        }
3369:                    }
3370:                    if (distinct > 0) {
3371:                        numPrecision = deltaSum / distinct;
3372:                    }
3373:                }
3374:                m_PriorErrorEstimator = new KernelEstimator(numPrecision);
3375:                m_ErrorEstimator = new KernelEstimator(numPrecision);
3376:                m_ClassPriors[0] = m_ClassPriorsSum = 0;
3377:                for (int i = 0; i < m_NumTrainClassVals; i++) {
3378:                    m_ClassPriors[0] += m_TrainClassVals[i]
3379:                            * m_TrainClassWeights[i];
3380:                    m_ClassPriorsSum += m_TrainClassWeights[i];
3381:                    m_PriorErrorEstimator.addValue(m_TrainClassVals[i],
3382:                            m_TrainClassWeights[i]);
3383:                }
3384:            }
3385:        }
www.java2java.com | Contact Us
Copyright 2009 - 12 Demo Source and Support. All rights reserved.
All other trademarks are property of their respective owners.