001: /*
002: * This program is free software; you can redistribute it and/or modify
003: * it under the terms of the GNU General Public License as published by
004: * the Free Software Foundation; either version 2 of the License, or
005: * (at your option) any later version.
006: *
007: * This program is distributed in the hope that it will be useful,
008: * but WITHOUT ANY WARRANTY; without even the implied warranty of
009: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
010: * GNU General Public License for more details.
011: *
012: * You should have received a copy of the GNU General Public License
013: * along with this program; if not, write to the Free Software
014: * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
015: */
016:
017: /*
018: * GeneticSearch.java
019: * Copyright (C) 2004 University of Waikato, Hamilton, New Zealand
020: *
021: */
022:
023: package weka.classifiers.bayes.net.search.local;
024:
025: import weka.classifiers.bayes.BayesNet;
026: import weka.classifiers.bayes.net.ParentSet;
027: import weka.core.Instances;
028: import weka.core.Option;
029: import weka.core.Utils;
030:
031: import java.util.Enumeration;
032: import java.util.Random;
033: import java.util.Vector;
034:
035: /**
036: <!-- globalinfo-start -->
037: * This Bayes Network learning algorithm uses genetic search for finding a well scoring Bayes network structure. Genetic search works by having a population of Bayes network structures and allow them to mutate and apply cross over to get offspring. The best network structure found during the process is returned.
038: * <p/>
039: <!-- globalinfo-end -->
040: *
041: <!-- options-start -->
042: * Valid options are: <p/>
043: *
044: * <pre> -L <integer>
045: * Population size</pre>
046: *
047: * <pre> -A <integer>
048: * Descendant population size</pre>
049: *
050: * <pre> -U <integer>
051: * Number of runs</pre>
052: *
053: * <pre> -M
054: * Use mutation.
055: * (default true)</pre>
056: *
057: * <pre> -C
058: * Use cross-over.
059: * (default true)</pre>
060: *
061: * <pre> -O
062: * Use tournament selection (true) or maximum subpopulatin (false).
063: * (default false)</pre>
064: *
065: * <pre> -R <seed>
066: * Random number seed</pre>
067: *
068: * <pre> -mbc
069: * Applies a Markov Blanket correction to the network structure,
070: * after a network structure is learned. This ensures that all
071: * nodes in the network are part of the Markov blanket of the
072: * classifier node.</pre>
073: *
074: * <pre> -S [BAYES|MDL|ENTROPY|AIC|CROSS_CLASSIC|CROSS_BAYES]
075: * Score type (BAYES, BDeu, MDL, ENTROPY and AIC)</pre>
076: *
077: <!-- options-end -->
078: *
079: * @author Remco Bouckaert (rrb@xm.co.nz)
080: * @version $Revision: 1.4 $
081: */
082: public class GeneticSearch extends LocalScoreSearchAlgorithm {
083:
084: /** for serialization */
085: static final long serialVersionUID = -7037070678911459757L;
086:
087: /** number of runs **/
088: int m_nRuns = 10;
089:
090: /** size of population **/
091: int m_nPopulationSize = 10;
092:
093: /** size of descendant population **/
094: int m_nDescendantPopulationSize = 100;
095:
096: /** use cross-over? **/
097: boolean m_bUseCrossOver = true;
098:
099: /** use mutation? **/
100: boolean m_bUseMutation = true;
101:
102: /** use tournament selection or take best sub-population **/
103: boolean m_bUseTournamentSelection = false;
104:
105: /** random number seed **/
106: int m_nSeed = 1;
107:
108: /** random number generator **/
109: Random m_random = null;
110:
111: /** used in BayesNetRepresentation for efficiently determining
112: * whether a number is square
113: */
114: static boolean[] g_bIsSquare;
115:
116: class BayesNetRepresentation {
117: /** number of nodes in network **/
118: int m_nNodes = 0;
119:
120: /** bit representation of parent sets
121: * m_bits[iTail + iHead * m_nNodes] represents arc iTail->iHead
122: */
123: boolean[] m_bits;
124:
125: /** score of represented network structure **/
126: double m_fScore = 0.0f;
127:
128: /**
129: * return score of represented network structure
130: *
131: * @return the score
132: */
133: public double getScore() {
134: return m_fScore;
135: } // getScore
136:
137: /**
138: * c'tor
139: *
140: * @param nNodes the number of nodes
141: */
142: BayesNetRepresentation(int nNodes) {
143: m_nNodes = nNodes;
144: } // c'tor
145:
146: /** initialize with a random structure by randomly placing
147: * m_nNodes arcs.
148: */
149: public void randomInit() {
150: do {
151: m_bits = new boolean[m_nNodes * m_nNodes];
152: for (int i = 0; i < m_nNodes; i++) {
153: int iPos;
154: do {
155: iPos = m_random.nextInt(m_nNodes * m_nNodes);
156: } while (isSquare(iPos));
157: m_bits[iPos] = true;
158: }
159: } while (hasCycles());
160: calcScore();
161: }
162:
163: /** calculate score of current network representation
164: * As a side effect, the parent sets are set
165: */
166: void calcScore() {
167: // clear current network
168: for (int iNode = 0; iNode < m_nNodes; iNode++) {
169: ParentSet parentSet = m_BayesNet.getParentSet(iNode);
170: while (parentSet.getNrOfParents() > 0) {
171: parentSet.deleteLastParent(m_BayesNet.m_Instances);
172: }
173: }
174: // insert arrows
175: for (int iNode = 0; iNode < m_nNodes; iNode++) {
176: ParentSet parentSet = m_BayesNet.getParentSet(iNode);
177: for (int iNode2 = 0; iNode2 < m_nNodes; iNode2++) {
178: if (m_bits[iNode2 + iNode * m_nNodes]) {
179: parentSet.addParent(iNode2,
180: m_BayesNet.m_Instances);
181: }
182: }
183: }
184: // calc score
185: m_fScore = 0.0;
186: for (int iNode = 0; iNode < m_nNodes; iNode++) {
187: m_fScore += calcNodeScore(iNode);
188: }
189: } // calcScore
190:
191: /** check whether there are cycles in the network
192: *
193: * @return true if a cycle is found, false otherwise
194: */
195: public boolean hasCycles() {
196: // check for cycles
197: boolean[] bDone = new boolean[m_nNodes];
198: for (int iNode = 0; iNode < m_nNodes; iNode++) {
199:
200: // find a node for which all parents are 'done'
201: boolean bFound = false;
202:
203: for (int iNode2 = 0; !bFound && iNode2 < m_nNodes; iNode2++) {
204: if (!bDone[iNode2]) {
205: boolean bHasNoParents = true;
206: for (int iParent = 0; iParent < m_nNodes; iParent++) {
207: if (m_bits[iParent + iNode2 * m_nNodes]
208: && !bDone[iParent]) {
209: bHasNoParents = false;
210: }
211: }
212: if (bHasNoParents) {
213: bDone[iNode2] = true;
214: bFound = true;
215: }
216: }
217: }
218: if (!bFound) {
219: return true;
220: }
221: }
222: return false;
223: } // hasCycles
224:
225: /** create clone of current object
226: * @return cloned object
227: */
228: BayesNetRepresentation copy() {
229: BayesNetRepresentation b = new BayesNetRepresentation(
230: m_nNodes);
231: b.m_bits = new boolean[m_bits.length];
232: for (int i = 0; i < m_nNodes * m_nNodes; i++) {
233: b.m_bits[i] = m_bits[i];
234: }
235: b.m_fScore = m_fScore;
236: return b;
237: } // copy
238:
239: /** Apply mutation operation to BayesNet
240: * Calculate score and as a side effect sets BayesNet parent sets.
241: */
242: void mutate() {
243: // flip a bit
244: do {
245: int iBit;
246: do {
247: iBit = m_random.nextInt(m_nNodes * m_nNodes);
248: } while (isSquare(iBit));
249:
250: m_bits[iBit] = !m_bits[iBit];
251: } while (hasCycles());
252:
253: calcScore();
254: } // mutate
255:
256: /** Apply cross-over operation to BayesNet
257: * Calculate score and as a side effect sets BayesNet parent sets.
258: * @param other BayesNetRepresentation to cross over with
259: */
260: void crossOver(BayesNetRepresentation other) {
261: boolean[] bits = new boolean[m_bits.length];
262: for (int i = 0; i < m_bits.length; i++) {
263: bits[i] = m_bits[i];
264: }
265: int iCrossOverPoint = m_bits.length;
266: do {
267: // restore to original state
268: for (int i = iCrossOverPoint; i < m_bits.length; i++) {
269: m_bits[i] = bits[i];
270: }
271: // take all bits from cross-over points onwards
272: iCrossOverPoint = m_random.nextInt(m_bits.length);
273: for (int i = iCrossOverPoint; i < m_bits.length; i++) {
274: m_bits[i] = other.m_bits[i];
275: }
276: } while (hasCycles());
277: calcScore();
278: } // crossOver
279:
280: /** check if number is square and initialize g_bIsSquare structure
281: * if necessary
282: * @param nNum number to check (should be below m_nNodes * m_nNodes)
283: * @return true if number is square
284: */
285: boolean isSquare(int nNum) {
286: if (g_bIsSquare == null || g_bIsSquare.length < nNum) {
287: g_bIsSquare = new boolean[m_nNodes * m_nNodes];
288: for (int i = 0; i < m_nNodes; i++) {
289: g_bIsSquare[i * m_nNodes + i] = true;
290: }
291: }
292: return g_bIsSquare[nNum];
293: } // isSquare
294: } // class BayesNetRepresentation
295:
296: /**
297: * search determines the network structure/graph of the network
298: * with a genetic search algorithm.
299: *
300: * @param bayesNet the network to use
301: * @param instances the data to use
302: * @throws Exception if population size doesn fit or neither cross-over or mutation was chosen
303: */
304: protected void search(BayesNet bayesNet, Instances instances)
305: throws Exception {
306: // sanity check
307: if (getDescendantPopulationSize() < getPopulationSize()) {
308: throw new Exception(
309: "Descendant PopulationSize should be at least Population Size");
310: }
311: if (!getUseCrossOver() && !getUseMutation()) {
312: throw new Exception(
313: "At least one of mutation or cross-over should be used");
314: }
315:
316: m_random = new Random(m_nSeed);
317:
318: // keeps track of best structure found so far
319: BayesNet bestBayesNet;
320: // keeps track of score pf best structure found so far
321: double fBestScore = 0.0;
322: for (int iAttribute = 0; iAttribute < instances.numAttributes(); iAttribute++) {
323: fBestScore += calcNodeScore(iAttribute);
324: }
325:
326: // initialize bestBayesNet
327: bestBayesNet = new BayesNet();
328: bestBayesNet.m_Instances = instances;
329: bestBayesNet.initStructure();
330: copyParentSets(bestBayesNet, bayesNet);
331:
332: // initialize population
333: BayesNetRepresentation[] population = new BayesNetRepresentation[getPopulationSize()];
334: for (int i = 0; i < getPopulationSize(); i++) {
335: population[i] = new BayesNetRepresentation(instances
336: .numAttributes());
337: population[i].randomInit();
338: if (population[i].getScore() > fBestScore) {
339: copyParentSets(bestBayesNet, bayesNet);
340: fBestScore = population[i].getScore();
341:
342: }
343: }
344:
345: // go do the search
346: for (int iRun = 0; iRun < m_nRuns; iRun++) {
347: // create descendants
348: BayesNetRepresentation[] descendantPopulation = new BayesNetRepresentation[getDescendantPopulationSize()];
349: for (int i = 0; i < getDescendantPopulationSize(); i++) {
350: descendantPopulation[i] = population[m_random
351: .nextInt(getPopulationSize())].copy();
352: if (getUseMutation()) {
353: if (getUseCrossOver() && m_random.nextBoolean()) {
354: descendantPopulation[i]
355: .crossOver(population[m_random
356: .nextInt(getPopulationSize())]);
357: } else {
358: descendantPopulation[i].mutate();
359: }
360: } else {
361: // use crossover
362: descendantPopulation[i]
363: .crossOver(population[m_random
364: .nextInt(getPopulationSize())]);
365: }
366:
367: if (descendantPopulation[i].getScore() > fBestScore) {
368: copyParentSets(bestBayesNet, bayesNet);
369: fBestScore = descendantPopulation[i].getScore();
370: }
371: }
372: // select new population
373: boolean[] bSelected = new boolean[getDescendantPopulationSize()];
374: for (int i = 0; i < getPopulationSize(); i++) {
375: int iSelected = 0;
376: if (m_bUseTournamentSelection) {
377: // use tournament selection
378: iSelected = m_random
379: .nextInt(getDescendantPopulationSize());
380: while (bSelected[iSelected]) {
381: iSelected = (iSelected + 1)
382: % getDescendantPopulationSize();
383: }
384: int iSelected2 = m_random
385: .nextInt(getDescendantPopulationSize());
386: while (bSelected[iSelected2]) {
387: iSelected2 = (iSelected2 + 1)
388: % getDescendantPopulationSize();
389: }
390: if (descendantPopulation[iSelected2].getScore() > descendantPopulation[iSelected]
391: .getScore()) {
392: iSelected = iSelected2;
393: }
394: } else {
395: // find best scoring network in population
396: while (bSelected[iSelected]) {
397: iSelected++;
398: }
399: double fScore = descendantPopulation[iSelected]
400: .getScore();
401: for (int j = 0; j < getDescendantPopulationSize(); j++) {
402: if (!bSelected[j]
403: && descendantPopulation[j].getScore() > fScore) {
404: fScore = descendantPopulation[j].getScore();
405: iSelected = j;
406: }
407: }
408: }
409: population[i] = descendantPopulation[iSelected];
410: bSelected[iSelected] = true;
411: }
412: }
413:
414: // restore current network to best network
415: copyParentSets(bayesNet, bestBayesNet);
416:
417: // free up memory
418: bestBayesNet = null;
419: } // search
420:
421: /** copyParentSets copies parent sets of source to dest BayesNet
422: * @param dest destination network
423: * @param source source network
424: */
425: void copyParentSets(BayesNet dest, BayesNet source) {
426: int nNodes = source.getNrOfNodes();
427: // clear parent set first
428: for (int iNode = 0; iNode < nNodes; iNode++) {
429: dest.getParentSet(iNode).copy(source.getParentSet(iNode));
430: }
431: } // CopyParentSets
432:
433: /**
434: * @return number of runs
435: */
436: public int getRuns() {
437: return m_nRuns;
438: } // getRuns
439:
440: /**
441: * Sets the number of runs
442: * @param nRuns The number of runs to set
443: */
444: public void setRuns(int nRuns) {
445: m_nRuns = nRuns;
446: } // setRuns
447:
448: /**
449: * Returns an enumeration describing the available options.
450: *
451: * @return an enumeration of all the available options.
452: */
453: public Enumeration listOptions() {
454: Vector newVector = new Vector(7);
455:
456: newVector.addElement(new Option("\tPopulation size", "L", 1,
457: "-L <integer>"));
458: newVector.addElement(new Option("\tDescendant population size",
459: "A", 1, "-A <integer>"));
460: newVector.addElement(new Option("\tNumber of runs", "U", 1,
461: "-U <integer>"));
462: newVector.addElement(new Option(
463: "\tUse mutation.\n\t(default true)", "M", 0, "-M"));
464: newVector.addElement(new Option(
465: "\tUse cross-over.\n\t(default true)", "C", 0, "-C"));
466: newVector
467: .addElement(new Option(
468: "\tUse tournament selection (true) or maximum subpopulatin (false).\n\t(default false)",
469: "O", 0, "-O"));
470: newVector.addElement(new Option("\tRandom number seed", "R", 1,
471: "-R <seed>"));
472:
473: Enumeration enu = super .listOptions();
474: while (enu.hasMoreElements()) {
475: newVector.addElement(enu.nextElement());
476: }
477: return newVector.elements();
478: } // listOptions
479:
480: /**
481: * Parses a given list of options. <p/>
482: *
483: <!-- options-start -->
484: * Valid options are: <p/>
485: *
486: * <pre> -L <integer>
487: * Population size</pre>
488: *
489: * <pre> -A <integer>
490: * Descendant population size</pre>
491: *
492: * <pre> -U <integer>
493: * Number of runs</pre>
494: *
495: * <pre> -M
496: * Use mutation.
497: * (default true)</pre>
498: *
499: * <pre> -C
500: * Use cross-over.
501: * (default true)</pre>
502: *
503: * <pre> -O
504: * Use tournament selection (true) or maximum subpopulatin (false).
505: * (default false)</pre>
506: *
507: * <pre> -R <seed>
508: * Random number seed</pre>
509: *
510: * <pre> -mbc
511: * Applies a Markov Blanket correction to the network structure,
512: * after a network structure is learned. This ensures that all
513: * nodes in the network are part of the Markov blanket of the
514: * classifier node.</pre>
515: *
516: * <pre> -S [BAYES|MDL|ENTROPY|AIC|CROSS_CLASSIC|CROSS_BAYES]
517: * Score type (BAYES, BDeu, MDL, ENTROPY and AIC)</pre>
518: *
519: <!-- options-end -->
520: *
521: * @param options the list of options as an array of strings
522: * @throws Exception if an option is not supported
523: */
524: public void setOptions(String[] options) throws Exception {
525: String sPopulationSize = Utils.getOption('L', options);
526: if (sPopulationSize.length() != 0) {
527: setPopulationSize(Integer.parseInt(sPopulationSize));
528: }
529: String sDescendantPopulationSize = Utils
530: .getOption('A', options);
531: if (sDescendantPopulationSize.length() != 0) {
532: setDescendantPopulationSize(Integer
533: .parseInt(sDescendantPopulationSize));
534: }
535: String sRuns = Utils.getOption('U', options);
536: if (sRuns.length() != 0) {
537: setRuns(Integer.parseInt(sRuns));
538: }
539: String sSeed = Utils.getOption('R', options);
540: if (sSeed.length() != 0) {
541: setSeed(Integer.parseInt(sSeed));
542: }
543: setUseMutation(Utils.getFlag('M', options));
544: setUseCrossOver(Utils.getFlag('C', options));
545: setUseTournamentSelection(Utils.getFlag('O', options));
546:
547: super .setOptions(options);
548: } // setOptions
549:
550: /**
551: * Gets the current settings of the search algorithm.
552: *
553: * @return an array of strings suitable for passing to setOptions
554: */
555: public String[] getOptions() {
556: String[] super Options = super .getOptions();
557: String[] options = new String[11 + super Options.length];
558: int current = 0;
559:
560: options[current++] = "-L";
561: options[current++] = "" + getPopulationSize();
562:
563: options[current++] = "-A";
564: options[current++] = "" + getDescendantPopulationSize();
565:
566: options[current++] = "-U";
567: options[current++] = "" + getRuns();
568:
569: options[current++] = "-R";
570: options[current++] = "" + getSeed();
571:
572: if (getUseMutation()) {
573: options[current++] = "-M";
574: }
575: if (getUseCrossOver()) {
576: options[current++] = "-C";
577: }
578: if (getUseTournamentSelection()) {
579: options[current++] = "-O";
580: }
581:
582: // insert options from parent class
583: for (int iOption = 0; iOption < super Options.length; iOption++) {
584: options[current++] = super Options[iOption];
585: }
586:
587: // Fill up rest with empty strings, not nulls!
588: while (current < options.length) {
589: options[current++] = "";
590: }
591: return options;
592: } // getOptions
593:
594: /**
595: * @return whether cross-over is used
596: */
597: public boolean getUseCrossOver() {
598: return m_bUseCrossOver;
599: }
600:
601: /**
602: * @return whether mutation is used
603: */
604: public boolean getUseMutation() {
605: return m_bUseMutation;
606: }
607:
608: /**
609: * @return descendant population size
610: */
611: public int getDescendantPopulationSize() {
612: return m_nDescendantPopulationSize;
613: }
614:
615: /**
616: * @return population size
617: */
618: public int getPopulationSize() {
619: return m_nPopulationSize;
620: }
621:
622: /**
623: * @param bUseCrossOver sets whether cross-over is used
624: */
625: public void setUseCrossOver(boolean bUseCrossOver) {
626: m_bUseCrossOver = bUseCrossOver;
627: }
628:
629: /**
630: * @param bUseMutation sets whether mutation is used
631: */
632: public void setUseMutation(boolean bUseMutation) {
633: m_bUseMutation = bUseMutation;
634: }
635:
636: /**
637: * @return whether Tournament Selection (true) or Maximum Sub-Population (false) should be used
638: */
639: public boolean getUseTournamentSelection() {
640: return m_bUseTournamentSelection;
641: }
642:
643: /**
644: * @param bUseTournamentSelection sets whether Tournament Selection or Maximum Sub-Population should be used
645: */
646: public void setUseTournamentSelection(
647: boolean bUseTournamentSelection) {
648: m_bUseTournamentSelection = bUseTournamentSelection;
649: }
650:
651: /**
652: * @param iDescendantPopulationSize sets descendant population size
653: */
654: public void setDescendantPopulationSize(
655: int iDescendantPopulationSize) {
656: m_nDescendantPopulationSize = iDescendantPopulationSize;
657: }
658:
659: /**
660: * @param iPopulationSize sets population size
661: */
662: public void setPopulationSize(int iPopulationSize) {
663: m_nPopulationSize = iPopulationSize;
664: }
665:
666: /**
667: * @return random number seed
668: */
669: public int getSeed() {
670: return m_nSeed;
671: } // getSeed
672:
673: /**
674: * Sets the random number seed
675: * @param nSeed The number of the seed to set
676: */
677: public void setSeed(int nSeed) {
678: m_nSeed = nSeed;
679: } // setSeed
680:
681: /**
682: * This will return a string describing the classifier.
683: * @return The string.
684: */
685: public String globalInfo() {
686: return "This Bayes Network learning algorithm uses genetic search for finding a well scoring "
687: + "Bayes network structure. Genetic search works by having a population of Bayes network structures "
688: + "and allow them to mutate and apply cross over to get offspring. The best network structure "
689: + "found during the process is returned.";
690: } // globalInfo
691:
692: /**
693: * @return a string to describe the Runs option.
694: */
695: public String runsTipText() {
696: return "Sets the number of generations of Bayes network structure populations.";
697: } // runsTipText
698:
699: /**
700: * @return a string to describe the Seed option.
701: */
702: public String seedTipText() {
703: return "Initialization value for random number generator."
704: + " Setting the seed allows replicability of experiments.";
705: } // seedTipText
706:
707: /**
708: * @return a string to describe the Population Size option.
709: */
710: public String populationSizeTipText() {
711: return "Sets the size of the population of network structures that is selected each generation.";
712: } // populationSizeTipText
713:
714: /**
715: * @return a string to describe the Descendant Population Size option.
716: */
717: public String descendantPopulationSizeTipText() {
718: return "Sets the size of the population of descendants that is created each generation.";
719: } // descendantPopulationSizeTipText
720:
721: /**
722: * @return a string to describe the Use Mutation option.
723: */
724: public String useMutationTipText() {
725: return "Determines whether mutation is allowed. Mutation flips a bit in the bit "
726: + "representation of the network structure. At least one of mutation or cross-over "
727: + "should be used.";
728: } // useMutationTipText
729:
730: /**
731: * @return a string to describe the Use Cross-Over option.
732: */
733: public String useCrossOverTipText() {
734: return "Determines whether cross-over is allowed. Cross over combined the bit "
735: + "representations of network structure by taking a random first k bits of one"
736: + "and adding the remainder of the other. At least one of mutation or cross-over "
737: + "should be used.";
738: } // useCrossOverTipText
739:
740: /**
741: * @return a string to describe the Use Tournament Selection option.
742: */
743: public String useTournamentSelectionTipText() {
744: return "Determines the method of selecting a population. When set to true, tournament "
745: + "selection is used (pick two at random and the highest is allowed to continue). "
746: + "When set to false, the top scoring network structures are selected.";
747: } // useTournamentSelectionTipText
748: } // GeneticSearch
|