001: /*
002: * This program is free software; you can redistribute it and/or modify
003: * it under the terms of the GNU General Public License as published by
004: * the Free Software Foundation; either version 2 of the License, or
005: * (at your option) any later version.
006: *
007: * This program is distributed in the hope that it will be useful,
008: * but WITHOUT ANY WARRANTY; without even the implied warranty of
009: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
010: * GNU General Public License for more details.
011: *
012: * You should have received a copy of the GNU General Public License
013: * along with this program; if not, write to the Free Software
014: * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
015: */
016:
017: /*
018: * Distribution.java
019: * Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
020: *
021: */
022:
023: package weka.classifiers.trees.j48;
024:
025: import weka.core.Instance;
026: import weka.core.Instances;
027: import weka.core.Utils;
028:
029: import java.io.Serializable;
030: import java.util.Enumeration;
031:
032: /**
033: * Class for handling a distribution of class values.
034: *
035: * @author Eibe Frank (eibe@cs.waikato.ac.nz)
036: * @version $Revision: 1.11 $
037: */
038: public class Distribution implements Cloneable, Serializable {
039:
040: /** for serialization */
041: private static final long serialVersionUID = 8526859638230806576L;
042:
043: /** Weight of instances per class per bag. */
044: private double m_perClassPerBag[][];
045:
046: /** Weight of instances per bag. */
047: private double m_perBag[];
048:
049: /** Weight of instances per class. */
050: private double m_perClass[];
051:
052: /** Total weight of instances. */
053: private double totaL;
054:
055: /**
056: * Creates and initializes a new distribution.
057: */
058: public Distribution(int numBags, int numClasses) {
059:
060: int i;
061:
062: m_perClassPerBag = new double[numBags][0];
063: m_perBag = new double[numBags];
064: m_perClass = new double[numClasses];
065: for (i = 0; i < numBags; i++)
066: m_perClassPerBag[i] = new double[numClasses];
067: totaL = 0;
068: }
069:
070: /**
071: * Creates and initializes a new distribution using the given
072: * array. WARNING: it just copies a reference to this array.
073: */
074: public Distribution(double[][] table) {
075:
076: int i, j;
077:
078: m_perClassPerBag = table;
079: m_perBag = new double[table.length];
080: m_perClass = new double[table[0].length];
081: for (i = 0; i < table.length; i++)
082: for (j = 0; j < table[i].length; j++) {
083: m_perBag[i] += table[i][j];
084: m_perClass[j] += table[i][j];
085: totaL += table[i][j];
086: }
087: }
088:
089: /**
090: * Creates a distribution with only one bag according
091: * to instances in source.
092: *
093: * @exception Exception if something goes wrong
094: */
095: public Distribution(Instances source) throws Exception {
096:
097: m_perClassPerBag = new double[1][0];
098: m_perBag = new double[1];
099: totaL = 0;
100: m_perClass = new double[source.numClasses()];
101: m_perClassPerBag[0] = new double[source.numClasses()];
102: Enumeration enu = source.enumerateInstances();
103: while (enu.hasMoreElements())
104: add(0, (Instance) enu.nextElement());
105: }
106:
107: /**
108: * Creates a distribution according to given instances and
109: * split model.
110: *
111: * @exception Exception if something goes wrong
112: */
113:
114: public Distribution(Instances source,
115: ClassifierSplitModel modelToUse) throws Exception {
116:
117: int index;
118: Instance instance;
119: double[] weights;
120:
121: m_perClassPerBag = new double[modelToUse.numSubsets()][0];
122: m_perBag = new double[modelToUse.numSubsets()];
123: totaL = 0;
124: m_perClass = new double[source.numClasses()];
125: for (int i = 0; i < modelToUse.numSubsets(); i++)
126: m_perClassPerBag[i] = new double[source.numClasses()];
127: Enumeration enu = source.enumerateInstances();
128: while (enu.hasMoreElements()) {
129: instance = (Instance) enu.nextElement();
130: index = modelToUse.whichSubset(instance);
131: if (index != -1)
132: add(index, instance);
133: else {
134: weights = modelToUse.weights(instance);
135: addWeights(instance, weights);
136: }
137: }
138: }
139:
140: /**
141: * Creates distribution with only one bag by merging all
142: * bags of given distribution.
143: */
144: public Distribution(Distribution toMerge) {
145:
146: totaL = toMerge.totaL;
147: m_perClass = new double[toMerge.numClasses()];
148: System.arraycopy(toMerge.m_perClass, 0, m_perClass, 0, toMerge
149: .numClasses());
150: m_perClassPerBag = new double[1][0];
151: m_perClassPerBag[0] = new double[toMerge.numClasses()];
152: System.arraycopy(toMerge.m_perClass, 0, m_perClassPerBag[0], 0,
153: toMerge.numClasses());
154: m_perBag = new double[1];
155: m_perBag[0] = totaL;
156: }
157:
158: /**
159: * Creates distribution with two bags by merging all bags apart of
160: * the indicated one.
161: */
162: public Distribution(Distribution toMerge, int index) {
163:
164: int i;
165:
166: totaL = toMerge.totaL;
167: m_perClass = new double[toMerge.numClasses()];
168: System.arraycopy(toMerge.m_perClass, 0, m_perClass, 0, toMerge
169: .numClasses());
170: m_perClassPerBag = new double[2][0];
171: m_perClassPerBag[0] = new double[toMerge.numClasses()];
172: System.arraycopy(toMerge.m_perClassPerBag[index], 0,
173: m_perClassPerBag[0], 0, toMerge.numClasses());
174: m_perClassPerBag[1] = new double[toMerge.numClasses()];
175: for (i = 0; i < toMerge.numClasses(); i++)
176: m_perClassPerBag[1][i] = toMerge.m_perClass[i]
177: - m_perClassPerBag[0][i];
178: m_perBag = new double[2];
179: m_perBag[0] = toMerge.m_perBag[index];
180: m_perBag[1] = totaL - m_perBag[0];
181: }
182:
183: /**
184: * Returns number of non-empty bags of distribution.
185: */
186: public final int actualNumBags() {
187:
188: int returnValue = 0;
189: int i;
190:
191: for (i = 0; i < m_perBag.length; i++)
192: if (Utils.gr(m_perBag[i], 0))
193: returnValue++;
194:
195: return returnValue;
196: }
197:
198: /**
199: * Returns number of classes actually occuring in distribution.
200: */
201: public final int actualNumClasses() {
202:
203: int returnValue = 0;
204: int i;
205:
206: for (i = 0; i < m_perClass.length; i++)
207: if (Utils.gr(m_perClass[i], 0))
208: returnValue++;
209:
210: return returnValue;
211: }
212:
213: /**
214: * Returns number of classes actually occuring in given bag.
215: */
216: public final int actualNumClasses(int bagIndex) {
217:
218: int returnValue = 0;
219: int i;
220:
221: for (i = 0; i < m_perClass.length; i++)
222: if (Utils.gr(m_perClassPerBag[bagIndex][i], 0))
223: returnValue++;
224:
225: return returnValue;
226: }
227:
228: /**
229: * Adds given instance to given bag.
230: *
231: * @exception Exception if something goes wrong
232: */
233: public final void add(int bagIndex, Instance instance)
234: throws Exception {
235:
236: int classIndex;
237: double weight;
238:
239: classIndex = (int) instance.classValue();
240: weight = instance.weight();
241: m_perClassPerBag[bagIndex][classIndex] = m_perClassPerBag[bagIndex][classIndex]
242: + weight;
243: m_perBag[bagIndex] = m_perBag[bagIndex] + weight;
244: m_perClass[classIndex] = m_perClass[classIndex] + weight;
245: totaL = totaL + weight;
246: }
247:
248: /**
249: * Subtracts given instance from given bag.
250: *
251: * @exception Exception if something goes wrong
252: */
253: public final void sub(int bagIndex, Instance instance)
254: throws Exception {
255:
256: int classIndex;
257: double weight;
258:
259: classIndex = (int) instance.classValue();
260: weight = instance.weight();
261: m_perClassPerBag[bagIndex][classIndex] = m_perClassPerBag[bagIndex][classIndex]
262: - weight;
263: m_perBag[bagIndex] = m_perBag[bagIndex] - weight;
264: m_perClass[classIndex] = m_perClass[classIndex] - weight;
265: totaL = totaL - weight;
266: }
267:
268: /**
269: * Adds counts to given bag.
270: */
271: public final void add(int bagIndex, double[] counts) {
272:
273: double sum = Utils.sum(counts);
274:
275: for (int i = 0; i < counts.length; i++)
276: m_perClassPerBag[bagIndex][i] += counts[i];
277: m_perBag[bagIndex] = m_perBag[bagIndex] + sum;
278: for (int i = 0; i < counts.length; i++)
279: m_perClass[i] = m_perClass[i] + counts[i];
280: totaL = totaL + sum;
281: }
282:
283: /**
284: * Adds all instances with unknown values for given attribute, weighted
285: * according to frequency of instances in each bag.
286: *
287: * @exception Exception if something goes wrong
288: */
289: public final void addInstWithUnknown(Instances source, int attIndex)
290: throws Exception {
291:
292: double[] probs;
293: double weight, newWeight;
294: int classIndex;
295: Instance instance;
296: int j;
297:
298: probs = new double[m_perBag.length];
299: for (j = 0; j < m_perBag.length; j++) {
300: if (Utils.eq(totaL, 0)) {
301: probs[j] = 1.0 / probs.length;
302: } else {
303: probs[j] = m_perBag[j] / totaL;
304: }
305: }
306: Enumeration enu = source.enumerateInstances();
307: while (enu.hasMoreElements()) {
308: instance = (Instance) enu.nextElement();
309: if (instance.isMissing(attIndex)) {
310: classIndex = (int) instance.classValue();
311: weight = instance.weight();
312: m_perClass[classIndex] = m_perClass[classIndex]
313: + weight;
314: totaL = totaL + weight;
315: for (j = 0; j < m_perBag.length; j++) {
316: newWeight = probs[j] * weight;
317: m_perClassPerBag[j][classIndex] = m_perClassPerBag[j][classIndex]
318: + newWeight;
319: m_perBag[j] = m_perBag[j] + newWeight;
320: }
321: }
322: }
323: }
324:
325: /**
326: * Adds all instances in given range to given bag.
327: *
328: * @exception Exception if something goes wrong
329: */
330: public final void addRange(int bagIndex, Instances source,
331: int startIndex, int lastPlusOne) throws Exception {
332:
333: double sumOfWeights = 0;
334: int classIndex;
335: Instance instance;
336: int i;
337:
338: for (i = startIndex; i < lastPlusOne; i++) {
339: instance = (Instance) source.instance(i);
340: classIndex = (int) instance.classValue();
341: sumOfWeights = sumOfWeights + instance.weight();
342: m_perClassPerBag[bagIndex][classIndex] += instance.weight();
343: m_perClass[classIndex] += instance.weight();
344: }
345: m_perBag[bagIndex] += sumOfWeights;
346: totaL += sumOfWeights;
347: }
348:
349: /**
350: * Adds given instance to all bags weighting it according to given weights.
351: *
352: * @exception Exception if something goes wrong
353: */
354: public final void addWeights(Instance instance, double[] weights)
355: throws Exception {
356:
357: int classIndex;
358: int i;
359:
360: classIndex = (int) instance.classValue();
361: for (i = 0; i < m_perBag.length; i++) {
362: double weight = instance.weight() * weights[i];
363: m_perClassPerBag[i][classIndex] = m_perClassPerBag[i][classIndex]
364: + weight;
365: m_perBag[i] = m_perBag[i] + weight;
366: m_perClass[classIndex] = m_perClass[classIndex] + weight;
367: totaL = totaL + weight;
368: }
369: }
370:
371: /**
372: * Checks if at least two bags contain a minimum number of instances.
373: */
374: public final boolean check(double minNoObj) {
375:
376: int counter = 0;
377: int i;
378:
379: for (i = 0; i < m_perBag.length; i++)
380: if (Utils.grOrEq(m_perBag[i], minNoObj))
381: counter++;
382: if (counter > 1)
383: return true;
384: else
385: return false;
386: }
387:
388: /**
389: * Clones distribution (Deep copy of distribution).
390: */
391: public final Object clone() {
392:
393: int i, j;
394:
395: Distribution newDistribution = new Distribution(
396: m_perBag.length, m_perClass.length);
397: for (i = 0; i < m_perBag.length; i++) {
398: newDistribution.m_perBag[i] = m_perBag[i];
399: for (j = 0; j < m_perClass.length; j++)
400: newDistribution.m_perClassPerBag[i][j] = m_perClassPerBag[i][j];
401: }
402: for (j = 0; j < m_perClass.length; j++)
403: newDistribution.m_perClass[j] = m_perClass[j];
404: newDistribution.totaL = totaL;
405:
406: return newDistribution;
407: }
408:
409: /**
410: * Deletes given instance from given bag.
411: *
412: * @exception Exception if something goes wrong
413: */
414: public final void del(int bagIndex, Instance instance)
415: throws Exception {
416:
417: int classIndex;
418: double weight;
419:
420: classIndex = (int) instance.classValue();
421: weight = instance.weight();
422: m_perClassPerBag[bagIndex][classIndex] = m_perClassPerBag[bagIndex][classIndex]
423: - weight;
424: m_perBag[bagIndex] = m_perBag[bagIndex] - weight;
425: m_perClass[classIndex] = m_perClass[classIndex] - weight;
426: totaL = totaL - weight;
427: }
428:
429: /**
430: * Deletes all instances in given range from given bag.
431: *
432: * @exception Exception if something goes wrong
433: */
434: public final void delRange(int bagIndex, Instances source,
435: int startIndex, int lastPlusOne) throws Exception {
436:
437: double sumOfWeights = 0;
438: int classIndex;
439: Instance instance;
440: int i;
441:
442: for (i = startIndex; i < lastPlusOne; i++) {
443: instance = (Instance) source.instance(i);
444: classIndex = (int) instance.classValue();
445: sumOfWeights = sumOfWeights + instance.weight();
446: m_perClassPerBag[bagIndex][classIndex] -= instance.weight();
447: m_perClass[classIndex] -= instance.weight();
448: }
449: m_perBag[bagIndex] -= sumOfWeights;
450: totaL -= sumOfWeights;
451: }
452:
453: /**
454: * Prints distribution.
455: */
456:
457: public final String dumpDistribution() {
458:
459: StringBuffer text;
460: int i, j;
461:
462: text = new StringBuffer();
463: for (i = 0; i < m_perBag.length; i++) {
464: text.append("Bag num " + i + "\n");
465: for (j = 0; j < m_perClass.length; j++)
466: text.append("Class num " + j + " "
467: + m_perClassPerBag[i][j] + "\n");
468: }
469: return text.toString();
470: }
471:
472: /**
473: * Sets all counts to zero.
474: */
475: public final void initialize() {
476:
477: for (int i = 0; i < m_perClass.length; i++)
478: m_perClass[i] = 0;
479: for (int i = 0; i < m_perBag.length; i++)
480: m_perBag[i] = 0;
481: for (int i = 0; i < m_perBag.length; i++)
482: for (int j = 0; j < m_perClass.length; j++)
483: m_perClassPerBag[i][j] = 0;
484: totaL = 0;
485: }
486:
487: /**
488: * Returns matrix with distribution of class values.
489: */
490: public final double[][] matrix() {
491:
492: return m_perClassPerBag;
493: }
494:
495: /**
496: * Returns index of bag containing maximum number of instances.
497: */
498: public final int maxBag() {
499:
500: double max;
501: int maxIndex;
502: int i;
503:
504: max = 0;
505: maxIndex = -1;
506: for (i = 0; i < m_perBag.length; i++)
507: if (Utils.grOrEq(m_perBag[i], max)) {
508: max = m_perBag[i];
509: maxIndex = i;
510: }
511: return maxIndex;
512: }
513:
514: /**
515: * Returns class with highest frequency over all bags.
516: */
517: public final int maxClass() {
518:
519: double maxCount = 0;
520: int maxIndex = 0;
521: int i;
522:
523: for (i = 0; i < m_perClass.length; i++)
524: if (Utils.gr(m_perClass[i], maxCount)) {
525: maxCount = m_perClass[i];
526: maxIndex = i;
527: }
528:
529: return maxIndex;
530: }
531:
532: /**
533: * Returns class with highest frequency for given bag.
534: */
535: public final int maxClass(int index) {
536:
537: double maxCount = 0;
538: int maxIndex = 0;
539: int i;
540:
541: if (Utils.gr(m_perBag[index], 0)) {
542: for (i = 0; i < m_perClass.length; i++)
543: if (Utils.gr(m_perClassPerBag[index][i], maxCount)) {
544: maxCount = m_perClassPerBag[index][i];
545: maxIndex = i;
546: }
547: return maxIndex;
548: } else
549: return maxClass();
550: }
551:
552: /**
553: * Returns number of bags.
554: */
555: public final int numBags() {
556:
557: return m_perBag.length;
558: }
559:
560: /**
561: * Returns number of classes.
562: */
563: public final int numClasses() {
564:
565: return m_perClass.length;
566: }
567:
568: /**
569: * Returns perClass(maxClass()).
570: */
571: public final double numCorrect() {
572:
573: return m_perClass[maxClass()];
574: }
575:
576: /**
577: * Returns perClassPerBag(index,maxClass(index)).
578: */
579: public final double numCorrect(int index) {
580:
581: return m_perClassPerBag[index][maxClass(index)];
582: }
583:
584: /**
585: * Returns total-numCorrect().
586: */
587: public final double numIncorrect() {
588:
589: return totaL - numCorrect();
590: }
591:
592: /**
593: * Returns perBag(index)-numCorrect(index).
594: */
595: public final double numIncorrect(int index) {
596:
597: return m_perBag[index] - numCorrect(index);
598: }
599:
600: /**
601: * Returns number of (possibly fractional) instances of given class in
602: * given bag.
603: */
604: public final double perClassPerBag(int bagIndex, int classIndex) {
605:
606: return m_perClassPerBag[bagIndex][classIndex];
607: }
608:
609: /**
610: * Returns number of (possibly fractional) instances in given bag.
611: */
612: public final double perBag(int bagIndex) {
613:
614: return m_perBag[bagIndex];
615: }
616:
617: /**
618: * Returns number of (possibly fractional) instances of given class.
619: */
620: public final double perClass(int classIndex) {
621:
622: return m_perClass[classIndex];
623: }
624:
625: /**
626: * Returns relative frequency of class over all bags with
627: * Laplace correction.
628: */
629: public final double laplaceProb(int classIndex) {
630:
631: return (m_perClass[classIndex] + 1)
632: / (totaL + (double) m_perClass.length);
633: }
634:
635: /**
636: * Returns relative frequency of class for given bag.
637: */
638: public final double laplaceProb(int classIndex, int intIndex) {
639:
640: if (Utils.gr(m_perBag[intIndex], 0))
641: return (m_perClassPerBag[intIndex][classIndex] + 1.0)
642: / (m_perBag[intIndex] + (double) m_perClass.length);
643: else
644: return laplaceProb(classIndex);
645:
646: }
647:
648: /**
649: * Returns relative frequency of class over all bags.
650: */
651: public final double prob(int classIndex) {
652:
653: if (!Utils.eq(totaL, 0)) {
654: return m_perClass[classIndex] / totaL;
655: } else {
656: return 0;
657: }
658: }
659:
660: /**
661: * Returns relative frequency of class for given bag.
662: */
663: public final double prob(int classIndex, int intIndex) {
664:
665: if (Utils.gr(m_perBag[intIndex], 0))
666: return m_perClassPerBag[intIndex][classIndex]
667: / m_perBag[intIndex];
668: else
669: return prob(classIndex);
670: }
671:
672: /**
673: * Subtracts the given distribution from this one. The results
674: * has only one bag.
675: */
676: public final Distribution subtract(Distribution toSubstract) {
677:
678: Distribution newDist = new Distribution(1, m_perClass.length);
679:
680: newDist.m_perBag[0] = totaL - toSubstract.totaL;
681: newDist.totaL = newDist.m_perBag[0];
682: for (int i = 0; i < m_perClass.length; i++) {
683: newDist.m_perClassPerBag[0][i] = m_perClass[i]
684: - toSubstract.m_perClass[i];
685: newDist.m_perClass[i] = newDist.m_perClassPerBag[0][i];
686: }
687: return newDist;
688: }
689:
690: /**
691: * Returns total number of (possibly fractional) instances.
692: */
693: public final double total() {
694:
695: return totaL;
696: }
697:
698: /**
699: * Shifts given instance from one bag to another one.
700: *
701: * @exception Exception if something goes wrong
702: */
703: public final void shift(int from, int to, Instance instance)
704: throws Exception {
705:
706: int classIndex;
707: double weight;
708:
709: classIndex = (int) instance.classValue();
710: weight = instance.weight();
711: m_perClassPerBag[from][classIndex] -= weight;
712: m_perClassPerBag[to][classIndex] += weight;
713: m_perBag[from] -= weight;
714: m_perBag[to] += weight;
715: }
716:
717: /**
718: * Shifts all instances in given range from one bag to another one.
719: *
720: * @exception Exception if something goes wrong
721: */
722: public final void shiftRange(int from, int to, Instances source,
723: int startIndex, int lastPlusOne) throws Exception {
724:
725: int classIndex;
726: double weight;
727: Instance instance;
728: int i;
729:
730: for (i = startIndex; i < lastPlusOne; i++) {
731: instance = (Instance) source.instance(i);
732: classIndex = (int) instance.classValue();
733: weight = instance.weight();
734: m_perClassPerBag[from][classIndex] -= weight;
735: m_perClassPerBag[to][classIndex] += weight;
736: m_perBag[from] -= weight;
737: m_perBag[to] += weight;
738: }
739: }
740: }
|