001: /*
002: * This program is free software; you can redistribute it and/or modify
003: * it under the terms of the GNU General Public License as published by
004: * the Free Software Foundation; either version 2 of the License, or
005: * (at your option) any later version.
006: *
007: * This program is distributed in the hope that it will be useful,
008: * but WITHOUT ANY WARRANTY; without even the implied warranty of
009: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
010: * GNU General Public License for more details.
011: *
012: * You should have received a copy of the GNU General Public License
013: * along with this program; if not, write to the Free Software
014: * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
015: */
016:
017: /*
018: * LocalScoreSearchAlgorithm.java
019: * Copyright (C) 2004 University of Waikato, Hamilton, New Zealand
020: *
021: */
022:
023: package weka.classifiers.bayes.net.search.local;
024:
025: import weka.classifiers.bayes.BayesNet;
026: import weka.classifiers.bayes.net.ParentSet;
027: import weka.classifiers.bayes.net.search.SearchAlgorithm;
028: import weka.core.Instances;
029: import weka.core.Instance;
030: import weka.core.Utils;
031: import weka.core.Statistics;
032: import weka.core.Tag;
033: import weka.core.Option;
034: import weka.core.SelectedTag;
035:
036: import java.util.Vector;
037: import java.util.Enumeration;
038:
039: /**
040: <!-- globalinfo-start -->
041: * The ScoreBasedSearchAlgorithm class supports Bayes net structure search algorithms that are based on maximizing scores (as opposed to for example conditional independence based search algorithms).
042: * <p/>
043: <!-- globalinfo-end -->
044: *
045: <!-- options-start -->
046: * Valid options are: <p/>
047: *
048: * <pre> -mbc
049: * Applies a Markov Blanket correction to the network structure,
050: * after a network structure is learned. This ensures that all
051: * nodes in the network are part of the Markov blanket of the
052: * classifier node.</pre>
053: *
054: * <pre> -S [BAYES|MDL|ENTROPY|AIC|CROSS_CLASSIC|CROSS_BAYES]
055: * Score type (BAYES, BDeu, MDL, ENTROPY and AIC)</pre>
056: *
057: <!-- options-end -->
058: *
059: * @author Remco Bouckaert
060: * @version $Revision: 1.7 $
061: */
062: public class LocalScoreSearchAlgorithm extends SearchAlgorithm {
063:
064: /** for serialization */
065: static final long serialVersionUID = 3325995552474190374L;
066:
067: /** points to Bayes network for which a structure is searched for **/
068: BayesNet m_BayesNet;
069:
070: /**
071: * default constructor
072: */
073: public LocalScoreSearchAlgorithm() {
074: } // c'tor
075:
076: /**
077: * constructor
078: *
079: * @param bayesNet the network
080: * @param instances the data
081: */
082: public LocalScoreSearchAlgorithm(BayesNet bayesNet,
083: Instances instances) {
084: m_BayesNet = bayesNet;
085: // m_Instances = instances;
086: } // c'tor
087:
088: /**
089: * Holds prior on count
090: */
091: double m_fAlpha = 0.5;
092:
093: /** the score types */
094: public static final Tag[] TAGS_SCORE_TYPE = {
095: new Tag(Scoreable.BAYES, "BAYES"),
096: new Tag(Scoreable.BDeu, "BDeu"),
097: new Tag(Scoreable.MDL, "MDL"),
098: new Tag(Scoreable.ENTROPY, "ENTROPY"),
099: new Tag(Scoreable.AIC, "AIC") };
100:
101: /**
102: * Holds the score type used to measure quality of network
103: */
104: int m_nScoreType = Scoreable.BAYES;
105:
106: /**
107: * logScore returns the log of the quality of a network
108: * (e.g. the posterior probability of the network, or the MDL
109: * value).
110: * @param nType score type (Bayes, MDL, etc) to calculate score with
111: * @return log score.
112: */
113: public double logScore(int nType) {
114: if (m_BayesNet.m_Distributions == null) {
115: return 0;
116: }
117: if (nType < 0) {
118: nType = m_nScoreType;
119: }
120:
121: double fLogScore = 0.0;
122:
123: Instances instances = m_BayesNet.m_Instances;
124:
125: for (int iAttribute = 0; iAttribute < instances.numAttributes(); iAttribute++) {
126: int nCardinality = m_BayesNet.getParentSet(iAttribute)
127: .getCardinalityOfParents();
128: for (int iParent = 0; iParent < nCardinality; iParent++) {
129: fLogScore += ((Scoreable) m_BayesNet.m_Distributions[iAttribute][iParent])
130: .logScore(nType, nCardinality);
131: }
132:
133: switch (nType) {
134: case (Scoreable.MDL): {
135: fLogScore -= 0.5
136: * m_BayesNet.getParentSet(iAttribute)
137: .getCardinalityOfParents()
138: * (instances.attribute(iAttribute).numValues() - 1)
139: * Math.log(instances.numInstances());
140: }
141: break;
142: case (Scoreable.AIC): {
143: fLogScore -= m_BayesNet.getParentSet(iAttribute)
144: .getCardinalityOfParents()
145: * (instances.attribute(iAttribute).numValues() - 1);
146: }
147: break;
148: }
149: }
150:
151: return fLogScore;
152: } // logScore
153:
154: /**
155: * buildStructure determines the network structure/graph of the network
156: * with the K2 algorithm, restricted by its initial structure (which can
157: * be an empty graph, or a Naive Bayes graph.
158: *
159: * @param bayesNet the network
160: * @param instances the data to use
161: * @throws Exception if something goes wrong
162: */
163: public void buildStructure(BayesNet bayesNet, Instances instances)
164: throws Exception {
165: m_BayesNet = bayesNet;
166: super .buildStructure(bayesNet, instances);
167: } // buildStructure
168:
169: /**
170: * Calc Node Score for given parent set
171: *
172: * @param nNode node for which the score is calculate
173: * @return log score
174: */
175: public double calcNodeScore(int nNode) {
176: if (m_BayesNet.getUseADTree() && m_BayesNet.getADTree() != null) {
177: return calcNodeScoreADTree(nNode);
178: } else {
179: return calcNodeScorePlain(nNode);
180: }
181: }
182:
183: /**
184: * helper function for CalcNodeScore above using the ADTree data structure
185: *
186: * @param nNode node for which the score is calculate
187: * @return log score
188: */
189: private double calcNodeScoreADTree(int nNode) {
190: Instances instances = m_BayesNet.m_Instances;
191: ParentSet oParentSet = m_BayesNet.getParentSet(nNode);
192: // get set of parents, insert iNode
193: int nNrOfParents = oParentSet.getNrOfParents();
194: int[] nNodes = new int[nNrOfParents + 1];
195: for (int iParent = 0; iParent < nNrOfParents; iParent++) {
196: nNodes[iParent] = oParentSet.getParent(iParent);
197: }
198: nNodes[nNrOfParents] = nNode;
199:
200: // calculate offsets
201: int[] nOffsets = new int[nNrOfParents + 1];
202: int nOffset = 1;
203: nOffsets[nNrOfParents] = 1;
204: nOffset *= instances.attribute(nNode).numValues();
205: for (int iNode = nNrOfParents - 1; iNode >= 0; iNode--) {
206: nOffsets[iNode] = nOffset;
207: nOffset *= instances.attribute(nNodes[iNode]).numValues();
208: }
209:
210: // sort nNodes & offsets
211: for (int iNode = 1; iNode < nNodes.length; iNode++) {
212: int iNode2 = iNode;
213: while (iNode2 > 0 && nNodes[iNode2] < nNodes[iNode2 - 1]) {
214: int h = nNodes[iNode2];
215: nNodes[iNode2] = nNodes[iNode2 - 1];
216: nNodes[iNode2 - 1] = h;
217: h = nOffsets[iNode2];
218: nOffsets[iNode2] = nOffsets[iNode2 - 1];
219: nOffsets[iNode2 - 1] = h;
220: iNode2--;
221: }
222: }
223:
224: // get counts from ADTree
225: int nCardinality = oParentSet.getCardinalityOfParents();
226: int numValues = instances.attribute(nNode).numValues();
227: int[] nCounts = new int[nCardinality * numValues];
228: //if (nNrOfParents > 1) {
229:
230: m_BayesNet.getADTree().getCounts(nCounts, nNodes, nOffsets, 0,
231: 0, false);
232:
233: return calcScoreOfCounts(nCounts, nCardinality, numValues,
234: instances);
235: } // CalcNodeScore
236:
237: private double calcNodeScorePlain(int nNode) {
238: Instances instances = m_BayesNet.m_Instances;
239: ParentSet oParentSet = m_BayesNet.getParentSet(nNode);
240:
241: // determine cardinality of parent set & reserve space for frequency counts
242: int nCardinality = oParentSet.getCardinalityOfParents();
243: int numValues = instances.attribute(nNode).numValues();
244: int[] nCounts = new int[nCardinality * numValues];
245:
246: // initialize (don't need this?)
247: for (int iParent = 0; iParent < nCardinality * numValues; iParent++) {
248: nCounts[iParent] = 0;
249: }
250:
251: // estimate distributions
252: Enumeration enumInsts = instances.enumerateInstances();
253:
254: while (enumInsts.hasMoreElements()) {
255: Instance instance = (Instance) enumInsts.nextElement();
256:
257: // updateClassifier;
258: double iCPT = 0;
259:
260: for (int iParent = 0; iParent < oParentSet.getNrOfParents(); iParent++) {
261: int nParent = oParentSet.getParent(iParent);
262:
263: iCPT = iCPT * instances.attribute(nParent).numValues()
264: + instance.value(nParent);
265: }
266:
267: nCounts[numValues * ((int) iCPT)
268: + (int) instance.value(nNode)]++;
269: }
270:
271: return calcScoreOfCounts(nCounts, nCardinality, numValues,
272: instances);
273: } // CalcNodeScore
274:
275: /**
276: * utility function used by CalcScore and CalcNodeScore to determine the score
277: * based on observed frequencies.
278: *
279: * @param nCounts array with observed frequencies
280: * @param nCardinality ardinality of parent set
281: * @param numValues number of values a node can take
282: * @param instances to calc score with
283: * @return log score
284: */
285: protected double calcScoreOfCounts(int[] nCounts, int nCardinality,
286: int numValues, Instances instances) {
287:
288: // calculate scores using the distributions
289: double fLogScore = 0.0;
290:
291: for (int iParent = 0; iParent < nCardinality; iParent++) {
292: switch (m_nScoreType) {
293:
294: case (Scoreable.BAYES): {
295: double nSumOfCounts = 0;
296:
297: for (int iSymbol = 0; iSymbol < numValues; iSymbol++) {
298: if (m_fAlpha
299: + nCounts[iParent * numValues + iSymbol] != 0) {
300: fLogScore += Statistics
301: .lnGamma(m_fAlpha
302: + nCounts[iParent * numValues
303: + iSymbol]);
304: nSumOfCounts += m_fAlpha
305: + nCounts[iParent * numValues + iSymbol];
306: }
307: }
308:
309: if (nSumOfCounts != 0) {
310: fLogScore -= Statistics.lnGamma(nSumOfCounts);
311: }
312:
313: if (m_fAlpha != 0) {
314: fLogScore -= numValues
315: * Statistics.lnGamma(m_fAlpha);
316: fLogScore += Statistics.lnGamma(numValues
317: * m_fAlpha);
318: }
319: }
320:
321: break;
322: case (Scoreable.BDeu): {
323: double nSumOfCounts = 0;
324:
325: for (int iSymbol = 0; iSymbol < numValues; iSymbol++) {
326: if (m_fAlpha
327: + nCounts[iParent * numValues + iSymbol] != 0) {
328: fLogScore += Statistics
329: .lnGamma(1.0
330: / (numValues * nCardinality)
331: + nCounts[iParent * numValues
332: + iSymbol]);
333: nSumOfCounts += 1.0
334: / (numValues * nCardinality)
335: + nCounts[iParent * numValues + iSymbol];
336: }
337: }
338: fLogScore -= Statistics.lnGamma(nSumOfCounts);
339:
340: fLogScore -= numValues
341: * Statistics
342: .lnGamma(1.0 / (numValues * nCardinality));
343: fLogScore += Statistics.lnGamma(1.0 / nCardinality);
344: }
345: break;
346:
347: case (Scoreable.MDL):
348:
349: case (Scoreable.AIC):
350:
351: case (Scoreable.ENTROPY): {
352: double nSumOfCounts = 0;
353:
354: for (int iSymbol = 0; iSymbol < numValues; iSymbol++) {
355: nSumOfCounts += nCounts[iParent * numValues
356: + iSymbol];
357: }
358:
359: for (int iSymbol = 0; iSymbol < numValues; iSymbol++) {
360: if (nCounts[iParent * numValues + iSymbol] > 0) {
361: fLogScore += nCounts[iParent * numValues
362: + iSymbol]
363: * Math.log(nCounts[iParent * numValues
364: + iSymbol]
365: / nSumOfCounts);
366: }
367: }
368: }
369:
370: break;
371:
372: default: {
373: }
374: }
375: }
376:
377: switch (m_nScoreType) {
378:
379: case (Scoreable.MDL): {
380: fLogScore -= 0.5 * nCardinality * (numValues - 1)
381: * Math.log(instances.numInstances());
382:
383: // it seems safe to assume that numInstances>0 here
384: }
385:
386: break;
387:
388: case (Scoreable.AIC): {
389: fLogScore -= nCardinality * (numValues - 1);
390: }
391:
392: break;
393: }
394:
395: return fLogScore;
396: } // CalcNodeScore
397:
398: protected double calcScoreOfCounts2(int[][] nCounts,
399: int nCardinality, int numValues, Instances instances) {
400:
401: // calculate scores using the distributions
402: double fLogScore = 0.0;
403:
404: for (int iParent = 0; iParent < nCardinality; iParent++) {
405: switch (m_nScoreType) {
406:
407: case (Scoreable.BAYES): {
408: double nSumOfCounts = 0;
409:
410: for (int iSymbol = 0; iSymbol < numValues; iSymbol++) {
411: if (m_fAlpha + nCounts[iParent][iSymbol] != 0) {
412: fLogScore += Statistics.lnGamma(m_fAlpha
413: + nCounts[iParent][iSymbol]);
414: nSumOfCounts += m_fAlpha
415: + nCounts[iParent][iSymbol];
416: }
417: }
418:
419: if (nSumOfCounts != 0) {
420: fLogScore -= Statistics.lnGamma(nSumOfCounts);
421: }
422:
423: if (m_fAlpha != 0) {
424: fLogScore -= numValues
425: * Statistics.lnGamma(m_fAlpha);
426: fLogScore += Statistics.lnGamma(numValues
427: * m_fAlpha);
428: }
429: }
430:
431: break;
432:
433: case (Scoreable.BDeu): {
434: double nSumOfCounts = 0;
435:
436: for (int iSymbol = 0; iSymbol < numValues; iSymbol++) {
437: if (m_fAlpha
438: + nCounts[iParent * numValues][iSymbol] != 0) {
439: fLogScore += Statistics
440: .lnGamma(1.0
441: / (numValues * nCardinality)
442: + nCounts[iParent * numValues][iSymbol]);
443: nSumOfCounts += 1.0
444: / (numValues * nCardinality)
445: + nCounts[iParent * numValues][iSymbol];
446: }
447: }
448: fLogScore -= Statistics.lnGamma(nSumOfCounts);
449:
450: fLogScore -= numValues
451: * Statistics
452: .lnGamma(1.0 / (nCardinality * numValues));
453: fLogScore += Statistics.lnGamma(1.0 / nCardinality);
454: }
455: break;
456:
457: case (Scoreable.MDL):
458:
459: case (Scoreable.AIC):
460:
461: case (Scoreable.ENTROPY): {
462: double nSumOfCounts = 0;
463:
464: for (int iSymbol = 0; iSymbol < numValues; iSymbol++) {
465: nSumOfCounts += nCounts[iParent][iSymbol];
466: }
467:
468: for (int iSymbol = 0; iSymbol < numValues; iSymbol++) {
469: if (nCounts[iParent][iSymbol] > 0) {
470: fLogScore += nCounts[iParent][iSymbol]
471: * Math.log(nCounts[iParent][iSymbol]
472: / nSumOfCounts);
473: }
474: }
475: }
476:
477: break;
478:
479: default: {
480: }
481: }
482: }
483:
484: switch (m_nScoreType) {
485:
486: case (Scoreable.MDL): {
487: fLogScore -= 0.5 * nCardinality * (numValues - 1)
488: * Math.log(instances.numInstances());
489:
490: // it seems safe to assume that numInstances>0 here
491: }
492:
493: break;
494:
495: case (Scoreable.AIC): {
496: fLogScore -= nCardinality * (numValues - 1);
497: }
498:
499: break;
500: }
501:
502: return fLogScore;
503: } // CalcNodeScore
504:
505: /**
506: * Calc Node Score With AddedParent
507: *
508: * @param nNode node for which the score is calculate
509: * @param nCandidateParent candidate parent to add to the existing parent set
510: * @return log score
511: */
512: public double calcScoreWithExtraParent(int nNode,
513: int nCandidateParent) {
514: ParentSet oParentSet = m_BayesNet.getParentSet(nNode);
515:
516: // sanity check: nCandidateParent should not be in parent set already
517: if (oParentSet.contains(nCandidateParent)) {
518: return -1e100;
519: }
520:
521: // set up candidate parent
522: oParentSet.addParent(nCandidateParent, m_BayesNet.m_Instances);
523:
524: // calculate the score
525: double logScore = calcNodeScore(nNode);
526:
527: // delete temporarily added parent
528: oParentSet.deleteLastParent(m_BayesNet.m_Instances);
529:
530: return logScore;
531: } // CalcScoreWithExtraParent
532:
533: /**
534: * Calc Node Score With Parent Deleted
535: *
536: * @param nNode node for which the score is calculate
537: * @param nCandidateParent candidate parent to delete from the existing parent set
538: * @return log score
539: */
540: public double calcScoreWithMissingParent(int nNode,
541: int nCandidateParent) {
542: ParentSet oParentSet = m_BayesNet.getParentSet(nNode);
543:
544: // sanity check: nCandidateParent should be in parent set already
545: if (!oParentSet.contains(nCandidateParent)) {
546: return -1e100;
547: }
548:
549: // set up candidate parent
550: int iParent = oParentSet.deleteParent(nCandidateParent,
551: m_BayesNet.m_Instances);
552:
553: // calculate the score
554: double logScore = calcNodeScore(nNode);
555:
556: // restore temporarily deleted parent
557: oParentSet.addParent(nCandidateParent, iParent,
558: m_BayesNet.m_Instances);
559:
560: return logScore;
561: } // CalcScoreWithMissingParent
562:
563: /**
564: * set quality measure to be used in searching for networks.
565: *
566: * @param newScoreType the new score type
567: */
568: public void setScoreType(SelectedTag newScoreType) {
569: if (newScoreType.getTags() == TAGS_SCORE_TYPE) {
570: m_nScoreType = newScoreType.getSelectedTag().getID();
571: }
572: }
573:
574: /**
575: * get quality measure to be used in searching for networks.
576: * @return quality measure
577: */
578: public SelectedTag getScoreType() {
579: return new SelectedTag(m_nScoreType, TAGS_SCORE_TYPE);
580: }
581:
582: /**
583: *
584: * @param bMarkovBlanketClassifier
585: */
586: public void setMarkovBlanketClassifier(
587: boolean bMarkovBlanketClassifier) {
588: super .setMarkovBlanketClassifier(bMarkovBlanketClassifier);
589: }
590:
591: /**
592: *
593: * @return
594: */
595: public boolean getMarkovBlanketClassifier() {
596: return super .getMarkovBlanketClassifier();
597: }
598:
599: /**
600: * Returns an enumeration describing the available options
601: *
602: * @return an enumeration of all the available options
603: */
604: public Enumeration listOptions() {
605: Vector newVector = new Vector();
606:
607: newVector
608: .addElement(new Option(
609: "\tApplies a Markov Blanket correction to the network structure, \n"
610: + "\tafter a network structure is learned. This ensures that all \n"
611: + "\tnodes in the network are part of the Markov blanket of the \n"
612: + "\tclassifier node.", "mbc", 0,
613: "-mbc"));
614:
615: newVector
616: .addElement(new Option(
617: "\tScore type (BAYES, BDeu, MDL, ENTROPY and AIC)",
618: "S", 1,
619: "-S [BAYES|MDL|ENTROPY|AIC|CROSS_CLASSIC|CROSS_BAYES]"));
620:
621: return newVector.elements();
622: } // listOptions
623:
624: /**
625: * Parses a given list of options. <p/>
626: *
627: <!-- options-start -->
628: * Valid options are: <p/>
629: *
630: * <pre> -mbc
631: * Applies a Markov Blanket correction to the network structure,
632: * after a network structure is learned. This ensures that all
633: * nodes in the network are part of the Markov blanket of the
634: * classifier node.</pre>
635: *
636: * <pre> -S [BAYES|MDL|ENTROPY|AIC|CROSS_CLASSIC|CROSS_BAYES]
637: * Score type (BAYES, BDeu, MDL, ENTROPY and AIC)</pre>
638: *
639: <!-- options-end -->
640: *
641: * @param options the list of options as an array of strings
642: * @throws Exception if an option is not supported
643: */
644: public void setOptions(String[] options) throws Exception {
645:
646: setMarkovBlanketClassifier(Utils.getFlag("mbc", options));
647:
648: String sScore = Utils.getOption('S', options);
649:
650: if (sScore.compareTo("BAYES") == 0) {
651: setScoreType(new SelectedTag(Scoreable.BAYES,
652: TAGS_SCORE_TYPE));
653: }
654: if (sScore.compareTo("BDeu") == 0) {
655: setScoreType(new SelectedTag(Scoreable.BDeu,
656: TAGS_SCORE_TYPE));
657: }
658: if (sScore.compareTo("MDL") == 0) {
659: setScoreType(new SelectedTag(Scoreable.MDL, TAGS_SCORE_TYPE));
660: }
661: if (sScore.compareTo("ENTROPY") == 0) {
662: setScoreType(new SelectedTag(Scoreable.ENTROPY,
663: TAGS_SCORE_TYPE));
664: }
665: if (sScore.compareTo("AIC") == 0) {
666: setScoreType(new SelectedTag(Scoreable.AIC, TAGS_SCORE_TYPE));
667: }
668: } // setOptions
669:
670: /**
671: * Gets the current settings of the search algorithm.
672: *
673: * @return an array of strings suitable for passing to setOptions
674: */
675: public String[] getOptions() {
676: String[] super Options = super .getOptions();
677: String[] options = new String[3 + super Options.length];
678: int current = 0;
679:
680: if (getMarkovBlanketClassifier())
681: options[current++] = "-mbc";
682:
683: options[current++] = "-S";
684:
685: switch (m_nScoreType) {
686:
687: case (Scoreable.BAYES):
688: options[current++] = "BAYES";
689: break;
690:
691: case (Scoreable.BDeu):
692: options[current++] = "BDeu";
693: break;
694:
695: case (Scoreable.MDL):
696: options[current++] = "MDL";
697: break;
698:
699: case (Scoreable.ENTROPY):
700: options[current++] = "ENTROPY";
701:
702: break;
703:
704: case (Scoreable.AIC):
705: options[current++] = "AIC";
706: break;
707: }
708:
709: // insert options from parent class
710: for (int iOption = 0; iOption < super Options.length; iOption++) {
711: options[current++] = super Options[iOption];
712: }
713:
714: // Fill up rest with empty strings, not nulls!
715: while (current < options.length) {
716: options[current++] = "";
717: }
718:
719: return options;
720: } // getOptions
721:
722: /**
723: * @return a string to describe the ScoreType option.
724: */
725: public String scoreTypeTipText() {
726: return "The score type determines the measure used to judge the quality of a"
727: + " network structure. It can be one of Bayes, BDeu, Minimum Description Length (MDL),"
728: + " Akaike Information Criterion (AIC), and Entropy.";
729: }
730:
731: /**
732: * @return a string to describe the MarkovBlanketClassifier option.
733: */
734: public String markovBlanketClassifierTipText() {
735: return super .markovBlanketClassifierTipText();
736: }
737:
738: /**
739: * This will return a string describing the search algorithm.
740: * @return The string.
741: */
742: public String globalInfo() {
743: return "The ScoreBasedSearchAlgorithm class supports Bayes net "
744: + "structure search algorithms that are based on maximizing "
745: + "scores (as opposed to for example conditional independence "
746: + "based search algorithms).";
747: } // globalInfo
748: }
|