KickJava   Java API By Example, From Geeks To Geeks.

Java > Open Source Codes > org > lsmp > djep > djep > DifferentiationVisitor


1 /* @author rich
2  * Created on 18-Jun-2003
3  *
4  * This code is covered by a Creative Commons
5  * Attribution, Non Commercial, Share Alike license
6  * <a HREF="http://creativecommons.org/licenses/by-nc-sa/1.0">License</a>
7  */

8    
9 package org.lsmp.djep.djep;
10 import org.lsmp.djep.djep.diffRules.*;
11 import org.lsmp.djep.xjep.*;
12 import org.nfunk.jep.*;
13 import org.nfunk.jep.function.*;
14 import java.util.Hashtable JavaDoc;
15 import java.util.Enumeration JavaDoc;
16 import java.io.PrintStream JavaDoc;
17
18 /**
19  * A class for perfoming differentation of an expression.
20  * To use do
21  * <pre>
22  * JEP j = ...; Node in = ...;
23  * DifferentiationVisitor dv = new DifferentiationVisitor(jep);
24  * dv.addStandardDiffRules();
25  * Node out = dv.differentiate(in,"x");
26  * </pre>
27  * The class follows the visitor pattern described in
28  * {@link org.nfunk.jep.ParserVisitor ParserVisitor}.
29  * The rules for differentiating specific functions are contained in
30  * object which implement
31  * {@link DiffRulesI DiffRulesI}
32  * A number of inner classes which use this interface are defined for specific
33  * function types.
34  * In particular
35  * {@link MacroDiffRules MacroDiffRules}
36  * allow the rule for differentiation to be specified by strings.
37  * New rules can be added using
38  * {@link DJep#addDiffRule} method.
39  * @author R Morris
40  * Created on 19-Jun-2003
41  */

42 public class DifferentiationVisitor extends DeepCopyVisitor
43 {
44     private static final boolean DEBUG = false;
45     private DJep localDJep;
46     private DJep globalDJep;
47     private NodeFactory nf;
48     private TreeUtils tu;
49 // private OperatorSet opSet;
50
/**
51    * Construction with a given set of tree utilities
52    * @param tu
53    */

54   public DifferentiationVisitor(DJep jep)
55   {
56     this.globalDJep = jep;
57     
58     addDiffRule(new AdditionDiffRule("+"));
59     addDiffRule(new SubtractDiffRule("-"));
60     addDiffRule(new MultiplyDiffRule("*"));
61     addDiffRule(new DivideDiffRule("/"));
62     addDiffRule(new PowerDiffRule("^"));
63     addDiffRule(new PassThroughDiffRule("UMinus",globalDJep.getOperatorSet().getUMinus().getPFMC()));
64
65   }
66   
67   
68    /**
69    * Adds the standard set of differentation rules.
70    * Corresponds to all standard functions in the JEP plus a few more.
71    * <pre>
72    * sin,cos,tan,asin,acos,atan,sinh,cosh,tanh,asinh,acosh,atanh
73    * sqrt,log,ln,abs,angle
74    * sum,im,re are handled seperatly.
75    * rand and mod currently unhandled
76    *
77    * Also adds rules for functions not in JEP function list:
78    * sec,cosec,cot,exp,pow,sgn
79    *
80    * TODO include if, min, max, sgn
81    * </pre>
82    * @return false on error
83    */

84   boolean addStandardDiffRules()
85   {
86     try
87     {
88         addDiffRule(new MacroDiffRules(globalDJep,"sin","cos(x)"));
89         addDiffRule(new MacroDiffRules(globalDJep,"cos","-sin(x)"));
90         addDiffRule(new MacroDiffRules(globalDJep,"tan","1/((cos(x))^2)"));
91
92         MacroFunction sec = new MacroFunction("sec",1,"1/cos(x)",globalDJep);
93         globalDJep.addFunction("sec",sec);
94         MacroFunction cosec = new MacroFunction("cosec",1,"1/sin(x)",globalDJep);
95         globalDJep.addFunction("cosec",cosec);
96         MacroFunction cot = new MacroFunction("cot",1,"1/tan(x)",globalDJep);
97         globalDJep.addFunction("cot",cot);
98         
99         addDiffRule(new MacroDiffRules(globalDJep,"sec","sec(x) * tan(x)"));
100         addDiffRule(new MacroDiffRules(globalDJep,"cosec","-cosec(x) * cot(x)"));
101         addDiffRule(new MacroDiffRules(globalDJep,"cot","-(cosec(x))^2"));
102             
103         addDiffRule(new MacroDiffRules(globalDJep,"asin","1/(sqrt(1-x^2))"));
104         addDiffRule(new MacroDiffRules(globalDJep,"acos","-1/(sqrt(1-x^2))"));
105         addDiffRule(new MacroDiffRules(globalDJep,"atan","1/(1+x^2)"));
106
107         addDiffRule(new MacroDiffRules(globalDJep,"sinh","cosh(x)"));
108         addDiffRule(new MacroDiffRules(globalDJep,"cosh","sinh(x)"));
109         addDiffRule(new MacroDiffRules(globalDJep,"tanh","1-(tanh(x))^2"));
110
111         addDiffRule(new MacroDiffRules(globalDJep,"asinh","1/(sqrt(1+x^2))"));
112         addDiffRule(new MacroDiffRules(globalDJep,"acosh","1/(sqrt(x^2-1))"));
113         addDiffRule(new MacroDiffRules(globalDJep,"atanh","1/(1-x^2)"));
114
115         addDiffRule(new MacroDiffRules(globalDJep,"sqrt","1/(2 (sqrt(x)))"));
116         
117         globalDJep.addFunction("exp",new Exp());
118         addDiffRule(new MacroDiffRules(globalDJep,"exp","exp(x)"));
119 // globalDJep.addFunction("pow",new Pow());
120
// addDiffRule(new MacroDiffRules(globalDJep,"pow","y*(pow(x,y-1))","(ln(x)) (pow(x,y))"));
121
addDiffRule(new MacroDiffRules(globalDJep,"ln","1/x"));
122         addDiffRule(new MacroDiffRules(globalDJep,"log", // -> (1/ln(10)) /x = log(e) / x but don't know if e exists
123
globalDJep.getNodeFactory().buildOperatorNode(globalDJep.getOperatorSet().getDivide(),
124                 globalDJep.getNodeFactory().buildConstantNode(
125                     globalDJep.getTreeUtils().getNumber(1/Math.log(10.0))),
126                 globalDJep.getNodeFactory().buildVariableNode(globalDJep.getSymbolTable().makeVarIfNeeded("x")))));
127         // TODO problems here with using a global variable (x) in an essentially local context
128
addDiffRule(new MacroDiffRules(globalDJep,"abs","abs(x)/x"));
129         addDiffRule(new MacroDiffRules(globalDJep,"angle","y/(x^2+y^2)","-x/(x^2+y^2)"));
130         addDiffRule(new MacroDiffRules(globalDJep,"mod","1","0"));
131         addDiffRule(new PassThroughDiffRule(globalDJep,"sum"));
132         addDiffRule(new PassThroughDiffRule(globalDJep,"re"));
133         addDiffRule(new PassThroughDiffRule(globalDJep,"im"));
134         addDiffRule(new PassThroughDiffRule(globalDJep,"rand"));
135         
136         MacroFunction complex = new MacroFunction("macrocomplex",2,"x+i*y",globalDJep);
137         globalDJep.addFunction("macrocomplex",complex);
138         addDiffRule(new MacroFunctionDiffRules(globalDJep,complex));
139         
140 /* addDiffRule(new PassThroughDiffRule("\"<\"",globalDJep.getOperatorSet().getLT().getPFMC()));
141         addDiffRule(new PassThroughDiffRule("\">\"",new Comparative(1)));
142         addDiffRule(new PassThroughDiffRule("\"<=\"",new Comparative(2)));
143         addDiffRule(new PassThroughDiffRule("\">=\"",new Comparative(3)));
144         addDiffRule(new PassThroughDiffRule("\"!=\"",new Comparative(4)));
145         addDiffRule(new PassThroughDiffRule("\"==\"",new Comparative(5)));
146 */

147 // addDiffRule(new DiffDiffRule(this,"diff"));
148
// TODO do we want to add eval here?
149
// addDiffRule(new EvalDiffRule(this,"eval",eval));
150

151         //addDiffRule(new PassThroughDiffRule("\"&&\""));
152
//addDiffRule(new PassThroughDiffRule("\"||\""));
153
//addDiffRule(new PassThroughDiffRule("\"!\""));
154

155         // also consider if, min, max, sgn, dot, cross,
156
//addDiffRule(new MacroDiffRules(this,"sgn","0"));
157
return true;
158     }
159     catch(ParseException e)
160     {
161         System.err.println(e.getMessage());
162         return false;
163     }
164   }
165   
166   /** The set of all differentation rules indexed by name of function. */
167   Hashtable JavaDoc diffRules = new Hashtable JavaDoc();
168   /** Adds the rules for a given function. */
169   void addDiffRule(DiffRulesI rule)
170   {
171     diffRules.put(rule.getName(),rule);
172     if(DEBUG) System.out.println("Adding rule for "+rule.getName());
173   }
174   /** finds the rule for function with given name. */
175   DiffRulesI getDiffRule(String JavaDoc name)
176   {
177     return (DiffRulesI) diffRules.get(name);
178   }
179   
180   /**
181    * Prints all the differentation rules for all functions on System.out.
182    */

183   public void printDiffRules() { printDiffRules(System.out); }
184   
185   /**
186    * Prints all the differentation rules for all functions on specifed stream.
187    */

188   public void printDiffRules(PrintStream JavaDoc out)
189   {
190     out.println("Standard Functions and their derivatives");
191     for(Enumeration JavaDoc enume = globalDJep.getFunctionTable().keys(); enume.hasMoreElements();)
192     {
193         String JavaDoc key = (String JavaDoc) enume.nextElement();
194         PostfixMathCommandI value = (PostfixMathCommandI) globalDJep.getFunctionTable().get(key);
195         DiffRulesI rule = (DiffRulesI) diffRules.get(key);
196         if(rule==null)
197             out.print(key+" No diff rules specified ("+value.getNumberOfParameters()+" arguments).");
198         else
199             out.print(rule.toString());
200         out.println();
201     }
202     for(Enumeration JavaDoc enume = diffRules.keys(); enume.hasMoreElements();)
203         {
204             String JavaDoc key = (String JavaDoc) enume.nextElement();
205             DiffRulesI rule = (DiffRulesI) diffRules.get(key);
206             if(!globalDJep.getFunctionTable().containsKey(key))
207             {
208                 out.print(rule.toString());
209                 out.println("\tnot in JEP function list");
210             }
211         }
212     }
213
214     /**
215      * Differentiates an expression tree wrt a variable var.
216      * @param node the top node of the expresion tree
217      * @param var the variable to differentiate wrt
218      * @return the top node of the differentiated expression
219      * @throws ParseException if some error occured while trying to differentiate, for instance of no rule supplied for given function.
220      * @throws IllegalArgumentException
221      */

222     public Node differentiate(Node node,String JavaDoc var,DJep djep) throws ParseException,IllegalArgumentException JavaDoc
223     {
224       this.localDJep = djep;
225       this.nf=djep.getNodeFactory();
226       this.tu=djep.getTreeUtils();
227       //this.opSet=djep.getOperatorSet();
228

229       if (node == null)
230           throw new IllegalArgumentException JavaDoc("node parameter is null");
231       if (var == null)
232           throw new IllegalArgumentException JavaDoc("var parameter is null");
233
234       Node res = (Node) node.jjtAccept(this,var);
235       return res;
236     }
237
238     /********** Now the recursive calls to differentiate the tree ************/
239
240     /**
241      * Applies differentiation to a function.
242      * Used the rules specified by objects of type {@link DiffRulesI}.
243      * @param node The node of the function.
244      * @param data The variable to differentiate wrt.
245      **/

246
247     public Object JavaDoc visit(ASTFunNode node, Object JavaDoc data) throws ParseException
248     {
249         String JavaDoc name = node.getName();
250
251        //System.out.println("FUN: "+ node + " nchild "+nchild);
252
Node children[] = TreeUtils.getChildrenAsArray(node);
253         Node dchildren[] = acceptChildrenAsArray(node,data);
254
255         if(node.getPFMC() instanceof DiffRulesI)
256         {
257              return ((DiffRulesI) node.getPFMC()).differentiate(node,(String JavaDoc) data,children,dchildren,localDJep);
258         }
259         DiffRulesI rules = (DiffRulesI) diffRules.get(name);
260         if(rules != null)
261         return rules.differentiate(node,(String JavaDoc) data,children,dchildren,localDJep);
262
263         throw new ParseException("Sorry I don't know how to differentiate "+node+"\n");
264     }
265
266      /**
267       * Differentiates a variable.
268       * May want to alter behaviour when using multi equation as diff(f,x)
269       * might not be zero.
270       * @return 1 if the variable has the same name as data
271       * @return 0 if the variable has a different name.
272       */

273      public Object JavaDoc visit(ASTVarNode node, Object JavaDoc data) throws ParseException {
274        String JavaDoc varName = (String JavaDoc) data;
275        Variable var = node.getVar();
276        if(var instanceof DVariable)
277        {
278         DVariable difvar = (DVariable) var;
279         if(varName.equals(var.getName()))
280             return nf.buildConstantNode(tu.getONE());
281         else if(difvar.hasEquation())
282         {
283             PartialDerivative deriv = difvar.findDerivative((String JavaDoc) data,localDJep);
284             return nf.buildVariableNode(deriv);
285         }
286         else
287             return nf.buildConstantNode(tu.getZERO());
288        }
289        if(var instanceof PartialDerivative)
290        {
291             PartialDerivative pvar = (PartialDerivative) var;
292             DVariable dvar = pvar.getRoot();
293             PartialDerivative deriv = dvar.findDerivative(pvar,varName,localDJep);
294             return nf.buildVariableNode(deriv);
295        }
296        throw new ParseException("Encountered non differentiable variable");
297      }
298
299      /**
300       * Differentiates a constant.
301       * @return 0 direvatives of constants are always zero.
302       */

303      public Object JavaDoc visit(ASTConstant node, Object JavaDoc data) throws ParseException {
304         return nf.buildConstantNode(tu.getZERO());
305      }
306 }
307
308 /*end*/
309
Popular Tags