001: /*
002: * This program is free software; you can redistribute it and/or modify
003: * it under the terms of the GNU General Public License as published by
004: * the Free Software Foundation; either version 2 of the License, or
005: * (at your option) any later version.
006: *
007: * This program is distributed in the hope that it will be useful,
008: * but WITHOUT ANY WARRANTY; without even the implied warranty of
009: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
010: * GNU General Public License for more details.
011: *
012: * You should have received a copy of the GNU General Public License
013: * along with this program; if not, write to the Free Software
014: * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
015: */
016:
017: /*
018: * HNB.java
019: * Copyright (C) 2004 Liangxiao Jiang
020: */
021:
022: package weka.classifiers.bayes;
023:
024: import weka.classifiers.Classifier;
025: import weka.core.Capabilities;
026: import weka.core.Instance;
027: import weka.core.Instances;
028: import weka.core.TechnicalInformation;
029: import weka.core.TechnicalInformationHandler;
030: import weka.core.Utils;
031: import weka.core.Capabilities.Capability;
032: import weka.core.TechnicalInformation.Field;
033: import weka.core.TechnicalInformation.Type;
034:
035: /**
036: <!-- globalinfo-start -->
037: * Contructs Hidden Naive Bayes classification model with high classification accuracy and AUC.<br/>
038: * <br/>
039: * For more information refer to:<br/>
040: * <br/>
041: * H. Zhang, L. Jiang, J. Su: Hidden Naive Bayes. In: Twentieth National Conference on Artificial Intelligence, 919-924, 2005.
042: * <p/>
043: <!-- globalinfo-end -->
044: *
045: <!-- technical-bibtex-start -->
046: * BibTeX:
047: * <pre>
048: * @inproceedings{Zhang2005,
049: * author = {H. Zhang and L. Jiang and J. Su},
050: * booktitle = {Twentieth National Conference on Artificial Intelligence},
051: * pages = {919-924},
052: * publisher = {AAAI Press},
053: * title = {Hidden Naive Bayes},
054: * year = {2005}
055: * }
056: * </pre>
057: * <p/>
058: <!-- technical-bibtex-end -->
059: *
060: <!-- options-start -->
061: * Valid options are: <p/>
062: *
063: * <pre> -D
064: * If set, classifier is run in debug mode and
065: * may output additional info to the console</pre>
066: *
067: <!-- options-end -->
068: *
069: * @author H. Zhang (hzhang@unb.ca)
070: * @author Liangxiao Jiang (ljiang@cug.edu.cn)
071: * @version $Revision: 1.8 $
072: */
073: public class HNB extends Classifier implements
074: TechnicalInformationHandler {
075:
076: /** for serialization */
077: static final long serialVersionUID = -4503874444306113214L;
078:
079: /** The number of each class value occurs in the dataset */
080: private double[] m_ClassCounts;
081:
082: /** The number of class and two attributes values occurs in the dataset */
083: private double[][][] m_ClassAttAttCounts;
084:
085: /** The number of values for each attribute in the dataset */
086: private int[] m_NumAttValues;
087:
088: /** The number of values for all attributes in the dataset */
089: private int m_TotalAttValues;
090:
091: /** The number of classes in the dataset */
092: private int m_NumClasses;
093:
094: /** The number of attributes including class in the dataset */
095: private int m_NumAttributes;
096:
097: /** The number of instances in the dataset */
098: private int m_NumInstances;
099:
100: /** The index of the class attribute in the dataset */
101: private int m_ClassIndex;
102:
103: /** The starting index of each attribute in the dataset */
104: private int[] m_StartAttIndex;
105:
106: /** The 2D array of conditional mutual information of each pair attributes */
107: private double[][] m_condiMutualInfo;
108:
109: /**
110: * Returns a string describing this classifier.
111: *
112: * @return a description of the data generator suitable for
113: * displaying in the explorer/experimenter gui
114: */
115: public String globalInfo() {
116:
117: return "Contructs Hidden Naive Bayes classification model with high "
118: + "classification accuracy and AUC.\n\n"
119: + "For more information refer to:\n\n"
120: + getTechnicalInformation().toString();
121: }
122:
123: /**
124: * Returns an instance of a TechnicalInformation object, containing
125: * detailed information about the technical background of this class,
126: * e.g., paper reference or book this class is based on.
127: *
128: * @return the technical information about this class
129: */
130: public TechnicalInformation getTechnicalInformation() {
131: TechnicalInformation result;
132:
133: result = new TechnicalInformation(Type.INPROCEEDINGS);
134: result
135: .setValue(Field.AUTHOR,
136: "H. Zhang and L. Jiang and J. Su");
137: result.setValue(Field.TITLE, "Hidden Naive Bayes");
138: result
139: .setValue(Field.BOOKTITLE,
140: "Twentieth National Conference on Artificial Intelligence");
141: result.setValue(Field.YEAR, "2005");
142: result.setValue(Field.PAGES, "919-924");
143: result.setValue(Field.PUBLISHER, "AAAI Press");
144:
145: return result;
146: }
147:
148: /**
149: * Returns default capabilities of the classifier.
150: *
151: * @return the capabilities of this classifier
152: */
153: public Capabilities getCapabilities() {
154: Capabilities result = super .getCapabilities();
155:
156: // attributes
157: result.enable(Capability.NOMINAL_ATTRIBUTES);
158:
159: // class
160: result.enable(Capability.NOMINAL_CLASS);
161: result.enable(Capability.MISSING_CLASS_VALUES);
162:
163: return result;
164: }
165:
166: /**
167: * Generates the classifier.
168: *
169: * @param instances set of instances serving as training data
170: * @exception Exception if the classifier has not been generated successfully
171: */
172: public void buildClassifier(Instances instances) throws Exception {
173:
174: // can classifier handle the data?
175: getCapabilities().testWithFail(instances);
176:
177: // remove instances with missing class
178: instances = new Instances(instances);
179: instances.deleteWithMissingClass();
180:
181: // reset variable
182: m_NumClasses = instances.numClasses();
183: m_ClassIndex = instances.classIndex();
184: m_NumAttributes = instances.numAttributes();
185: m_NumInstances = instances.numInstances();
186: m_TotalAttValues = 0;
187:
188: // allocate space for attribute reference arrays
189: m_StartAttIndex = new int[m_NumAttributes];
190: m_NumAttValues = new int[m_NumAttributes];
191:
192: // set the starting index of each attribute and the number of values for
193: // each attribute and the total number of values for all attributes (not including class).
194: for (int i = 0; i < m_NumAttributes; i++) {
195: if (i != m_ClassIndex) {
196: m_StartAttIndex[i] = m_TotalAttValues;
197: m_NumAttValues[i] = instances.attribute(i).numValues();
198: m_TotalAttValues += m_NumAttValues[i];
199: } else {
200: m_StartAttIndex[i] = -1;
201: m_NumAttValues[i] = m_NumClasses;
202: }
203: }
204:
205: // allocate space for counts and frequencies
206: m_ClassCounts = new double[m_NumClasses];
207: m_ClassAttAttCounts = new double[m_NumClasses][m_TotalAttValues][m_TotalAttValues];
208:
209: // Calculate the counts
210: for (int k = 0; k < m_NumInstances; k++) {
211: int classVal = (int) instances.instance(k).classValue();
212: m_ClassCounts[classVal]++;
213: int[] attIndex = new int[m_NumAttributes];
214: for (int i = 0; i < m_NumAttributes; i++) {
215: if (i == m_ClassIndex)
216: attIndex[i] = -1;
217: else
218: attIndex[i] = m_StartAttIndex[i]
219: + (int) instances.instance(k).value(i);
220: }
221: for (int Att1 = 0; Att1 < m_NumAttributes; Att1++) {
222: if (attIndex[Att1] == -1)
223: continue;
224: for (int Att2 = 0; Att2 < m_NumAttributes; Att2++) {
225: if ((attIndex[Att2] != -1)) {
226: m_ClassAttAttCounts[classVal][attIndex[Att1]][attIndex[Att2]]++;
227: }
228: }
229: }
230: }
231:
232: //compute conditional mutual information of each pair attributes (not including class)
233: m_condiMutualInfo = new double[m_NumAttributes][m_NumAttributes];
234: for (int son = 0; son < m_NumAttributes; son++) {
235: if (son == m_ClassIndex)
236: continue;
237: for (int parent = 0; parent < m_NumAttributes; parent++) {
238: if (parent == m_ClassIndex || son == parent)
239: continue;
240: m_condiMutualInfo[son][parent] = conditionalMutualInfo(
241: son, parent);
242: }
243: }
244: }
245:
246: /**
247: * Computes conditional mutual information between a pair of attributes.
248: *
249: * @param son the son attribute
250: * @param parent the parent attribute
251: * @return the conditional mutual information between son and parent given class
252: * @throws Exception if computation fails
253: */
254: private double conditionalMutualInfo(int son, int parent)
255: throws Exception {
256:
257: double CondiMutualInfo = 0;
258: int sIndex = m_StartAttIndex[son];
259: int pIndex = m_StartAttIndex[parent];
260: double[] PriorsClass = new double[m_NumClasses];
261: double[][] PriorsClassSon = new double[m_NumClasses][m_NumAttValues[son]];
262: double[][] PriorsClassParent = new double[m_NumClasses][m_NumAttValues[parent]];
263: double[][][] PriorsClassParentSon = new double[m_NumClasses][m_NumAttValues[parent]][m_NumAttValues[son]];
264:
265: for (int i = 0; i < m_NumClasses; i++) {
266: PriorsClass[i] = m_ClassCounts[i] / m_NumInstances;
267: }
268:
269: for (int i = 0; i < m_NumClasses; i++) {
270: for (int j = 0; j < m_NumAttValues[son]; j++) {
271: PriorsClassSon[i][j] = m_ClassAttAttCounts[i][sIndex
272: + j][sIndex + j]
273: / m_NumInstances;
274: }
275: }
276:
277: for (int i = 0; i < m_NumClasses; i++) {
278: for (int j = 0; j < m_NumAttValues[parent]; j++) {
279: PriorsClassParent[i][j] = m_ClassAttAttCounts[i][pIndex
280: + j][pIndex + j]
281: / m_NumInstances;
282: }
283: }
284:
285: for (int i = 0; i < m_NumClasses; i++) {
286: for (int j = 0; j < m_NumAttValues[parent]; j++) {
287: for (int k = 0; k < m_NumAttValues[son]; k++) {
288: PriorsClassParentSon[i][j][k] = m_ClassAttAttCounts[i][pIndex
289: + j][sIndex + k]
290: / m_NumInstances;
291: }
292: }
293: }
294:
295: for (int i = 0; i < m_NumClasses; i++) {
296: for (int j = 0; j < m_NumAttValues[parent]; j++) {
297: for (int k = 0; k < m_NumAttValues[son]; k++) {
298: CondiMutualInfo += PriorsClassParentSon[i][j][k]
299: * log2(PriorsClassParentSon[i][j][k]
300: * PriorsClass[i],
301: PriorsClassParent[i][j]
302: * PriorsClassSon[i][k]);
303: }
304: }
305: }
306: return CondiMutualInfo;
307: }
308:
309: /**
310: * compute the logarithm whose base is 2.
311: *
312: * @param x numerator of the fraction.
313: * @param y denominator of the fraction.
314: * @return the natual logarithm of this fraction.
315: */
316: private double log2(double x, double y) {
317:
318: if (x < 1e-6 || y < 1e-6)
319: return 0.0;
320: else
321: return Math.log(x / y) / Math.log(2);
322: }
323:
324: /**
325: * Calculates the class membership probabilities for the given test instance
326: *
327: * @param instance the instance to be classified
328: * @return predicted class probability distribution
329: * @exception Exception if there is a problem generating the prediction
330: */
331: public double[] distributionForInstance(Instance instance)
332: throws Exception {
333:
334: //Definition of local variables
335: double[] probs = new double[m_NumClasses];
336: int sIndex;
337: double prob;
338: double condiMutualInfoSum;
339:
340: // store instance's att values in an int array
341: int[] attIndex = new int[m_NumAttributes];
342: for (int att = 0; att < m_NumAttributes; att++) {
343: if (att == m_ClassIndex)
344: attIndex[att] = -1;
345: else
346: attIndex[att] = m_StartAttIndex[att]
347: + (int) instance.value(att);
348: }
349:
350: // calculate probabilities for each possible class value
351: for (int classVal = 0; classVal < m_NumClasses; classVal++) {
352: probs[classVal] = (m_ClassCounts[classVal] + 1.0 / m_NumClasses)
353: / (m_NumInstances + 1.0);
354: for (int son = 0; son < m_NumAttributes; son++) {
355: if (attIndex[son] == -1)
356: continue;
357: sIndex = attIndex[son];
358: attIndex[son] = -1;
359: prob = 0;
360: condiMutualInfoSum = 0;
361: for (int parent = 0; parent < m_NumAttributes; parent++) {
362: if (attIndex[parent] == -1)
363: continue;
364: condiMutualInfoSum += m_condiMutualInfo[son][parent];
365: prob += m_condiMutualInfo[son][parent]
366: * (m_ClassAttAttCounts[classVal][attIndex[parent]][sIndex] + 1.0 / m_NumAttValues[son])
367: / (m_ClassAttAttCounts[classVal][attIndex[parent]][attIndex[parent]] + 1.0);
368: }
369: if (condiMutualInfoSum > 0) {
370: prob = prob / condiMutualInfoSum;
371: probs[classVal] *= prob;
372: } else {
373: prob = (m_ClassAttAttCounts[classVal][sIndex][sIndex] + 1.0 / m_NumAttValues[son])
374: / (m_ClassCounts[classVal] + 1.0);
375: probs[classVal] *= prob;
376: }
377: attIndex[son] = sIndex;
378: }
379: }
380: Utils.normalize(probs);
381: return probs;
382: }
383:
384: /**
385: * returns a string representation of the classifier
386: *
387: * @return a representation of the classifier
388: */
389: public String toString() {
390:
391: return "HNB (Hidden Naive Bayes)";
392: }
393:
394: /**
395: * Main method for testing this class.
396: *
397: * @param args the options
398: */
399: public static void main(String[] args) {
400: runClassifier(new HNB(), args);
401: }
402: }
|