KickJava   Java API By Example, From Geeks To Geeks.

Java > Open Source Codes > mondrian > olap > fun > LinReg


1 /*
2 // $Id: //open/mondrian/src/main/mondrian/olap/fun/LinReg.java#11 $
3 // This software is subject to the terms of the Common Public License
4 // Agreement, available at the following URL:
5 // http://www.opensource.org/licenses/cpl.html.
6 // Copyright (C) 2005-2006 Julian Hyde
7 // All Rights Reserved.
8 // You must accept the terms of that agreement to use this software.
9 */

10
11
12 package mondrian.olap.fun;
13
14 import mondrian.olap.*;
15 import mondrian.olap.type.TupleType;
16 import mondrian.olap.type.SetType;
17 import mondrian.calc.*;
18 import mondrian.calc.impl.AbstractDoubleCalc;
19 import mondrian.calc.impl.ValueCalc;
20 import mondrian.mdx.ResolvedFunCall;
21
22 import java.util.ArrayList JavaDoc;
23 import java.util.Iterator JavaDoc;
24 import java.util.List JavaDoc;
25
26 /**
27  * Abstract base class for definitions of linear regression functions.
28  *
29  * @see InterceptFunDef
30  * @see PointFunDef
31  * @see R2FunDef
32  * @see SlopeFunDef
33  * @see VarianceFunDef
34  *
35  * <h2>Correlation coefficient</h2>
36  * <p><i>Correlation coefficient</i></p>
37  *
38  * <p>The correlation coefficient, r, ranges from -1 to +1. The
39  * nonparametric Spearman correlation coefficient, abbreviated rs, has
40  * the same range.</p>
41  *
42  * <table border="1" cellpadding="6" cellspacing="0">
43  * <tr>
44  * <td>Value of r (or rs)</td>
45  * <td>Interpretation</td>
46  * </tr>
47  * <tr>
48  * <td valign="top">r= 0</td>
49  *
50  * <td>The two variables do not vary together at all.</td>
51  * </tr>
52  * <tr>
53  * <td valign="top">0 &gt; r &gt; 1</td>
54  * <td>
55  * <p>The two variables tend to increase or decrease together.</p>
56  * </td>
57  * </tr>
58  * <tr>
59  * <td valign="top">r = 1.0</td>
60  * <td>
61  * <p>Perfect correlation.</p>
62  * </td>
63  * </tr>
64  *
65  * <tr>
66  * <td valign="top">-1 &gt; r &gt; 0</td>
67  * <td>
68  * <p>One variable increases as the other decreases.</p>
69  * </td>
70  * </tr>
71  *
72  * <tr>
73  * <td valign="top">r = -1.0</td>
74  * <td>
75  * <p></p>
76  * <p>Perfect negative or inverse correlation.</p>
77  * </td>
78  * </tr>
79  * </table>
80  *
81  * <p>If r or rs is far from zero, there are four possible explanations:</p>
82  * <p>The X variable helps determine the value of the Y variable.</p>
83  * <ul>
84  * <li>The Y variable helps determine the value of the X variable.
85  * <li>Another variable influences both X and Y.
86  * <li>X and Y don't really correlate at all, and you just
87  * happened to observe such a strong correlation by chance. The P value
88  * determines how often this could occur.
89  * </ul>
90  * <p><i>r2 </i></p>
91  *
92  * <p>Perhaps the best way to interpret the value of r is to square it to
93  * calculate r2. Statisticians call this quantity the coefficient of
94  * determination, but scientists call it r squared. It is has a value
95  * that ranges from zero to one, and is the fraction of the variance in
96  * the two variables that is shared. For example, if r2=0.59, then 59% of
97  * the variance in X can be explained by variation in Y. &nbsp;Likewise,
98  * 59% of the variance in Y can be explained by (or goes along with)
99  * variation in X. More simply, 59% of the variance is shared between X
100  * and Y.</p>
101  *
102  * <p>(<a HREF="http://www.graphpad.com/articles/interpret/corl_n_linear_reg/correlation.htm">Source</a>).
103  *
104  * <p>Also see: <a HREF="http://mathworld.wolfram.com/LeastSquaresFitting.html">least squares fitting</a>.
105  */

106
107
108 public abstract class LinReg extends FunDefBase {
109     /** Code for the specific function. */
110     final int regType;
111
112     public static final int Point = 0;
113     public static final int R2 = 1;
114     public static final int Intercept = 2;
115     public static final int Slope = 3;
116     public static final int Variance = 4;
117
118     static final Resolver InterceptResolver = new ReflectiveMultiResolver(
119             "LinRegIntercept",
120             "LinRegIntercept(<Set>, <Numeric Expression>[, <Numeric Expression>])",
121             "Calculates the linear regression of a set and returns the value of b in the regression line y = ax + b.",
122             new String JavaDoc[]{"fnxn","fnxnn"},
123             InterceptFunDef.class);
124
125     static final Resolver PointResolver = new ReflectiveMultiResolver(
126             "LinRegPoint",
127             "LinRegPoint(<Numeric Expression>, <Set>, <Numeric Expression>[, <Numeric Expression>])",
128             "Calculates the linear regression of a set and returns the value of y in the regression line y = ax + b.",
129             new String JavaDoc[]{"fnnxn","fnnxnn"},
130             PointFunDef.class);
131
132     static final Resolver SlopeResolver = new ReflectiveMultiResolver(
133             "LinRegSlope",
134             "LinRegSlope(<Set>, <Numeric Expression>[, <Numeric Expression>])",
135             "Calculates the linear regression of a set and returns the value of a in the regression line y = ax + b.",
136             new String JavaDoc[]{"fnxn","fnxnn"},
137             SlopeFunDef.class);
138
139     static final Resolver R2Resolver = new ReflectiveMultiResolver(
140             "LinRegR2",
141             "LinRegR2(<Set>, <Numeric Expression>[, <Numeric Expression>])",
142             "Calculates the linear regression of a set and returns R2 (the coefficient of determination).",
143             new String JavaDoc[]{"fnxn","fnxnn"},
144             R2FunDef.class);
145
146     static final Resolver VarianceResolver = new ReflectiveMultiResolver(
147             "LinRegVariance",
148             "LinRegVariance(<Set>, <Numeric Expression>[, <Numeric Expression>])",
149             "Calculates the linear regression of a set and returns the variance associated with the regression line y = ax + b.",
150             new String JavaDoc[]{"fnxn","fnxnn"},
151             VarianceFunDef.class);
152
153
154     public Calc compileCall(ResolvedFunCall call, ExpCompiler compiler) {
155         final ListCalc listCalc = compiler.compileList(call.getArg(0));
156         final DoubleCalc yCalc = compiler.compileDouble(call.getArg(1));
157         final DoubleCalc xCalc = call.getArgCount() > 2 ?
158                 compiler.compileDouble(call.getArg(2)) :
159                 new ValueCalc(call);
160         final boolean isTuples =
161                 ((SetType) listCalc.getType()).getElementType() instanceof
162                 TupleType;
163         return new LinRegCalc(call, listCalc, yCalc, xCalc, isTuples, regType);
164     }
165
166     /////////////////////////////////////////////////////////////////////////
167
//
168
// Helper
169
//
170
/////////////////////////////////////////////////////////////////////////
171
static class Value {
172         private List JavaDoc xs;
173         private List JavaDoc ys;
174         /**
175          * The intercept for the linear regression model. Initialized
176          * following a call to accuracy.
177          */

178         double intercept;
179
180         /**
181          * The slope for the linear regression model. Initialized following a
182          * call to accuracy.
183          */

184         double slope;
185
186          /** the coefficient of determination */
187         double rSquared = Double.MAX_VALUE;
188
189         /** variance = sum square diff mean / n - 1 */
190         double variance = Double.MAX_VALUE;
191
192         Value(double intercept, double slope, List JavaDoc xs, List JavaDoc ys) {
193             this.intercept = intercept;
194             this.slope = slope;
195             this.xs = xs;
196             this.ys = ys;
197         }
198
199         public double getIntercept() {
200             return this.intercept;
201         }
202
203         public double getSlope() {
204             return this.slope;
205         }
206
207         public double getRSquared() {
208             return this.rSquared;
209         }
210
211         /**
212          * strength of the correlation
213          *
214          * @param rSquared
215          */

216         public void setRSquared(double rSquared) {
217             this.rSquared = rSquared;
218         }
219
220         public double getVariance() {
221             return this.variance;
222         }
223
224         public void setVariance(double variance) {
225             this.variance = variance;
226         }
227
228         public String JavaDoc toString() {
229             return "LinReg.Value: slope of "
230                 + slope
231                 + " and an intercept of " + intercept
232                 + ". That is, y="
233                 + intercept
234                 + (slope>0.0 ? " +" : " ")
235                 + slope
236                 + " * x.";
237         }
238     }
239
240     /**
241      * Definition of the <code>LinRegIntercept</code> MDX function.
242      *
243      * <p>Synopsis:
244      * <blockquote><code>LinRegIntercept(&lt;Numeric Expression&gt;, &lt;Set&gt;, &lt;Numeric Expression&gt;[, &lt;Numeric Expression&gt;])</code></blockquote>
245      */

246     public static class InterceptFunDef extends LinReg {
247         public InterceptFunDef(FunDef funDef) {
248             super(funDef, Intercept);
249         }
250     }
251
252     /**
253      * Definition of the <code>LinRegPoint</code> MDX function.
254      *
255      * <p>Synopsis:
256      * <blockquote><code>LinRegPoint(&lt;Numeric Expression&gt;, &lt;Set&gt;, &lt;Numeric Expression&gt;[, &lt;Numeric Expression&gt;])</code></blockquote>
257      */

258     public static class PointFunDef extends LinReg {
259         public PointFunDef(FunDef funDef) {
260             super(funDef, Point);
261         }
262
263         public Calc compileCall(ResolvedFunCall call, ExpCompiler compiler) {
264             final DoubleCalc xPointCalc = compiler.compileDouble(call.getArg(0));
265             final ListCalc listCalc = compiler.compileList(call.getArg(1));
266             final DoubleCalc yCalc = compiler.compileDouble(call.getArg(2));
267             final DoubleCalc xCalc = call.getArgCount() > 3 ?
268                     compiler.compileDouble(call.getArg(3)) :
269                     new ValueCalc(call);
270             final boolean isTuples =
271                     ((SetType) listCalc.getType()).getElementType() instanceof
272                     TupleType;
273             return new PointCalc(call, xPointCalc, listCalc, yCalc, xCalc, isTuples);
274         }
275
276     }
277
278     private static class PointCalc extends AbstractDoubleCalc {
279         private final DoubleCalc xPointCalc;
280         private final ListCalc listCalc;
281         private final DoubleCalc yCalc;
282         private final DoubleCalc xCalc;
283         private final boolean tuples;
284
285         public PointCalc(
286                 ResolvedFunCall call,
287                 DoubleCalc xPointCalc,
288                 ListCalc listCalc,
289                 DoubleCalc yCalc, DoubleCalc xCalc, boolean tuples) {
290             super(call, new Calc[]{xPointCalc, listCalc, yCalc, xCalc});
291             this.xPointCalc = xPointCalc;
292             this.listCalc = listCalc;
293             this.yCalc = yCalc;
294             this.xCalc = xCalc;
295             this.tuples = tuples;
296         }
297
298         public double evaluateDouble(Evaluator evaluator) {
299             double xPoint = xPointCalc.evaluateDouble(evaluator);
300             Value value =
301                     process(evaluator, listCalc, yCalc, xCalc, tuples);
302             if (value == null) {
303                 return FunUtil.DoubleNull;
304             }
305             // use first arg to generate y position
306
double yPoint = xPoint * value.getSlope() +
307                     value.getIntercept();
308             return yPoint;
309         }
310     }
311
312     /**
313      * Definition of the <code>LinRegSlope</code> MDX function.
314      *
315      * <p>Synopsis:
316      * <blockquote><code>LinRegSlope(&lt;Numeric Expression&gt;, &lt;Set&gt;, &lt;Numeric Expression&gt;[, &lt;Numeric Expression&gt;])</code></blockquote>
317      */

318     public static class SlopeFunDef extends LinReg {
319         public SlopeFunDef(FunDef funDef) {
320             super(funDef, Slope);
321         }
322     }
323
324     /**
325      * Definition of the <code>LinRegR2</code> MDX function.
326      *
327      * <p>Synopsis:
328      * <blockquote><code>LinRegR2(&lt;Numeric Expression&gt;, &lt;Set&gt;, &lt;Numeric Expression&gt;[, &lt;Numeric Expression&gt;])</code></blockquote>
329      */

330     public static class R2FunDef extends LinReg {
331         public R2FunDef(FunDef funDef) {
332             super(funDef, R2);
333         }
334     }
335
336     /**
337      * Definition of the <code>LinRegVariance</code> MDX function.
338      *
339      * <p>Synopsis:
340      * <blockquote><code>LinRegVariance(&lt;Numeric Expression&gt;, &lt;Set&gt;, &lt;Numeric Expression&gt;[, &lt;Numeric Expression&gt;])</code></blockquote>
341      */

342     public static class VarianceFunDef extends LinReg {
343         public VarianceFunDef(FunDef funDef) {
344             super(funDef, Variance);
345         }
346     }
347
348     protected static void debug(String JavaDoc type, String JavaDoc msg) {
349         // comment out for no output
350
// RME
351
//System.out.println(type + ": " +msg);
352
}
353
354
355     protected LinReg(FunDef funDef, int regType) {
356         super(funDef);
357         this.regType = regType;
358     }
359
360     protected static LinReg.Value process(
361             Evaluator evaluator,
362             ListCalc listCalc,
363             DoubleCalc yCalc,
364             DoubleCalc xCalc,
365             boolean isTuples) {
366         List JavaDoc members = listCalc.evaluateList(evaluator);
367
368         evaluator = evaluator.push();
369
370         SetWrapper[] sws = evaluateSet(
371                 evaluator, members, new DoubleCalc[] {yCalc, xCalc}, isTuples);
372         SetWrapper swY = sws[0];
373         SetWrapper swX = sws[1];
374
375         if (swY.errorCount > 0) {
376 debug("LinReg.process","ERROR error(s) count =" +swY.errorCount);
377             // TODO: throw exception
378
return null;
379         } else if (swY.v.size() == 0) {
380             return null;
381         }
382
383         return linearReg(swX.v, swY.v);
384     }
385
386     public static LinReg.Value accuracy(LinReg.Value value) {
387         // for variance
388
double sumErrSquared = 0.0;
389
390         double sumErr = 0.0;
391
392         // for r2
393
// data
394
double sumSquaredY = 0.0;
395         double sumY = 0.0;
396         // predicted
397
double sumSquaredYF = 0.0;
398         double sumYF = 0.0;
399
400         // Obtain the forecast values for this model
401
List JavaDoc yfs = forecast(value);
402
403         // Calculate the Sum of the Absolute Errors
404
Iterator JavaDoc ity = value.ys.iterator();
405         Iterator JavaDoc ityf = yfs.iterator();
406         while (ity.hasNext()) {
407             // Get next data point
408
Double JavaDoc dy = (Double JavaDoc) ity.next();
409             if (dy == null) {
410                 continue;
411             }
412             Double JavaDoc dyf = (Double JavaDoc) ityf.next();
413             if (dyf == null) {
414                 continue;
415             }
416
417             double y = dy.doubleValue();
418             double yf = dyf.doubleValue();
419
420             // Calculate error in forecast, and update sums appropriately
421

422             // the y residual or error
423
double error = yf - y;
424
425             sumErr += error;
426             sumErrSquared += error*error;
427
428             sumY += y;
429             sumSquaredY += (y*y);
430
431             sumYF =+ yf;
432             sumSquaredYF =+ (yf*yf);
433         }
434
435
436         // Initialize the accuracy indicators
437
int n = value.ys.size();
438
439         // Variance
440
// The estimate the value of the error variance is a measure of
441
// variability of the y values about the estimated line.
442
// http://home.ubalt.edu/ntsbarsh/Business-stat/opre504.htm
443
// s2 = SSE/(n-2) = sum (y - yf)2 /(n-2)
444
if (n > 2) {
445             double variance = sumErrSquared / (n-2);
446
447             value.setVariance(variance);
448         }
449
450         // R2
451
// R2 = 1 - (SSE/SST)
452
// SSE = sum square error = Sum( (error-MSE)*(error-MSE) )
453
// MSE = mean error = Sum( error )/n
454
// SST = sum square y diff = Sum( (y-MST)*(y-MST) )
455
// MST = mean y = Sum( y )/n
456
double MSE = sumErr/n;
457         double MST = sumY/n;
458         double SSE = 0.0;
459         double SST = 0.0;
460         ity = value.ys.iterator();
461         ityf = yfs.iterator();
462         while (ity.hasNext()) {
463             // Get next data point
464
Double JavaDoc dy = (Double JavaDoc) ity.next();
465             if (dy == null) {
466                 continue;
467             }
468             Double JavaDoc dyf = (Double JavaDoc) ityf.next();
469             if (dyf == null) {
470                 continue;
471             }
472
473             double y = dy.doubleValue();
474             double yf = dyf.doubleValue();
475
476             double error = yf - y;
477             SSE += (error - MSE)*(error - MSE);
478             SST += (y - MST)*(y - MST);
479         }
480         if (SST != 0.0) {
481             double rSquared = 1 - (SSE/SST);
482
483             value.setRSquared(rSquared);
484         }
485
486
487         return value;
488     }
489
490     public static LinReg.Value linearReg(List JavaDoc xlist, List JavaDoc ylist) {
491
492         // y and x have same number of points
493
int size = ylist.size();
494         double sumX = 0.0;
495         double sumY = 0.0;
496         double sumXX = 0.0;
497         double sumXY = 0.0;
498
499 debug("LinReg.linearReg","ylist.size()=" +ylist.size());
500 debug("LinReg.linearReg","xlist.size()=" +xlist.size());
501         int n = 0;
502         for (int i = 0; i < size; i++) {
503             Object JavaDoc yo = ylist.get(i);
504             Object JavaDoc xo = xlist.get(i);
505             if ((yo == null) || (xo == null)) {
506                 continue;
507             }
508             n++;
509             double y = ((Double JavaDoc) yo).doubleValue();
510             double x = ((Double JavaDoc) xo).doubleValue();
511
512 debug("LinReg.linearReg"," " +i+ " (" +x+ "," +y+ ")");
513             sumX += x;
514             sumY += y;
515             sumXX += x*x;
516             sumXY += x*y;
517         }
518
519         double xMean = sumX / n;
520         double yMean = sumY / n;
521
522 debug("LinReg.linearReg", "yMean=" +yMean);
523 debug("LinReg.linearReg", "(n*sumXX - sumX*sumX)=" +(n*sumXX - sumX*sumX));
524         // The regression line is the line that minimizes the variance of the
525
// errors. The mean error is zero; so, this means that it minimizes the
526
// sum of the squares errors.
527
double slope = (n*sumXY - sumX*sumY) / (n*sumXX - sumX*sumX);
528         double intercept = yMean - slope*xMean;
529
530         LinReg.Value value = new LinReg.Value(intercept, slope, xlist, ylist);
531 debug("LinReg.linearReg","value=" +value);
532
533         return value;
534     }
535
536
537     public static List JavaDoc forecast(LinReg.Value value) {
538         List JavaDoc yfs = new ArrayList JavaDoc(value.xs.size());
539
540         Iterator JavaDoc it = value.xs.iterator();
541         while (it.hasNext()) {
542             Double JavaDoc d = (Double JavaDoc) it.next();
543             // If the value is missing we still must put a place
544
// holder in the y axis, otherwise there is a discontinuity
545
// between the data and the fit.
546
if (d == null) {
547                 yfs.add(null);
548             } else {
549                 double x = d.doubleValue();
550                 double yf = value.intercept + value.slope * x;
551                 yfs.add(new Double JavaDoc(yf));
552             }
553         }
554
555         return yfs;
556     }
557
558     private static class LinRegCalc extends AbstractDoubleCalc {
559         private final ListCalc listCalc;
560         private final DoubleCalc yCalc;
561         private final DoubleCalc xCalc;
562         private final boolean tuples;
563         private final int regType;
564
565         public LinRegCalc(
566                 ResolvedFunCall call,
567                 ListCalc listCalc,
568                 DoubleCalc yCalc,
569                 DoubleCalc xCalc,
570                 boolean tuples,
571                 int regType) {
572             super(call, new Calc[]{listCalc, yCalc, xCalc});
573             this.listCalc = listCalc;
574             this.yCalc = yCalc;
575             this.xCalc = xCalc;
576             this.tuples = tuples;
577             this.regType = regType;
578         }
579
580         public double evaluateDouble(Evaluator evaluator) {
581             Value value =
582                     process(evaluator, listCalc, yCalc, xCalc, tuples);
583             if (value == null) {
584                 return FunUtil.DoubleNull;
585             }
586             switch (regType) {
587             case Intercept:
588                 return value.getIntercept();
589             case Slope:
590                 return value.getSlope();
591             case Variance:
592                 return value.getVariance();
593             case R2:
594                 return value.getRSquared();
595             default:
596             case Point:
597                 throw Util.newInternal("unexpected value " + regType);
598             }
599         }
600     }
601 }
602
603 // End LinReg.java
604
Popular Tags