chiark / gitweb /
curveopt: symbolic: wip
[moebius3.git] / symbolic.py
1
2 from __future__ import print_function
3
4 from sympy.vector.vector import *
5 from sympy import *
6 import itertools
7 from sympy.utilities.lambdify import lambdify, implemented_function
8 import numpy as np
9
10 from moedebug import *
11
12 def dprint(*args):
13   if not dbg_enabled(): return
14   print(*args)
15
16 def dbg(*args):
17   if not dbg_enabled(): return
18   for vn in args:
19     print('\n    ' + vn + '\n')
20     pprint(eval(vn))
21     print('\n          =\n')
22     pprint(cse(eval(vn)))
23
24 N = CoordSysCartesian('N')
25
26 calculated = False
27
28 def vector_symbols(vnames):
29   out = []
30   for vname in vnames.split(' '):
31     v = Vector.zero
32     for cname in 'i j k'.split(' '):
33       v += getattr(N, cname) * symbols(vname + '_' + cname)
34     out.append(v)
35   return out
36
37 A, B, C, D = vector_symbols('A B C D')
38
39 def calculate():
40   global calculated
41   if calculated: return
42
43   Q = B + 0.5 * (B - A)
44   R = C + 0.5 * (C - D)
45   QR = R - Q
46   BC = C - B
47   cost_ABCD = (QR & QR) / (BC & BC)
48   global cost_ABCD
49   dbg('cost_ABCD')
50   dprint(A)
51
52   calculated = True
53
54 def ourccode(*a, **kw):
55   return ccode(*a, user_functions={'sinc':'sinc'}, **kw)
56
57 def cprintraw(*s):
58   print(*s)
59
60 def cprint(s):
61   for l in s.split('\n'):
62     cprintraw(l, '\\')
63
64 def cse_prep_cprint(v, tmp_prefix):
65   # => v, but also having cprint'd the common subexpression assignments
66   sym_iter = map((lambda i: symbols('%s%d' % (tmp_prefix,i))),
67                  itertools.count())
68   (defs, vs) = cse(v, symbols=sym_iter)
69   for defname, defval in defs:
70     cprint('double '+ourccode(defval, assign_to=defname))
71   return vs[0]
72
73 def cassign(v, assign_to, tmp_prefix):
74   v = cse_prep_cprint(v, tmp_prefix)
75   cprint(ourccode(v, assign_to=assign_to))
76
77 def gen_diff(current, smalls):
78   global j
79   if not smalls:
80     j = zeros(len(params),0)
81     for param in params:
82       global d
83       paramv = eval(param)
84       d = diff(current, paramv)
85       dbg('d')
86       j = j.row_join(d)
87     dbg('j')
88     j = cse_prep_cprint(j, 'jtmp')
89     for ix in range(0, j.cols):
90       cprint(ourccode(j.col(ix), 'J_COL'))
91       cprint('J_END_COL(%d)' % ix)
92   else:
93     small = smalls[0]
94     smalls = smalls[1:]
95     cprint('if (!IS_SMALL(' + ourccode(small) + ')) {')
96     gen_diff(current, smalls)
97     cprint('} else { /* %s small */' % small)
98     gen_diff(current.replace(
99       sinc(small),
100       1 - small*small/factorial(3) - small**4/factorial(5),
101       ),
102       smalls
103     )
104     cprint('} /* %s small */' % small)
105
106 def gen_misc():
107   cprintraw('// AUTOGENERATED - DO NOT EDIT\n')
108
109 def gen_x_extract():
110   pass
111
112 def gen_f_populate():
113   cprint('#define F_POPULATE')
114   cassign(cost_ABCD,'F','ftmp')
115   cprintraw('')
116
117 def gen_j_populate():
118   cprint('#define J_POPULATE')
119   gen_diff(result_dirnscaled, (sh*sh*la, th*th*la))
120   cprintraw('')
121
122 def gen_C():
123   gen_misc()
124   gen_x_extract()
125   gen_f_populate()
126   gen_j_populate()
127
128 def get_python():
129   # https://github.com/sympy/sympy/issues/13642
130   # "lambdify sinc gives wrong answer!"
131   out = q_sqparm
132   sinc_fixed = Function('sinc_fixed')
133   implemented_function(sinc_fixed, lambda x: np.sinc(x/np.pi))
134   out = out.subs(sinc,sinc_fixed)
135   p = list(map(eval,params))
136   return lambdify(p, out)