chiark / gitweb /
curveopt: symbolic: wip, before go back to conditional in C
[moebius3.git] / symbolic.py
1
2 from __future__ import print_function
3
4 from sympy.vector.vector import *
5 from sympy import *
6 import itertools
7 from sympy.utilities.lambdify import lambdify, implemented_function
8 import numpy as np
9
10 from moedebug import *
11
12 def dprint(*args):
13   if not dbg_enabled(): return
14   print(*args)
15
16 def dbg(*args):
17   if not dbg_enabled(): return
18   for vn in args:
19     print('\n    ' + vn + '\n')
20     pprint(eval(vn))
21     print('\n          =\n')
22     pprint(cse(eval(vn)))
23
24 N = CoordSysCartesian('N')
25
26 calculated = False
27
28 def vector_symbols(vnames):
29   out = []
30   for vname in vnames.split(' '):
31     v = Vector.zero
32     for cname in 'i j k'.split(' '):
33       v += getattr(N, cname) * symbols(vname + '_' + cname)
34     out.append(v)
35   return out
36
37 abcd = vector_symbols('A B C D')
38 p = vector_symbols('p')
39
40 E, F = vector_symbols('E F')
41 En, Fn = vector_symbols('En Fn')
42
43 def vector_component(v, ix):
44   return v.components[N.base_vectors()[ix]]
45
46 # x array in numerical algorithm has:
47 #    N x 3    coordinates of points 1..N-2
48 #    1        length parameter |EA| for point 0
49 #    1        length parameter |DE| for point N-1
50
51 def point(p):
52   return Piecewise((p == 0, E + En * 
53
54 def calculate():
55   global calculated
56   if calculated: return
57
58   for i in range(0..len(abcd)):
59     abcd[i] = 
60
61 Piecewise(( P 
62
63   Q = B + 0.5 * (B - A)
64   R = C + 0.5 * (C - D)
65   QR = R - Q
66   BC = C - B
67   cost_ABCD = (QR & QR) / (BC & BC)
68   global cost_ABCD
69   dbg('cost_ABCD')
70   dprint(A)
71
72   cost_ABCD = subst_vect(
73   
74
75   A_end = E + 
76
77   cost_EBCD = subst_vect(cost_ABCD, A, A_end)
78
79   # diff_A_i = cost_ABCD.diff(vector_component(A, 0))
80   # global diff_A_i
81   # dbg('diff_A_i')
82
83   calculated = True
84
85 def ourccode(*a, **kw):
86   return ccode(*a, user_functions={'sinc':'sinc'}, **kw)
87
88 def cprintraw(*s):
89   print(*s)
90
91 def cprint(s):
92   for l in s.split('\n'):
93     cprintraw(l, '\\')
94
95 def cse_prep_cprint(v, tmp_prefix):
96   # => v, but also having cprint'd the common subexpression assignments
97   sym_iter = map((lambda i: symbols('%s%d' % (tmp_prefix,i))),
98                  itertools.count())
99   (defs, vs) = cse(v, symbols=sym_iter)
100   for defname, defval in defs:
101     cprint('double '+ourccode(defval, assign_to=defname))
102   return vs[0]
103
104 def cassign(v, assign_to, tmp_prefix):
105   v = cse_prep_cprint(v, tmp_prefix)
106   cprint(ourccode(v, assign_to=assign_to))
107
108 def gen_diff(current, smalls):
109   global j
110   if not smalls:
111     j = zeros(len(params),0)
112     for param in params:
113       global d
114       paramv = eval(param)
115       d = diff(current, paramv)
116       dbg('d')
117       j = j.row_join(d)
118     dbg('j')
119     j = cse_prep_cprint(j, 'jtmp')
120     for ix in range(0, j.cols):
121       cprint(ourccode(j.col(ix), 'J_COL'))
122       cprint('J_END_COL(%d)' % ix)
123   else:
124     small = smalls[0]
125     smalls = smalls[1:]
126     cprint('if (!IS_SMALL(' + ourccode(small) + ')) {')
127     gen_diff(current, smalls)
128     cprint('} else { /* %s small */' % small)
129     gen_diff(current.replace(
130       sinc(small),
131       1 - small*small/factorial(3) - small**4/factorial(5),
132       ),
133       smalls
134     )
135     cprint('} /* %s small */' % small)
136
137 def gen_misc():
138   cprintraw('// AUTOGENERATED - DO NOT EDIT\n')
139
140 def gen_abcd_ijk_macros():
141   abcd = 'A B C D'.split(' ')
142   ijk = 'i j k'.split(' ')
143   for ai in range(0, len(abcd)):
144     for ii in range(0, len(ijk)):
145       cprintraw(('#define %s_%s' % (abcd[ai], ijk[ii]) +
146                  ' X[ ((P) - 1 + %d) * 3 + %d ]' % (ai, ii)))
147
148   cprint('#define E_CALCULATE_MID')
149   cassign(cost_ABCD,'d','dtmp_mid')
150
151 def gen_e_calculate():
152   cprint('#define E_CALCULATE_MID')
153   cassign(cost_ABCD,'d','dtmp_mid')
154
155 def gen_f_populate():
156   cprint('#define F_POPULATE')
157   cassign(cost_ABCD,'F','ftmp')
158   cprintraw('')
159
160 def gen_j_populate():
161   cprint('#define J_POPULATE')
162   gen_diff(result_dirnscaled, (sh*sh*la, th*th*la))
163   cprintraw('')
164
165 def gen_C():
166   gen_misc()
167   gen_abcd_ijk_macros()
168   gen_e_calculate()
169
170 def get_python():
171   # https://github.com/sympy/sympy/issues/13642
172   # "lambdify sinc gives wrong answer!"
173   out = q_sqparm
174   sinc_fixed = Function('sinc_fixed')
175   implemented_function(sinc_fixed, lambda x: np.sinc(x/np.pi))
176   out = out.subs(sinc,sinc_fixed)
177   p = list(map(eval,params))
178   return lambdify(p, out)