001: /*
002: * This program is free software; you can redistribute it and/or modify
003: * it under the terms of the GNU General Public License as published by
004: * the Free Software Foundation; either version 2 of the License, or
005: * (at your option) any later version.
006: *
007: * This program is distributed in the hope that it will be useful,
008: * but WITHOUT ANY WARRANTY; without even the implied warranty of
009: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
010: * GNU General Public License for more details.
011: *
012: * You should have received a copy of the GNU General Public License
013: * along with this program; if not, write to the Free Software
014: * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
015: */
016:
017: /*
018: * CostMatrix.java
019: * Copyright (C) 2006 University of Waikato, Hamilton, New Zealand
020: *
021: */
022:
023: package weka.classifiers;
024:
025: import weka.core.AttributeExpression;
026: import weka.core.Instance;
027: import weka.core.Instances;
028: import weka.core.Matrix;
029: import weka.core.Utils;
030:
031: import java.io.LineNumberReader;
032: import java.io.Reader;
033: import java.io.Serializable;
034: import java.io.StreamTokenizer;
035: import java.io.Writer;
036: import java.util.Random;
037: import java.util.StringTokenizer;
038:
039: /**
040: * Class for storing and manipulating a misclassification cost matrix.
041: * The element at position i,j in the matrix is the penalty for classifying
042: * an instance of class j as class i. Cost values can be fixed or
043: * computed on a per-instance basis (cost sensitive evaluation only)
044: * from the value of an attribute or an expression involving
045: * attribute(s).
046: *
047: * @author Mark Hall
048: * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
049: * @version $Revision: 1.15 $
050: */
051: public class CostMatrix implements Serializable {
052:
053: /** for serialization */
054: private static final long serialVersionUID = -1973792250544554965L;
055:
056: private int m_size;
057:
058: /** [rows][columns] */
059: protected Object[][] m_matrix;
060:
061: /** The deafult file extension for cost matrix files */
062: public static String FILE_EXTENSION = ".cost";
063:
064: /**
065: * Creates a default cost matrix of a particular size.
066: * All diagonal values will be 0 and all non-diagonal values 1.
067: *
068: * @param numOfClasses the number of classes that the cost matrix holds.
069: */
070: public CostMatrix(int numOfClasses) {
071: m_size = numOfClasses;
072: initialize();
073: }
074:
075: /**
076: * Creates a cost matrix that is a copy of another.
077: *
078: * @param toCopy the matrix to copy.
079: */
080: public CostMatrix(CostMatrix toCopy) {
081: this (toCopy.size());
082:
083: for (int i = 0; i < m_size; i++) {
084: for (int j = 0; j < m_size; j++) {
085: setCell(i, j, toCopy.getCell(i, j));
086: }
087: }
088: }
089:
090: /**
091: * Initializes the matrix
092: */
093: public void initialize() {
094: m_matrix = new Object[m_size][m_size];
095: for (int i = 0; i < m_size; i++) {
096: for (int j = 0; j < m_size; j++) {
097: setCell(i, j, i == j ? new Double(0.0)
098: : new Double(1.0));
099: }
100: }
101: }
102:
103: /**
104: * The number of rows (and columns)
105: * @return the size of the matrix
106: */
107: public int size() {
108: return m_size;
109: }
110:
111: /**
112: * Same as size
113: * @return the number of columns
114: */
115: public int numColumns() {
116: return size();
117: }
118:
119: /**
120: * Same as size
121: * @return the number of rows
122: */
123: public int numRows() {
124: return size();
125: }
126:
127: private boolean replaceStrings() throws Exception {
128: boolean nonDouble = false;
129:
130: for (int i = 0; i < m_size; i++) {
131: for (int j = 0; j < m_size; j++) {
132: if (getCell(i, j) instanceof String) {
133: AttributeExpression temp = new AttributeExpression();
134: temp.convertInfixToPostfix((String) getCell(i, j));
135: setCell(i, j, temp);
136: nonDouble = true;
137: } else if (getCell(i, j) instanceof AttributeExpression) {
138: nonDouble = true;
139: }
140: }
141: }
142:
143: return nonDouble;
144: }
145:
146: /**
147: * Applies the cost matrix to a set of instances. If a random number generator is
148: * supplied the instances will be resampled, otherwise they will be rewighted.
149: * Adapted from code once sitting in Instances.java
150: *
151: * @param data the instances to reweight.
152: * @param random a random number generator for resampling, if null then instances are
153: * rewighted.
154: * @return a new dataset reflecting the cost of misclassification.
155: * @exception Exception if the data has no class or the matrix in inappropriate.
156: */
157: public Instances applyCostMatrix(Instances data, Random random)
158: throws Exception {
159:
160: if (replaceStrings()) {
161: // could reweight in the two class case
162: throw new Exception(
163: "Can't resample/reweight instances using "
164: + "non-fixed cost values!");
165: }
166:
167: double sumOfWeightFactors = 0, sumOfMissClassWeights, sumOfWeights;
168: double[] weightOfInstancesInClass, weightFactor, weightOfInstances;
169: Instances newData;
170:
171: if (data.classIndex() < 0) {
172: throw new Exception("Class index is not set!");
173: }
174:
175: if (size() != data.numClasses()) {
176: throw new Exception("Misclassification cost matrix has "
177: + "wrong format!");
178: }
179:
180: weightFactor = new double[data.numClasses()];
181: weightOfInstancesInClass = new double[data.numClasses()];
182: for (int j = 0; j < data.numInstances(); j++) {
183: weightOfInstancesInClass[(int) data.instance(j)
184: .classValue()] += data.instance(j).weight();
185: }
186: sumOfWeights = Utils.sum(weightOfInstancesInClass);
187:
188: // normalize the matrix if not already
189: for (int i = 0; i < m_size; i++) {
190: if (!Utils.eq(((Double) getCell(i, i)).doubleValue(), 0)) {
191: CostMatrix normMatrix = new CostMatrix(this );
192: normMatrix.normalize();
193: return normMatrix.applyCostMatrix(data, random);
194: }
195: }
196:
197: for (int i = 0; i < data.numClasses(); i++) {
198: // Using Kai Ming Ting's formula for deriving weights for
199: // the classes and Breiman's heuristic for multiclass
200: // problems.
201:
202: sumOfMissClassWeights = 0;
203: for (int j = 0; j < data.numClasses(); j++) {
204: if (Utils.sm(((Double) getCell(i, j)).doubleValue(), 0)) {
205: throw new Exception(
206: "Neg. weights in misclassification "
207: + "cost matrix!");
208: }
209: sumOfMissClassWeights += ((Double) getCell(i, j))
210: .doubleValue();
211: }
212: weightFactor[i] = sumOfMissClassWeights * sumOfWeights;
213: sumOfWeightFactors += sumOfMissClassWeights
214: * weightOfInstancesInClass[i];
215: }
216: for (int i = 0; i < data.numClasses(); i++) {
217: weightFactor[i] /= sumOfWeightFactors;
218: }
219:
220: // Store new weights
221: weightOfInstances = new double[data.numInstances()];
222: for (int i = 0; i < data.numInstances(); i++) {
223: weightOfInstances[i] = data.instance(i).weight()
224: * weightFactor[(int) data.instance(i).classValue()];
225: }
226:
227: // Change instances weight or do resampling
228: if (random != null) {
229: return data.resampleWithWeights(random, weightOfInstances);
230: } else {
231: Instances instances = new Instances(data);
232: for (int i = 0; i < data.numInstances(); i++) {
233: instances.instance(i).setWeight(weightOfInstances[i]);
234: }
235: return instances;
236: }
237: }
238:
239: /**
240: * Calculates the expected misclassification cost for each possible class value,
241: * given class probability estimates.
242: *
243: * @param classProbs the class probability estimates.
244: * @return the expected costs.
245: * @exception Exception if the wrong number of class probabilities is supplied.
246: */
247: public double[] expectedCosts(double[] classProbs) throws Exception {
248:
249: if (classProbs.length != m_size) {
250: throw new Exception(
251: "Length of probability estimates don't "
252: + "match cost matrix");
253: }
254:
255: double[] costs = new double[m_size];
256:
257: for (int x = 0; x < m_size; x++) {
258: for (int y = 0; y < m_size; y++) {
259: Object element = getCell(y, x);
260: if (!(element instanceof Double)) {
261: throw new Exception("Can't use non-fixed costs in "
262: + "computing expected costs.");
263: }
264: costs[x] += classProbs[y]
265: * ((Double) element).doubleValue();
266: }
267: }
268:
269: return costs;
270: }
271:
272: /**
273: * Calculates the expected misclassification cost for each possible class value,
274: * given class probability estimates.
275: *
276: * @param classProbs the class probability estimates.
277: * @param inst the current instance for which the class probabilites
278: * apply. Is used for computing any non-fixed cost values.
279: * @return the expected costs.
280: * @exception Exception if something goes wrong
281: */
282: public double[] expectedCosts(double[] classProbs, Instance inst)
283: throws Exception {
284:
285: if (classProbs.length != m_size) {
286: throw new Exception(
287: "Length of probability estimates don't "
288: + "match cost matrix");
289: }
290:
291: if (!replaceStrings()) {
292: return expectedCosts(classProbs);
293: }
294:
295: double[] costs = new double[m_size];
296:
297: for (int x = 0; x < m_size; x++) {
298: for (int y = 0; y < m_size; y++) {
299: Object element = getCell(y, x);
300: double costVal;
301: if (!(element instanceof Double)) {
302: costVal = ((AttributeExpression) element)
303: .evaluateExpression(inst);
304: } else {
305: costVal = ((Double) element).doubleValue();
306: }
307: costs[x] += classProbs[y] * costVal;
308: }
309: }
310:
311: return costs;
312: }
313:
314: /**
315: * Gets the maximum cost for a particular class value.
316: *
317: * @param classVal the class value.
318: * @return the maximum cost.
319: * @exception Exception if cost matrix contains non-fixed
320: * costs
321: */
322: public double getMaxCost(int classVal) throws Exception {
323:
324: double maxCost = Double.NEGATIVE_INFINITY;
325:
326: for (int i = 0; i < m_size; i++) {
327: Object element = getCell(classVal, i);
328: if (!(element instanceof Double)) {
329: throw new Exception("Can't use non-fixed costs when "
330: + "getting max cost.");
331: }
332: double cost = ((Double) element).doubleValue();
333: if (cost > maxCost)
334: maxCost = cost;
335: }
336:
337: return maxCost;
338: }
339:
340: /**
341: * Gets the maximum cost for a particular class value.
342: *
343: * @param classVal the class value.
344: * @return the maximum cost.
345: * @exception Exception if cost matrix contains non-fixed
346: * costs
347: */
348: public double getMaxCost(int classVal, Instance inst)
349: throws Exception {
350:
351: if (!replaceStrings()) {
352: return getMaxCost(classVal);
353: }
354:
355: double maxCost = Double.NEGATIVE_INFINITY;
356: double cost;
357: for (int i = 0; i < m_size; i++) {
358: Object element = getCell(classVal, i);
359: if (!(element instanceof Double)) {
360: cost = ((AttributeExpression) element)
361: .evaluateExpression(inst);
362: } else {
363: cost = ((Double) element).doubleValue();
364: }
365: if (cost > maxCost)
366: maxCost = cost;
367: }
368:
369: return maxCost;
370: }
371:
372: /**
373: * Normalizes the matrix so that the diagonal contains zeros.
374: *
375: */
376: public void normalize() {
377:
378: for (int y = 0; y < m_size; y++) {
379: double diag = ((Double) getCell(y, y)).doubleValue();
380: for (int x = 0; x < m_size; x++) {
381: setCell(x, y, new Double(((Double) getCell(x, y))
382: .doubleValue()
383: - diag));
384: }
385: }
386: }
387:
388: /**
389: * Loads a cost matrix in the old format from a reader. Adapted from code once sitting
390: * in Instances.java
391: *
392: * @param reader the reader to get the values from.
393: * @exception Exception if the matrix cannot be read correctly.
394: */
395: public void readOldFormat(Reader reader) throws Exception {
396:
397: StreamTokenizer tokenizer;
398: int currentToken;
399: double firstIndex, secondIndex, weight;
400:
401: tokenizer = new StreamTokenizer(reader);
402:
403: initialize();
404:
405: tokenizer.commentChar('%');
406: tokenizer.eolIsSignificant(true);
407: while (StreamTokenizer.TT_EOF != (currentToken = tokenizer
408: .nextToken())) {
409:
410: // Skip empty lines
411: if (currentToken == StreamTokenizer.TT_EOL) {
412: continue;
413: }
414:
415: // Get index of first class.
416: if (currentToken != StreamTokenizer.TT_NUMBER) {
417: throw new Exception(
418: "Only numbers and comments allowed "
419: + "in cost file!");
420: }
421: firstIndex = tokenizer.nval;
422: if (!Utils.eq((double) (int) firstIndex, firstIndex)) {
423: throw new Exception("First number in line has to be "
424: + "index of a class!");
425: }
426: if ((int) firstIndex >= size()) {
427: throw new Exception("Class index out of range!");
428: }
429:
430: // Get index of second class.
431: if (StreamTokenizer.TT_EOF == (currentToken = tokenizer
432: .nextToken())) {
433: throw new Exception("Premature end of file!");
434: }
435: if (currentToken == StreamTokenizer.TT_EOL) {
436: throw new Exception("Premature end of line!");
437: }
438: if (currentToken != StreamTokenizer.TT_NUMBER) {
439: throw new Exception(
440: "Only numbers and comments allowed "
441: + "in cost file!");
442: }
443: secondIndex = tokenizer.nval;
444: if (!Utils.eq((double) (int) secondIndex, secondIndex)) {
445: throw new Exception("Second number in line has to be "
446: + "index of a class!");
447: }
448: if ((int) secondIndex >= size()) {
449: throw new Exception("Class index out of range!");
450: }
451: if ((int) secondIndex == (int) firstIndex) {
452: throw new Exception("Diagonal of cost matrix non-zero!");
453: }
454:
455: // Get cost factor.
456: if (StreamTokenizer.TT_EOF == (currentToken = tokenizer
457: .nextToken())) {
458: throw new Exception("Premature end of file!");
459: }
460: if (currentToken == StreamTokenizer.TT_EOL) {
461: throw new Exception("Premature end of line!");
462: }
463: if (currentToken != StreamTokenizer.TT_NUMBER) {
464: throw new Exception(
465: "Only numbers and comments allowed "
466: + "in cost file!");
467: }
468: weight = tokenizer.nval;
469: if (!Utils.gr(weight, 0)) {
470: throw new Exception("Only positive weights allowed!");
471: }
472: setCell((int) firstIndex, (int) secondIndex, new Double(
473: weight));
474: }
475: }
476:
477: /**
478: * Reads a matrix from a reader. The first line in the file should
479: * contain the number of rows and columns. Subsequent lines
480: * contain elements of the matrix.
481: * (FracPete: taken from old weka.core.Matrix class)
482: *
483: * @param reader the reader containing the matrix
484: * @throws Exception if an error occurs
485: * @see #write(Writer)
486: */
487: public CostMatrix(Reader reader) throws Exception {
488: LineNumberReader lnr = new LineNumberReader(reader);
489: String line;
490: int currentRow = -1;
491:
492: while ((line = lnr.readLine()) != null) {
493:
494: // Comments
495: if (line.startsWith("%")) {
496: continue;
497: }
498:
499: StringTokenizer st = new StringTokenizer(line);
500: // Ignore blank lines
501: if (!st.hasMoreTokens()) {
502: continue;
503: }
504:
505: if (currentRow < 0) {
506: int rows = Integer.parseInt(st.nextToken());
507: if (!st.hasMoreTokens()) {
508: throw new Exception("Line " + lnr.getLineNumber()
509: + ": expected number of columns");
510: }
511:
512: int cols = Integer.parseInt(st.nextToken());
513: if (rows != cols) {
514: throw new Exception(
515: "Trying to create a non-square cost "
516: + "matrix");
517: }
518: // m_matrix = new Object[rows][cols];
519: m_size = rows;
520: initialize();
521: currentRow++;
522: continue;
523:
524: } else {
525: if (currentRow == m_size) {
526: throw new Exception("Line " + lnr.getLineNumber()
527: + ": too many rows provided");
528: }
529:
530: for (int i = 0; i < m_size; i++) {
531: if (!st.hasMoreTokens()) {
532: throw new Exception("Line "
533: + lnr.getLineNumber()
534: + ": too few matrix elements provided");
535: }
536:
537: String nextTok = st.nextToken();
538: // try to parse as a double first
539: Double val = null;
540: try {
541: val = new Double(nextTok);
542: double value = val.doubleValue();
543: } catch (Exception ex) {
544: val = null;
545: }
546: if (val == null) {
547: setCell(currentRow, i, nextTok);
548: } else {
549: setCell(currentRow, i, val);
550: }
551: }
552: currentRow++;
553: }
554: }
555:
556: if (currentRow == -1) {
557: throw new Exception("Line " + lnr.getLineNumber()
558: + ": expected number of rows");
559: } else if (currentRow != m_size) {
560: throw new Exception("Line " + lnr.getLineNumber()
561: + ": too few rows provided");
562: }
563: }
564:
565: /**
566: * Writes out a matrix. The format can be read via the
567: * CostMatrix(Reader) constructor.
568: * (FracPete: taken from old weka.core.Matrix class)
569: *
570: * @param w the output Writer
571: * @throws Exception if an error occurs
572: */
573: public void write(Writer w) throws Exception {
574: w.write("% Rows\tColumns\n");
575: w.write("" + m_size + "\t" + m_size + "\n");
576: w.write("% Matrix elements\n");
577: for (int i = 0; i < m_size; i++) {
578: for (int j = 0; j < m_size; j++) {
579: w.write("" + getCell(i, j) + "\t");
580: }
581: w.write("\n");
582: }
583: w.flush();
584: }
585:
586: /**
587: * converts the Matrix into a single line Matlab string: matrix is enclosed
588: * by parentheses, rows are separated by semicolon and single cells by
589: * blanks, e.g., [1 2; 3 4].
590: * @return the matrix in Matlab single line format
591: */
592: public String toMatlab() {
593: StringBuffer result;
594: int i;
595: int n;
596:
597: result = new StringBuffer();
598:
599: result.append("[");
600:
601: for (i = 0; i < m_size; i++) {
602: if (i > 0) {
603: result.append("; ");
604: }
605:
606: for (n = 0; n < m_size; n++) {
607: if (n > 0) {
608: result.append(" ");
609: }
610: result.append(getCell(i, n));
611: }
612: }
613:
614: result.append("]");
615:
616: return result.toString();
617: }
618:
619: /**
620: * Set the value of a particular cell in the matrix
621: *
622: * @param rowIndex the row
623: * @param columnIndex the column
624: * @param value the value to set
625: */
626: public final void setCell(int rowIndex, int columnIndex,
627: Object value) {
628: m_matrix[rowIndex][columnIndex] = value;
629: }
630:
631: /**
632: * Return the contents of a particular cell. Note: this
633: * method returns the Object stored at a particular cell.
634: *
635: * @param rowIndex the row
636: * @param columnIndex the column
637: * @return the value at the cell
638: */
639: public final Object getCell(int rowIndex, int columnIndex) {
640: return m_matrix[rowIndex][columnIndex];
641: }
642:
643: /**
644: * Return the value of a cell as a double (for legacy code)
645: *
646: * @param rowIndex the row
647: * @param columnIndex the column
648: * @return the value at a particular cell as a double
649: * @exception Exception if the value is not a double
650: */
651: public final double getElement(int rowIndex, int columnIndex)
652: throws Exception {
653: if (!(m_matrix[rowIndex][columnIndex] instanceof Double)) {
654: throw new Exception("Cost matrix contains non-fixed costs!");
655: }
656: return ((Double) m_matrix[rowIndex][columnIndex]).doubleValue();
657: }
658:
659: /**
660: * Return the value of a cell as a double. Computes the
661: * value for non-fixed costs using the supplied Instance
662: *
663: * @param rowIndex the row
664: * @param columnIndex the column
665: * @return the value from a particular cell
666: * @exception Exception if something goes wrong
667: */
668: public final double getElement(int rowIndex, int columnIndex,
669: Instance inst) throws Exception {
670:
671: if (m_matrix[rowIndex][columnIndex] instanceof Double) {
672: return ((Double) m_matrix[rowIndex][columnIndex])
673: .doubleValue();
674: } else if (m_matrix[rowIndex][columnIndex] instanceof String) {
675: replaceStrings();
676: }
677:
678: return ((AttributeExpression) m_matrix[rowIndex][columnIndex])
679: .evaluateExpression(inst);
680: }
681:
682: /**
683: * Set the value of a cell as a double
684: *
685: * @param rowIndex the row
686: * @param columnIndex the column
687: * @param value the value (double) to set
688: */
689: public final void setElement(int rowIndex, int columnIndex,
690: double value) {
691: m_matrix[rowIndex][columnIndex] = new Double(value);
692: }
693:
694: /**
695: * creates a matrix from the given Matlab string.
696: * @param matlab the matrix in matlab format
697: * @return the matrix represented by the given string
698: */
699: public static Matrix parseMatlab(String matlab) throws Exception {
700: return Matrix.parseMatlab(matlab);
701: }
702:
703: /**
704: * Converts a matrix to a string.
705: * (FracPete: taken from old weka.core.Matrix class)
706: *
707: * @return the converted string
708: */
709: public String toString() {
710: // Determine the width required for the maximum element,
711: // and check for fractional display requirement.
712: double maxval = 0;
713: boolean fractional = false;
714: Object element = null;
715: int widthNumber = 0;
716: int widthExpression = 0;
717: for (int i = 0; i < size(); i++) {
718: for (int j = 0; j < size(); j++) {
719: element = getCell(i, j);
720: if (element instanceof Double) {
721: double current = ((Double) element).doubleValue();
722:
723: if (current < 0)
724: current *= -11;
725: if (current > maxval)
726: maxval = current;
727: double fract = Math.abs(current
728: - Math.rint(current));
729: if (!fractional
730: && ((Math.log(fract) / Math.log(10)) >= -2)) {
731: fractional = true;
732: }
733: } else {
734: if (element.toString().length() > widthExpression) {
735: widthExpression = element.toString().length();
736: }
737: }
738: }
739: }
740: if (maxval > 0) {
741: widthNumber = (int) (Math.log(maxval) / Math.log(10) + (fractional ? 4
742: : 1));
743: }
744:
745: int width = (widthNumber > widthExpression) ? widthNumber
746: : widthExpression;
747:
748: StringBuffer text = new StringBuffer();
749: for (int i = 0; i < size(); i++) {
750: for (int j = 0; j < size(); j++) {
751: element = getCell(i, j);
752: if (element instanceof Double) {
753: text.append(" ").append(
754: Utils.doubleToString(((Double) element)
755: .doubleValue(), width,
756: (fractional ? 2 : 0)));
757: } else {
758: int diff = width - element.toString().length();
759: if (diff > 0) {
760: int left = diff % 2;
761: left += diff / 2;
762: String temp = Utils.padLeft(element.toString(),
763: element.toString().length() + left);
764: temp = Utils.padRight(temp, width);
765: text.append(" ").append(temp);
766: } else {
767: text.append(" ").append(element.toString());
768: }
769: }
770: }
771: text.append("\n");
772: }
773:
774: return text.toString();
775: }
776: }
|