chiark / gitweb /
curveopt: symbolic: wip
[moebius3.git] / symbolic.py
1
2 from __future__ import print_function
3
4 from sympy.vector.vector import *
5 from sympy import *
6 import itertools
7 from sympy.utilities.lambdify import lambdify, implemented_function
8 import numpy as np
9
10 from moedebug import *
11
12 def dprint(*args):
13   if not dbg_enabled(): return
14   print(*args)
15
16 def dbg(*args):
17   if not dbg_enabled(): return
18   for vn in args:
19     print('\n    ' + vn + '\n')
20     pprint(eval(vn))
21     print('\n          =\n')
22     pprint(cse(eval(vn)))
23
24 N = CoordSysCartesian('N')
25
26 calculated = False
27
28 def vector_symbols(vnames):
29   out = []
30   for vname in vnames.split(' '):
31     v = Vector.zero
32     for cname in 'i j k'.split(' '):
33       v += getattr(N, cname) * symbols(vname + '_' + cname)
34     out.append(v)
35   return out
36
37 A, B, C, D = vector_symbols('A B C D')
38 p = vector_symbols('p')
39
40 E, H = vector_symbols('E H')
41 En, Hn = vector_symbols('En Hn')
42
43 EFl, GHl = symbols('EFl GHl')
44
45 def vector_component(v, ix):
46   return v.components[N.base_vectors()[ix]]
47
48 # x array in numerical algorithm has:
49 #    N x 3    coordinates of points 0..N-3
50 #    1        EFl = length parameter |EF| for point 1
51 #    1        GHl = length parameter |GH| for point N-2
52
53 # fixed array in numerical algorithm has:
54 #    4 x 3    E, En, H, Hn
55
56 #def subst_vect():
57
58 def calculate():
59   global calculated
60   if calculated: return
61
62   Q = B + 0.5 * (B - A)
63   R = C + 0.5 * (C - D)
64   QR = R - Q
65   BC = C - B
66   cost_ABCD = (QR & QR) / (BC & BC)
67   global cost_ABCD
68   dbg('cost_ABCD')
69   dprint(A)
70
71   F = E + En * EFl
72   G = H + Hn * GHl
73
74   # diff_A_i = cost_ABCD.diff(vector_component(A, 0))
75   # global diff_A_i
76   # dbg('diff_A_i')
77
78   calculated = True
79
80 def ourccode(*a, **kw):
81   return ccode(*a, user_functions={'sinc':'sinc'}, **kw)
82
83 def cprintraw(*s):
84   print(*s)
85
86 def cprint(s):
87   for l in s.split('\n'):
88     cprintraw(l, '\\')
89
90 def cse_prep_cprint(v, tmp_prefix):
91   # => v, but also having cprint'd the common subexpression assignments
92   sym_iter = map((lambda i: symbols('%s%d' % (tmp_prefix,i))),
93                  itertools.count())
94   (defs, vs) = cse(v, symbols=sym_iter)
95   for defname, defval in defs:
96     cprint('double '+ourccode(defval, assign_to=defname))
97   return vs[0]
98
99 def cassign(v, assign_to, tmp_prefix):
100   v = cse_prep_cprint(v, tmp_prefix)
101   cprint(ourccode(v, assign_to=assign_to))
102
103 def gen_diff(current, smalls):
104   global j
105   if not smalls:
106     j = zeros(len(params),0)
107     for param in params:
108       global d
109       paramv = eval(param)
110       d = diff(current, paramv)
111       dbg('d')
112       j = j.row_join(d)
113     dbg('j')
114     j = cse_prep_cprint(j, 'jtmp')
115     for ix in range(0, j.cols):
116       cprint(ourccode(j.col(ix), 'J_COL'))
117       cprint('J_END_COL(%d)' % ix)
118   else:
119     small = smalls[0]
120     smalls = smalls[1:]
121     cprint('if (!IS_SMALL(' + ourccode(small) + ')) {')
122     gen_diff(current, smalls)
123     cprint('} else { /* %s small */' % small)
124     gen_diff(current.replace(
125       sinc(small),
126       1 - small*small/factorial(3) - small**4/factorial(5),
127       ),
128       smalls
129     )
130     cprint('} /* %s small */' % small)
131
132 def gen_misc():
133   cprintraw('// AUTOGENERATED - DO NOT EDIT\n')
134
135 def gen_point_coords_macro(macro_basename):
136   ijk = 'i j k'.split(' ')
137   for ii in range(0, len(ijk)):
138     cprintraw('#define %s_%s (%s[%d])'
139               % (macro_basename, ijk[ii], macro_basename, ii))
140
141 def gen_point_index_macro(macro_basename, c_array_name, base_index):
142   cprintraw('#define %s (&%s[%s])'
143             % (macro_basename, c_array_name, base_index))
144   gen_point_coords_macro(macro_basename)
145
146 def gen_point_references():
147   abcd = 'A B C D'.split(' ')
148   eh = 'E En H Hn'.split(' ')
149   for ehi in range(0, len(eh)):
150     gen_point_index_macro(eh[ehi], 'FIXED', ehi * 3)
151
152   cprint('#define POINT(PP) (')
153   cprint(' (PP) == 0   ? E :')
154   cprint(' (PP) == 1   ? F :')
155   cprint(' (PP) == N-2 ? G :')
156   cprint(' (PP) == N-1 ? H :')
157   cprint(' &X[((PP)-2)*3]')
158   cprintraw(')')
159
160   for ai in range(0, len(abcd)):
161     cprintraw('#define %s POINT(P%+d)' % (abcd[ai], ai))
162     gen_point_coords_macro(abcd[ai])
163
164 def gen_calculate_FG():
165   cprint('#define CALCULAGE_F_G')
166   cassign(E,'E','tmp_E')
167   cassign(E,'F','tmp_F')
168
169 def gen_calculate_cost():
170   cprint('#define CALCULATE_COST')
171   cassign(cost_ABCD,'P_cost','tmp_P_cost')
172   cprintraw('')
173
174 def gen_f_populate():
175   cprint('#define F_POPULATE')
176   cassign(cost_ABCD,'F','ftmp')
177   cprintraw('')
178
179 def gen_j_populate():
180   cprint('#define J_POPULATE')
181   gen_diff(result_dirnscaled, (sh*sh*la, th*th*la))
182   cprintraw('')
183
184 def gen_C():
185   gen_misc()
186   gen_point_references()
187   #gen_calculate_FG()
188   gen_calculate_cost()
189
190 def get_python():
191   # https://github.com/sympy/sympy/issues/13642
192   # "lambdify sinc gives wrong answer!"
193   out = q_sqparm
194   sinc_fixed = Function('sinc_fixed')
195   implemented_function(sinc_fixed, lambda x: np.sinc(x/np.pi))
196   out = out.subs(sinc,sinc_fixed)
197   p = list(map(eval,params))
198   return lambdify(p, out)