001: /*
002: * This program is free software; you can redistribute it and/or modify
003: * it under the terms of the GNU General Public License as published by
004: * the Free Software Foundation; either version 2 of the License, or
005: * (at your option) any later version.
006: *
007: * This program is distributed in the hope that it will be useful,
008: * but WITHOUT ANY WARRANTY; without even the implied warranty of
009: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
010: * GNU General Public License for more details.
011: *
012: * You should have received a copy of the GNU General Public License
013: * along with this program; if not, write to the Free Software
014: * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
015: */
016:
017: /*
018: * WAODE.java
019: * Copyright 2006 Liangxiao Jiang
020: */
021:
022: package weka.classifiers.bayes;
023:
024: import weka.classifiers.Classifier;
025: import weka.core.Capabilities;
026: import weka.core.Instance;
027: import weka.core.Instances;
028: import weka.core.Option;
029: import weka.core.TechnicalInformation;
030: import weka.core.TechnicalInformationHandler;
031: import weka.core.Utils;
032: import weka.core.Capabilities.Capability;
033: import weka.core.TechnicalInformation.Field;
034: import weka.core.TechnicalInformation.Type;
035:
036: import java.util.Enumeration;
037: import java.util.Vector;
038:
039: /**
040: <!-- globalinfo-start -->
041: * WAODE contructs the model called Weightily Averaged One-Dependence Estimators.<br/>
042: * <br/>
043: * For more information, see<br/>
044: * <br/>
045: * L. Jiang, H. Zhang: Weightily Averaged One-Dependence Estimators. In: Proceedings of the 9th Biennial Pacific Rim International Conference on Artificial Intelligence, PRICAI 2006, 970-974, 2006.
046: * <p/>
047: <!-- globalinfo-end -->
048: *
049: <!-- technical-bibtex-start -->
050: * BibTeX:
051: * <pre>
052: * @inproceedings{Jiang2006,
053: * author = {L. Jiang and H. Zhang},
054: * booktitle = {Proceedings of the 9th Biennial Pacific Rim International Conference on Artificial Intelligence, PRICAI 2006},
055: * pages = {970-974},
056: * series = {LNAI},
057: * title = {Weightily Averaged One-Dependence Estimators},
058: * volume = {4099},
059: * year = {2006}
060: * }
061: * </pre>
062: * <p/>
063: <!-- technical-bibtex-end -->
064: *
065: <!-- options-start -->
066: * Valid options are: <p/>
067: *
068: * <pre> -D
069: * If set, classifier is run in debug mode and
070: * may output additional info to the console</pre>
071: *
072: * <pre> -I
073: * Whether to print some more internals.
074: * (default: no)</pre>
075: *
076: <!-- options-end -->
077: *
078: * @author Liangxiao Jiang (ljiang@cug.edu.cn)
079: * @author H. Zhang (hzhang@unb.ca)
080: * @version $Revision: 1.2 $
081: */
082: public class WAODE extends Classifier implements
083: TechnicalInformationHandler {
084:
085: /** for serialization */
086: private static final long serialVersionUID = 2170978824284697882L;
087:
088: /** The number of each class value occurs in the dataset */
089: private double[] m_ClassCounts;
090:
091: /** The number of each attribute value occurs in the dataset */
092: private double[] m_AttCounts;
093:
094: /** The number of two attributes values occurs in the dataset */
095: private double[][] m_AttAttCounts;
096:
097: /** The number of class and two attributes values occurs in the dataset */
098: private double[][][] m_ClassAttAttCounts;
099:
100: /** The number of values for each attribute in the dataset */
101: private int[] m_NumAttValues;
102:
103: /** The number of values for all attributes in the dataset */
104: private int m_TotalAttValues;
105:
106: /** The number of classes in the dataset */
107: private int m_NumClasses;
108:
109: /** The number of attributes including class in the dataset */
110: private int m_NumAttributes;
111:
112: /** The number of instances in the dataset */
113: private int m_NumInstances;
114:
115: /** The index of the class attribute in the dataset */
116: private int m_ClassIndex;
117:
118: /** The starting index of each attribute in the dataset */
119: private int[] m_StartAttIndex;
120:
121: /** The array of mutual information between each attribute and class */
122: private double[] m_mutualInformation;
123:
124: /** the header information of the training data */
125: private Instances m_Header = null;
126:
127: /** whether to print more internals in the toString method
128: * @see #toString() */
129: private boolean m_Internals = false;
130:
131: /** a ZeroR model in case no model can be built from the data */
132: private Classifier m_ZeroR;
133:
134: /**
135: * Returns a string describing this classifier
136: *
137: * @return a description of the classifier suitable for
138: * displaying in the explorer/experimenter gui
139: */
140: public String globalInfo() {
141: return "WAODE contructs the model called Weightily Averaged One-Dependence "
142: + "Estimators.\n\n"
143: + "For more information, see\n\n"
144: + getTechnicalInformation().toString();
145: }
146:
147: /**
148: * Gets an enumeration describing the available options.
149: *
150: * @return an enumeration of all the available options.
151: */
152: public Enumeration listOptions() {
153: Vector result = new Vector();
154: Enumeration enm = super .listOptions();
155: while (enm.hasMoreElements())
156: result.add(enm.nextElement());
157:
158: result.addElement(new Option(
159: "\tWhether to print some more internals.\n"
160: + "\t(default: no)", "I", 0, "-I"));
161:
162: return result.elements();
163: }
164:
165: /**
166: * Parses a given list of options. <p/>
167: *
168: <!-- options-start -->
169: * Valid options are: <p/>
170: *
171: * <pre> -D
172: * If set, classifier is run in debug mode and
173: * may output additional info to the console</pre>
174: *
175: * <pre> -I
176: * Whether to print some more internals.
177: * (default: no)</pre>
178: *
179: <!-- options-end -->
180: *
181: * @param options the list of options as an array of strings
182: * @throws Exception if an option is not supported
183: */
184: public void setOptions(String[] options) throws Exception {
185: super .setOptions(options);
186:
187: setInternals(Utils.getFlag('I', options));
188: }
189:
190: /**
191: * Gets the current settings of the filter.
192: *
193: * @return an array of strings suitable for passing to setOptions
194: */
195: public String[] getOptions() {
196: Vector result;
197: String[] options;
198: int i;
199:
200: result = new Vector();
201:
202: options = super .getOptions();
203: for (i = 0; i < options.length; i++)
204: result.add(options[i]);
205:
206: if (getInternals())
207: result.add("-I");
208:
209: return (String[]) result.toArray(new String[result.size()]);
210: }
211:
212: /**
213: * Returns the tip text for this property
214: *
215: * @return tip text for this property suitable for
216: * displaying in the explorer/experimenter gui
217: */
218: public String internalsTipText() {
219: return "Prints more internals of the classifier.";
220: }
221:
222: /**
223: * Sets whether internals about classifier are printed via toString().
224: *
225: * @param value if internals should be printed
226: * @see #toString()
227: */
228: public void setInternals(boolean value) {
229: m_Internals = value;
230: }
231:
232: /**
233: * Gets whether more internals of the classifier are printed.
234: *
235: * @return true if more internals are printed
236: */
237: public boolean getInternals() {
238: return m_Internals;
239: }
240:
241: /**
242: * Returns an instance of a TechnicalInformation object, containing
243: * detailed information about the technical background of this class,
244: * e.g., paper reference or book this class is based on.
245: *
246: * @return the technical information about this class
247: */
248: public TechnicalInformation getTechnicalInformation() {
249: TechnicalInformation result;
250:
251: result = new TechnicalInformation(Type.INPROCEEDINGS);
252: result.setValue(Field.AUTHOR, "L. Jiang and H. Zhang");
253: result.setValue(Field.TITLE,
254: "Weightily Averaged One-Dependence Estimators");
255: result
256: .setValue(
257: Field.BOOKTITLE,
258: "Proceedings of the 9th Biennial Pacific Rim International Conference on Artificial Intelligence, PRICAI 2006");
259: result.setValue(Field.YEAR, "2006");
260: result.setValue(Field.PAGES, "970-974");
261: result.setValue(Field.SERIES, "LNAI");
262: result.setValue(Field.VOLUME, "4099");
263:
264: return result;
265: }
266:
267: /**
268: * Returns default capabilities of the classifier.
269: *
270: * @return the capabilities of this classifier
271: */
272: public Capabilities getCapabilities() {
273: Capabilities result = super .getCapabilities();
274:
275: // attributes
276: result.enable(Capability.NOMINAL_ATTRIBUTES);
277:
278: // class
279: result.enable(Capability.NOMINAL_CLASS);
280:
281: return result;
282: }
283:
284: /**
285: * Generates the classifier.
286: *
287: * @param instances set of instances serving as training data
288: * @throws Exception if the classifier has not been generated successfully
289: */
290: public void buildClassifier(Instances instances) throws Exception {
291:
292: // can classifier handle the data?
293: getCapabilities().testWithFail(instances);
294:
295: // only class? -> build ZeroR model
296: if (instances.numAttributes() == 1) {
297: System.err
298: .println("Cannot build model (only class attribute present in data!), "
299: + "using ZeroR model instead!");
300: m_ZeroR = new weka.classifiers.rules.ZeroR();
301: m_ZeroR.buildClassifier(instances);
302: return;
303: } else {
304: m_ZeroR = null;
305: }
306:
307: // reset variable
308: m_NumClasses = instances.numClasses();
309: m_ClassIndex = instances.classIndex();
310: m_NumAttributes = instances.numAttributes();
311: m_NumInstances = instances.numInstances();
312: m_TotalAttValues = 0;
313:
314: // allocate space for attribute reference arrays
315: m_StartAttIndex = new int[m_NumAttributes];
316: m_NumAttValues = new int[m_NumAttributes];
317:
318: // set the starting index of each attribute and the number of values for
319: // each attribute and the total number of values for all attributes (not including class).
320: for (int i = 0; i < m_NumAttributes; i++) {
321: if (i != m_ClassIndex) {
322: m_StartAttIndex[i] = m_TotalAttValues;
323: m_NumAttValues[i] = instances.attribute(i).numValues();
324: m_TotalAttValues += m_NumAttValues[i];
325: } else {
326: m_StartAttIndex[i] = -1;
327: m_NumAttValues[i] = m_NumClasses;
328: }
329: }
330:
331: // allocate space for counts and frequencies
332: m_ClassCounts = new double[m_NumClasses];
333: m_AttCounts = new double[m_TotalAttValues];
334: m_AttAttCounts = new double[m_TotalAttValues][m_TotalAttValues];
335: m_ClassAttAttCounts = new double[m_NumClasses][m_TotalAttValues][m_TotalAttValues];
336: m_Header = new Instances(instances, 0);
337:
338: // Calculate the counts
339: for (int k = 0; k < m_NumInstances; k++) {
340: int classVal = (int) instances.instance(k).classValue();
341: m_ClassCounts[classVal]++;
342: int[] attIndex = new int[m_NumAttributes];
343: for (int i = 0; i < m_NumAttributes; i++) {
344: if (i == m_ClassIndex) {
345: attIndex[i] = -1;
346: } else {
347: attIndex[i] = m_StartAttIndex[i]
348: + (int) instances.instance(k).value(i);
349: m_AttCounts[attIndex[i]]++;
350: }
351: }
352: for (int Att1 = 0; Att1 < m_NumAttributes; Att1++) {
353: if (attIndex[Att1] == -1)
354: continue;
355: for (int Att2 = 0; Att2 < m_NumAttributes; Att2++) {
356: if ((attIndex[Att2] != -1)) {
357: m_AttAttCounts[attIndex[Att1]][attIndex[Att2]]++;
358: m_ClassAttAttCounts[classVal][attIndex[Att1]][attIndex[Att2]]++;
359: }
360: }
361: }
362: }
363:
364: //compute mutual information between each attribute and class
365: m_mutualInformation = new double[m_NumAttributes];
366: for (int att = 0; att < m_NumAttributes; att++) {
367: if (att == m_ClassIndex)
368: continue;
369: m_mutualInformation[att] = mutualInfo(att);
370: }
371: }
372:
373: /**
374: * Computes mutual information between each attribute and class attribute.
375: *
376: * @param att is the attribute
377: * @return the conditional mutual information between son and parent given class
378: */
379: private double mutualInfo(int att) {
380:
381: double mutualInfo = 0;
382: int attIndex = m_StartAttIndex[att];
383: double[] PriorsClass = new double[m_NumClasses];
384: double[] PriorsAttribute = new double[m_NumAttValues[att]];
385: double[][] PriorsClassAttribute = new double[m_NumClasses][m_NumAttValues[att]];
386:
387: for (int i = 0; i < m_NumClasses; i++) {
388: PriorsClass[i] = m_ClassCounts[i] / m_NumInstances;
389: }
390:
391: for (int j = 0; j < m_NumAttValues[att]; j++) {
392: PriorsAttribute[j] = m_AttCounts[attIndex + j]
393: / m_NumInstances;
394: }
395:
396: for (int i = 0; i < m_NumClasses; i++) {
397: for (int j = 0; j < m_NumAttValues[att]; j++) {
398: PriorsClassAttribute[i][j] = m_ClassAttAttCounts[i][attIndex
399: + j][attIndex + j]
400: / m_NumInstances;
401: }
402: }
403:
404: for (int i = 0; i < m_NumClasses; i++) {
405: for (int j = 0; j < m_NumAttValues[att]; j++) {
406: mutualInfo += PriorsClassAttribute[i][j]
407: * log2(PriorsClassAttribute[i][j],
408: PriorsClass[i] * PriorsAttribute[j]);
409: }
410: }
411: return mutualInfo;
412: }
413:
414: /**
415: * compute the logarithm whose base is 2.
416: *
417: * @param x numerator of the fraction.
418: * @param y denominator of the fraction.
419: * @return the natual logarithm of this fraction.
420: */
421: private double log2(double x, double y) {
422:
423: if (x < Utils.SMALL || y < Utils.SMALL)
424: return 0.0;
425: else
426: return Math.log(x / y) / Math.log(2);
427: }
428:
429: /**
430: * Calculates the class membership probabilities for the given test instance
431: *
432: * @param instance the instance to be classified
433: * @return predicted class probability distribution
434: * @throws Exception if there is a problem generating the prediction
435: */
436: public double[] distributionForInstance(Instance instance)
437: throws Exception {
438:
439: // default model?
440: if (m_ZeroR != null) {
441: return m_ZeroR.distributionForInstance(instance);
442: }
443:
444: //Definition of local variables
445: double[] probs = new double[m_NumClasses];
446: double prob;
447: double mutualInfoSum;
448:
449: // store instance's att values in an int array
450: int[] attIndex = new int[m_NumAttributes];
451: for (int att = 0; att < m_NumAttributes; att++) {
452: if (att == m_ClassIndex)
453: attIndex[att] = -1;
454: else
455: attIndex[att] = m_StartAttIndex[att]
456: + (int) instance.value(att);
457: }
458:
459: // calculate probabilities for each possible class value
460: for (int classVal = 0; classVal < m_NumClasses; classVal++) {
461: probs[classVal] = 0;
462: prob = 1;
463: mutualInfoSum = 0.0;
464: for (int parent = 0; parent < m_NumAttributes; parent++) {
465: if (attIndex[parent] == -1)
466: continue;
467: prob = (m_ClassAttAttCounts[classVal][attIndex[parent]][attIndex[parent]] + 1.0 / (m_NumClasses * m_NumAttValues[parent]))
468: / (m_NumInstances + 1.0);
469: for (int son = 0; son < m_NumAttributes; son++) {
470: if (attIndex[son] == -1 || son == parent)
471: continue;
472: prob *= (m_ClassAttAttCounts[classVal][attIndex[parent]][attIndex[son]] + 1.0 / m_NumAttValues[son])
473: / (m_ClassAttAttCounts[classVal][attIndex[parent]][attIndex[parent]] + 1.0);
474: }
475: mutualInfoSum += m_mutualInformation[parent];
476: probs[classVal] += m_mutualInformation[parent] * prob;
477: }
478: probs[classVal] /= mutualInfoSum;
479: }
480: if (!Double.isNaN(Utils.sum(probs)))
481: Utils.normalize(probs);
482: return probs;
483: }
484:
485: /**
486: * returns a string representation of the classifier
487: *
488: * @return string representation of the classifier
489: */
490: public String toString() {
491: StringBuffer result;
492: String classname;
493: int i;
494:
495: // only ZeroR model?
496: if (m_ZeroR != null) {
497: result = new StringBuffer();
498: result.append(this .getClass().getName().replaceAll(".*\\.",
499: "")
500: + "\n");
501: result.append(this .getClass().getName().replaceAll(".*\\.",
502: "").replaceAll(".", "=")
503: + "\n\n");
504: result
505: .append("Warning: No model could be built, hence ZeroR model is used:\n\n");
506: result.append(m_ZeroR.toString());
507: } else {
508: classname = this .getClass().getName().replaceAll(".*\\.",
509: "");
510: result = new StringBuffer();
511: result.append(classname + "\n");
512: result.append(classname.replaceAll(".", "=") + "\n\n");
513:
514: if (m_Header == null) {
515: result.append("No Model built yet.\n");
516: } else {
517: if (getInternals()) {
518: result
519: .append("Mutual information of attributes with class attribute:\n");
520: for (i = 0; i < m_Header.numAttributes(); i++) {
521: // skip class
522: if (i == m_Header.classIndex())
523: continue;
524:
525: result.append((i + 1)
526: + ". "
527: + m_Header.attribute(i).name()
528: + ": "
529: + Utils.doubleToString(
530: m_mutualInformation[i], 6)
531: + "\n");
532: }
533: } else {
534: result.append("Model built successfully.\n");
535: }
536: }
537: }
538:
539: return result.toString();
540: }
541:
542: /**
543: * Main method for testing this class.
544: *
545: * @param argv the commandline options, use -h to list all options
546: */
547: public static void main(String[] argv) {
548: runClassifier(new WAODE(), argv);
549: }
550: }
|