So I have this (maybe rather strange) use case of JITing only "parts" of some riscv(32) bytecode.
For a bit of context, I run these instructions in a interpreter normally because I am generating proofs of correct execution which requires me to collect auxiliary information per cycle, however I also have an "unconstrained" mode, which runs the instructions without collecting this data, this is the part I want to JIT.
For my use case, Im only given this bytecode, and I only want to JIT some parts of it that are "sandwhiched" say, between two placeholder opcodes.
I think most of what i'm trying to do is clear to me, but one thing im having trouble fleshing out is the JALR instruction.
My naive approach to solve this is to actually JIT the whole program, interpreting each instruction as a basic block, and basically implementing like a dynamic jump table at runtime.
Im curious to gather feedback and see if anyone can help me think of a better approach to handling these dynamic jumps.
Thank you in advance!
The classical way of handling indirect jumps in a dynamic binary translation (DBT) system like Valgrind, Pin or Dynamo (to give some keywords to search on!) is to have a big dispatch hashtable from entry-point PC to "trace" or "block" of jit'd code starting at that address. You then compile indirect jumps (or calls) into a sequence that looks up in that hashtable, does a jump to the equivalent native code if present, otherwise invokes the JIT compiler to generate new code (or escape to an interpreter or whatever)
It sounds like what you're wanting to do is a slightly more precomputed version of that where you handle every individual instruction as a potential entry point ahead of time; probably you'll have much better results if you do it on demand (if your environment permits true JIT), since relatively few instructions will be targets of indirect jumps/calls
Thanks for the response, lots of interesting info here.
I think that last part is exactly what im looking to do. I think my next question is, once I do know my entrypoint, say at some pc=X,
How much of the program do I start jitting from pc=X? Until I hit another indirect jump?
Maybe im misunderstanding, but my understanding of what youre saying is that I should basically maintain a mapping, just in time, of pc -> jitted function, but how do i share state between this?
And maybe for more clarity I will always be starting off in the interpreter, and then enter the JIT at some arbitrary PC
Maybe im misunderstanding, but my understanding of what youre saying is that I should basically maintain a mapping, just in time, of pc -> jitted function, but how do i share state between this?
Is it like I might pass in some pointer to my "virtual registers", ie some *mut [Value; 32] and pass this between jitted functions
Yes, exactly, typically you have a "struct CPUState" and within a jit'd block you can load and store to this, perhaps keeping values in registers (SSA values) until the end; and your interpreter operates on the same state.
How long to JIT, and whether you follow one path or both sides of a conditional, etc., is an interesting heuristic question that different systems answer differently
Basically your questions boil down to "how do I design a dynamic binary translation system" and I'd recommend reading at least the Dynamo paper (https://dl.acm.org/doi/pdf/10.1145/349299.349303) -- "fragment cache", "trace selection" and all the rest
Valgrind is a good open-source contemporary example of a system like this, and you may find it instructive to read their internal docs and source too
Awesome thank you very much!
@Chris Fallin Last question for now would be which do you think is better for my registers, the stack operations vs the *mut [Value; u32]
approach?
Couldnt find much info on how cranelift treats/reasons about the stack assignments, are they persistent between functions from the same context?
Cranelift has stackslots but they are local to the current function; so if you need to store state that persists across different compiled instruction traces and your interpreter (I think you do, for CPU registers?) then you'll want to emit normal loads and stores to an array of values in memory
Hey @Chris Fallin ive made alot of progress already and its been such a pleasure to use Cranelift so far!
No pressure at all, but im curious if you would have some insight on how to implement this function correctly, as it fails during (Cranelift) compilation with
failed to compile: Compilation(Verifier(VerifierErrors([VerifierError { location: inst3, context: Some("jump block1(v8)"), message: "uses value arg from non-dominating block1" }])))
#[tracing::instrument(skip_all, fields(opcode = instruction.opcode.mnemonic(), pc = pc))]
fn translate_branch(&mut self, instruction: &Instruction, pc: u32) -> BuilderResult {
tracing::trace!("Translating branch");
let (rs1, rs2, imm) = instruction.b_type();
let rs1_val = self.builder.ins().load(
self.int_32_type,
MemFlags::trusted(),
self.registers_ptr,
rs1.register_offset(),
);
let rs2_val = self.builder.ins().load(
self.int_32_type,
MemFlags::trusted(),
self.registers_ptr,
rs2.register_offset(),
);
let cond = match instruction.opcode {
Opcode::BEQ => self.builder.ins().icmp(IntCC::Equal, rs1_val, rs2_val),
Opcode::BNE => self.builder.ins().icmp(IntCC::NotEqual, rs1_val, rs2_val),
Opcode::BLT => self.builder.ins().icmp(IntCC::SignedLessThan, rs1_val, rs2_val),
Opcode::BGE => {
self.builder.ins().icmp(IntCC::SignedGreaterThanOrEqual, rs1_val, rs2_val)
}
Opcode::BLTU => self.builder.ins().icmp(IntCC::UnsignedLessThan, rs1_val, rs2_val),
Opcode::BGEU => {
self.builder.ins().icmp(IntCC::UnsignedGreaterThanOrEqual, rs1_val, rs2_val)
}
_ => unreachable!(),
};
// If weve already visited this branch point, jump to it.
if let Some(branch_block) = self.branch_points.get(&pc) {
tracing::trace!("Found a branch we already know about: {}", pc);
self.builder.ins().jump(*branch_block, &[cond]);
return BuilderResult::Branch;
}
let branch_block = self.builder.create_block();
self.branch_points.insert(pc, branch_block);
let cond_param = self.builder.append_block_param(branch_block, self.int_32_type);
self.builder.ins().jump(branch_block, &[cond_param]);
self.builder.switch_to_block(branch_block);
let branched = self.builder.create_block();
let not_branched = self.builder.create_block();
self.builder.ins().brif(cond, branched, &[], not_branched, &[]);
self.builder.switch_to_block(branched);
self.builder.seal_block(branched);
BuilderResult::NewBranch { target: pc.wrapping_add(imm), not_branched }
}
I beleive the problem is here:
// If weve already visited this branch point, jump to it.
if let Some(branch_block) = self.branch_points.get(&pc) {
tracing::trace!("Found a branch we already know about: {}", pc);
self.builder.ins().jump(*branch_block, &[cond]);
return BuilderResult::Branch;
}
Basically I think we can encounter a situation when translating like a for loop that says something like
so my strategy is to store these branch points, and immediately eval the case where we branched, (since its an immediate offset) sure this means we might have some overlapping instructions in diff blocks but thats ok for now.
If i hit this branch at this pc again, then I can signal to "translator" to follow the non branch case, but this is where the error comes from
Thanks for all the help so far!
The error message I think makes sense, cond comes from a block that is like a successor of the branch block, but im curious f theres like an obvious way to make this work im not seeing
the error message is indicating that an SSA invariant is not valid -- basically you're using a value that isn't always set on all paths into a use-point (that's the "non-dominating block" bit)
the high-level idea is that you need to add block params whenever you have a "merge point" -- control flow coming in and a use that comes from one of several defs in the joining paths
if you're tracking SSA values for registers in the translated riscv32 code, I suspect the simplest way about this will be to have a block param for each machine register on every basic block
(there are more optimal ways, but you'll need to get more into SSA construction algorithms to get there -- the above should be sufficient to get a correct compilation)
This topic was moved here from #general > (open discussion) JIT only parts of an unstructured programs by Till Schneidereit.
@Nate You may be interested in https://github.com/Amanieu/a-tale-of-binary-translation which seems to be exactly what your are trying to do. It's a simple binary translator that I developed as a teaching example which targets RV32I.
This is the cranelift backend: https://github.com/Amanieu/a-tale-of-binary-translation/blob/master/src/backend/jit/cranelift.rs
Last updated: Jan 24 2025 at 00:11 UTC