0001: /*
0002: * This program is free software; you can redistribute it and/or modify
0003: * it under the terms of the GNU General Public License as published by
0004: * the Free Software Foundation; either version 2 of the License, or
0005: * (at your option) any later version.
0006: *
0007: * This program is distributed in the hope that it will be useful,
0008: * but WITHOUT ANY WARRANTY; without even the implied warranty of
0009: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
0010: * GNU General Public License for more details.
0011: *
0012: * You should have received a copy of the GNU General Public License
0013: * along with this program; if not, write to the Free Software
0014: * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
0015: */
0016:
0017: /*
0018: * BayesNet.java
0019: * Copyright (C) 2001 University of Waikato, Hamilton, New Zealand
0020: *
0021: */
0022: package weka.classifiers.bayes;
0023:
0024: import weka.classifiers.Classifier;
0025: import weka.classifiers.bayes.net.ADNode;
0026: import weka.classifiers.bayes.net.BIFReader;
0027: import weka.classifiers.bayes.net.ParentSet;
0028: import weka.classifiers.bayes.net.estimate.BayesNetEstimator;
0029: import weka.classifiers.bayes.net.estimate.DiscreteEstimatorBayes;
0030: import weka.classifiers.bayes.net.estimate.SimpleEstimator;
0031: import weka.classifiers.bayes.net.search.SearchAlgorithm;
0032: import weka.classifiers.bayes.net.search.local.K2;
0033: import weka.classifiers.bayes.net.search.local.LocalScoreSearchAlgorithm;
0034: import weka.classifiers.bayes.net.search.local.Scoreable;
0035: import weka.core.AdditionalMeasureProducer;
0036: import weka.core.Attribute;
0037: import weka.core.Capabilities;
0038: import weka.core.Drawable;
0039: import weka.core.Instance;
0040: import weka.core.Instances;
0041: import weka.core.Option;
0042: import weka.core.OptionHandler;
0043: import weka.core.Utils;
0044: import weka.core.WeightedInstancesHandler;
0045: import weka.core.Capabilities.Capability;
0046: import weka.estimators.Estimator;
0047: import weka.filters.Filter;
0048: import weka.filters.supervised.attribute.Discretize;
0049: import weka.filters.unsupervised.attribute.ReplaceMissingValues;
0050:
0051: import java.util.Enumeration;
0052: import java.util.Vector;
0053:
0054: /**
0055: <!-- globalinfo-start -->
0056: * Bayes Network learning using various search algorithms and quality measures.<br/>
0057: * Base class for a Bayes Network classifier. Provides datastructures (network structure, conditional probability distributions, etc.) and facilities common to Bayes Network learning algorithms like K2 and B.<br/>
0058: * <br/>
0059: * For more information see:<br/>
0060: * <br/>
0061: * http://www.cs.waikato.ac.nz/~remco/weka.pdf
0062: * <p/>
0063: <!-- globalinfo-end -->
0064: *
0065: <!-- options-start -->
0066: * Valid options are: <p/>
0067: *
0068: * <pre> -D
0069: * Do not use ADTree data structure
0070: * </pre>
0071: *
0072: * <pre> -B <BIF file>
0073: * BIF file to compare with
0074: * </pre>
0075: *
0076: * <pre> -Q weka.classifiers.bayes.net.search.SearchAlgorithm
0077: * Search algorithm
0078: * </pre>
0079: *
0080: * <pre> -E weka.classifiers.bayes.net.estimate.SimpleEstimator
0081: * Estimator algorithm
0082: * </pre>
0083: *
0084: <!-- options-end -->
0085: *
0086: * @author Remco Bouckaert (rrb@xm.co.nz)
0087: * @version $Revision: 1.30 $
0088: */
0089: public class BayesNet extends Classifier implements OptionHandler,
0090: WeightedInstancesHandler, Drawable, AdditionalMeasureProducer {
0091:
0092: /** for serialization */
0093: static final long serialVersionUID = 746037443258775954L;
0094:
0095: /**
0096: * The parent sets.
0097: */
0098: protected ParentSet[] m_ParentSets;
0099:
0100: /**
0101: * The attribute estimators containing CPTs.
0102: */
0103: public Estimator[][] m_Distributions;
0104:
0105: /** filter used to quantize continuous variables, if any **/
0106: Discretize m_DiscretizeFilter = null;
0107:
0108: /** attribute index of a non-nominal attribute */
0109: int m_nNonDiscreteAttribute = -1;
0110:
0111: /** filter used to fill in missing values, if any **/
0112: ReplaceMissingValues m_MissingValuesFilter = null;
0113:
0114: /**
0115: * The number of classes
0116: */
0117: protected int m_NumClasses;
0118:
0119: /**
0120: * The dataset header for the purposes of printing out a semi-intelligible
0121: * model
0122: */
0123: public Instances m_Instances;
0124:
0125: /**
0126: * Datastructure containing ADTree representation of the database.
0127: * This may result in more efficient access to the data.
0128: */
0129: ADNode m_ADTree;
0130:
0131: /**
0132: * Bayes network to compare the structure with.
0133: */
0134: protected BIFReader m_otherBayesNet = null;
0135:
0136: /**
0137: * Use the experimental ADTree datastructure for calculating contingency tables
0138: */
0139: boolean m_bUseADTree = false;
0140:
0141: /**
0142: * Search algorithm used for learning the structure of a network.
0143: */
0144: SearchAlgorithm m_SearchAlgorithm = new K2();
0145:
0146: /**
0147: * Search algorithm used for learning the structure of a network.
0148: */
0149: BayesNetEstimator m_BayesNetEstimator = new SimpleEstimator();
0150:
0151: /**
0152: * Returns default capabilities of the classifier.
0153: *
0154: * @return the capabilities of this classifier
0155: */
0156: public Capabilities getCapabilities() {
0157: Capabilities result = super .getCapabilities();
0158:
0159: // attributes
0160: result.enable(Capability.NOMINAL_ATTRIBUTES);
0161: result.enable(Capability.NUMERIC_ATTRIBUTES);
0162: result.enable(Capability.MISSING_VALUES);
0163:
0164: // class
0165: result.enable(Capability.NOMINAL_CLASS);
0166: result.enable(Capability.MISSING_CLASS_VALUES);
0167:
0168: // instances
0169: result.setMinimumNumberInstances(0);
0170:
0171: return result;
0172: }
0173:
0174: /**
0175: * Generates the classifier.
0176: *
0177: * @param instances set of instances serving as training data
0178: * @throws Exception if the classifier has not been generated
0179: * successfully
0180: */
0181: public void buildClassifier(Instances instances) throws Exception {
0182:
0183: // can classifier handle the data?
0184: getCapabilities().testWithFail(instances);
0185:
0186: // remove instances with missing class
0187: instances = new Instances(instances);
0188: instances.deleteWithMissingClass();
0189:
0190: // ensure we have a data set with discrete variables only and with no missing values
0191: instances = normalizeDataSet(instances);
0192:
0193: // Copy the instances
0194: m_Instances = new Instances(instances);
0195:
0196: // sanity check: need more than 1 variable in datat set
0197: m_NumClasses = instances.numClasses();
0198:
0199: // initialize ADTree
0200: if (m_bUseADTree) {
0201: m_ADTree = ADNode.makeADTree(instances);
0202: // System.out.println("Oef, done!");
0203: }
0204:
0205: // build the network structure
0206: initStructure();
0207:
0208: // build the network structure
0209: buildStructure();
0210:
0211: // build the set of CPTs
0212: estimateCPTs();
0213:
0214: // Save space
0215: // m_Instances = new Instances(m_Instances, 0);
0216: m_ADTree = null;
0217: } // buildClassifier
0218:
0219: /** ensure that all variables are nominal and that there are no missing values
0220: * @param instances data set to check and quantize and/or fill in missing values
0221: * @return filtered instances
0222: * @throws Exception if a filter (Discretize, ReplaceMissingValues) fails
0223: */
0224: Instances normalizeDataSet(Instances instances) throws Exception {
0225: m_DiscretizeFilter = null;
0226: m_MissingValuesFilter = null;
0227:
0228: boolean bHasNonNominal = false;
0229: boolean bHasMissingValues = false;
0230:
0231: Enumeration enu = instances.enumerateAttributes();
0232: while (enu.hasMoreElements()) {
0233: Attribute attribute = (Attribute) enu.nextElement();
0234: if (attribute.type() != Attribute.NOMINAL) {
0235: m_nNonDiscreteAttribute = attribute.index();
0236: bHasNonNominal = true;
0237: //throw new UnsupportedAttributeTypeException("BayesNet handles nominal variables only. Non-nominal variable in dataset detected.");
0238: }
0239: Enumeration enum2 = instances.enumerateInstances();
0240: while (enum2.hasMoreElements()) {
0241: if (((Instance) enum2.nextElement())
0242: .isMissing(attribute)) {
0243: bHasMissingValues = true;
0244: // throw new NoSupportForMissingValuesException("BayesNet: no missing values, please.");
0245: }
0246: }
0247: }
0248:
0249: if (bHasNonNominal) {
0250: System.err.println("Warning: discretizing data set");
0251: m_DiscretizeFilter = new Discretize();
0252: m_DiscretizeFilter.setInputFormat(instances);
0253: instances = Filter.useFilter(instances, m_DiscretizeFilter);
0254: }
0255:
0256: if (bHasMissingValues) {
0257: System.err
0258: .println("Warning: filling in missing values in data set");
0259: m_MissingValuesFilter = new ReplaceMissingValues();
0260: m_MissingValuesFilter.setInputFormat(instances);
0261: instances = Filter.useFilter(instances,
0262: m_MissingValuesFilter);
0263: }
0264: return instances;
0265: } // normalizeDataSet
0266:
0267: /** ensure that all variables are nominal and that there are no missing values
0268: * @param instance instance to check and quantize and/or fill in missing values
0269: * @return filtered instance
0270: * @throws Exception if a filter (Discretize, ReplaceMissingValues) fails
0271: */
0272: Instance normalizeInstance(Instance instance) throws Exception {
0273: if ((m_DiscretizeFilter != null)
0274: && (instance.attribute(m_nNonDiscreteAttribute).type() != Attribute.NOMINAL)) {
0275: m_DiscretizeFilter.input(instance);
0276: instance = m_DiscretizeFilter.output();
0277: }
0278: if (m_MissingValuesFilter != null) {
0279: m_MissingValuesFilter.input(instance);
0280: instance = m_MissingValuesFilter.output();
0281: } else {
0282: // is there a missing value in this instance?
0283: // this can happen when there is no missing value in the training set
0284: for (int iAttribute = 0; iAttribute < m_Instances
0285: .numAttributes(); iAttribute++) {
0286: if (iAttribute != instance.classIndex()
0287: && instance.isMissing(iAttribute)) {
0288: System.err
0289: .println("Warning: Found missing value in test set, filling in values.");
0290: m_MissingValuesFilter = new ReplaceMissingValues();
0291: m_MissingValuesFilter.setInputFormat(m_Instances);
0292: Filter
0293: .useFilter(m_Instances,
0294: m_MissingValuesFilter);
0295: m_MissingValuesFilter.input(instance);
0296: instance = m_MissingValuesFilter.output();
0297: iAttribute = m_Instances.numAttributes();
0298: }
0299: }
0300: }
0301: return instance;
0302: } // normalizeInstance
0303:
0304: /**
0305: * Init structure initializes the structure to an empty graph or a Naive Bayes
0306: * graph (depending on the -N flag).
0307: *
0308: * @throws Exception in case of an error
0309: */
0310: public void initStructure() throws Exception {
0311:
0312: // initialize topological ordering
0313: // m_nOrder = new int[m_Instances.numAttributes()];
0314: // m_nOrder[0] = m_Instances.classIndex();
0315:
0316: int nAttribute = 0;
0317:
0318: for (int iOrder = 1; iOrder < m_Instances.numAttributes(); iOrder++) {
0319: if (nAttribute == m_Instances.classIndex()) {
0320: nAttribute++;
0321: }
0322:
0323: // m_nOrder[iOrder] = nAttribute++;
0324: }
0325:
0326: // reserve memory
0327: m_ParentSets = new ParentSet[m_Instances.numAttributes()];
0328:
0329: for (int iAttribute = 0; iAttribute < m_Instances
0330: .numAttributes(); iAttribute++) {
0331: m_ParentSets[iAttribute] = new ParentSet(m_Instances
0332: .numAttributes());
0333: }
0334: } // initStructure
0335:
0336: /**
0337: * buildStructure determines the network structure/graph of the network.
0338: * The default behavior is creating a network where all nodes have the first
0339: * node as its parent (i.e., a BayesNet that behaves like a naive Bayes classifier).
0340: * This method can be overridden by derived classes to restrict the class
0341: * of network structures that are acceptable.
0342: *
0343: * @throws Exception in case of an error
0344: */
0345: public void buildStructure() throws Exception {
0346: m_SearchAlgorithm.buildStructure(this , m_Instances);
0347: } // buildStructure
0348:
0349: /**
0350: * estimateCPTs estimates the conditional probability tables for the Bayes
0351: * Net using the network structure.
0352: *
0353: * @throws Exception in case of an error
0354: */
0355: public void estimateCPTs() throws Exception {
0356: m_BayesNetEstimator.estimateCPTs(this );
0357: } // estimateCPTs
0358:
0359: /**
0360: * initializes the conditional probabilities
0361: *
0362: * @throws Exception in case of an error
0363: */
0364: public void initCPTs() throws Exception {
0365: m_BayesNetEstimator.initCPTs(this );
0366: } // estimateCPTs
0367:
0368: /**
0369: * Updates the classifier with the given instance.
0370: *
0371: * @param instance the new training instance to include in the model
0372: * @throws Exception if the instance could not be incorporated in
0373: * the model.
0374: */
0375: public void updateClassifier(Instance instance) throws Exception {
0376: instance = normalizeInstance(instance);
0377: m_BayesNetEstimator.updateClassifier(this , instance);
0378: } // updateClassifier
0379:
0380: /**
0381: * Calculates the class membership probabilities for the given test
0382: * instance.
0383: *
0384: * @param instance the instance to be classified
0385: * @return predicted class probability distribution
0386: * @throws Exception if there is a problem generating the prediction
0387: */
0388: public double[] distributionForInstance(Instance instance)
0389: throws Exception {
0390: instance = normalizeInstance(instance);
0391: return m_BayesNetEstimator.distributionForInstance(this ,
0392: instance);
0393: } // distributionForInstance
0394:
0395: /**
0396: * Calculates the counts for Dirichlet distribution for the
0397: * class membership probabilities for the given test instance.
0398: *
0399: * @param instance the instance to be classified
0400: * @return counts for Dirichlet distribution for class probability
0401: * @throws Exception if there is a problem generating the prediction
0402: */
0403: public double[] countsForInstance(Instance instance)
0404: throws Exception {
0405: double[] fCounts = new double[m_NumClasses];
0406:
0407: for (int iClass = 0; iClass < m_NumClasses; iClass++) {
0408: fCounts[iClass] = 0.0;
0409: }
0410:
0411: for (int iClass = 0; iClass < m_NumClasses; iClass++) {
0412: double fCount = 0;
0413:
0414: for (int iAttribute = 0; iAttribute < m_Instances
0415: .numAttributes(); iAttribute++) {
0416: double iCPT = 0;
0417:
0418: for (int iParent = 0; iParent < m_ParentSets[iAttribute]
0419: .getNrOfParents(); iParent++) {
0420: int nParent = m_ParentSets[iAttribute]
0421: .getParent(iParent);
0422:
0423: if (nParent == m_Instances.classIndex()) {
0424: iCPT = iCPT * m_NumClasses + iClass;
0425: } else {
0426: iCPT = iCPT
0427: * m_Instances.attribute(nParent)
0428: .numValues()
0429: + instance.value(nParent);
0430: }
0431: }
0432:
0433: if (iAttribute == m_Instances.classIndex()) {
0434: fCount += ((DiscreteEstimatorBayes) m_Distributions[iAttribute][(int) iCPT])
0435: .getCount(iClass);
0436: } else {
0437: fCount += ((DiscreteEstimatorBayes) m_Distributions[iAttribute][(int) iCPT])
0438: .getCount(instance.value(iAttribute));
0439: }
0440: }
0441:
0442: fCounts[iClass] += fCount;
0443: }
0444: return fCounts;
0445: } // countsForInstance
0446:
0447: /**
0448: * Returns an enumeration describing the available options
0449: *
0450: * @return an enumeration of all the available options
0451: */
0452: public Enumeration listOptions() {
0453: Vector newVector = new Vector(4);
0454:
0455: newVector.addElement(new Option(
0456: "\tDo not use ADTree data structure\n", "D", 0, "-D"));
0457: newVector.addElement(new Option("\tBIF file to compare with\n",
0458: "B", 1, "-B <BIF file>"));
0459: newVector
0460: .addElement(new Option("\tSearch algorithm\n", "Q", 1,
0461: "-Q weka.classifiers.bayes.net.search.SearchAlgorithm"));
0462: newVector
0463: .addElement(new Option("\tEstimator algorithm\n", "E",
0464: 1,
0465: "-E weka.classifiers.bayes.net.estimate.SimpleEstimator"));
0466:
0467: return newVector.elements();
0468: } // listOptions
0469:
0470: /**
0471: * Parses a given list of options. <p>
0472: *
0473: <!-- options-start -->
0474: * Valid options are: <p/>
0475: *
0476: * <pre> -D
0477: * Do not use ADTree data structure
0478: * </pre>
0479: *
0480: * <pre> -B <BIF file>
0481: * BIF file to compare with
0482: * </pre>
0483: *
0484: * <pre> -Q weka.classifiers.bayes.net.search.SearchAlgorithm
0485: * Search algorithm
0486: * </pre>
0487: *
0488: * <pre> -E weka.classifiers.bayes.net.estimate.SimpleEstimator
0489: * Estimator algorithm
0490: * </pre>
0491: *
0492: <!-- options-end -->
0493: *
0494: * @param options the list of options as an array of strings
0495: * @throws Exception if an option is not supported
0496: */
0497: public void setOptions(String[] options) throws Exception {
0498: m_bUseADTree = !(Utils.getFlag('D', options));
0499:
0500: String sBIFFile = Utils.getOption('B', options);
0501: if (sBIFFile != null && !sBIFFile.equals("")) {
0502: setBIFFile(sBIFFile);
0503: }
0504:
0505: String searchAlgorithmName = Utils.getOption('Q', options);
0506: if (searchAlgorithmName.length() != 0) {
0507: setSearchAlgorithm((SearchAlgorithm) Utils.forName(
0508: SearchAlgorithm.class, searchAlgorithmName,
0509: partitionOptions(options)));
0510: } else {
0511: setSearchAlgorithm(new K2());
0512: }
0513:
0514: String estimatorName = Utils.getOption('E', options);
0515: if (estimatorName.length() != 0) {
0516: setEstimator((BayesNetEstimator) Utils.forName(
0517: BayesNetEstimator.class, estimatorName, Utils
0518: .partitionOptions(options)));
0519: } else {
0520: setEstimator(new SimpleEstimator());
0521: }
0522:
0523: Utils.checkForRemainingOptions(options);
0524: } // setOptions
0525:
0526: /**
0527: * Returns the secondary set of options (if any) contained in
0528: * the supplied options array. The secondary set is defined to
0529: * be any options after the first "--" but before the "-E". These
0530: * options are removed from the original options array.
0531: *
0532: * @param options the input array of options
0533: * @return the array of secondary options
0534: */
0535: public static String[] partitionOptions(String[] options) {
0536:
0537: for (int i = 0; i < options.length; i++) {
0538: if (options[i].equals("--")) {
0539: // ensure it follows by a -E option
0540: int j = i;
0541: while ((j < options.length)
0542: && !(options[j].equals("-E"))) {
0543: j++;
0544: }
0545: if (j >= options.length) {
0546: return new String[0];
0547: }
0548: options[i++] = "";
0549: String[] result = new String[options.length - i];
0550: j = i;
0551: while ((j < options.length)
0552: && !(options[j].equals("-E"))) {
0553: result[j - i] = options[j];
0554: options[j] = "";
0555: j++;
0556: }
0557: while (j < options.length) {
0558: result[j - i] = "";
0559: j++;
0560: }
0561: return result;
0562: }
0563: }
0564: return new String[0];
0565: }
0566:
0567: /**
0568: * Gets the current settings of the classifier.
0569: *
0570: * @return an array of strings suitable for passing to setOptions
0571: */
0572: public String[] getOptions() {
0573: String[] searchOptions = m_SearchAlgorithm.getOptions();
0574: String[] estimatorOptions = m_BayesNetEstimator.getOptions();
0575: String[] options = new String[11 + searchOptions.length
0576: + estimatorOptions.length];
0577: int current = 0;
0578:
0579: if (!m_bUseADTree) {
0580: options[current++] = "-D";
0581: }
0582:
0583: if (m_otherBayesNet != null) {
0584: options[current++] = "-B";
0585: options[current++] = ((BIFReader) m_otherBayesNet)
0586: .getFileName();
0587: }
0588:
0589: options[current++] = "-Q";
0590: options[current++] = ""
0591: + getSearchAlgorithm().getClass().getName();
0592: options[current++] = "--";
0593: for (int iOption = 0; iOption < searchOptions.length; iOption++) {
0594: options[current++] = searchOptions[iOption];
0595: }
0596:
0597: options[current++] = "-E";
0598: options[current++] = "" + getEstimator().getClass().getName();
0599: options[current++] = "--";
0600: for (int iOption = 0; iOption < estimatorOptions.length; iOption++) {
0601: options[current++] = estimatorOptions[iOption];
0602: }
0603:
0604: // Fill up rest with empty strings, not nulls!
0605: while (current < options.length) {
0606: options[current++] = "";
0607: }
0608:
0609: return options;
0610: } // getOptions
0611:
0612: /**
0613: * Set the SearchAlgorithm used in searching for network structures.
0614: * @param newSearchAlgorithm the SearchAlgorithm to use.
0615: */
0616: public void setSearchAlgorithm(SearchAlgorithm newSearchAlgorithm) {
0617: m_SearchAlgorithm = newSearchAlgorithm;
0618: }
0619:
0620: /**
0621: * Get the SearchAlgorithm used as the search algorithm
0622: * @return the SearchAlgorithm used as the search algorithm
0623: */
0624: public SearchAlgorithm getSearchAlgorithm() {
0625: return m_SearchAlgorithm;
0626: }
0627:
0628: /**
0629: * Set the Estimator Algorithm used in calculating the CPTs
0630: * @param newBayesNetEstimator the Estimator to use.
0631: */
0632: public void setEstimator(BayesNetEstimator newBayesNetEstimator) {
0633: m_BayesNetEstimator = newBayesNetEstimator;
0634: }
0635:
0636: /**
0637: * Get the BayesNetEstimator used for calculating the CPTs
0638: * @return the BayesNetEstimator used.
0639: */
0640: public BayesNetEstimator getEstimator() {
0641: return m_BayesNetEstimator;
0642: }
0643:
0644: /**
0645: * Set whether ADTree structure is used or not
0646: * @param bUseADTree true if an ADTree structure is used
0647: */
0648: public void setUseADTree(boolean bUseADTree) {
0649: m_bUseADTree = bUseADTree;
0650: }
0651:
0652: /**
0653: * Method declaration
0654: * @return whether ADTree structure is used or not
0655: */
0656: public boolean getUseADTree() {
0657: return m_bUseADTree;
0658: }
0659:
0660: /**
0661: * Set name of network in BIF file to compare with
0662: * @param sBIFFile the name of the BIF file
0663: */
0664: public void setBIFFile(String sBIFFile) {
0665: try {
0666: m_otherBayesNet = new BIFReader().processFile(sBIFFile);
0667: } catch (Throwable t) {
0668: m_otherBayesNet = null;
0669: }
0670: }
0671:
0672: /**
0673: * Get name of network in BIF file to compare with
0674: * @return BIF file name
0675: */
0676: public String getBIFFile() {
0677: if (m_otherBayesNet != null) {
0678: return m_otherBayesNet.getFileName();
0679: }
0680: return "";
0681: }
0682:
0683: /**
0684: * Returns a description of the classifier.
0685: *
0686: * @return a description of the classifier as a string.
0687: */
0688: public String toString() {
0689: StringBuffer text = new StringBuffer();
0690:
0691: text.append("Bayes Network Classifier");
0692: text.append("\n" + (m_bUseADTree ? "Using " : "not using ")
0693: + "ADTree");
0694:
0695: if (m_Instances == null) {
0696: text.append(": No model built yet.");
0697: } else {
0698:
0699: // flatten BayesNet down to text
0700: text.append("\n#attributes=");
0701: text.append(m_Instances.numAttributes());
0702: text.append(" #classindex=");
0703: text.append(m_Instances.classIndex());
0704: text
0705: .append("\nNetwork structure (nodes followed by parents)\n");
0706:
0707: for (int iAttribute = 0; iAttribute < m_Instances
0708: .numAttributes(); iAttribute++) {
0709: text.append(m_Instances.attribute(iAttribute).name()
0710: + "("
0711: + m_Instances.attribute(iAttribute).numValues()
0712: + "): ");
0713:
0714: for (int iParent = 0; iParent < m_ParentSets[iAttribute]
0715: .getNrOfParents(); iParent++) {
0716: text.append(m_Instances
0717: .attribute(
0718: m_ParentSets[iAttribute]
0719: .getParent(iParent)).name()
0720: + " ");
0721: }
0722:
0723: text.append("\n");
0724:
0725: // Description of distributions tends to be too much detail, so it is commented out here
0726: // for (int iParent = 0; iParent < m_ParentSets[iAttribute].GetCardinalityOfParents(); iParent++) {
0727: // text.append('(' + m_Distributions[iAttribute][iParent].toString() + ')');
0728: // }
0729: // text.append("\n");
0730: }
0731:
0732: text
0733: .append("LogScore Bayes: " + measureBayesScore()
0734: + "\n");
0735: text.append("LogScore BDeu: " + measureBDeuScore() + "\n");
0736: text.append("LogScore MDL: " + measureMDLScore() + "\n");
0737: text.append("LogScore ENTROPY: " + measureEntropyScore()
0738: + "\n");
0739: text.append("LogScore AIC: " + measureAICScore() + "\n");
0740:
0741: if (m_otherBayesNet != null) {
0742: text.append("Missing: "
0743: + m_otherBayesNet.missingArcs(this )
0744: + " Extra: " + m_otherBayesNet.extraArcs(this )
0745: + " Reversed: "
0746: + m_otherBayesNet.reversedArcs(this ) + "\n");
0747: text.append("Divergence: "
0748: + m_otherBayesNet.divergence(this ) + "\n");
0749: }
0750: }
0751:
0752: return text.toString();
0753: } // toString
0754:
0755: /**
0756: * Returns the type of graph this classifier
0757: * represents.
0758: * @return Drawable.TREE
0759: */
0760: public int graphType() {
0761: return Drawable.BayesNet;
0762: }
0763:
0764: /**
0765: * Returns a BayesNet graph in XMLBIF ver 0.3 format.
0766: * @return String representing this BayesNet in XMLBIF ver 0.3
0767: * @throws Exception in case BIF generation fails
0768: */
0769: public String graph() throws Exception {
0770: return toXMLBIF03();
0771: }
0772:
0773: /**
0774: * Returns a description of the classifier in XML BIF 0.3 format.
0775: * See http://www-2.cs.cmu.edu/~fgcozman/Research/InterchangeFormat/
0776: * for details on XML BIF.
0777: * @return an XML BIF 0.3 description of the classifier as a string.
0778: */
0779: public String toXMLBIF03() {
0780: if (m_Instances == null) {
0781: return ("<!--No model built yet-->");
0782: }
0783:
0784: StringBuffer text = new StringBuffer();
0785:
0786: text.append("<?xml version=\"1.0\"?>\n");
0787: text.append("<!-- DTD for the XMLBIF 0.3 format -->\n");
0788: text.append("<!DOCTYPE BIF [\n");
0789: text.append(" <!ELEMENT BIF ( NETWORK )*>\n");
0790: text.append(" <!ATTLIST BIF VERSION CDATA #REQUIRED>\n");
0791: text
0792: .append(" <!ELEMENT NETWORK ( NAME, ( PROPERTY | VARIABLE | DEFINITION )* )>\n");
0793: text.append(" <!ELEMENT NAME (#PCDATA)>\n");
0794: text
0795: .append(" <!ELEMENT VARIABLE ( NAME, ( OUTCOME | PROPERTY )* ) >\n");
0796: text
0797: .append(" <!ATTLIST VARIABLE TYPE (nature|decision|utility) \"nature\">\n");
0798: text.append(" <!ELEMENT OUTCOME (#PCDATA)>\n");
0799: text
0800: .append(" <!ELEMENT DEFINITION ( FOR | GIVEN | TABLE | PROPERTY )* >\n");
0801: text.append(" <!ELEMENT FOR (#PCDATA)>\n");
0802: text.append(" <!ELEMENT GIVEN (#PCDATA)>\n");
0803: text.append(" <!ELEMENT TABLE (#PCDATA)>\n");
0804: text.append(" <!ELEMENT PROPERTY (#PCDATA)>\n");
0805: text.append("]>\n");
0806: text.append("\n");
0807: text.append("\n");
0808: text.append("<BIF VERSION=\"0.3\">\n");
0809: text.append("<NETWORK>\n");
0810: text.append("<NAME>" + XMLNormalize(m_Instances.relationName())
0811: + "</NAME>\n");
0812: for (int iAttribute = 0; iAttribute < m_Instances
0813: .numAttributes(); iAttribute++) {
0814: text.append("<VARIABLE TYPE=\"nature\">\n");
0815: text.append("<NAME>"
0816: + XMLNormalize(m_Instances.attribute(iAttribute)
0817: .name()) + "</NAME>\n");
0818: for (int iValue = 0; iValue < m_Instances.attribute(
0819: iAttribute).numValues(); iValue++) {
0820: text.append("<OUTCOME>"
0821: + XMLNormalize(m_Instances
0822: .attribute(iAttribute).value(iValue))
0823: + "</OUTCOME>\n");
0824: }
0825: text.append("</VARIABLE>\n");
0826: }
0827:
0828: for (int iAttribute = 0; iAttribute < m_Instances
0829: .numAttributes(); iAttribute++) {
0830: text.append("<DEFINITION>\n");
0831: text.append("<FOR>"
0832: + XMLNormalize(m_Instances.attribute(iAttribute)
0833: .name()) + "</FOR>\n");
0834: for (int iParent = 0; iParent < m_ParentSets[iAttribute]
0835: .getNrOfParents(); iParent++) {
0836: text.append("<GIVEN>"
0837: + XMLNormalize(m_Instances.attribute(
0838: m_ParentSets[iAttribute]
0839: .getParent(iParent)).name())
0840: + "</GIVEN>\n");
0841: }
0842: text.append("<TABLE>\n");
0843: for (int iParent = 0; iParent < m_ParentSets[iAttribute]
0844: .getCardinalityOfParents(); iParent++) {
0845: for (int iValue = 0; iValue < m_Instances.attribute(
0846: iAttribute).numValues(); iValue++) {
0847: text.append(m_Distributions[iAttribute][iParent]
0848: .getProbability(iValue));
0849: text.append(' ');
0850: }
0851: text.append('\n');
0852: }
0853: text.append("</TABLE>\n");
0854: text.append("</DEFINITION>\n");
0855: }
0856: text.append("</NETWORK>\n");
0857: text.append("</BIF>\n");
0858: return text.toString();
0859: } // toXMLBIF03
0860:
0861: /** XMLNormalize converts the five standard XML entities in a string
0862: * g.e. the string V&D's is returned as V&D's
0863: * @param sStr string to normalize
0864: * @return normalized string
0865: */
0866: String XMLNormalize(String sStr) {
0867: StringBuffer sStr2 = new StringBuffer();
0868: for (int iStr = 0; iStr < sStr.length(); iStr++) {
0869: char c = sStr.charAt(iStr);
0870: switch (c) {
0871: case '&':
0872: sStr2.append("&");
0873: break;
0874: case '\'':
0875: sStr2.append("'");
0876: break;
0877: case '\"':
0878: sStr2.append(""");
0879: break;
0880: case '<':
0881: sStr2.append("<");
0882: break;
0883: case '>':
0884: sStr2.append(">");
0885: break;
0886: default:
0887: sStr2.append(c);
0888: }
0889: }
0890: return sStr2.toString();
0891: } // XMLNormalize
0892:
0893: /**
0894: * @return a string to describe the UseADTreeoption.
0895: */
0896: public String useADTreeTipText() {
0897: return "When ADTree (the data structure for increasing speed on counts,"
0898: + " not to be confused with the classifier under the same name) is used"
0899: + " learning time goes down typically. However, because ADTrees are memory"
0900: + " intensive, memory problems may occur. Switching this option off makes"
0901: + " the structure learning algorithms slower, and run with less memory."
0902: + " By default, ADTrees are used.";
0903: }
0904:
0905: /**
0906: * @return a string to describe the SearchAlgorithm.
0907: */
0908: public String searchAlgorithmTipText() {
0909: return "Select method used for searching network structures.";
0910: }
0911:
0912: /**
0913: * This will return a string describing the BayesNetEstimator.
0914: * @return The string.
0915: */
0916: public String estimatorTipText() {
0917: return "Select Estimator algorithm for finding the conditional probability tables"
0918: + " of the Bayes Network.";
0919: }
0920:
0921: /**
0922: * @return a string to describe the BIFFile.
0923: */
0924: public String BIFFileTipText() {
0925: return "Set the name of a file in BIF XML format. A Bayes network learned"
0926: + " from data can be compared with the Bayes network represented by the BIF file."
0927: + " Statistics calculated are o.a. the number of missing and extra arcs.";
0928: }
0929:
0930: /**
0931: * This will return a string describing the classifier.
0932: * @return The string.
0933: */
0934: public String globalInfo() {
0935: return "Bayes Network learning using various search algorithms and "
0936: + "quality measures.\n"
0937: + "Base class for a Bayes Network classifier. Provides "
0938: + "datastructures (network structure, conditional probability "
0939: + "distributions, etc.) and facilities common to Bayes Network "
0940: + "learning algorithms like K2 and B.\n\n"
0941: + "For more information see:\n\n"
0942: + "http://www.cs.waikato.ac.nz/~remco/weka.pdf";
0943: }
0944:
0945: /**
0946: * Main method for testing this class.
0947: *
0948: * @param argv the options
0949: */
0950: public static void main(String[] argv) {
0951: runClassifier(new BayesNet(), argv);
0952: } // main
0953:
0954: /** get name of the Bayes network
0955: * @return name of the Bayes net
0956: */
0957: public String getName() {
0958: return m_Instances.relationName();
0959: }
0960:
0961: /** get number of nodes in the Bayes network
0962: * @return number of nodes
0963: */
0964: public int getNrOfNodes() {
0965: return m_Instances.numAttributes();
0966: }
0967:
0968: /** get name of a node in the Bayes network
0969: * @param iNode index of the node
0970: * @return name of the specified node
0971: */
0972: public String getNodeName(int iNode) {
0973: return m_Instances.attribute(iNode).name();
0974: }
0975:
0976: /** get number of values a node can take
0977: * @param iNode index of the node
0978: * @return cardinality of the specified node
0979: */
0980: public int getCardinality(int iNode) {
0981: return m_Instances.attribute(iNode).numValues();
0982: }
0983:
0984: /** get name of a particular value of a node
0985: * @param iNode index of the node
0986: * @param iValue index of the value
0987: * @return cardinality of the specified node
0988: */
0989: public String getNodeValue(int iNode, int iValue) {
0990: return m_Instances.attribute(iNode).value(iValue);
0991: }
0992:
0993: /** get number of parents of a node in the network structure
0994: * @param iNode index of the node
0995: * @return number of parents of the specified node
0996: */
0997: public int getNrOfParents(int iNode) {
0998: return m_ParentSets[iNode].getNrOfParents();
0999: }
1000:
1001: /** get node index of a parent of a node in the network structure
1002: * @param iNode index of the node
1003: * @param iParent index of the parents, e.g., 0 is the first parent, 1 the second parent, etc.
1004: * @return node index of the iParent's parent of the specified node
1005: */
1006: public int getParent(int iNode, int iParent) {
1007: return m_ParentSets[iNode].getParent(iParent);
1008: }
1009:
1010: /** Get full set of parent sets.
1011: * @return parent sets;
1012: */
1013: public ParentSet[] getParentSets() {
1014: return m_ParentSets;
1015: }
1016:
1017: /** Get full set of estimators.
1018: * @return estimators;
1019: */
1020: public Estimator[][] getDistributions() {
1021: return m_Distributions;
1022: }
1023:
1024: /** get number of values the collection of parents of a node can take
1025: * @param iNode index of the node
1026: * @return cardinality of the parent set of the specified node
1027: */
1028: public int getParentCardinality(int iNode) {
1029: return m_ParentSets[iNode].getCardinalityOfParents();
1030: }
1031:
1032: /** get particular probability of the conditional probability distribtion
1033: * of a node given its parents.
1034: * @param iNode index of the node
1035: * @param iParent index of the parent set, 0 <= iParent <= getParentCardinality(iNode)
1036: * @param iValue index of the value, 0 <= iValue <= getCardinality(iNode)
1037: * @return probability
1038: */
1039: public double getProbability(int iNode, int iParent, int iValue) {
1040: return m_Distributions[iNode][iParent].getProbability(iValue);
1041: }
1042:
1043: /** get the parent set of a node
1044: * @param iNode index of the node
1045: * @return Parent set of the specified node.
1046: */
1047: public ParentSet getParentSet(int iNode) {
1048: return m_ParentSets[iNode];
1049: }
1050:
1051: /** get ADTree strucrture containing efficient representation of counts.
1052: * @return ADTree strucrture
1053: */
1054: public ADNode getADTree() {
1055: return m_ADTree;
1056: }
1057:
1058: // implementation of AdditionalMeasureProducer interface
1059: /**
1060: * Returns an enumeration of the measure names. Additional measures
1061: * must follow the naming convention of starting with "measure", eg.
1062: * double measureBlah()
1063: * @return an enumeration of the measure names
1064: */
1065: public Enumeration enumerateMeasures() {
1066: Vector newVector = new Vector(4);
1067: newVector.addElement("measureExtraArcs");
1068: newVector.addElement("measureMissingArcs");
1069: newVector.addElement("measureReversedArcs");
1070: newVector.addElement("measureDivergence");
1071: newVector.addElement("measureBayesScore");
1072: newVector.addElement("measureBDeuScore");
1073: newVector.addElement("measureMDLScore");
1074: newVector.addElement("measureAICScore");
1075: newVector.addElement("measureEntropyScore");
1076: return newVector.elements();
1077: } // enumerateMeasures
1078:
1079: public double measureExtraArcs() {
1080: if (m_otherBayesNet != null) {
1081: return m_otherBayesNet.extraArcs(this );
1082: }
1083: return 0;
1084: } // measureExtraArcs
1085:
1086: public double measureMissingArcs() {
1087: if (m_otherBayesNet != null) {
1088: return m_otherBayesNet.missingArcs(this );
1089: }
1090: return 0;
1091: } // measureMissingArcs
1092:
1093: public double measureReversedArcs() {
1094: if (m_otherBayesNet != null) {
1095: return m_otherBayesNet.reversedArcs(this );
1096: }
1097: return 0;
1098: } // measureReversedArcs
1099:
1100: public double measureDivergence() {
1101: if (m_otherBayesNet != null) {
1102: return m_otherBayesNet.divergence(this );
1103: }
1104: return 0;
1105: } // measureDivergence
1106:
1107: public double measureBayesScore() {
1108: LocalScoreSearchAlgorithm s = new LocalScoreSearchAlgorithm(
1109: this , m_Instances);
1110: return s.logScore(Scoreable.BAYES);
1111: } // measureBayesScore
1112:
1113: public double measureBDeuScore() {
1114: LocalScoreSearchAlgorithm s = new LocalScoreSearchAlgorithm(
1115: this , m_Instances);
1116: return s.logScore(Scoreable.BDeu);
1117: } // measureBDeuScore
1118:
1119: public double measureMDLScore() {
1120: LocalScoreSearchAlgorithm s = new LocalScoreSearchAlgorithm(
1121: this , m_Instances);
1122: return s.logScore(Scoreable.MDL);
1123: } // measureMDLScore
1124:
1125: public double measureAICScore() {
1126: LocalScoreSearchAlgorithm s = new LocalScoreSearchAlgorithm(
1127: this , m_Instances);
1128: return s.logScore(Scoreable.AIC);
1129: } // measureAICScore
1130:
1131: public double measureEntropyScore() {
1132: LocalScoreSearchAlgorithm s = new LocalScoreSearchAlgorithm(
1133: this , m_Instances);
1134: return s.logScore(Scoreable.ENTROPY);
1135: } // measureEntropyScore
1136:
1137: /**
1138: * Returns the value of the named measure
1139: * @param measureName the name of the measure to query for its value
1140: * @return the value of the named measure
1141: * @throws IllegalArgumentException if the named measure is not supported
1142: */
1143: public double getMeasure(String measureName) {
1144: if (measureName.equals("measureExtraArcs")) {
1145: return measureExtraArcs();
1146: }
1147: if (measureName.equals("measureMissingArcs")) {
1148: return measureMissingArcs();
1149: }
1150: if (measureName.equals("measureReversedArcs")) {
1151: return measureReversedArcs();
1152: }
1153: if (measureName.equals("measureDivergence")) {
1154: return measureDivergence();
1155: }
1156: if (measureName.equals("measureBayesScore")) {
1157: return measureBayesScore();
1158: }
1159: if (measureName.equals("measureBDeuScore")) {
1160: return measureBDeuScore();
1161: }
1162: if (measureName.equals("measureMDLScore")) {
1163: return measureMDLScore();
1164: }
1165: if (measureName.equals("measureAICScore")) {
1166: return measureAICScore();
1167: }
1168: if (measureName.equals("measureEntropyScore")) {
1169: return measureEntropyScore();
1170: }
1171: return 0;
1172: } // getMeasure
1173:
1174: } // class BayesNet
|