001: /*
002: * This program is free software; you can redistribute it and/or modify
003: * it under the terms of the GNU General Public License as published by
004: * the Free Software Foundation; either version 2 of the License, or
005: * (at your option) any later version.
006: *
007: * This program is distributed in the hope that it will be useful,
008: * but WITHOUT ANY WARRANTY; without even the implied warranty of
009: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
010: * GNU General Public License for more details.
011: *
012: * You should have received a copy of the GNU General Public License
013: * along with this program; if not, write to the Free Software
014: * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
015: */
016:
017: /*
018: * KernelEstimator.java
019: * Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
020: *
021: */
022:
023: package weka.estimators;
024:
025: import weka.core.Capabilities.Capability;
026: import weka.core.Capabilities;
027: import weka.core.Utils;
028: import weka.core.Statistics;
029:
030: /**
031: * Simple kernel density estimator. Uses one gaussian kernel per observed
032: * data value.
033: *
034: * @author Len Trigg (trigg@cs.waikato.ac.nz)
035: * @version $Revision: 1.7 $
036: */
037: public class KernelEstimator extends Estimator implements
038: IncrementalEstimator {
039:
040: /** for serialization */
041: private static final long serialVersionUID = 3646923563367683925L;
042:
043: /** Vector containing all of the values seen */
044: private double[] m_Values;
045:
046: /** Vector containing the associated weights */
047: private double[] m_Weights;
048:
049: /** Number of values stored in m_Weights and m_Values so far */
050: private int m_NumValues;
051:
052: /** The sum of the weights so far */
053: private double m_SumOfWeights;
054:
055: /** The standard deviation */
056: private double m_StandardDev;
057:
058: /** The precision of data values */
059: private double m_Precision;
060:
061: /** Whether we can optimise the kernel summation */
062: private boolean m_AllWeightsOne;
063:
064: /** Maximum percentage error permitted in probability calculations */
065: private static double MAX_ERROR = 0.01;
066:
067: /**
068: * Execute a binary search to locate the nearest data value
069: *
070: * @param the data value to locate
071: * @return the index of the nearest data value
072: */
073: private int findNearestValue(double key) {
074:
075: int low = 0;
076: int high = m_NumValues;
077: int middle = 0;
078: while (low < high) {
079: middle = (low + high) / 2;
080: double current = m_Values[middle];
081: if (current == key) {
082: return middle;
083: }
084: if (current > key) {
085: high = middle;
086: } else if (current < key) {
087: low = middle + 1;
088: }
089: }
090: return low;
091: }
092:
093: /**
094: * Round a data value using the defined precision for this estimator
095: *
096: * @param data the value to round
097: * @return the rounded data value
098: */
099: private double round(double data) {
100:
101: return Math.rint(data / m_Precision) * m_Precision;
102: }
103:
104: // ===============
105: // Public methods.
106: // ===============
107:
108: /**
109: * Constructor that takes a precision argument.
110: *
111: * @param precision the precision to which numeric values are given. For
112: * example, if the precision is stated to be 0.1, the values in the
113: * interval (0.25,0.35] are all treated as 0.3.
114: */
115: public KernelEstimator(double precision) {
116:
117: m_Values = new double[50];
118: m_Weights = new double[50];
119: m_NumValues = 0;
120: m_SumOfWeights = 0;
121: m_AllWeightsOne = true;
122: m_Precision = precision;
123: // precision cannot be zero
124: if (m_Precision < Utils.SMALL)
125: m_Precision = Utils.SMALL;
126: // m_StandardDev = 1e10 * m_Precision; // Set the standard deviation initially very wide
127: m_StandardDev = m_Precision / (2 * 3);
128: }
129:
130: /**
131: * Add a new data value to the current estimator.
132: *
133: * @param data the new data value
134: * @param weight the weight assigned to the data value
135: */
136: public void addValue(double data, double weight) {
137:
138: if (weight == 0) {
139: return;
140: }
141: data = round(data);
142: int insertIndex = findNearestValue(data);
143: if ((m_NumValues <= insertIndex)
144: || (m_Values[insertIndex] != data)) {
145: if (m_NumValues < m_Values.length) {
146: int left = m_NumValues - insertIndex;
147: System.arraycopy(m_Values, insertIndex, m_Values,
148: insertIndex + 1, left);
149: System.arraycopy(m_Weights, insertIndex, m_Weights,
150: insertIndex + 1, left);
151:
152: m_Values[insertIndex] = data;
153: m_Weights[insertIndex] = weight;
154: m_NumValues++;
155: } else {
156: double[] newValues = new double[m_Values.length * 2];
157: double[] newWeights = new double[m_Values.length * 2];
158: int left = m_NumValues - insertIndex;
159: System
160: .arraycopy(m_Values, 0, newValues, 0,
161: insertIndex);
162: System.arraycopy(m_Weights, 0, newWeights, 0,
163: insertIndex);
164: newValues[insertIndex] = data;
165: newWeights[insertIndex] = weight;
166: System.arraycopy(m_Values, insertIndex, newValues,
167: insertIndex + 1, left);
168: System.arraycopy(m_Weights, insertIndex, newWeights,
169: insertIndex + 1, left);
170: m_NumValues++;
171: m_Values = newValues;
172: m_Weights = newWeights;
173: }
174: if (weight != 1) {
175: m_AllWeightsOne = false;
176: }
177: } else {
178: m_Weights[insertIndex] += weight;
179: m_AllWeightsOne = false;
180: }
181: m_SumOfWeights += weight;
182: double range = m_Values[m_NumValues - 1] - m_Values[0];
183: if (range > 0) {
184: m_StandardDev = Math.max(range / Math.sqrt(m_SumOfWeights),
185: // allow at most 3 sds within one interval
186: m_Precision / (2 * 3));
187: }
188: }
189:
190: /**
191: * Get a probability estimate for a value.
192: *
193: * @param data the value to estimate the probability of
194: * @return the estimated probability of the supplied value
195: */
196: public double getProbability(double data) {
197:
198: double delta = 0, sum = 0, currentProb = 0;
199: double zLower = 0, zUpper = 0;
200: if (m_NumValues == 0) {
201: zLower = (data - (m_Precision / 2)) / m_StandardDev;
202: zUpper = (data + (m_Precision / 2)) / m_StandardDev;
203: return (Statistics.normalProbability(zUpper) - Statistics
204: .normalProbability(zLower));
205: }
206: double weightSum = 0;
207: int start = findNearestValue(data);
208: for (int i = start; i < m_NumValues; i++) {
209: delta = m_Values[i] - data;
210: zLower = (delta - (m_Precision / 2)) / m_StandardDev;
211: zUpper = (delta + (m_Precision / 2)) / m_StandardDev;
212: currentProb = (Statistics.normalProbability(zUpper) - Statistics
213: .normalProbability(zLower));
214: sum += currentProb * m_Weights[i];
215: /*
216: System.out.print("zL" + (i + 1) + ": " + zLower + " ");
217: System.out.print("zU" + (i + 1) + ": " + zUpper + " ");
218: System.out.print("P" + (i + 1) + ": " + currentProb + " ");
219: System.out.println("total: " + (currentProb * m_Weights[i]) + " ");
220: */
221: weightSum += m_Weights[i];
222: if (currentProb * (m_SumOfWeights - weightSum) < sum
223: * MAX_ERROR) {
224: break;
225: }
226: }
227: for (int i = start - 1; i >= 0; i--) {
228: delta = m_Values[i] - data;
229: zLower = (delta - (m_Precision / 2)) / m_StandardDev;
230: zUpper = (delta + (m_Precision / 2)) / m_StandardDev;
231: currentProb = (Statistics.normalProbability(zUpper) - Statistics
232: .normalProbability(zLower));
233: sum += currentProb * m_Weights[i];
234: weightSum += m_Weights[i];
235: if (currentProb * (m_SumOfWeights - weightSum) < sum
236: * MAX_ERROR) {
237: break;
238: }
239: }
240: return sum / m_SumOfWeights;
241: }
242:
243: /** Display a representation of this estimator */
244: public String toString() {
245:
246: String result = m_NumValues
247: + " Normal Kernels. \nStandardDev = "
248: + Utils.doubleToString(m_StandardDev, 6, 4)
249: + " Precision = " + m_Precision;
250: if (m_NumValues == 0) {
251: result += " \nMean = 0";
252: } else {
253: result += " \nMeans =";
254: for (int i = 0; i < m_NumValues; i++) {
255: result += " " + m_Values[i];
256: }
257: if (!m_AllWeightsOne) {
258: result += "\nWeights = ";
259: for (int i = 0; i < m_NumValues; i++) {
260: result += " " + m_Weights[i];
261: }
262: }
263: }
264: return result + "\n";
265: }
266:
267: /**
268: * Returns default capabilities of the classifier.
269: *
270: * @return the capabilities of this classifier
271: */
272: public Capabilities getCapabilities() {
273: Capabilities result = super .getCapabilities();
274:
275: // attributes
276: result.enable(Capability.NUMERIC_ATTRIBUTES);
277: return result;
278: }
279:
280: /**
281: * Main method for testing this class.
282: *
283: * @param argv should contain a sequence of numeric values
284: */
285: public static void main(String[] argv) {
286:
287: try {
288: if (argv.length < 2) {
289: System.out
290: .println("Please specify a set of instances.");
291: return;
292: }
293: KernelEstimator newEst = new KernelEstimator(0.01);
294: for (int i = 0; i < argv.length - 3; i += 2) {
295: newEst.addValue(Double.valueOf(argv[i]).doubleValue(),
296: Double.valueOf(argv[i + 1]).doubleValue());
297: }
298: System.out.println(newEst);
299:
300: double start = Double.valueOf(argv[argv.length - 2])
301: .doubleValue();
302: double finish = Double.valueOf(argv[argv.length - 1])
303: .doubleValue();
304: for (double current = start; current < finish; current += (finish - start) / 50) {
305: System.out.println("Data: " + current + " "
306: + newEst.getProbability(current));
307: }
308: } catch (Exception e) {
309: System.out.println(e.getMessage());
310: }
311: }
312: }
|