shamatar opened issue #4077:
Feature
Implement an optimization pass that would eliminate the
__multi3
function from WASM binary during JIT by replacing it with ISA specific (mainly forx86_64
andarm64
) sequences, and then inline such sequences into callsites that would allow further optimizationsBenefit
A lot of code dealing with cryptography would benefit form faster full width
u64
multiplications where such__multi3
arisesImplementation
If someone would give a few hints about where to start I'd try to implement it by myself
Alternatives
Not that I'm aware of. Patching into calling come native library function is a huge overhead for modern CPUs (4 cycles for
x86_64
for e.g.mulx
ormul
), and while it would be faster most likely, it's still far from optimal case on a hot pathAs an example a simple multiply-add-carry function like
a*b + c + carry -> (high, low)
that accumulates intou128
without overflows compiles down to the listing below, and it can be a good test subject (transformed fromwasm
intowat
, may be not the best readable)(module (type (;0;) (func (param i32 i64 i64 i64 i64))) (func $mac (type 0) (param i32 i64 i64 i64 i64) (local i32) global.get $__stack_pointer i32.const 16 i32.sub local.tee 5 global.set $__stack_pointer local.get 5 local.get 2 i64.const 0 local.get 1 i64.const 0 call $__multi3 local.get 0 local.get 5 i64.load local.tee 2 local.get 3 i64.add local.tee 3 local.get 4 i64.add local.tee 4 i64.store local.get 0 local.get 5 i32.const 8 i32.add i64.load local.get 3 local.get 2 i64.lt_u i64.extend_i32_u i64.add local.get 4 local.get 3 i64.lt_u i64.extend_i32_u i64.add i64.store offset=8 local.get 5 i32.const 16 i32.add global.set $__stack_pointer ) (func $__multi3 (type 0) (param i32 i64 i64 i64 i64) (local i64 i64 i64 i64 i64 i64) local.get 0 local.get 3 i64.const 4294967295 i64.and local.tee 5 local.get 1 i64.const 4294967295 i64.and local.tee 6 i64.mul local.tee 7 local.get 5 local.get 1 i64.const 32 i64.shr_u local.tee 8 i64.mul local.tee 9 local.get 3 i64.const 32 i64.shr_u local.tee 10 local.get 6 i64.mul i64.add local.tee 5 i64.const 32 i64.shl i64.add local.tee 6 i64.store local.get 0 local.get 10 local.get 8 i64.mul local.get 5 local.get 9 i64.lt_u i64.extend_i32_u i64.const 32 i64.shl local.get 5 i64.const 32 i64.shr_u i64.or i64.add local.get 6 local.get 7 i64.lt_u i64.extend_i32_u i64.add local.get 4 local.get 1 i64.mul local.get 3 local.get 2 i64.mul i64.add i64.add i64.store offset=8 ) (table (;0;) 1 1 funcref) (memory (;0;) 16) (global $__stack_pointer (mut i32) i32.const 1048576) (global (;1;) i32 i32.const 1048576) (global (;2;) i32 i32.const 1048576) (export "memory" (memory 0)) (export "mac" (func $mac)) (export "__data_end" (global 1)) (export "__heap_base" (global 2)) )
cfallin commented on issue #4077:
Hi @shamatar -- thanks for raising this issue! I agree that the lack of "architecture-specific acceleration" for wide arithmetic lowered into Wasm bytecode is suboptimal and annoying.
On the other hand, somehow recognizing that a Wasm module's internal function happens to be the
__multi3
function from the Wasm toolchain's runtime, and then actually replacing that function with another function, seems very likely to me to break the abstraction boundary (and hence security and portability properties) that the Wasm runtime provides. Either we check that it is exactly the__multi3
that we expect (same Wasm bytecode body), in which case this is a very brittle optimization (breaks with any little Wasm-toolchain implementation detail change), or we don't, in which case we might replace the code with an implementation that does something different (incorrectly). So, for certain, we would not implement logic that somehow recognizes__multi3
by name as special.I think that there is a way around this though: we could pattern-match the actual operations and use the faster architecture-specific lowering when possible. In this specific case, a pattern could match whatever subgraphs correspond to the fast x64 instructions (
mulx
and friends?). I don't know enough about the state of the art of computer arithmetic implementations on x64 or aarch64 to suggest specific mappings.The place to put such lowerings would be in the ISLE DSL (see e.g. cranelift/codegen/src/isa/x64/lower.isle); we'd be happy to help if you want to try to implement something here.
shamatar commented on issue #4077:
Hey @cfallin
Thank you for a detailed response. I was only mentioning
__multi3
replacement by name as I have seen in one of the old threads an idea to call external library for it for a speedup. But if name matching is potentially too fragile (it may be safely-enough allowed under some feature flag if such flag would exist :) ) I'd try to implement matching-by-logic.__multi3
body looks like a standardu64 -> u32
andu32 * u32 -> u64
approach for simulation and hopefully should be catchable.As for "state of the art" - the best one can get is
adx + mulx
combo for generic case I believe, but e.g. llvm never emitsmulx
, so I'd start with trying to optimize intoadd + mul
forx86_64
I'm not use that a file you have mentioned is a right place as https://github.com/bytecodealliance/wasmtime/blob/b69fede72f6df81edbb0b3286b01e0a101a1758f/cranelift/codegen/src/isa/x64/lower.isle#L916 has a rule for full width multiplication lowering already, but the optimization first needs to catch
__multi3
, and emit justimul x y
for it + inline into callsite. May be it also requires more than one pass, but the first step is to get to justimul x y
cfallin commented on issue #4077:
catch __multi3, and emit just imul x y for it + inline into callsite.
That's true, one way this could work in the future is to recognize the whole function and replace it with an
imul.i128
(if that's what you mean by "catch `__multi3, and emit just imul"?) but that would I think hit the brittleness issues I mentioned.Maybe the right way to start is to draw the dataflow graph of the bytecode you listed above, and find subgraphs that correspond to what
adx
andmulx
do? If we can write a pattern that matches each of those, maybe we can get something at least slightly better than what we have today.Eventually we will also have an inliner (it's really an orthogonal improvement) and when we do, it could combine with good-enough pattern matching on the function body here to produce an inlined sequence of instructions. In other words, getting to a tight sequence of instructions, and inlining that at the callsite, are two separate problems so let's solve each separately.
cc @abrown and @jlb6740 for thoughts on the x86-64 lowering here...
shamatar commented on issue #4077:
to recognize the whole function and replace it
- yes, it's a final goal, and ideally function body will be patter-matched somehow (and not by name). My impression on lower.isle file is that it's a definition of "simple" rules to lower some short sequences of CLIF instructions to machine instructions, and not the rules for "large" pattern matching.I'll make a dataflow, just need some time
shamatar commented on issue #4077:
More precisely a rule
;; `i64` and smaller. ;; Multiply two registers. (rule (lower (has_type (fits_in_64 ty) (imul x y))) (x64_mul ty x y)) ``` Would match u64 * u64 multiplication, but I'm also not sure if this rule is about getting low bits only (half width mul) or all of them (full width) as it doesn't reflect anything about return type. `__multi3` indeed has `i64.mul` in it, but inputs to this `i64.mul` are 32-bit values, so it's naturally full-width mul. And I'd like to "guess" that larger piece of code is indeed full width u64 multiplication that just doesn't have a WASM instruction :) ~~~
abrown commented on issue #4077:
(Catching up on this thread...) For context, the code behind the WAT that you posted is probably coming from somewhere like this in Clang: https://github.com/llvm-mirror/compiler-rt/blob/master/lib/builtins/multi3.c. There is quite a bit of shifting and masking there which, as we can see in the WAT, get translated directly to WebAssembly.
I'd be interested to understand what library is being used and how you're using it (e.g., how come
__multi3
is on the hot path?). Here's why:
a. it could be that the library could be tweaked for better compilation to Wasm (e.g., should the algorithm use SIMD instead?)
b. if it is a widely-used library, then it motivates improving this sequence in WasmtimeTo optimize this in Wasmtime, I suspect the bottom-up approach (i.e., attempting to get rid of the extra shifting and masking) is going to be "good enough" relative to the code sequence currently emitted by Clang in native code (@cfallin has that) and will be less brittle than trying to match the entire function or the called name. But another option might be to attempt to optimize this at the LLVM layer: perhaps it does not need to be as literal about how it translates the C code I linked to to the WAT @shamatar posted.
abrown edited a comment on issue #4077:
(Catching up on this thread...) For context, the code behind the WAT that you posted is probably coming from somewhere like this in LLVM: https://github.com/llvm-mirror/compiler-rt/blob/master/lib/builtins/multi3.c. There is quite a bit of shifting and masking there which, as we can see in the WAT, get translated directly to WebAssembly.
I'd be interested to understand what library is being used and how you're using it (e.g., how come
__multi3
is on the hot path?). Here's why:
a. it could be that the library could be tweaked for better compilation to Wasm (e.g., should the algorithm use SIMD instead?)
b. if it is a widely-used library, then it motivates improving this sequence in WasmtimeTo optimize this in Wasmtime, I suspect the bottom-up approach (i.e., attempting to get rid of the extra shifting and masking) is going to be "good enough" relative to the code sequence currently emitted by Clang in native code (@cfallin has that) and will be less brittle than trying to match the entire function or the called name. But another option might be to attempt to optimize this at the LLVM layer: perhaps it does not need to be as literal about how it translates the C code I linked to to the WAT @shamatar posted.
abrown edited a comment on issue #4077:
(Catching up on this thread...) For context, the code behind the WAT that you posted is probably coming from somewhere like this in LLVM: https://github.com/llvm-mirror/compiler-rt/blob/master/lib/builtins/multi3.c. There is quite a bit of shifting and masking there which, as we can see in the WAT, get translated directly to WebAssembly.
I'd be interested to understand what library is being used and how you're using it (e.g., how come
__multi3
is on the hot path?). Here's why:
a. it could be that the library could be tweaked for better compilation to Wasm (e.g., should the algorithm use SIMD instead?)
b. if it is a widely-used library, then it motivates improving this sequence in Wasmtime (or even in Wasm itself)To optimize this in Wasmtime, I suspect the bottom-up approach (i.e., attempting to get rid of the extra shifting and masking) is going to be "good enough" relative to the code sequence currently emitted by Clang in native code (@cfallin has that) and will be less brittle than trying to match the entire function or the called name. But another option might be to attempt to optimize this at the LLVM layer: perhaps it does not need to be as literal about how it translates the C code I linked to to the WAT @shamatar posted.
shamatar commented on issue #4077:
I've made a minimal example of the typical
u64 * u64 + u64 + u64
operation that is non-overflowing due to output range being(2^64 - 1) * (2^64 - 1) + 2 * (2^64 - 1) = 2^128 - 1
#![feature(bigint_helper_methods)] #[repr(C)] pub struct U128Pair { pub low: u64, pub high: u64 } // #[no_mangle] // pub extern "C" fn mac(a: u64, b: u64, c: u64, carry: u64) -> U128Pair { // let result = (a as u128) * (b as u128) + (c as u128) + (carry as u128); // U128Pair { // low: result as u64, // high: (result >> 64) as u64 // } // } #[no_mangle] pub extern "C" fn mac(a: u64, b: u64, c: u64, carry: u64) -> U128Pair { let (low, high) = a.carrying_mul(b, c); let (low, of) = low.overflowing_add(carry); let high = high.wrapping_add(of as u64); U128Pair { low, high, } }
(both the commented code and one using more explicit operations provides the same assembly. This is a typical way to do big integer math on CPU. SIMD optimizations are possible, but a separate field of the art due to non-carry SIMD, non-widening SIMD, and CPU frequency quircks if SIMD is used e.g. on Intel. So my example is an "average case" and solving exactly this part would be the most beneficial (I'm actually not in the need to have a speedup in this problem for any production code, but it's a nice free time task)
So compiler fits
__multi3
in there as half-width u128 multiplication. Even though the name is the same, it's actually not used in "full power" since high parts of arguments are constant zeroes. I think it's possible to match the internals of__multi3
, but it's still more then just elimination of shifts - it must be replaced just 1 machine instruction + stack manipulationexport function mac(a:{ a:long, b:long }, b:long, c:long, d:long, e:long) { var f:long_ptr = stack_pointer - 16; stack_pointer = f; multi3(f, c, 0L, b, 0L); a.a = (e = (d = (c = f[0]) + d) + e); a.b = i64_extend_i32_u(e < d) + (f + 8)[0]:long + i64_extend_i32_u(d < c); stack_pointer = f + 16; } function multi3(a:{ a:long, b:long }, b:long, c:long, d:long, e:long) { var f:long; var g:long; var h:long; var i:long; var j:long; var k:long; a.a = (g = (h = (f = d & 4294967295L) * (g = b & 4294967295L)) + ((f = (j = f * (i = b >> 32L)) + (k = d >> 32L) * g) << 32L)); a.b = k * i + (i64_extend_i32_u(f < j) << 32L | f >> 32L) + i64_extend_i32_u(g < h) + e * b + d * c; }
I can
shamatar edited a comment on issue #4077:
I've made a minimal example of the typical
u64 * u64 + u64 + u64
operation that is non-overflowing due to output range being(2^64 - 1) * (2^64 - 1) + 2 * (2^64 - 1) = 2^128 - 1
#![feature(bigint_helper_methods)] #[repr(C)] pub struct U128Pair { pub low: u64, pub high: u64 } // #[no_mangle] // pub extern "C" fn mac(a: u64, b: u64, c: u64, carry: u64) -> U128Pair { // let result = (a as u128) * (b as u128) + (c as u128) + (carry as u128); // U128Pair { // low: result as u64, // high: (result >> 64) as u64 // } // } #[no_mangle] pub extern "C" fn mac(a: u64, b: u64, c: u64, carry: u64) -> U128Pair { let (low, high) = a.carrying_mul(b, c); let (low, of) = low.overflowing_add(carry); let high = high.wrapping_add(of as u64); U128Pair { low, high, } }
(both the commented code and one using more explicit operations provides the same assembly. This is a typical way to do big integer math on CPU. SIMD optimizations are possible, but a separate field of the art due to non-carry SIMD, non-widening SIMD, and CPU frequency quircks if SIMD is used e.g. on Intel. So my example is an "average case" and solving exactly this part would be the most beneficial (I'm actually not in the need to have a speedup in this problem for any production code, but it's a nice free time task)
So compiler fits
__multi3
in there as half-width u128 multiplication. Even though the name is the same, it's actually not used in "full power" since high parts of arguments are constant zeroes. I think it's possible to match the internals of__multi3
, but it's still more then just elimination of shifts - it must be replaced just 1 machine instruction + stack manipulationexport function mac(a:{ a:long, b:long }, b:long, c:long, d:long, e:long) { var f:long_ptr = stack_pointer - 16; stack_pointer = f; multi3(f, c, 0L, b, 0L); a.a = (e = (d = (c = f[0]) + d) + e); a.b = i64_extend_i32_u(e < d) + (f + 8)[0]:long + i64_extend_i32_u(d < c); stack_pointer = f + 16; } function multi3(a:{ a:long, b:long }, b:long, c:long, d:long, e:long) { var f:long; var g:long; var h:long; var i:long; var j:long; var k:long; a.a = (g = (h = (f = d & 4294967295L) * (g = b & 4294967295L)) + ((f = (j = f * (i = b >> 32L)) + (k = d >> 32L) * g) << 32L)); a.b = k * i + (i64_extend_i32_u(f < j) << 32L | f >> 32L) + i64_extend_i32_u(g < h) + e * b + d * c; }
shamatar edited a comment on issue #4077:
I've made a minimal example of the typical
u64 * u64 + u64 + u64
operation that is non-overflowing due to output range being(2^64 - 1) * (2^64 - 1) + 2 * (2^64 - 1) = 2^128 - 1
#![feature(bigint_helper_methods)] #[repr(C)] pub struct U128Pair { pub low: u64, pub high: u64 } // #[no_mangle] // pub extern "C" fn mac(a: u64, b: u64, c: u64, carry: u64) -> U128Pair { // let result = (a as u128) * (b as u128) + (c as u128) + (carry as u128); // U128Pair { // low: result as u64, // high: (result >> 64) as u64 // } // } #[no_mangle] pub extern "C" fn mac(a: u64, b: u64, c: u64, carry: u64) -> U128Pair { let (low, high) = a.carrying_mul(b, c); let (low, of) = low.overflowing_add(carry); let high = high.wrapping_add(of as u64); U128Pair { low, high, } }
both the commented code and one using more explicit operations provides the same assembly. This is a typical way to do big integer math on CPU. SIMD optimizations are possible, but a separate field of the art due to non-carry SIMD, non-widening SIMD, and CPU frequency quircks if SIMD is used e.g. on Intel. So my example is an "average case" and solving exactly this part would be the most beneficial (I'm actually not in the need to have a speedup in this problem for any production code, but it's a nice free time task)
So compiler fits
__multi3
in there as half-width u128 multiplication. Even though the name is the same, it's actually not used in "full power" since high parts of arguments are constant zeroes. I think it's possible to match the internals of__multi3
, but it's still more then just elimination of shifts - it must be replaced just 1 machine instruction + stack manipulationexport function mac(a:{ a:long, b:long }, b:long, c:long, d:long, e:long) { var f:long_ptr = stack_pointer - 16; stack_pointer = f; multi3(f, c, 0L, b, 0L); a.a = (e = (d = (c = f[0]) + d) + e); a.b = i64_extend_i32_u(e < d) + (f + 8)[0]:long + i64_extend_i32_u(d < c); stack_pointer = f + 16; } function multi3(a:{ a:long, b:long }, b:long, c:long, d:long, e:long) { var f:long; var g:long; var h:long; var i:long; var j:long; var k:long; a.a = (g = (h = (f = d & 4294967295L) * (g = b & 4294967295L)) + ((f = (j = f * (i = b >> 32L)) + (k = d >> 32L) * g) << 32L)); a.b = k * i + (i64_extend_i32_u(f < j) << 32L | f >> 32L) + i64_extend_i32_u(g < h) + e * b + d * c; }
shamatar edited a comment on issue #4077:
I've made a minimal example of the typical
u64 * u64 + u64 + u64
operation that is non-overflowing due to output range being(2^64 - 1) * (2^64 - 1) + 2 * (2^64 - 1) = 2^128 - 1
#![feature(bigint_helper_methods)] #[repr(C)] pub struct U128Pair { pub low: u64, pub high: u64 } // #[no_mangle] // pub extern "C" fn mac(a: u64, b: u64, c: u64, carry: u64) -> U128Pair { // let result = (a as u128) * (b as u128) + (c as u128) + (carry as u128); // U128Pair { // low: result as u64, // high: (result >> 64) as u64 // } // } #[no_mangle] pub extern "C" fn mac(a: u64, b: u64, c: u64, carry: u64) -> U128Pair { let (low, high) = a.carrying_mul(b, c); let (low, of) = low.overflowing_add(carry); let high = high.wrapping_add(of as u64); U128Pair { low, high, } }
both the commented code and one using more explicit operations provides the same assembly. This is a typical way to do big integer math on CPU. SIMD optimizations are possible, but a separate field of the art due to non-carry SIMD, non-widening SIMD, and CPU frequency quircks if SIMD is used e.g. on Intel. So my example is an "average case" and solving exactly this part would be the most beneficial (I'm actually not in the need to have a speedup in this problem for any production code, but it's a nice free time task)
So compiler fits
__multi3
in there as half-width u128 multiplication. Even though the name is the same, it's actually not used in "full power" since high parts of arguments are constant zeroes. I think it's possible to match the internals of__multi3
, but it's still more then just elimination of shifts - it must be replaced just 1 machine instruction + stack manipulationexport function mac(a:{ a:long, b:long }, b:long, c:long, d:long, e:long) { var f:long_ptr = stack_pointer - 16; stack_pointer = f; multi3(f, c, 0L, b, 0L); a.a = (e = (d = (c = f[0]) + d) + e); a.b = i64_extend_i32_u(e < d) + (f + 8)[0]:long + i64_extend_i32_u(d < c); stack_pointer = f + 16; } function multi3(a:{ a:long, b:long }, b:long, c:long, d:long, e:long) { var f:long; var g:long; var h:long; var i:long; var j:long; var k:long; a.a = (g = (h = (f = d & 4294967295L) * (g = b & 4294967295L)) + ((f = (j = f * (i = b >> 32L)) + (k = d >> 32L) * g) << 32L)); a.b = k * i + (i64_extend_i32_u(f < j) << 32L | f >> 32L) + i64_extend_i32_u(g < h) + e * b + d * c; }
Note: it's actually only possible to replace the way how
__multi3
is used in here, and not in general, since it's indeed full-width multiplication for u64 while expressed as half-width mul for u128 with constant zero high partsmulti3(f, c, 0L, b, 0L);
shamatar commented on issue #4077:
I should add that even though it may be possible to fine tune LLVM too (even further from my world), it's still not possible to generate WASM code that would be better than internals of
__multi3
(we can only removee * b + d * c;
subexpression as trivial multiplications by 0), so the problem will move from matching current__multi3
internals into matching just some other similar function body
cfallin commented on issue #4077:
My impression on lower.isle file is that it's a definition of "simple" rules to lower some short sequences of CLIF instructions to machine instructions, and not the rules for "large" pattern matching.
It's potentially both! We're still in the midst of translating all of the basic lowering into ISLE, so that is most of what we have. But the design goal is absolutely to allow arbitrarily complex patterns, and we hope to grow the library of patterns we match and optimize over time.
I took a look at the gcc output on x86-64 for an
__int128 * __int128
case (in C) and I saw:movq %rdx, %r8 movq %rdx, %rax mulq %rdi imulq %r8, %rsi addq %rsi, %rdx imulq %rdi, %rcx addq %rcx, %rdx
so it should be possible to do a lot better here, and without recent extensions like mulx/adx.
shamatar commented on issue #4077:
so it should be possible to do a lot better here, and without recent extensions like mulx/adx.
Yes, it's kind of possible due to two carry chains (in my example of
u64 * u64 + u64 + u64
), but it would be the next step (andadx
is much more useful for even "larger" math, like u256 full width multiplication).So my example's "optimal" code is like
mac: mov r8, rdx mov rax, rsi mul rdi add rax, r8 adc rdx, 0 add rax, rcx adc rdx, 0 ret
where
adx
will allow to speedup two carry propagation chains.mulx
allows more flexible register allocation, but none of the compilers uses it in practice as far as I known
sparker-arm commented on issue #4077:
Since cranelift supports i128, maybe we could perform clif-to-clif transforms to generate a i128 mul (if a backend wants that). That way we can avoid each backend having to add complicated matching patterns. Plus, at least for aarch64, there's already a lot of existing support for i128.
shamatar commented on issue #4077:
I'm still trying to understand pattern-matching and have it (at least fragile) to capture the full body of
__multi3
. Then it can be replaced byi128
half-width multiplication or something else. Then I'd have to make another large step to inline it and also constant-fold it (since in practice it's noti128
half-mul, buti64
full-mul), that is most likely even more complex
shamatar commented on issue #4077:
I've tried to insert an optimization into cranelift directly. I should be able to match the full body of the generated
__multi3
and transform it intoi128
multiplication. Please check the #4106
shamatar commented on issue #4077:
Made an initial functional PR, folds from
block0(v0: i64, v1: i64, v2: i32, v3: i64, v4: i64, v5: i64, v6: i64): @5130 v7 = iconst.i64 0 @5136 v8 = iconst.i64 0xffff_ffff @513c v9 = band_imm v5, 0xffff_ffff @5141 v10 = iconst.i64 0xffff_ffff @5147 v11 = band_imm v3, 0xffff_ffff @514a v12 = imul v9, v11 @5151 v13 = iconst.i64 32 @5153 v14 = ushr_imm v3, 32 @5156 v15 = imul v9, v14 @515b v16 = iconst.i64 32 @515d v17 = ushr_imm v5, 32 @5162 v18 = imul v17, v11 @5163 v19 = iadd v15, v18 @5166 v20 = iconst.i64 32 @5168 v21 = ishl_imm v19, 32 @5169 v22 = iadd v12, v21 @516c v23 = heap_addr.i64 heap0, v2, 1 @516c store little v22, v23 @5175 v24 = imul v17, v14 @517a v25 = icmp ult v19, v15 @517a v26 = bint.i32 v25 @517b v27 = uextend.i64 v26 @517c v28 = iconst.i64 32 @517e v29 = ishl_imm v27, 32 @5181 v30 = iconst.i64 32 @5183 v31 = ushr_imm v19, 32 @5184 v32 = bor v29, v31 @5185 v33 = iadd v24, v32 @518a v34 = icmp ult v22, v12 @518a v35 = bint.i32 v34 @518b v36 = uextend.i64 v35 @518c v37 = iadd v33, v36 @5191 v38 = imul v6, v3 @5196 v39 = imul v5, v4 @5197 v40 = iadd v38, v39 @5198 v41 = iadd v37, v40 @5199 v42 = heap_addr.i64 heap0, v2, 1 @5199 store little v41, v42+8 @519c jump block1 block1: @519c return
into
block0(v0: i64, v1: i64, v2: i32, v3: i64, v4: i64, v5: i64, v6: i64): v43 = iconcat v3, v4 v44 = iconcat v5, v6 v45 = imul v43, v44 v46, v47 = isplit v45 @5130 v7 = iconst.i64 0 @5136 v8 = iconst.i64 0xffff_ffff @513c v9 = nop @5141 v10 = iconst.i64 0xffff_ffff @5147 v11 = nop @514a v12 = nop @5151 v13 = iconst.i64 32 @5153 v14 = nop @5156 v15 = nop @515b v16 = iconst.i64 32 @515d v17 = nop @5162 v18 = nop @5163 v19 = nop @5166 v20 = iconst.i64 32 @5168 v21 = nop @5169 v22 = nop @516c v23 = heap_addr.i64 heap0, v2, 1 @516c store little v46, v23 @5175 v24 = nop @517a v25 = nop @517a v26 = nop @517b v27 = nop @517c v28 = iconst.i64 32 @517e v29 = nop @5181 v30 = iconst.i64 32 @5183 v31 = nop @5184 v32 = nop @5185 v33 = nop @518a v34 = nop @518a v35 = nop @518b v36 = nop @518c v37 = nop @5191 v38 = nop @5196 v39 = nop @5197 v40 = nop @5198 v41 = nop @5199 v42 = heap_addr.i64 heap0, v2, 1 @5199 store little v47, v42+8 @519c jump block1 block1: @519c return
shamatar edited a comment on issue #4077:
Made an initial functional PR, folds from
block0(v0: i64, v1: i64, v2: i32, v3: i64, v4: i64, v5: i64, v6: i64): @5130 v7 = iconst.i64 0 @5136 v8 = iconst.i64 0xffff_ffff @513c v9 = band_imm v5, 0xffff_ffff @5141 v10 = iconst.i64 0xffff_ffff @5147 v11 = band_imm v3, 0xffff_ffff @514a v12 = imul v9, v11 @5151 v13 = iconst.i64 32 @5153 v14 = ushr_imm v3, 32 @5156 v15 = imul v9, v14 @515b v16 = iconst.i64 32 @515d v17 = ushr_imm v5, 32 @5162 v18 = imul v17, v11 @5163 v19 = iadd v15, v18 @5166 v20 = iconst.i64 32 @5168 v21 = ishl_imm v19, 32 @5169 v22 = iadd v12, v21 @516c v23 = heap_addr.i64 heap0, v2, 1 @516c store little v22, v23 @5175 v24 = imul v17, v14 @517a v25 = icmp ult v19, v15 @517a v26 = bint.i32 v25 @517b v27 = uextend.i64 v26 @517c v28 = iconst.i64 32 @517e v29 = ishl_imm v27, 32 @5181 v30 = iconst.i64 32 @5183 v31 = ushr_imm v19, 32 @5184 v32 = bor v29, v31 @5185 v33 = iadd v24, v32 @518a v34 = icmp ult v22, v12 @518a v35 = bint.i32 v34 @518b v36 = uextend.i64 v35 @518c v37 = iadd v33, v36 @5191 v38 = imul v6, v3 @5196 v39 = imul v5, v4 @5197 v40 = iadd v38, v39 @5198 v41 = iadd v37, v40 @5199 v42 = heap_addr.i64 heap0, v2, 1 @5199 store little v41, v42+8 @519c jump block1 block1: @519c return
into
block0(v0: i64, v1: i64, v2: i32, v3: i64, v4: i64, v5: i64, v6: i64): v43 = iconcat v3, v4 v44 = iconcat v5, v6 v45 = imul v43, v44 v46, v47 = isplit v45 @5130 v7 = iconst.i64 0 @5136 v8 = iconst.i64 0xffff_ffff @513c v9 = nop @5141 v10 = iconst.i64 0xffff_ffff @5147 v11 = nop @514a v12 = nop @5151 v13 = iconst.i64 32 @5153 v14 = nop @5156 v15 = nop @515b v16 = iconst.i64 32 @515d v17 = nop @5162 v18 = nop @5163 v19 = nop @5166 v20 = iconst.i64 32 @5168 v21 = nop @5169 v22 = nop @516c v23 = heap_addr.i64 heap0, v2, 1 @516c store little v46, v23 @5175 v24 = nop @517a v25 = nop @517a v26 = nop @517b v27 = nop @517c v28 = iconst.i64 32 @517e v29 = nop @5181 v30 = iconst.i64 32 @5183 v31 = nop @5184 v32 = nop @5185 v33 = nop @518a v34 = nop @518a v35 = nop @518b v36 = nop @518c v37 = nop @5191 v38 = nop @5196 v39 = nop @5197 v40 = nop @5198 v41 = nop @5199 v42 = heap_addr.i64 heap0, v2, 1 @5199 store little v47, v42+8 @519c jump block1 block1: @519c return
Very naive benchmark gives around 30% speedup
shamatar commented on issue #4077:
Few comments:
- is there any pass that removes unused values? I believe it should be, but I'd like to trigger it manually
- it may be a better options if instead of concat - imul - split I'd try manual implementation using imul/umulhi
- is there any inlining step? In my benchmarks I've tried when the main MAC arithmetic function is "extern C" and normal Rust with "inline" attribute and there is no difference, that most likely indicated that "__multi3" is inserted by the compiler backend at the later stages. And after I substantially reduce the body of "__multi3" it should be a good candidate for inlining, and it will also give me more information to work with, like the fact that parts of the input are always 0
bjorn3 commented on issue #4077:
is there any pass that removes unused values? I believe it should be, but I'd like to trigger it manually
Yes, the dce pass. Note that lowering to the backend specific ir already does this implicitly. Also I think you shouldn't replace the instructions with nops. It results in invalid clif ir (as nop doesn't return any values) and it is incorrect if any of the instruction return values are used elsewhere. For example because you matched a function that looks somewhat like __multi3 but not exactly.
is there any inlining step?
No, there isn't. It is also non-trivial to implement as currently every function is independently lowered to clif ir, optimized and codegened to native code. Wasmtime even compiles multiple functions in parallel.
shamatar commented on issue #4077:
Ty @bjorn3
I've analyzed my NOPing approach and it's overzealous indeed.
As for inlining - are there any plans to add it? It would require some form of "synchronization point" after CLIF generation, then inlining itself can also be parallelized, and then it's again a parallel native codegen again
bjorn3 commented on issue #4077:
Yes, https://github.com/bytecodealliance/rfcs/blob/main/accepted/cranelift-roadmap-2022.md#inlining
cfallin labeled issue #4077:
Feature
Implement an optimization pass that would eliminate the
__multi3
function from WASM binary during JIT by replacing it with ISA specific (mainly forx86_64
andarm64
) sequences, and then inline such sequences into callsites that would allow further optimizationsBenefit
A lot of code dealing with cryptography would benefit form faster full width
u64
multiplications where such__multi3
arisesImplementation
If someone would give a few hints about where to start I'd try to implement it by myself
Alternatives
Not that I'm aware of. Patching into calling come native library function is a huge overhead for modern CPUs (4 cycles for
x86_64
for e.g.mulx
ormul
), and while it would be faster most likely, it's still far from optimal case on a hot pathAs an example a simple multiply-add-carry function like
a*b + c + carry -> (high, low)
that accumulates intou128
without overflows compiles down to the listing below, and it can be a good test subject (transformed fromwasm
intowat
, may be not the best readable)(module (type (;0;) (func (param i32 i64 i64 i64 i64))) (func $mac (type 0) (param i32 i64 i64 i64 i64) (local i32) global.get $__stack_pointer i32.const 16 i32.sub local.tee 5 global.set $__stack_pointer local.get 5 local.get 2 i64.const 0 local.get 1 i64.const 0 call $__multi3 local.get 0 local.get 5 i64.load local.tee 2 local.get 3 i64.add local.tee 3 local.get 4 i64.add local.tee 4 i64.store local.get 0 local.get 5 i32.const 8 i32.add i64.load local.get 3 local.get 2 i64.lt_u i64.extend_i32_u i64.add local.get 4 local.get 3 i64.lt_u i64.extend_i32_u i64.add i64.store offset=8 local.get 5 i32.const 16 i32.add global.set $__stack_pointer ) (func $__multi3 (type 0) (param i32 i64 i64 i64 i64) (local i64 i64 i64 i64 i64 i64) local.get 0 local.get 3 i64.const 4294967295 i64.and local.tee 5 local.get 1 i64.const 4294967295 i64.and local.tee 6 i64.mul local.tee 7 local.get 5 local.get 1 i64.const 32 i64.shr_u local.tee 8 i64.mul local.tee 9 local.get 3 i64.const 32 i64.shr_u local.tee 10 local.get 6 i64.mul i64.add local.tee 5 i64.const 32 i64.shl i64.add local.tee 6 i64.store local.get 0 local.get 10 local.get 8 i64.mul local.get 5 local.get 9 i64.lt_u i64.extend_i32_u i64.const 32 i64.shl local.get 5 i64.const 32 i64.shr_u i64.or i64.add local.get 6 local.get 7 i64.lt_u i64.extend_i32_u i64.add local.get 4 local.get 1 i64.mul local.get 3 local.get 2 i64.mul i64.add i64.add i64.store offset=8 ) (table (;0;) 1 1 funcref) (memory (;0;) 16) (global $__stack_pointer (mut i32) i32.const 1048576) (global (;1;) i32 i32.const 1048576) (global (;2;) i32 i32.const 1048576) (export "memory" (memory 0)) (export "mac" (func $mac)) (export "__data_end" (global 1)) (export "__heap_base" (global 2)) )
cfallin labeled issue #4077:
Feature
Implement an optimization pass that would eliminate the
__multi3
function from WASM binary during JIT by replacing it with ISA specific (mainly forx86_64
andarm64
) sequences, and then inline such sequences into callsites that would allow further optimizationsBenefit
A lot of code dealing with cryptography would benefit form faster full width
u64
multiplications where such__multi3
arisesImplementation
If someone would give a few hints about where to start I'd try to implement it by myself
Alternatives
Not that I'm aware of. Patching into calling come native library function is a huge overhead for modern CPUs (4 cycles for
x86_64
for e.g.mulx
ormul
), and while it would be faster most likely, it's still far from optimal case on a hot pathAs an example a simple multiply-add-carry function like
a*b + c + carry -> (high, low)
that accumulates intou128
without overflows compiles down to the listing below, and it can be a good test subject (transformed fromwasm
intowat
, may be not the best readable)(module (type (;0;) (func (param i32 i64 i64 i64 i64))) (func $mac (type 0) (param i32 i64 i64 i64 i64) (local i32) global.get $__stack_pointer i32.const 16 i32.sub local.tee 5 global.set $__stack_pointer local.get 5 local.get 2 i64.const 0 local.get 1 i64.const 0 call $__multi3 local.get 0 local.get 5 i64.load local.tee 2 local.get 3 i64.add local.tee 3 local.get 4 i64.add local.tee 4 i64.store local.get 0 local.get 5 i32.const 8 i32.add i64.load local.get 3 local.get 2 i64.lt_u i64.extend_i32_u i64.add local.get 4 local.get 3 i64.lt_u i64.extend_i32_u i64.add i64.store offset=8 local.get 5 i32.const 16 i32.add global.set $__stack_pointer ) (func $__multi3 (type 0) (param i32 i64 i64 i64 i64) (local i64 i64 i64 i64 i64 i64) local.get 0 local.get 3 i64.const 4294967295 i64.and local.tee 5 local.get 1 i64.const 4294967295 i64.and local.tee 6 i64.mul local.tee 7 local.get 5 local.get 1 i64.const 32 i64.shr_u local.tee 8 i64.mul local.tee 9 local.get 3 i64.const 32 i64.shr_u local.tee 10 local.get 6 i64.mul i64.add local.tee 5 i64.const 32 i64.shl i64.add local.tee 6 i64.store local.get 0 local.get 10 local.get 8 i64.mul local.get 5 local.get 9 i64.lt_u i64.extend_i32_u i64.const 32 i64.shl local.get 5 i64.const 32 i64.shr_u i64.or i64.add local.get 6 local.get 7 i64.lt_u i64.extend_i32_u i64.add local.get 4 local.get 1 i64.mul local.get 3 local.get 2 i64.mul i64.add i64.add i64.store offset=8 ) (table (;0;) 1 1 funcref) (memory (;0;) 16) (global $__stack_pointer (mut i32) i32.const 1048576) (global (;1;) i32 i32.const 1048576) (global (;2;) i32 i32.const 1048576) (export "memory" (memory 0)) (export "mac" (func $mac)) (export "__data_end" (global 1)) (export "__heap_base" (global 2)) )
jameysharp commented on issue #4077:
I've been investigating what we can do to improve the performance of
__multi3
. For reference, the C source for this function is in LLVM atcompiler-rt/lib/builtins/multi3.c
, which helps a little with understanding what's going on.Part of the solution proposed in earlier discussion here is function inlining. @elliottt and I discussed this some yesterday and what I'd like to see is a separate tool that supports an inlining transformation on core WebAssembly. I'd then recommend that people use that to eliminate the optimization barriers around
__multi3
. I don't see inlining actually happening soon in Wasmtime itself and it'd be nice to have something to recommend for that.One difficult problem is that this function returns a two-member
struct
, which according to the WebAssembly basic C ABI means it must be returned through linear memory. If it were returned on the wasm stack we'd be able to keep both parts in registers but instead we have to write to RAM. There is no optimization we can legally do to avoid the writes, even with inlining, although with inlining we likely can avoid reading the values back from RAM. Ideally the C ABI would be updated to take advantage of multi-value returns, which apparently didn't exist when it was specified.Now, what is this function actually computing, and what patterns can we recognize and transform in the mid-end?
<details>
<summary>Raw CLIF for __multi3, slightly edited</summary>test optimize precise-output set opt_level=speed_and_size target x86_64 ; v3 = a_lo ; v4 = a_hi ; v5 = b_lo ; v6 = b_hi ; v11 = a0 ; v14 = a1 ; v9 = b0 ; v17 = b1 function %multi3(i64 vmctx, i64, i32, i64, i64, i64, i64) fast { gv0 = vmctx gv1 = load.i64 notrap aligned readonly gv0+80 block0(v0: i64, v1: i64, v2: i32, v3: i64, v4: i64, v5: i64, v6: i64): v7 = iconst.i64 0 v8 = iconst.i64 0xffff_ffff v9 = band v5, v8 v10 = iconst.i64 0xffff_ffff v11 = band v3, v10 v12 = imul v9, v11 v13 = iconst.i64 32 v14 = ushr v3, v13 v15 = imul v9, v14 v16 = iconst.i64 32 v17 = ushr v5, v16 v18 = imul v17, v11 v19 = iadd v15, v18 v20 = iconst.i64 32 v21 = ishl v19, v20 v22 = iadd v12, v21 v23 = uextend.i64 v2 v24 = global_value.i64 gv1 v25 = iadd v24, v23 store little heap v22, v25 v26 = imul v17, v14 v27 = icmp ult v19, v15 v28 = uextend.i32 v27 v29 = uextend.i64 v28 v30 = iconst.i64 32 v31 = ishl v29, v30 v32 = iconst.i64 32 v33 = ushr v19, v32 v34 = bor v31, v33 v35 = iadd v26, v34 v36 = icmp ult v22, v12 v37 = uextend.i32 v36 v38 = uextend.i64 v37 v39 = iadd v35, v38 v40 = imul v6, v3 v41 = imul v5, v4 v42 = iadd v40, v41 v43 = iadd v39, v42 v44 = uextend.i64 v2 v45 = global_value.i64 gv1 v46 = iadd v45, v44 v47 = iadd_imm v46, 8 store little heap v43, v47 return }
</details>
This function implements a 128-bit multiply where the operands and result are stored as pairs of 64-bit integers. It builds up the result, in part, using a series of 32x32->64 multiplies. Viewed as a sequence of 32-bit "digits", the result of long multiplication should look like this, although each single-digit product may have a carry into the column to its left:
a3 a2 a1 a0 * b3 b2 b1 b0 ======================= a0*b3 a0*b2 a0*b1 a0*b0 + a1*b2 a1*b1 a1*b0 + a2*b1 a2*b0 + a3*b0
However only a1/a0 and b1/b0 are actually treated this way, to compute the lower half of the result along with part of the upper half.
a1 a0 * b1 b0 ======================= a0*b1 a0*b0 + a1*b1 a1*b0
In Cranelift, we'd want to implement this with a pair of instructions:
imul
andumulhi
, to produce the lower and upper 64 bits, respectively.The remaining part of the upper half is just
a_lo*b_hi+a_hi*b_lo
, performed as regular 64-bit multiplies which are equivalent to these parts of the 32-bit-at-a-time long multiplication:a0*b3 a0*b2 + a1*b2 + a2*b1 a2*b0 + a3*b0
There's actually nothing we can improve in this part, as shown in @cfallin's gcc-generated assembly listing, which has two
imulq
and twoaddq
.So back to the lower-half 64x64->128 multiply, performed 32 bits at a time. Our current optimizer produces this sequence which is equivalent to
v22 = imul.i64 v3, v5
:v8 = iconst.i64 0xffff_ffff v9 = band v5, v8 ; v8 = 0xffff_ffff v11 = band v3, v8 ; v8 = 0xffff_ffff v12 = imul v9, v11 v13 = iconst.i64 32 v14 = ushr v3, v13 ; v13 = 32 v15 = imul v9, v14 v17 = ushr v5, v13 ; v13 = 32 v18 = imul v17, v11 v19 = iadd v15, v18 v21 = ishl v19, v13 ; v13 = 32 v22 = iadd v12, v21
Then I think this part (which reuses some of the intermediate values above) is equivalent to either
umulhi
orsmulhi
. I haven't yet figured out what purpose theicmp
instructions serve, although I can see they're effectively checking if the adds overflowed forv22
andv19
above.v26 = imul v17, v14 v27 = icmp ult v19, v15 v49 = uextend.i64 v27 v31 = ishl v49, v13 ; v13 = 32 v33 = ushr v19, v13 ; v13 = 32 v34 = bor v31, v33 v35 = iadd v26, v34 v36 = icmp ult v22, v12 v51 = uextend.i64 v36 v39 = iadd v35, v51
If we can get egraph rules to simplify these two sequences to the upper and lower halves of a 64-bit multiply, then once we also solve #5623 we'll get the same sequence of instructions out that gcc produces for 128-bit multiplies.
Alternatively, if we could match on multiple results at once then we could write a rule that matches this combination of three
imul
, oneumulhi
, and threeiadd
and turns them into a single 128-bitimul
surrounded byiconcat
andisplit
. I believe on x64 we already lower that to something equivalent to the gcc-generated sequence.
Last updated: Nov 22 2024 at 16:03 UTC