def cprint(s):
print(s)
-def cassign(v, assign_to, tmp_prefix):
+def cse_prep_cprint(v, tmp_prefix):
+ # => v, but also having cprint'd the common subexpression assignments
sym_iter = map((lambda i: symbols('%s%d' % (tmp_prefix,i))),
itertools.count())
(defs, vs) = cse(v, symbols=sym_iter)
for defname, defval in defs:
cprint(ccode(defval, assign_to=defname))
- v = vs[0]
- if isinstance(v,Matrix) and v.cols > 1:
- for ix in range(0, v.cols):
- cprint(ccode(v.col(ix), '%s[%d]' % (assign_to, ix)))
- else:
- cprint(ccode(v, assign_to=assign_to))
-
+ return vs[0]
+
+def cassign(v, assign_to, tmp_prefix):
+ v = cse_prep_cprint(v, tmp_prefix)
+ cprint(ccode(v, assign_to=assign_to))
def gen_diff(current, smalls):
global j
dbg('d')
j = j.row_join(d)
dbg('j')
- cassign((j,), 'j', 'jtmp')
+ j = cse_prep_cprint(j, 'jtmp')
+ for ix in range(0, j.cols):
+ cprint(ccode(j.col(ix), 'J(%d)' % ix))
+ cprint('J_END_COL(%d)' % ix)
else:
small = smalls[0]
smalls = smalls[1:]
- cprint('if (!is_small(' + ccode(small) + ')) {')
+ cprint('if (!IS_SMALL(' + ccode(small) + ')) {')
gen_diff(current, smalls)
cprint('} else { /* %s small */' % small)
gen_diff(current.replace(