1 10 11 12 package mondrian.olap.fun; 13 14 import mondrian.olap.*; 15 import mondrian.olap.type.TupleType; 16 import mondrian.olap.type.SetType; 17 import mondrian.calc.*; 18 import mondrian.calc.impl.AbstractDoubleCalc; 19 import mondrian.calc.impl.ValueCalc; 20 import mondrian.mdx.ResolvedFunCall; 21 22 import java.util.ArrayList ; 23 import java.util.Iterator ; 24 import java.util.List ; 25 26 106 107 108 public abstract class LinReg extends FunDefBase { 109 110 final int regType; 111 112 public static final int Point = 0; 113 public static final int R2 = 1; 114 public static final int Intercept = 2; 115 public static final int Slope = 3; 116 public static final int Variance = 4; 117 118 static final Resolver InterceptResolver = new ReflectiveMultiResolver( 119 "LinRegIntercept", 120 "LinRegIntercept(<Set>, <Numeric Expression>[, <Numeric Expression>])", 121 "Calculates the linear regression of a set and returns the value of b in the regression line y = ax + b.", 122 new String []{"fnxn","fnxnn"}, 123 InterceptFunDef.class); 124 125 static final Resolver PointResolver = new ReflectiveMultiResolver( 126 "LinRegPoint", 127 "LinRegPoint(<Numeric Expression>, <Set>, <Numeric Expression>[, <Numeric Expression>])", 128 "Calculates the linear regression of a set and returns the value of y in the regression line y = ax + b.", 129 new String []{"fnnxn","fnnxnn"}, 130 PointFunDef.class); 131 132 static final Resolver SlopeResolver = new ReflectiveMultiResolver( 133 "LinRegSlope", 134 "LinRegSlope(<Set>, <Numeric Expression>[, <Numeric Expression>])", 135 "Calculates the linear regression of a set and returns the value of a in the regression line y = ax + b.", 136 new String []{"fnxn","fnxnn"}, 137 SlopeFunDef.class); 138 139 static final Resolver R2Resolver = new ReflectiveMultiResolver( 140 "LinRegR2", 141 "LinRegR2(<Set>, <Numeric Expression>[, <Numeric Expression>])", 142 "Calculates the linear regression of a set and returns R2 (the coefficient of determination).", 143 new String []{"fnxn","fnxnn"}, 144 R2FunDef.class); 145 146 static final Resolver VarianceResolver = new ReflectiveMultiResolver( 147 "LinRegVariance", 148 "LinRegVariance(<Set>, <Numeric Expression>[, <Numeric Expression>])", 149 "Calculates the linear regression of a set and returns the variance associated with the regression line y = ax + b.", 150 new String []{"fnxn","fnxnn"}, 151 VarianceFunDef.class); 152 153 154 public Calc compileCall(ResolvedFunCall call, ExpCompiler compiler) { 155 final ListCalc listCalc = compiler.compileList(call.getArg(0)); 156 final DoubleCalc yCalc = compiler.compileDouble(call.getArg(1)); 157 final DoubleCalc xCalc = call.getArgCount() > 2 ? 158 compiler.compileDouble(call.getArg(2)) : 159 new ValueCalc(call); 160 final boolean isTuples = 161 ((SetType) listCalc.getType()).getElementType() instanceof 162 TupleType; 163 return new LinRegCalc(call, listCalc, yCalc, xCalc, isTuples, regType); 164 } 165 166 static class Value { 172 private List xs; 173 private List ys; 174 178 double intercept; 179 180 184 double slope; 185 186 187 double rSquared = Double.MAX_VALUE; 188 189 190 double variance = Double.MAX_VALUE; 191 192 Value(double intercept, double slope, List xs, List ys) { 193 this.intercept = intercept; 194 this.slope = slope; 195 this.xs = xs; 196 this.ys = ys; 197 } 198 199 public double getIntercept() { 200 return this.intercept; 201 } 202 203 public double getSlope() { 204 return this.slope; 205 } 206 207 public double getRSquared() { 208 return this.rSquared; 209 } 210 211 216 public void setRSquared(double rSquared) { 217 this.rSquared = rSquared; 218 } 219 220 public double getVariance() { 221 return this.variance; 222 } 223 224 public void setVariance(double variance) { 225 this.variance = variance; 226 } 227 228 public String toString() { 229 return "LinReg.Value: slope of " 230 + slope 231 + " and an intercept of " + intercept 232 + ". That is, y=" 233 + intercept 234 + (slope>0.0 ? " +" : " ") 235 + slope 236 + " * x."; 237 } 238 } 239 240 246 public static class InterceptFunDef extends LinReg { 247 public InterceptFunDef(FunDef funDef) { 248 super(funDef, Intercept); 249 } 250 } 251 252 258 public static class PointFunDef extends LinReg { 259 public PointFunDef(FunDef funDef) { 260 super(funDef, Point); 261 } 262 263 public Calc compileCall(ResolvedFunCall call, ExpCompiler compiler) { 264 final DoubleCalc xPointCalc = compiler.compileDouble(call.getArg(0)); 265 final ListCalc listCalc = compiler.compileList(call.getArg(1)); 266 final DoubleCalc yCalc = compiler.compileDouble(call.getArg(2)); 267 final DoubleCalc xCalc = call.getArgCount() > 3 ? 268 compiler.compileDouble(call.getArg(3)) : 269 new ValueCalc(call); 270 final boolean isTuples = 271 ((SetType) listCalc.getType()).getElementType() instanceof 272 TupleType; 273 return new PointCalc(call, xPointCalc, listCalc, yCalc, xCalc, isTuples); 274 } 275 276 } 277 278 private static class PointCalc extends AbstractDoubleCalc { 279 private final DoubleCalc xPointCalc; 280 private final ListCalc listCalc; 281 private final DoubleCalc yCalc; 282 private final DoubleCalc xCalc; 283 private final boolean tuples; 284 285 public PointCalc( 286 ResolvedFunCall call, 287 DoubleCalc xPointCalc, 288 ListCalc listCalc, 289 DoubleCalc yCalc, DoubleCalc xCalc, boolean tuples) { 290 super(call, new Calc[]{xPointCalc, listCalc, yCalc, xCalc}); 291 this.xPointCalc = xPointCalc; 292 this.listCalc = listCalc; 293 this.yCalc = yCalc; 294 this.xCalc = xCalc; 295 this.tuples = tuples; 296 } 297 298 public double evaluateDouble(Evaluator evaluator) { 299 double xPoint = xPointCalc.evaluateDouble(evaluator); 300 Value value = 301 process(evaluator, listCalc, yCalc, xCalc, tuples); 302 if (value == null) { 303 return FunUtil.DoubleNull; 304 } 305 double yPoint = xPoint * value.getSlope() + 307 value.getIntercept(); 308 return yPoint; 309 } 310 } 311 312 318 public static class SlopeFunDef extends LinReg { 319 public SlopeFunDef(FunDef funDef) { 320 super(funDef, Slope); 321 } 322 } 323 324 330 public static class R2FunDef extends LinReg { 331 public R2FunDef(FunDef funDef) { 332 super(funDef, R2); 333 } 334 } 335 336 342 public static class VarianceFunDef extends LinReg { 343 public VarianceFunDef(FunDef funDef) { 344 super(funDef, Variance); 345 } 346 } 347 348 protected static void debug(String type, String msg) { 349 } 353 354 355 protected LinReg(FunDef funDef, int regType) { 356 super(funDef); 357 this.regType = regType; 358 } 359 360 protected static LinReg.Value process( 361 Evaluator evaluator, 362 ListCalc listCalc, 363 DoubleCalc yCalc, 364 DoubleCalc xCalc, 365 boolean isTuples) { 366 List members = listCalc.evaluateList(evaluator); 367 368 evaluator = evaluator.push(); 369 370 SetWrapper[] sws = evaluateSet( 371 evaluator, members, new DoubleCalc[] {yCalc, xCalc}, isTuples); 372 SetWrapper swY = sws[0]; 373 SetWrapper swX = sws[1]; 374 375 if (swY.errorCount > 0) { 376 debug("LinReg.process","ERROR error(s) count =" +swY.errorCount); 377 return null; 379 } else if (swY.v.size() == 0) { 380 return null; 381 } 382 383 return linearReg(swX.v, swY.v); 384 } 385 386 public static LinReg.Value accuracy(LinReg.Value value) { 387 double sumErrSquared = 0.0; 389 390 double sumErr = 0.0; 391 392 double sumSquaredY = 0.0; 395 double sumY = 0.0; 396 double sumSquaredYF = 0.0; 398 double sumYF = 0.0; 399 400 List yfs = forecast(value); 402 403 Iterator ity = value.ys.iterator(); 405 Iterator ityf = yfs.iterator(); 406 while (ity.hasNext()) { 407 Double dy = (Double ) ity.next(); 409 if (dy == null) { 410 continue; 411 } 412 Double dyf = (Double ) ityf.next(); 413 if (dyf == null) { 414 continue; 415 } 416 417 double y = dy.doubleValue(); 418 double yf = dyf.doubleValue(); 419 420 422 double error = yf - y; 424 425 sumErr += error; 426 sumErrSquared += error*error; 427 428 sumY += y; 429 sumSquaredY += (y*y); 430 431 sumYF =+ yf; 432 sumSquaredYF =+ (yf*yf); 433 } 434 435 436 int n = value.ys.size(); 438 439 if (n > 2) { 445 double variance = sumErrSquared / (n-2); 446 447 value.setVariance(variance); 448 } 449 450 double MSE = sumErr/n; 457 double MST = sumY/n; 458 double SSE = 0.0; 459 double SST = 0.0; 460 ity = value.ys.iterator(); 461 ityf = yfs.iterator(); 462 while (ity.hasNext()) { 463 Double dy = (Double ) ity.next(); 465 if (dy == null) { 466 continue; 467 } 468 Double dyf = (Double ) ityf.next(); 469 if (dyf == null) { 470 continue; 471 } 472 473 double y = dy.doubleValue(); 474 double yf = dyf.doubleValue(); 475 476 double error = yf - y; 477 SSE += (error - MSE)*(error - MSE); 478 SST += (y - MST)*(y - MST); 479 } 480 if (SST != 0.0) { 481 double rSquared = 1 - (SSE/SST); 482 483 value.setRSquared(rSquared); 484 } 485 486 487 return value; 488 } 489 490 public static LinReg.Value linearReg(List xlist, List ylist) { 491 492 int size = ylist.size(); 494 double sumX = 0.0; 495 double sumY = 0.0; 496 double sumXX = 0.0; 497 double sumXY = 0.0; 498 499 debug("LinReg.linearReg","ylist.size()=" +ylist.size()); 500 debug("LinReg.linearReg","xlist.size()=" +xlist.size()); 501 int n = 0; 502 for (int i = 0; i < size; i++) { 503 Object yo = ylist.get(i); 504 Object xo = xlist.get(i); 505 if ((yo == null) || (xo == null)) { 506 continue; 507 } 508 n++; 509 double y = ((Double ) yo).doubleValue(); 510 double x = ((Double ) xo).doubleValue(); 511 512 debug("LinReg.linearReg"," " +i+ " (" +x+ "," +y+ ")"); 513 sumX += x; 514 sumY += y; 515 sumXX += x*x; 516 sumXY += x*y; 517 } 518 519 double xMean = sumX / n; 520 double yMean = sumY / n; 521 522 debug("LinReg.linearReg", "yMean=" +yMean); 523 debug("LinReg.linearReg", "(n*sumXX - sumX*sumX)=" +(n*sumXX - sumX*sumX)); 524 double slope = (n*sumXY - sumX*sumY) / (n*sumXX - sumX*sumX); 528 double intercept = yMean - slope*xMean; 529 530 LinReg.Value value = new LinReg.Value(intercept, slope, xlist, ylist); 531 debug("LinReg.linearReg","value=" +value); 532 533 return value; 534 } 535 536 537 public static List forecast(LinReg.Value value) { 538 List yfs = new ArrayList (value.xs.size()); 539 540 Iterator it = value.xs.iterator(); 541 while (it.hasNext()) { 542 Double d = (Double ) it.next(); 543 if (d == null) { 547 yfs.add(null); 548 } else { 549 double x = d.doubleValue(); 550 double yf = value.intercept + value.slope * x; 551 yfs.add(new Double (yf)); 552 } 553 } 554 555 return yfs; 556 } 557 558 private static class LinRegCalc extends AbstractDoubleCalc { 559 private final ListCalc listCalc; 560 private final DoubleCalc yCalc; 561 private final DoubleCalc xCalc; 562 private final boolean tuples; 563 private final int regType; 564 565 public LinRegCalc( 566 ResolvedFunCall call, 567 ListCalc listCalc, 568 DoubleCalc yCalc, 569 DoubleCalc xCalc, 570 boolean tuples, 571 int regType) { 572 super(call, new Calc[]{listCalc, yCalc, xCalc}); 573 this.listCalc = listCalc; 574 this.yCalc = yCalc; 575 this.xCalc = xCalc; 576 this.tuples = tuples; 577 this.regType = regType; 578 } 579 580 public double evaluateDouble(Evaluator evaluator) { 581 Value value = 582 process(evaluator, listCalc, yCalc, xCalc, tuples); 583 if (value == null) { 584 return FunUtil.DoubleNull; 585 } 586 switch (regType) { 587 case Intercept: 588 return value.getIntercept(); 589 case Slope: 590 return value.getSlope(); 591 case Variance: 592 return value.getVariance(); 593 case R2: 594 return value.getRSquared(); 595 default: 596 case Point: 597 throw Util.newInternal("unexpected value " + regType); 598 } 599 } 600 } 601 } 602 603 | Popular Tags |