Over multiple iterations of improving FormaK (reference IMU model, rocket model, the original Python code generation), I've wanted to leverage the power of Sympy to provide efficient implementations of symbolic concepts before converting to Python or C++.
The tool for this job is
simplify
(API
docs).
With one call, it can simplify polynomials, simplify trigonometry and other
approaches. Combine this with Common Subexpression
Elimination
and we have a powerful pair of tools to write efficient code regardless of the
model.
There's just one problem: Sympy can be incredibly sluggish for some
expressions. Each call can take 10s of seconds. These 10s of seconds can stack
up to minutes of time spent waiting and hoping for a result. For this
experiment, I wanted to take some time to dive into what's going on and try to
understand why simplify
can be so darn slow sometimes.
Experiment Setup¶
The first part of the experiment was to make a slow expression. Taking inspiration from a particularly slow case in the reference IMU design, I opted to simplify expressions for converting from rotation matrices to quaternions then back to rotation matrices.
slow_expr = Quaternion.from_rotation_matrix(reference).to_rotation_matrix()
This results in long polynomials that (apparently) result in a large simplification burden. To find a slow-but-not-too-slow expression, I performed a bottoms up traversal of the slow expression until I got an expression taking ~10 seconds to simplify.
(a/4 + e/4 + i/4 - (-a - e + i + (a*e*i - a*f*h - b*d*i + b*f*g + c*d*h -
c*e*g)**(1/3))*sign(-b + d)**2/4 - (-a + e - i + (a*e*i - a*f*h - b*d*i + b*f*g
+ c*d*h - c*e*g)**(1/3))*sign(c - g)**2/4 + (a - e - i + (a*e*i - a*f*h - b*d*i
+ b*f*g + c*d*h - c*e*g)**(1/3))*sign(-f + h)**2/4 + (a*e*i - a*f*h - b*d*i +
b*f*g + c*d*h - c*e*g)**(1/3)/4)/(a/4 + e/4 + i/4 + (-a - e + i + (a*e*i -
a*f*h - b*d*i + b*f*g + c*d*h - c*e*g)**(1/3))*sign(-b + d)**2/4 + (-a + e - i
+ (a*e*i - a*f*h - b*d*i + b*f*g + c*d*h - c*e*g)**(1/3))*sign(c - g)**2/4 + (a
- e - i + (a*e*i - a*f*h - b*d*i + b*f*g + c*d*h - c*e*g)**(1/3))*sign(-f +
h)**2/4 + (a*e*i - a*f*h - b*d*i + b*f*g + c*d*h - c*e*g)**(1/3)/4)
Profiling¶
The profiling is built on Python's cProfile. It's not exactly performant, but it's easy enough to take a function invocation:
simplify(expr)
and convert the function call to a profiling exercise that can yield information about the function's inner timings:
cProfile.runctx(
"simplify(expr, inverse=False)",
globals=globals(),
locals={"expr": expr},
filename=filename
)
This profiling provides lots of information, such as the cumulative time in a function or the time spent in the function specifically (excluding sub-functions). Unfortunately, that lots of information is often too much information to understand intuitively during an exploratory exercise such as this one where I'm not familiar with the code.
Flamegraphs¶
Flamegraphs to the rescue! Flamegraphs show function timing via samples (on the horizontal axis) and then the call dependencies on the vertical axis. I love the way they look and they're also flexible for quickly understanding function performance. If a function has a material number of calls, then it will stand out visually, but if there is a mix of calls (or none of the calls you expect) then it's worth investing in further exploration to better understand the problem because there's not a clear performance win to be had at that stage.
The flamegraph for this shows time going to cancel
, which itself is split
between factor_terms
and sign_simp
. factor_terms
is built of a recursive
series of calls to gcd_terms
(which makes sense) but does obscure the timing
of related functions and where there may be a performance improvement.
Results?¶
In the time I had for this experiment, I wasn't able to get anything to a particularly satisfying conclusion. I have a few more hypothesis to check (perhaps common subexpression elimination could be used to speed up the simplify process?) but for now I remain stuck. Simplification may remain a final finishing step for a "release" build and won't be actively used quite yet in the FormaK design and development process.
Bonus Round¶
I got a suggestion from @hugovk to try out the latest Python release:
https://fosstodon.org/@hugovk@mastodon.social/111351314621822006
If you're not already, you could compare Python 3.12:
"Dictionary, list, and set comprehensions are now inlined, rather than creating a new single-use function object for each execution of the comprehension. This speeds up execution of a comprehension by up to two times."
https://docs.python.org/3/whatsnew/3.12.html#whatsnew312-pep709 #Python #Python312 #PEP709
The overall trend that I found was the performance on 3.10 and 3.11 were close, but 3.12 performance was actually slower than the other two. It seems like lots of libraries are making major updates for 3.12 (as well as the fact that list comprehensions are getting a speedup), so it may be the case that as libraries update the performance will return to parity or surpass the performance of the libraries executing on 3.10.
I've included the data from profiling below.
3.12 version e991b8c¶
Python Version sys.version_info(major=3, minor=12, micro=0, releaselevel='final', serial=0)
Slow Expression
(a/4 + e/4 + i/4 - (-a - e + i + (a*e*i - a*f*h - b*d*i + b*f*g + c*d*h - c*e*g)**(1/3))*sign(-b + d)**2/4 - (-a + e - i + (a*e*i - a*f*h - b*d*i + b*f*g + c*d*h - c*e*g)**(1/3))*sign(c - g)**2/4 + (a - e - i + (a*e*i - a*f*h - b*d*i + b*f*g + c*d*h - c*e*g)**(1/3))*sign(-f + h)**2/4 + (a*e*i - a*f*h - b*d*i + b*f*g + c*d*h - c*e*g)**(1/3)/4)/(a/4 + e/4 + i/4 + (-a - e + i + (a*e*i - a*f*h - b*d*i + b*f*g + c*d*h - c*e*g)**(1/3))*sign(-b + d)**2/4 + (-a + e - i + (a*e*i - a*f*h - b*d*i + b*f*g + c*d*h - c*e*g)**(1/3))*sign(c - g)**2/4 + (a - e - i + (a*e*i - a*f*h - b*d*i + b*f*g + c*d*h - c*e*g)**(1/3))*sign(-f + h)**2/4 + (a*e*i - a*f*h - b*d*i + b*f*g + c*d*h - c*e*g)**(1/3)/4)
Begin Profiling
End Profiling
Profiling took about 58.970478 seconds
Tue Nov 7 00:50:37 2023 simplify_pstats
83403827 function calls (71409971 primitive calls) in 58.891 seconds
Ordered by: cumulative time
List reduced from 1377 to 10 due to restriction <10>
ncalls tottime percall cumtime percall filename:lineno(function)
512/1 0.066 0.000 58.902 58.902 {built-in method builtins.exec}
147/1 0.007 0.000 58.902 58.902 simplify.py:420(simplify)
1 0.000 0.000 55.085 55.085 piecewise.py:1333(piecewise_simplify)
1 0.000 0.000 54.667 54.667 piecewise.py:1145(piecewise_simplify_arguments)
203 0.006 0.000 44.180 0.218 polytools.py:6801(cancel)
68 0.001 0.000 41.195 0.606 expr.py:3788(cancel)
389 0.001 0.000 35.406 0.091 exprtools.py:1156(factor_terms)
148423/389 0.681 0.000 35.404 0.091 exprtools.py:1217(do)
9880 0.109 0.000 32.502 0.003 exprtools.py:980(gcd_terms)
9880 0.262 0.000 32.006 0.003 exprtools.py:915(_gcd_terms)
Reruns:
- 83462111 function calls (71463102 primitive calls) in 59.368 seconds
- 83402767 function calls (71409367 primitive calls) in 60.314 seconds
3.11 Version 486a2eb¶
Python Version sys.version_info(major=3, minor=11, micro=6, releaselevel='final', serial=0)
Slow Expression
(a/4 + e/4 + i/4 - (-a - e + i + (a*e*i - a*f*h - b*d*i + b*f*g + c*d*h - c*e*g)**(1/3))*sign(-b + d)**2/4 - (-a + e - i + (a*e*i - a*f*h - b*d*i + b*f*g + c*d*h - c*e*g)**(1/3))*sign(c - g)**2/4 + (a - e - i + (a*e*i - a*f*h - b*d*i + b*f*g + c*d*h - c*e*g)**(1/3))*sign(-f + h)**2/4 + (a*e*i - a*f*h - b*d*i + b*f*g + c*d*h - c*e*g)**(1/3)/4)/(a/4 + e/4 + i/4 + (-a - e + i + (a*e*i - a*f*h - b*d*i + b*f*g + c*d*h - c*e*g)**(1/3))*sign(-b + d)**2/4 + (-a + e - i + (a*e*i - a*f*h - b*d*i + b*f*g + c*d*h - c*e*g)**(1/3))*sign(c - g)**2/4 + (a - e - i + (a*e*i - a*f*h - b*d*i + b*f*g + c*d*h - c*e*g)**(1/3))*sign(-f + h)**2/4 + (a*e*i - a*f*h - b*d*i + b*f*g + c*d*h - c*e*g)**(1/3)/4)
Begin Profiling
End Profiling
Profiling took about 51.619768 seconds
Tue Nov 7 00:52:18 2023 simplify_pstats
84451839 function calls (72211354 primitive calls) in 51.589 seconds
Ordered by: cumulative time
List reduced from 1564 to 10 due to restriction <10>
ncalls tottime percall cumtime percall filename:lineno(function)
513/1 0.047 0.000 51.599 51.599 {built-in method builtins.exec}
147/1 0.007 0.000 51.599 51.599 simplify.py:420(simplify)
1 0.000 0.000 48.176 48.176 piecewise.py:1333(piecewise_simplify)
1 0.000 0.000 47.823 47.823 piecewise.py:1145(piecewise_simplify_arguments)
203 0.007 0.000 38.900 0.192 polytools.py:6801(cancel)
68 0.001 0.000 36.493 0.537 expr.py:3788(cancel)
389 0.001 0.000 31.691 0.081 exprtools.py:1156(factor_terms)
148423/389 0.839 0.000 31.689 0.081 exprtools.py:1217(do)
9088/405 0.024 0.000 29.952 0.074 exprtools.py:1242(<listcomp>)
21008/538 0.043 0.000 29.707 0.055 exprtools.py:1263(<listcomp>)
Reruns
- 84435622 function calls (72194285 primitive calls) in 53.926 seconds
- 84430371 function calls (72188205 primitive calls) in 51.855 seconds
3.10 Version 21aaae8¶
Python Version sys.version_info(major=3, minor=10, micro=13, releaselevel='final', serial=0)
Slow Expression
(a/4 + e/4 + i/4 - (-a - e + i + (a*e*i - a*f*h - b*d*i + b*f*g + c*d*h - c*e*g)**(1/3))*sign(-b + d)**2/4 - (-a + e - i + (a*e*i - a*f*h - b*d*i + b*f*g + c*d*h - c*e*g)**(1/3))*sign(c - g)**2/4 + (a - e - i + (a*e*i - a*f*h - b*d*i + b*f*g + c*d*h - c*e*g)**(1/3))*sign(-f + h)**2/4 + (a*e*i - a*f*h - b*d*i + b*f*g + c*d*h - c*e*g)**(1/3)/4)/(a/4 + e/4 + i/4 + (-a - e + i + (a*e*i - a*f*h - b*d*i + b*f*g + c*d*h - c*e*g)**(1/3))*sign(-b + d)**2/4 + (-a + e - i + (a*e*i - a*f*h - b*d*i + b*f*g + c*d*h - c*e*g)**(1/3))*sign(c - g)**2/4 + (a - e - i + (a*e*i - a*f*h - b*d*i + b*f*g + c*d*h - c*e*g)**(1/3))*sign(-f + h)**2/4 + (a*e*i - a*f*h - b*d*i + b*f*g + c*d*h - c*e*g)**(1/3)/4)
Begin Profiling
End Profiling
Profiling took about 50.896008 seconds
Tue Nov 7 00:54:13 2023 simplify_pstats
84462355 function calls (72217560 primitive calls) in 50.865 seconds
Ordered by: cumulative time
List reduced from 1552 to 10 due to restriction <10>
ncalls tottime percall cumtime percall filename:lineno(function)
513/1 0.058 0.000 50.878 50.878 {built-in method builtins.exec}
1 0.000 0.000 50.877 50.877 <string>:1(<module>)
147/1 0.007 0.000 50.877 50.877 simplify.py:420(simplify)
1 0.000 0.000 47.561 47.561 piecewise.py:1333(piecewise_simplify)
1 0.000 0.000 47.213 47.213 piecewise.py:1145(piecewise_simplify_arguments)
203 0.006 0.000 38.083 0.188 polytools.py:6801(cancel)
68 0.001 0.000 35.561 0.523 expr.py:3788(cancel)
389 0.001 0.000 30.927 0.080 exprtools.py:1156(factor_terms)
148423/389 0.838 0.000 30.925 0.079 exprtools.py:1217(do)
9088/405 0.023 0.000 29.286 0.072 exprtools.py:1242(<listcomp>)
Reruns:
- 84422661 function calls (72181588 primitive calls) in 50.693 seconds
- 84433281 function calls (72191175 primitive calls) in 49.238 seconds