001: /*
002: * This program is free software; you can redistribute it and/or modify
003: * it under the terms of the GNU General Public License as published by
004: * the Free Software Foundation; either version 2 of the License, or
005: * (at your option) any later version.
006: *
007: * This program is distributed in the hope that it will be useful,
008: * but WITHOUT ANY WARRANTY; without even the implied warranty of
009: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
010: * GNU General Public License for more details.
011: *
012: * You should have received a copy of the GNU General Public License
013: * along with this program; if not, write to the Free Software
014: * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
015: */
016:
017: /*
018: * NormalEstimator.java
019: * Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
020: *
021: */
022:
023: package weka.estimators;
024:
025: import weka.core.Capabilities.Capability;
026: import weka.core.Capabilities;
027: import weka.core.Statistics;
028: import weka.core.Utils;
029:
030: /**
031: * Simple probability estimator that places a single normal distribution
032: * over the observed values.
033: *
034: * @author Len Trigg (trigg@cs.waikato.ac.nz)
035: * @version $Revision: 1.8 $
036: */
037: public class NormalEstimator extends Estimator implements
038: IncrementalEstimator {
039:
040: /** for serialization */
041: private static final long serialVersionUID = 93584379632315841L;
042:
043: /** The sum of the weights */
044: private double m_SumOfWeights;
045:
046: /** The sum of the values seen */
047: private double m_SumOfValues;
048:
049: /** The sum of the values squared */
050: private double m_SumOfValuesSq;
051:
052: /** The current mean */
053: private double m_Mean;
054:
055: /** The current standard deviation */
056: private double m_StandardDev;
057:
058: /** The precision of numeric values ( = minimum std dev permitted) */
059: private double m_Precision;
060:
061: /**
062: * Round a data value using the defined precision for this estimator
063: *
064: * @param data the value to round
065: * @return the rounded data value
066: */
067: private double round(double data) {
068:
069: return Math.rint(data / m_Precision) * m_Precision;
070: }
071:
072: // ===============
073: // Public methods.
074: // ===============
075:
076: /**
077: * Constructor that takes a precision argument.
078: *
079: * @param precision the precision to which numeric values are given. For
080: * example, if the precision is stated to be 0.1, the values in the
081: * interval (0.25,0.35] are all treated as 0.3.
082: */
083: public NormalEstimator(double precision) {
084:
085: m_Precision = precision;
086:
087: // Allow at most 3 sd's within one interval
088: m_StandardDev = m_Precision / (2 * 3);
089: }
090:
091: /**
092: * Add a new data value to the current estimator.
093: *
094: * @param data the new data value
095: * @param weight the weight assigned to the data value
096: */
097: public void addValue(double data, double weight) {
098:
099: if (weight == 0) {
100: return;
101: }
102: data = round(data);
103: m_SumOfWeights += weight;
104: m_SumOfValues += data * weight;
105: m_SumOfValuesSq += data * data * weight;
106:
107: if (m_SumOfWeights > 0) {
108: m_Mean = m_SumOfValues / m_SumOfWeights;
109: double stdDev = Math.sqrt(Math.abs(m_SumOfValuesSq - m_Mean
110: * m_SumOfValues)
111: / m_SumOfWeights);
112: // If the stdDev ~= 0, we really have no idea of scale yet,
113: // so stick with the default. Otherwise...
114: if (stdDev > 1e-10) {
115: m_StandardDev = Math.max(m_Precision / (2 * 3),
116: // allow at most 3sd's within one interval
117: stdDev);
118: }
119: }
120: }
121:
122: /**
123: * Get a probability estimate for a value
124: *
125: * @param data the value to estimate the probability of
126: * @return the estimated probability of the supplied value
127: */
128: public double getProbability(double data) {
129:
130: data = round(data);
131: double zLower = (data - m_Mean - (m_Precision / 2))
132: / m_StandardDev;
133: double zUpper = (data - m_Mean + (m_Precision / 2))
134: / m_StandardDev;
135:
136: double pLower = Statistics.normalProbability(zLower);
137: double pUpper = Statistics.normalProbability(zUpper);
138: return pUpper - pLower;
139: }
140:
141: /**
142: * Display a representation of this estimator
143: */
144: public String toString() {
145:
146: return "Normal Distribution. Mean = "
147: + Utils.doubleToString(m_Mean, 4) + " StandardDev = "
148: + Utils.doubleToString(m_StandardDev, 4)
149: + " WeightSum = "
150: + Utils.doubleToString(m_SumOfWeights, 4)
151: + " Precision = " + m_Precision + "\n";
152: }
153:
154: /**
155: * Returns default capabilities of the classifier.
156: *
157: * @return the capabilities of this classifier
158: */
159: public Capabilities getCapabilities() {
160: Capabilities result = super .getCapabilities();
161:
162: // attributes
163: result.enable(Capability.NUMERIC_ATTRIBUTES);
164: return result;
165: }
166:
167: /**
168: * Main method for testing this class.
169: *
170: * @param argv should contain a sequence of numeric values
171: */
172: public static void main(String[] argv) {
173:
174: try {
175:
176: if (argv.length == 0) {
177: System.out
178: .println("Please specify a set of instances.");
179: return;
180: }
181: NormalEstimator newEst = new NormalEstimator(0.01);
182: for (int i = 0; i < argv.length; i++) {
183: double current = Double.valueOf(argv[i]).doubleValue();
184: System.out.println(newEst);
185: System.out.println("Prediction for " + current + " = "
186: + newEst.getProbability(current));
187: newEst.addValue(current, 1);
188: }
189: } catch (Exception e) {
190: System.out.println(e.getMessage());
191: }
192: }
193: }
|