001: /*
002: * This program is free software; you can redistribute it and/or modify
003: * it under the terms of the GNU General Public License as published by
004: * the Free Software Foundation; either version 2 of the License, or
005: * (at your option) any later version.
006: *
007: * This program is distributed in the hope that it will be useful,
008: * but WITHOUT ANY WARRANTY; without even the implied warranty of
009: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
010: * GNU General Public License for more details.
011: *
012: * You should have received a copy of the GNU General Public License
013: * along with this program; if not, write to the Free Software
014: * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
015: */
016:
017: /*
018: * DDConditionalEstimator.java
019: * Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
020: *
021: */
022:
023: package weka.estimators;
024:
025: /**
026: * Conditional probability estimator for a discrete domain conditional upon
027: * a discrete domain.
028: *
029: * @author Len Trigg (trigg@cs.waikato.ac.nz)
030: * @version $Revision: 1.7 $
031: */
032: public class DDConditionalEstimator implements ConditionalEstimator {
033:
034: /** Hold the sub-estimators */
035: private DiscreteEstimator[] m_Estimators;
036:
037: /**
038: * Constructor
039: *
040: * @param numSymbols the number of possible symbols (remember to include 0)
041: * @param numCondSymbols the number of conditioning symbols
042: * @param laplace if true, sub-estimators will use laplace
043: */
044: public DDConditionalEstimator(int numSymbols, int numCondSymbols,
045: boolean laplace) {
046:
047: m_Estimators = new DiscreteEstimator[numCondSymbols];
048: for (int i = 0; i < numCondSymbols; i++) {
049: m_Estimators[i] = new DiscreteEstimator(numSymbols, laplace);
050: }
051: }
052:
053: /**
054: * Add a new data value to the current estimator.
055: *
056: * @param data the new data value
057: * @param given the new value that data is conditional upon
058: * @param weight the weight assigned to the data value
059: */
060: public void addValue(double data, double given, double weight) {
061:
062: m_Estimators[(int) given].addValue(data, weight);
063: }
064:
065: /**
066: * Get a probability estimator for a value
067: *
068: * @param given the new value that data is conditional upon
069: * @return the estimator for the supplied value given the condition
070: */
071: public Estimator getEstimator(double given) {
072:
073: return m_Estimators[(int) given];
074: }
075:
076: /**
077: * Get a probability estimate for a value
078: *
079: * @param data the value to estimate the probability of
080: * @param given the new value that data is conditional upon
081: * @return the estimated probability of the supplied value
082: */
083: public double getProbability(double data, double given) {
084:
085: return getEstimator(given).getProbability(data);
086: }
087:
088: /** Display a representation of this estimator */
089: public String toString() {
090:
091: String result = "DD Conditional Estimator. "
092: + m_Estimators.length + " sub-estimators:\n";
093: for (int i = 0; i < m_Estimators.length; i++) {
094: result += "Sub-estimator " + i + ": " + m_Estimators[i];
095: }
096: return result;
097: }
098:
099: /**
100: * Main method for testing this class.
101: *
102: * @param argv should contain a sequence of pairs of integers which
103: * will be treated as symbolic.
104: */
105: public static void main(String[] argv) {
106:
107: try {
108: if (argv.length == 0) {
109: System.out
110: .println("Please specify a set of instances.");
111: return;
112: }
113: int currentA = Integer.parseInt(argv[0]);
114: int maxA = currentA;
115: int currentB = Integer.parseInt(argv[1]);
116: int maxB = currentB;
117: for (int i = 2; i < argv.length - 1; i += 2) {
118: currentA = Integer.parseInt(argv[i]);
119: currentB = Integer.parseInt(argv[i + 1]);
120: if (currentA > maxA) {
121: maxA = currentA;
122: }
123: if (currentB > maxB) {
124: maxB = currentB;
125: }
126: }
127: DDConditionalEstimator newEst = new DDConditionalEstimator(
128: maxA + 1, maxB + 1, true);
129: for (int i = 0; i < argv.length - 1; i += 2) {
130: currentA = Integer.parseInt(argv[i]);
131: currentB = Integer.parseInt(argv[i + 1]);
132: System.out.println(newEst);
133: System.out.println("Prediction for " + currentA + '|'
134: + currentB + " = "
135: + newEst.getProbability(currentA, currentB));
136: newEst.addValue(currentA, currentB, 1);
137: }
138: } catch (Exception e) {
139: System.out.println(e.getMessage());
140: }
141: }
142: }
|