chiark / gitweb /
helixish: Introduce matmultiply and augmatmultiply
[moebius3.git] / symbolic.py
1
2 from sympy import *
3 import itertools
4
5 from moedebug import *
6
7 from sympy.utilities.lambdify import lambdify, implemented_function
8
9 r, theta, s, la, mu, kappa = symbols('r theta s lambda mu kappa')
10
11 # start      original formulation
12 # rightvars  replaces 
13
14 def dprint(*args):
15   if not dbg_enabled(): return
16   print(*args)
17
18 def dbg(*args):
19   if not dbg_enabled(): return
20   for vn in args:
21     print('\n    ' + vn + '\n')
22     pprint(eval(vn))
23     print('\n          =\n')
24     pprint(cse(eval(vn)))
25
26 calculated = False
27
28 def calculate():
29   global calculated
30   if calculated: return
31
32   p_start = Matrix([
33     r * (1 - cos(theta)),
34     r * sin(theta),
35     mu * s,
36   ])
37
38   global p_rightvars
39   p_rightvars = p_start.subs( theta, s/r ).subs( r, 1/la )
40   dbg('p_rightvars')
41
42   global p_dirn_rightvars
43   p_dirn_rightvars = diff(p_rightvars, s)
44   dbg('p_dirn_rightvars')
45
46   zeta = Wild('zeta')
47
48   global p_nosing
49   p_nosing = (p_rightvars
50               .replace( 1-cos(zeta)  ,   2*sin(zeta/2)**2          )
51               .replace( sin(zeta)**2 ,   zeta*sinc(zeta)*sin(zeta) )
52               )
53   p_nosing[1] = (p_nosing[1]
54               .replace( sin(zeta) , zeta * sinc(zeta)                )
55                  )
56
57   dbg('p_nosing')
58
59   global t
60   t = symbols('t')
61
62   global q_owncoords, q_dirn_owncoords
63   q_owncoords = p_nosing.replace(s,t).replace(la,-la)
64   q_dirn_owncoords = p_dirn_rightvars.replace(s,t).replace(la,-la)
65
66   dbg('q_owncoords','q_dirn_owncoords')
67   dbg('q_owncoords.replace(t,0)','q_dirn_owncoords.replace(t,0)')
68
69   global p2q_translate, p2q_rotate
70   p2q_translate = p_nosing
71   #p2q_rotate_2d = Matrix([ p_dirn_rightvars[0:2],
72
73   #p2q_rotate = eye(3)
74   #p2q_rotate[0:2, 0] = Matrix([ p_dirn_rightvars[1], -p_dirn_rightvars[0] ])
75   #p2q_rotate[0:2, 1] = p_dirn_rightvars[0:2]
76
77   p2q_rotate = Matrix([[  cos(theta), sin(theta), 0 ],
78                        [ -sin(theta), cos(theta), 0 ],
79                        [  0         , 0,          1 ]]).subs(theta,la*s)
80   #p2q_rotate.add_col([0,0])
81   #p2q_rotate.add_row([0,0,1])
82
83   dbg('p2q_rotate')
84
85   global q_dirn_maincoords, q_maincoords
86   q_dirn_maincoords = p2q_rotate * q_dirn_owncoords;
87   q_maincoords = p2q_rotate * q_owncoords + p2q_translate
88
89   dbg('diff(p_dirn_rightvars,s)')
90   dbg('diff(q_dirn_maincoords,t)')
91   dbg('diff(q_dirn_maincoords,t).replace(t,0)')
92
93   assert(Eq(p2q_rotate * Matrix([0,1,mu]), p_dirn_rightvars))
94
95   #for v in 's','t','la','mu':
96   #  dbg('diff(q_maincoords,%s)' % v)
97
98   #print('\n eye3 subs etc.\n')
99   #dbg('''Eq(eye(3) * Matrix([1,0,mu]),
100   #     p_dirn_rightvars .cross(Matrix([0,0,1]) .subs(s,0)))''')
101
102   #dbg('''Eq(p2q_rotate * Matrix([1,0,mu]),
103   #          p_dirn_rightvars .cross(Matrix([0,0,1])))''')
104
105   #eq = Eq(qmat * q_dirn_owncoords_0, p_dirn_rightvars)
106   #print
107   #pprint(eq)
108   #solve(eq, Q)
109
110   dbg('q_maincoords.replace(t,0)','q_dirn_maincoords.replace(t,0)')
111
112   dbg('q_maincoords','q_dirn_maincoords')
113
114   global sinof_mu, cosof_mu
115   sinof_mu = sin(atan(mu))
116   cosof_mu = cos(atan(mu))
117
118   dbg('cosof_mu','sinof_mu')
119
120   o2p_rotate1 = Matrix([[ 1,  0,         0        ],
121                         [ 0,  cosof_mu, +sinof_mu ],
122                         [ 0, -sinof_mu,  cosof_mu ]])
123
124   global check_dirn_p_s0
125   check_dirn_p_s0 = o2p_rotate1 * p_dirn_rightvars.replace(s,0)
126   check_dirn_p_s0.simplify()
127   dbg('check_dirn_p_s0')
128
129   o2p_rotate2 = Matrix([[  cos(kappa), 0, -sin(kappa) ],
130                         [  0,          1,  0          ],
131                         [ +sin(kappa), 0,  cos(kappa) ]])
132
133   p_dirn_orgcoords = o2p_rotate2 * o2p_rotate1 * p_dirn_rightvars
134
135   check_dirn_p_s0 = p_dirn_orgcoords.replace(s,0)
136   check_dirn_p_s0.simplify()
137   dbg('check_dirn_p_s0')
138
139   global check_accel_p_s0
140   check_accel_p_s0 = diff(p_dirn_orgcoords,s).replace(s,0)
141   check_accel_p_s0.simplify()
142   dbg('check_accel_p_s0')
143
144   global q_dirn_orgcoords, q_orgcoords
145   q_dirn_orgcoords = o2p_rotate2 * o2p_rotate1 * q_dirn_maincoords;
146   q_orgcoords = o2p_rotate2 * o2p_rotate1 * q_maincoords;
147   dbg('q_orgcoords','q_dirn_orgcoords')
148
149   global sh, th
150   sh, th = symbols('alpha beta')
151
152   global q_dirn_sqparm, q_sqparm
153   q_dirn_sqparm = q_dirn_orgcoords.replace(s, sh**2).replace(t, th**2)
154   q_sqparm      = q_orgcoords     .replace(s, sh**2).replace(t, th**2)
155
156   dprint('----------------------------------------')
157   dbg('q_sqparm', 'q_dirn_sqparm')
158   dprint('----------------------------------------')
159   for v in 'sh','th','la','mu':
160     dbg('diff(q_sqparm,%s)' % v)
161     dbg('diff(q_dirn_sqparm,%s)' % v)
162   dprint('----------------------------------------')
163
164   gamma = symbols('gamma')
165
166   q_dirn_dirnscaled = q_dirn_sqparm * gamma
167
168   global result_dirnscaled
169   result_dirnscaled = q_sqparm.col_join(q_dirn_dirnscaled)
170   dbg('result_dirnscaled')
171
172   calculated = True
173
174 params = ('sh','th','la','mu','gamma','kappa')
175
176 def ourccode(*a, **kw):
177   return ccode(*a, user_functions={'sinc':'sinc'}, **kw)
178
179 def cprintraw(*s):
180   print(*s)
181
182 def cprint(s):
183   for l in s.split('\n'):
184     cprintraw(l, '\\')
185
186 def cse_prep_cprint(v, tmp_prefix):
187   # => v, but also having cprint'd the common subexpression assignments
188   sym_iter = map((lambda i: symbols('%s%d' % (tmp_prefix,i))),
189                  itertools.count())
190   (defs, vs) = cse(v, symbols=sym_iter)
191   for defname, defval in defs:
192     cprint('double '+ourccode(defval, assign_to=defname))
193   return vs[0]
194
195 def cassign(v, assign_to, tmp_prefix):
196   v = cse_prep_cprint(v, tmp_prefix)
197   cprint(ourccode(v, assign_to=assign_to))
198
199 def gen_diff(current, smalls):
200   global j
201   if not smalls:
202     j = zeros(len(params),0)
203     for param in params:
204       global d
205       paramv = eval(param)
206       d = diff(current, paramv)
207       dbg('d')
208       j = j.row_join(d)
209     dbg('j')
210     j = cse_prep_cprint(j, 'jtmp')
211     for ix in range(0, j.cols):
212       cprint(ourccode(j.col(ix), 'J_COL'))
213       cprint('J_END_COL(%d)' % ix)
214   else:
215     small = smalls[0]
216     smalls = smalls[1:]
217     cprint('if (!IS_SMALL(' + ourccode(small) + ')) {')
218     gen_diff(current, smalls)
219     cprint('} else { /* %s small */' % small)
220     gen_diff(current.replace(
221       sinc(small),
222       1 - small*small/factorial(3) - small**4/factorial(5),
223       ),
224       smalls
225     )
226     cprint('} /* %s small */' % small)
227
228 def gen_misc():
229   cprintraw('// AUTOGENERATED - DO NOT EDIT\n')
230   cprintraw('#define N %d\n' % len(params))
231
232 def gen_x_extract():
233   cprint('#define X_EXTRACT')
234   for ix in range(0, len(params)):
235     cprint('double %s = X(%d);' % (eval(params[ix]), ix))
236   cprintraw()
237
238 def gen_f_populate():
239   cprint('#define F_POPULATE')
240   cassign(result_dirnscaled,'F','ftmp')
241   cprintraw('')
242
243 def gen_j_populate():
244   cprint('#define J_POPULATE')
245   gen_diff(result_dirnscaled, (sh*sh*la, th*th*la))
246   cprintraw('')
247
248 def gen_C():
249   gen_misc()
250   gen_x_extract()
251   gen_f_populate()
252   gen_j_populate()
253
254 def get_python():
255   # https://github.com/sympy/sympy/issues/13642
256   # "lambdify sinc gives wrong answer!"
257   out = q_sqparm
258   sinc_fixed = Function('sinc_fixed')
259   implemented_function(sinc_fixed, lambda x: np.sinc(x/np.pi))
260   out = out.subs(sinc,sinc_fixed)
261   p = list(map(eval,params))
262   return lambdify(p, out)