001: /*
002: * This program is free software; you can redistribute it and/or modify
003: * it under the terms of the GNU General Public License as published by
004: * the Free Software Foundation; either version 2 of the License, or
005: * (at your option) any later version.
006: *
007: * This program is distributed in the hope that it will be useful,
008: * but WITHOUT ANY WARRANTY; without even the implied warranty of
009: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
010: * GNU General Public License for more details.
011: *
012: * You should have received a copy of the GNU General Public License
013: * along with this program; if not, write to the Free Software
014: * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
015: */
016:
017: /*
018: * NaiveBayesSimple.java
019: * Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
020: *
021: */
022:
023: package weka.classifiers.bayes;
024:
025: import weka.classifiers.Classifier;
026: import weka.core.Attribute;
027: import weka.core.Capabilities;
028: import weka.core.Instance;
029: import weka.core.Instances;
030: import weka.core.TechnicalInformation;
031: import weka.core.TechnicalInformationHandler;
032: import weka.core.Utils;
033: import weka.core.Capabilities.Capability;
034: import weka.core.TechnicalInformation.Field;
035: import weka.core.TechnicalInformation.Type;
036:
037: import java.util.Enumeration;
038:
039: /**
040: <!-- globalinfo-start -->
041: * Class for building and using a simple Naive Bayes classifier.Numeric attributes are modelled by a normal distribution.<br/>
042: * <br/>
043: * For more information, see<br/>
044: * <br/>
045: * Richard Duda, Peter Hart (1973). Pattern Classification and Scene Analysis. Wiley, New York.
046: * <p/>
047: <!-- globalinfo-end -->
048: *
049: <!-- technical-bibtex-start -->
050: * BibTeX:
051: * <pre>
052: * @book{Duda1973,
053: * address = {New York},
054: * author = {Richard Duda and Peter Hart},
055: * publisher = {Wiley},
056: * title = {Pattern Classification and Scene Analysis},
057: * year = {1973}
058: * }
059: * </pre>
060: * <p/>
061: <!-- technical-bibtex-end -->
062: *
063: <!-- options-start -->
064: * Valid options are: <p/>
065: *
066: * <pre> -D
067: * If set, classifier is run in debug mode and
068: * may output additional info to the console</pre>
069: *
070: <!-- options-end -->
071: *
072: * @author Eibe Frank (eibe@cs.waikato.ac.nz)
073: * @version $Revision: 1.19 $
074: */
075: public class NaiveBayesSimple extends Classifier implements
076: TechnicalInformationHandler {
077:
078: /** for serialization */
079: static final long serialVersionUID = -1478242251770381214L;
080:
081: /** All the counts for nominal attributes. */
082: protected double[][][] m_Counts;
083:
084: /** The means for numeric attributes. */
085: protected double[][] m_Means;
086:
087: /** The standard deviations for numeric attributes. */
088: protected double[][] m_Devs;
089:
090: /** The prior probabilities of the classes. */
091: protected double[] m_Priors;
092:
093: /** The instances used for training. */
094: protected Instances m_Instances;
095:
096: /** Constant for normal distribution. */
097: protected static double NORM_CONST = Math.sqrt(2 * Math.PI);
098:
099: /**
100: * Returns a string describing this classifier
101: * @return a description of the classifier suitable for
102: * displaying in the explorer/experimenter gui
103: */
104: public String globalInfo() {
105: return "Class for building and using a simple Naive Bayes classifier."
106: + "Numeric attributes are modelled by a normal distribution.\n\n"
107: + "For more information, see\n\n"
108: + getTechnicalInformation().toString();
109: }
110:
111: /**
112: * Returns an instance of a TechnicalInformation object, containing
113: * detailed information about the technical background of this class,
114: * e.g., paper reference or book this class is based on.
115: *
116: * @return the technical information about this class
117: */
118: public TechnicalInformation getTechnicalInformation() {
119: TechnicalInformation result;
120:
121: result = new TechnicalInformation(Type.BOOK);
122: result.setValue(Field.AUTHOR, "Richard Duda and Peter Hart");
123: result.setValue(Field.YEAR, "1973");
124: result.setValue(Field.TITLE,
125: "Pattern Classification and Scene Analysis");
126: result.setValue(Field.PUBLISHER, "Wiley");
127: result.setValue(Field.ADDRESS, "New York");
128:
129: return result;
130: }
131:
132: /**
133: * Returns default capabilities of the classifier.
134: *
135: * @return the capabilities of this classifier
136: */
137: public Capabilities getCapabilities() {
138: Capabilities result = super .getCapabilities();
139:
140: // attributes
141: result.enable(Capability.NOMINAL_ATTRIBUTES);
142: result.enable(Capability.NUMERIC_ATTRIBUTES);
143: result.enable(Capability.DATE_ATTRIBUTES);
144: result.enable(Capability.MISSING_VALUES);
145:
146: // class
147: result.enable(Capability.NOMINAL_CLASS);
148: result.enable(Capability.MISSING_CLASS_VALUES);
149:
150: return result;
151: }
152:
153: /**
154: * Generates the classifier.
155: *
156: * @param instances set of instances serving as training data
157: * @exception Exception if the classifier has not been generated successfully
158: */
159: public void buildClassifier(Instances instances) throws Exception {
160:
161: int attIndex = 0;
162: double sum;
163:
164: // can classifier handle the data?
165: getCapabilities().testWithFail(instances);
166:
167: // remove instances with missing class
168: instances = new Instances(instances);
169: instances.deleteWithMissingClass();
170:
171: m_Instances = new Instances(instances, 0);
172:
173: // Reserve space
174: m_Counts = new double[instances.numClasses()][instances
175: .numAttributes() - 1][0];
176: m_Means = new double[instances.numClasses()][instances
177: .numAttributes() - 1];
178: m_Devs = new double[instances.numClasses()][instances
179: .numAttributes() - 1];
180: m_Priors = new double[instances.numClasses()];
181: Enumeration enu = instances.enumerateAttributes();
182: while (enu.hasMoreElements()) {
183: Attribute attribute = (Attribute) enu.nextElement();
184: if (attribute.isNominal()) {
185: for (int j = 0; j < instances.numClasses(); j++) {
186: m_Counts[j][attIndex] = new double[attribute
187: .numValues()];
188: }
189: } else {
190: for (int j = 0; j < instances.numClasses(); j++) {
191: m_Counts[j][attIndex] = new double[1];
192: }
193: }
194: attIndex++;
195: }
196:
197: // Compute counts and sums
198: Enumeration enumInsts = instances.enumerateInstances();
199: while (enumInsts.hasMoreElements()) {
200: Instance instance = (Instance) enumInsts.nextElement();
201: if (!instance.classIsMissing()) {
202: Enumeration enumAtts = instances.enumerateAttributes();
203: attIndex = 0;
204: while (enumAtts.hasMoreElements()) {
205: Attribute attribute = (Attribute) enumAtts
206: .nextElement();
207: if (!instance.isMissing(attribute)) {
208: if (attribute.isNominal()) {
209: m_Counts[(int) instance.classValue()][attIndex][(int) instance
210: .value(attribute)]++;
211: } else {
212: m_Means[(int) instance.classValue()][attIndex] += instance
213: .value(attribute);
214: m_Counts[(int) instance.classValue()][attIndex][0]++;
215: }
216: }
217: attIndex++;
218: }
219: m_Priors[(int) instance.classValue()]++;
220: }
221: }
222:
223: // Compute means
224: Enumeration enumAtts = instances.enumerateAttributes();
225: attIndex = 0;
226: while (enumAtts.hasMoreElements()) {
227: Attribute attribute = (Attribute) enumAtts.nextElement();
228: if (attribute.isNumeric()) {
229: for (int j = 0; j < instances.numClasses(); j++) {
230: if (m_Counts[j][attIndex][0] < 2) {
231: throw new Exception("attribute "
232: + attribute.name()
233: + ": less than two values for class "
234: + instances.classAttribute().value(j));
235: }
236: m_Means[j][attIndex] /= m_Counts[j][attIndex][0];
237: }
238: }
239: attIndex++;
240: }
241:
242: // Compute standard deviations
243: enumInsts = instances.enumerateInstances();
244: while (enumInsts.hasMoreElements()) {
245: Instance instance = (Instance) enumInsts.nextElement();
246: if (!instance.classIsMissing()) {
247: enumAtts = instances.enumerateAttributes();
248: attIndex = 0;
249: while (enumAtts.hasMoreElements()) {
250: Attribute attribute = (Attribute) enumAtts
251: .nextElement();
252: if (!instance.isMissing(attribute)) {
253: if (attribute.isNumeric()) {
254: m_Devs[(int) instance.classValue()][attIndex] += (m_Means[(int) instance
255: .classValue()][attIndex] - instance
256: .value(attribute))
257: * (m_Means[(int) instance
258: .classValue()][attIndex] - instance
259: .value(attribute));
260: }
261: }
262: attIndex++;
263: }
264: }
265: }
266: enumAtts = instances.enumerateAttributes();
267: attIndex = 0;
268: while (enumAtts.hasMoreElements()) {
269: Attribute attribute = (Attribute) enumAtts.nextElement();
270: if (attribute.isNumeric()) {
271: for (int j = 0; j < instances.numClasses(); j++) {
272: if (m_Devs[j][attIndex] <= 0) {
273: throw new Exception(
274: "attribute "
275: + attribute.name()
276: + ": standard deviation is 0 for class "
277: + instances.classAttribute()
278: .value(j));
279: } else {
280: m_Devs[j][attIndex] /= m_Counts[j][attIndex][0] - 1;
281: m_Devs[j][attIndex] = Math
282: .sqrt(m_Devs[j][attIndex]);
283: }
284: }
285: }
286: attIndex++;
287: }
288:
289: // Normalize counts
290: enumAtts = instances.enumerateAttributes();
291: attIndex = 0;
292: while (enumAtts.hasMoreElements()) {
293: Attribute attribute = (Attribute) enumAtts.nextElement();
294: if (attribute.isNominal()) {
295: for (int j = 0; j < instances.numClasses(); j++) {
296: sum = Utils.sum(m_Counts[j][attIndex]);
297: for (int i = 0; i < attribute.numValues(); i++) {
298: m_Counts[j][attIndex][i] = (m_Counts[j][attIndex][i] + 1)
299: / (sum + (double) attribute.numValues());
300: }
301: }
302: }
303: attIndex++;
304: }
305:
306: // Normalize priors
307: sum = Utils.sum(m_Priors);
308: for (int j = 0; j < instances.numClasses(); j++)
309: m_Priors[j] = (m_Priors[j] + 1)
310: / (sum + (double) instances.numClasses());
311: }
312:
313: /**
314: * Calculates the class membership probabilities for the given test instance.
315: *
316: * @param instance the instance to be classified
317: * @return predicted class probability distribution
318: * @exception Exception if distribution can't be computed
319: */
320: public double[] distributionForInstance(Instance instance)
321: throws Exception {
322:
323: double[] probs = new double[instance.numClasses()];
324: int attIndex;
325:
326: for (int j = 0; j < instance.numClasses(); j++) {
327: probs[j] = 1;
328: Enumeration enumAtts = instance.enumerateAttributes();
329: attIndex = 0;
330: while (enumAtts.hasMoreElements()) {
331: Attribute attribute = (Attribute) enumAtts
332: .nextElement();
333: if (!instance.isMissing(attribute)) {
334: if (attribute.isNominal()) {
335: probs[j] *= m_Counts[j][attIndex][(int) instance
336: .value(attribute)];
337: } else {
338: probs[j] *= normalDens(instance
339: .value(attribute),
340: m_Means[j][attIndex],
341: m_Devs[j][attIndex]);
342: }
343: }
344: attIndex++;
345: }
346: probs[j] *= m_Priors[j];
347: }
348:
349: // Normalize probabilities
350: Utils.normalize(probs);
351:
352: return probs;
353: }
354:
355: /**
356: * Returns a description of the classifier.
357: *
358: * @return a description of the classifier as a string.
359: */
360: public String toString() {
361:
362: if (m_Instances == null) {
363: return "Naive Bayes (simple): No model built yet.";
364: }
365: try {
366: StringBuffer text = new StringBuffer("Naive Bayes (simple)");
367: int attIndex;
368:
369: for (int i = 0; i < m_Instances.numClasses(); i++) {
370: text.append("\n\nClass "
371: + m_Instances.classAttribute().value(i)
372: + ": P(C) = "
373: + Utils.doubleToString(m_Priors[i], 10, 8)
374: + "\n\n");
375: Enumeration enumAtts = m_Instances
376: .enumerateAttributes();
377: attIndex = 0;
378: while (enumAtts.hasMoreElements()) {
379: Attribute attribute = (Attribute) enumAtts
380: .nextElement();
381: text.append("Attribute " + attribute.name() + "\n");
382: if (attribute.isNominal()) {
383: for (int j = 0; j < attribute.numValues(); j++) {
384: text.append(attribute.value(j) + "\t");
385: }
386: text.append("\n");
387: for (int j = 0; j < attribute.numValues(); j++)
388: text.append(Utils.doubleToString(
389: m_Counts[i][attIndex][j], 10, 8)
390: + "\t");
391: } else {
392: text.append("Mean: "
393: + Utils.doubleToString(
394: m_Means[i][attIndex], 10, 8)
395: + "\t");
396: text.append("Standard Deviation: "
397: + Utils.doubleToString(
398: m_Devs[i][attIndex], 10, 8));
399: }
400: text.append("\n\n");
401: attIndex++;
402: }
403: }
404:
405: return text.toString();
406: } catch (Exception e) {
407: return "Can't print Naive Bayes classifier!";
408: }
409: }
410:
411: /**
412: * Density function of normal distribution.
413: *
414: * @param x the value to get the density for
415: * @param mean the mean
416: * @param stdDev the standard deviation
417: * @return the density
418: */
419: protected double normalDens(double x, double mean, double stdDev) {
420:
421: double diff = x - mean;
422:
423: return (1 / (NORM_CONST * stdDev))
424: * Math.exp(-(diff * diff / (2 * stdDev * stdDev)));
425: }
426:
427: /**
428: * Main method for testing this class.
429: *
430: * @param argv the options
431: */
432: public static void main(String[] argv) {
433: runClassifier(new NaiveBayesSimple(), argv);
434: }
435: }
|