chiark / gitweb /
f12a96d18d9496645394195604ef286abb1d7f61
[moebius3.git] / symbolic.py
1
2 from __future__ import print_function
3
4 from sympy.vector.vector import *
5 from sympy import *
6 import itertools
7 from sympy.utilities.lambdify import lambdify, implemented_function
8 import numpy as np
9
10 from moedebug import *
11
12 def dprint(*args):
13   if not dbg_enabled(): return
14   print(*args)
15
16 def dbg(*args):
17   if not dbg_enabled(): return
18   for vn in args:
19     print('\n    ' + vn + '\n')
20     pprint(eval(vn))
21     print('\n          =\n')
22     pprint(cse(eval(vn)))
23
24 N = CoordSysCartesian('N')
25
26 calculated = False
27
28 def vector_symbols(vnames):
29   out = []
30   for vname in vnames.split(' '):
31     v = Vector.zero
32     for cname in 'i j k'.split(' '):
33       v += getattr(N, cname) * symbols(vname + '_' + cname)
34     out.append(v)
35   return out
36
37 A, B, C, D = vector_symbols('A B C D')
38 p = vector_symbols('p')
39
40 E, H = vector_symbols('E H')
41 En, Hn = vector_symbols('En Hn')
42
43 EFl, GHl = symbols('EFl GHl')
44
45 def vector_component(v, ix):
46   return v.components[N.base_vectors()[ix]]
47
48 # x array in numerical algorithm has:
49 #    N x 3    coordinates of points 0..N-3
50 #    1        EFl = length parameter |EF| for point 1
51 #    1        GHl = length parameter |GH| for point N-2
52
53 # fixed array in numerical algorithm has:
54 #    4 x 3    E, En, H, Hn
55
56 #def subst_vect():
57
58 def calculate():
59   global calculated
60   if calculated: return
61
62   Q = B + 0.5 * (B - A)
63   R = C + 0.5 * (C - D)
64   QR = R - Q
65   BC = C - B
66   cost_ABCD = (QR & QR) / (BC & BC)
67   global cost_ABCD
68   dbg('cost_ABCD')
69   dprint(A)
70
71   F = E + En * EFl
72   G = H + Hn * GHl
73
74   #cost_EFCD = subst_vect(cost_ABCD, (A,E), (B,F))
75   #cost_FBCD = subst_vect(cost_ABCD,        (A,F))
76   #cost_ABCG = subst_vect(cost_ABCD,        (D,G))
77   #cost_ABGH = subst_vect(cost_ABCD, (C,G), (D,H))
78
79   # diff_A_i = cost_ABCD.diff(vector_component(A, 0))
80   # global diff_A_i
81   # dbg('diff_A_i')
82
83   calculated = True
84
85 def ourccode(*a, **kw):
86   return ccode(*a, user_functions={'sinc':'sinc'}, **kw)
87
88 def cprintraw(*s):
89   print(*s)
90
91 def cprint(s):
92   for l in s.split('\n'):
93     cprintraw(l, '\\')
94
95 def cse_prep_cprint(v, tmp_prefix):
96   # => v, but also having cprint'd the common subexpression assignments
97   sym_iter = map((lambda i: symbols('%s%d' % (tmp_prefix,i))),
98                  itertools.count())
99   (defs, vs) = cse(v, symbols=sym_iter)
100   for defname, defval in defs:
101     cprint('double '+ourccode(defval, assign_to=defname))
102   return vs[0]
103
104 def cassign(v, assign_to, tmp_prefix):
105   v = cse_prep_cprint(v, tmp_prefix)
106   cprint(ourccode(v, assign_to=assign_to))
107
108 def gen_diff(current, smalls):
109   global j
110   if not smalls:
111     j = zeros(len(params),0)
112     for param in params:
113       global d
114       paramv = eval(param)
115       d = diff(current, paramv)
116       dbg('d')
117       j = j.row_join(d)
118     dbg('j')
119     j = cse_prep_cprint(j, 'jtmp')
120     for ix in range(0, j.cols):
121       cprint(ourccode(j.col(ix), 'J_COL'))
122       cprint('J_END_COL(%d)' % ix)
123   else:
124     small = smalls[0]
125     smalls = smalls[1:]
126     cprint('if (!IS_SMALL(' + ourccode(small) + ')) {')
127     gen_diff(current, smalls)
128     cprint('} else { /* %s small */' % small)
129     gen_diff(current.replace(
130       sinc(small),
131       1 - small*small/factorial(3) - small**4/factorial(5),
132       ),
133       smalls
134     )
135     cprint('} /* %s small */' % small)
136
137 def gen_misc():
138   cprintraw('// AUTOGENERATED - DO NOT EDIT\n')
139
140 def gen_point_index_macro(macro_basename, c_array_name, base_index):
141   cprintraw('#define %s (&%s[%s])'
142             % (macro_basename, c_array_name, base_index))
143   ijk = 'i j k'.split(' ')
144   for ii in range(0, len(ijk)):
145     cprintraw('#define %s_%s (%s[%d])'
146               % (macro_basename, ijk[ii], macro_basename, ii))
147
148 def gen_abcd_ijk_macros():
149   abcd = 'A B C D'.split(' ')
150   eh = 'E En H Hn'.split(' ')
151   for ehi in range(0, len(eh)):
152     gen_point_index_macro(eh[ehi], 'FIXED', ehi * 3)
153
154   for ai in range(0, len(abcd)):
155     gen_point_index_macro(abcd[ai], 'X', '(P%+d) * 3' % (ai - 2))
156
157   cprint('#define E_CALCULATE_MID')
158   cassign(cost_ABCD,'d','dtmp_mid')
159
160 def gen_e_calculate():
161   cprint('#define E_CALCULATE_MID')
162   cprint('if (P==0) '
163   cassign(cost_ABCD,'d','dtmp_mid')
164
165 def gen_f_populate():
166   cprint('#define F_POPULATE')
167   cassign(cost_ABCD,'F','ftmp')
168   cprintraw('')
169
170 def gen_j_populate():
171   cprint('#define J_POPULATE')
172   gen_diff(result_dirnscaled, (sh*sh*la, th*th*la))
173   cprintraw('')
174
175 def gen_C():
176   gen_misc()
177   gen_abcd_ijk_macros()
178   gen_e_calculate()
179
180 def get_python():
181   # https://github.com/sympy/sympy/issues/13642
182   # "lambdify sinc gives wrong answer!"
183   out = q_sqparm
184   sinc_fixed = Function('sinc_fixed')
185   implemented_function(sinc_fixed, lambda x: np.sinc(x/np.pi))
186   out = out.subs(sinc,sinc_fixed)
187   p = list(map(eval,params))
188   return lambdify(p, out)