001: /*
002: * This program is free software; you can redistribute it and/or modify
003: * it under the terms of the GNU General Public License as published by
004: * the Free Software Foundation; either version 2 of the License, or
005: * (at your option) any later version.
006: *
007: * This program is distributed in the hope that it will be useful,
008: * but WITHOUT ANY WARRANTY; without even the implied warranty of
009: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
010: * GNU General Public License for more details.
011: *
012: * You should have received a copy of the GNU General Public License
013: * along with this program; if not, write to the Free Software
014: * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
015: */
016:
017: /*
018: * NNConditionalEstimator.java
019: * Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
020: *
021: */
022:
023: package weka.estimators;
024:
025: import java.util.Random;
026: import java.util.Vector;
027:
028: import weka.core.matrix.Matrix;
029: import weka.core.Utils;
030:
031: /**
032: * Conditional probability estimator for a numeric domain conditional upon
033: * a numeric domain (using Mahalanobis distance).
034: *
035: * @author Len Trigg (trigg@cs.waikato.ac.nz)
036: * @version $Revision: 1.7 $
037: */
038: public class NNConditionalEstimator implements ConditionalEstimator {
039:
040: /** Vector containing all of the values seen */
041: private Vector m_Values = new Vector();
042:
043: /** Vector containing all of the conditioning values seen */
044: private Vector m_CondValues = new Vector();
045:
046: /** Vector containing the associated weights */
047: private Vector m_Weights = new Vector();
048:
049: /** The sum of the weights so far */
050: private double m_SumOfWeights;
051:
052: /** Current Conditional mean */
053: private double m_CondMean;
054:
055: /** Current Values mean */
056: private double m_ValueMean;
057:
058: /** Current covariance matrix */
059: private Matrix m_Covariance;
060:
061: /** Whether we can optimise the kernel summation */
062: private boolean m_AllWeightsOne = true;
063:
064: /** 2 * PI */
065: private static double TWO_PI = 2 * Math.PI;
066:
067: // ===============
068: // Private methods
069: // ===============
070:
071: /**
072: * Execute a binary search to locate the nearest data value
073: *
074: * @param key the data value to locate
075: * @param secondaryKey the data value to locate
076: * @return the index of the nearest data value
077: */
078: private int findNearestPair(double key, double secondaryKey) {
079:
080: int low = 0;
081: int high = m_CondValues.size();
082: int middle = 0;
083: while (low < high) {
084: middle = (low + high) / 2;
085: double current = ((Double) m_CondValues.elementAt(middle))
086: .doubleValue();
087: if (current == key) {
088: double secondary = ((Double) m_Values.elementAt(middle))
089: .doubleValue();
090: if (secondary == secondaryKey) {
091: return middle;
092: }
093: if (secondary > secondaryKey) {
094: high = middle;
095: } else if (secondary < secondaryKey) {
096: low = middle + 1;
097: }
098: }
099: if (current > key) {
100: high = middle;
101: } else if (current < key) {
102: low = middle + 1;
103: }
104: }
105: return low;
106: }
107:
108: /** Calculate covariance and value means */
109: private void calculateCovariance() {
110:
111: double sumValues = 0, sumConds = 0;
112: for (int i = 0; i < m_Values.size(); i++) {
113: sumValues += ((Double) m_Values.elementAt(i)).doubleValue()
114: * ((Double) m_Weights.elementAt(i)).doubleValue();
115: sumConds += ((Double) m_CondValues.elementAt(i))
116: .doubleValue()
117: * ((Double) m_Weights.elementAt(i)).doubleValue();
118: }
119: m_ValueMean = sumValues / m_SumOfWeights;
120: m_CondMean = sumConds / m_SumOfWeights;
121: double c00 = 0, c01 = 0, c10 = 0, c11 = 0;
122: for (int i = 0; i < m_Values.size(); i++) {
123: double x = ((Double) m_Values.elementAt(i)).doubleValue();
124: double y = ((Double) m_CondValues.elementAt(i))
125: .doubleValue();
126: double weight = ((Double) m_Weights.elementAt(i))
127: .doubleValue();
128: c00 += (x - m_ValueMean) * (x - m_ValueMean) * weight;
129: c01 += (x - m_ValueMean) * (y - m_CondMean) * weight;
130: c11 += (y - m_CondMean) * (y - m_CondMean) * weight;
131: }
132: c00 /= (m_SumOfWeights - 1.0);
133: c01 /= (m_SumOfWeights - 1.0);
134: c10 = c01;
135: c11 /= (m_SumOfWeights - 1.0);
136: m_Covariance = new Matrix(2, 2);
137: m_Covariance.set(0, 0, c00);
138: m_Covariance.set(0, 1, c01);
139: m_Covariance.set(1, 0, c10);
140: m_Covariance.set(1, 1, c11);
141: }
142:
143: /**
144: * Returns value for normal kernel
145: *
146: * @param x the argument to the kernel function
147: * @param variance the variance
148: * @return the value for a normal kernel
149: */
150: private double normalKernel(double x, double variance) {
151:
152: return Math.exp(-x * x / (2 * variance))
153: / Math.sqrt(variance * TWO_PI);
154: }
155:
156: /**
157: * Add a new data value to the current estimator.
158: *
159: * @param data the new data value
160: * @param given the new value that data is conditional upon
161: * @param weight the weight assigned to the data value
162: */
163: public void addValue(double data, double given, double weight) {
164:
165: int insertIndex = findNearestPair(given, data);
166: if ((m_Values.size() <= insertIndex)
167: || (((Double) m_CondValues.elementAt(insertIndex))
168: .doubleValue() != given)
169: || (((Double) m_Values.elementAt(insertIndex))
170: .doubleValue() != data)) {
171: m_CondValues
172: .insertElementAt(new Double(given), insertIndex);
173: m_Values.insertElementAt(new Double(data), insertIndex);
174: m_Weights.insertElementAt(new Double(weight), insertIndex);
175: if (weight != 1) {
176: m_AllWeightsOne = false;
177: }
178: } else {
179: double newWeight = ((Double) m_Weights
180: .elementAt(insertIndex)).doubleValue();
181: newWeight += weight;
182: m_Weights.setElementAt(new Double(newWeight), insertIndex);
183: m_AllWeightsOne = false;
184: }
185: m_SumOfWeights += weight;
186: // Invalidate any previously calculated covariance matrix
187: m_Covariance = null;
188: }
189:
190: /**
191: * Get a probability estimator for a value
192: *
193: * @param given the new value that data is conditional upon
194: * @return the estimator for the supplied value given the condition
195: */
196: public Estimator getEstimator(double given) {
197:
198: if (m_Covariance == null) {
199: calculateCovariance();
200: }
201: Estimator result = new MahalanobisEstimator(m_Covariance, given
202: - m_CondMean, m_ValueMean);
203: return result;
204: }
205:
206: /**
207: * Get a probability estimate for a value
208: *
209: * @param data the value to estimate the probability of
210: * @param given the new value that data is conditional upon
211: * @return the estimated probability of the supplied value
212: */
213: public double getProbability(double data, double given) {
214:
215: return getEstimator(given).getProbability(data);
216: }
217:
218: /** Display a representation of this estimator */
219: public String toString() {
220:
221: if (m_Covariance == null) {
222: calculateCovariance();
223: }
224: String result = "NN Conditional Estimator. "
225: + m_CondValues.size() + " data points. Mean = "
226: + Utils.doubleToString(m_ValueMean, 4, 2)
227: + " Conditional mean = "
228: + Utils.doubleToString(m_CondMean, 4, 2);
229: result += " Covariance Matrix: \n" + m_Covariance;
230: return result;
231: }
232:
233: /**
234: * Main method for testing this class.
235: *
236: * @param argv should contain a sequence of numeric values
237: */
238: public static void main(String[] argv) {
239:
240: try {
241: int seed = 42;
242: if (argv.length > 0) {
243: seed = Integer.parseInt(argv[0]);
244: }
245: NNConditionalEstimator newEst = new NNConditionalEstimator();
246:
247: // Create 100 random points and add them
248: Random r = new Random(seed);
249:
250: int numPoints = 50;
251: if (argv.length > 2) {
252: numPoints = Integer.parseInt(argv[2]);
253: }
254: for (int i = 0; i < numPoints; i++) {
255: int x = Math.abs(r.nextInt() % 100);
256: int y = Math.abs(r.nextInt() % 100);
257: System.out.println("# " + x + " " + y);
258: newEst.addValue(x, y, 1);
259: }
260: // System.out.println(newEst);
261: int cond;
262: if (argv.length > 1) {
263: cond = Integer.parseInt(argv[1]);
264: } else
265: cond = Math.abs(r.nextInt() % 100);
266: System.out.println("## Conditional = " + cond);
267: Estimator result = newEst.getEstimator(cond);
268: for (int i = 0; i <= 100; i += 5) {
269: System.out.println(" " + i + " "
270: + result.getProbability(i));
271: }
272: } catch (Exception e) {
273: System.out.println(e.getMessage());
274: }
275: }
276: }
|