aboutsummaryrefslogtreecommitdiff
path: root/arch/x86/net
diff options
context:
space:
mode:
Diffstat (limited to 'arch/x86/net')
-rw-r--r--arch/x86/net/bpf_jit_comp.c219
1 files changed, 132 insertions, 87 deletions
diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c
index 45e4eb5bcbb2..cbf94d44f3d5 100644
--- a/arch/x86/net/bpf_jit_comp.c
+++ b/arch/x86/net/bpf_jit_comp.c
@@ -61,7 +61,12 @@ static bool is_imm8(int value)
static bool is_simm32(s64 value)
{
- return value == (s64) (s32) value;
+ return value == (s64)(s32)value;
+}
+
+static bool is_uimm32(u64 value)
+{
+ return value == (u64)(u32)value;
}
/* mov dst, src */
@@ -212,7 +217,7 @@ struct jit_context {
/* emit x64 prologue code for BPF program and check it's size.
* bpf_tail_call helper will skip it while jumping into another program
*/
-static void emit_prologue(u8 **pprog, u32 stack_depth)
+static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf)
{
u8 *prog = *pprog;
int cnt = 0;
@@ -247,18 +252,21 @@ static void emit_prologue(u8 **pprog, u32 stack_depth)
/* mov qword ptr [rbp+24],r15 */
EMIT4(0x4C, 0x89, 0x7D, 24);
- /* Clear the tail call counter (tail_call_cnt): for eBPF tail calls
- * we need to reset the counter to 0. It's done in two instructions,
- * resetting rax register to 0 (xor on eax gets 0 extended), and
- * moving it to the counter location.
- */
+ if (!ebpf_from_cbpf) {
+ /* Clear the tail call counter (tail_call_cnt): for eBPF tail
+ * calls we need to reset the counter to 0. It's done in two
+ * instructions, resetting rax register to 0, and moving it
+ * to the counter location.
+ */
+
+ /* xor eax, eax */
+ EMIT2(0x31, 0xc0);
+ /* mov qword ptr [rbp+32], rax */
+ EMIT4(0x48, 0x89, 0x45, 32);
- /* xor eax, eax */
- EMIT2(0x31, 0xc0);
- /* mov qword ptr [rbp+32], rax */
- EMIT4(0x48, 0x89, 0x45, 32);
+ BUILD_BUG_ON(cnt != PROLOGUE_SIZE);
+ }
- BUILD_BUG_ON(cnt != PROLOGUE_SIZE);
*pprog = prog;
}
@@ -356,6 +364,86 @@ static void emit_load_skb_data_hlen(u8 **pprog)
*pprog = prog;
}
+static void emit_mov_imm32(u8 **pprog, bool sign_propagate,
+ u32 dst_reg, const u32 imm32)
+{
+ u8 *prog = *pprog;
+ u8 b1, b2, b3;
+ int cnt = 0;
+
+ /* optimization: if imm32 is positive, use 'mov %eax, imm32'
+ * (which zero-extends imm32) to save 2 bytes.
+ */
+ if (sign_propagate && (s32)imm32 < 0) {
+ /* 'mov %rax, imm32' sign extends imm32 */
+ b1 = add_1mod(0x48, dst_reg);
+ b2 = 0xC7;
+ b3 = 0xC0;
+ EMIT3_off32(b1, b2, add_1reg(b3, dst_reg), imm32);
+ goto done;
+ }
+
+ /* optimization: if imm32 is zero, use 'xor %eax, %eax'
+ * to save 3 bytes.
+ */
+ if (imm32 == 0) {
+ if (is_ereg(dst_reg))
+ EMIT1(add_2mod(0x40, dst_reg, dst_reg));
+ b2 = 0x31; /* xor */
+ b3 = 0xC0;
+ EMIT2(b2, add_2reg(b3, dst_reg, dst_reg));
+ goto done;
+ }
+
+ /* mov %eax, imm32 */
+ if (is_ereg(dst_reg))
+ EMIT1(add_1mod(0x40, dst_reg));
+ EMIT1_off32(add_1reg(0xB8, dst_reg), imm32);
+done:
+ *pprog = prog;
+}
+
+static void emit_mov_imm64(u8 **pprog, u32 dst_reg,
+ const u32 imm32_hi, const u32 imm32_lo)
+{
+ u8 *prog = *pprog;
+ int cnt = 0;
+
+ if (is_uimm32(((u64)imm32_hi << 32) | (u32)imm32_lo)) {
+ /* For emitting plain u32, where sign bit must not be
+ * propagated LLVM tends to load imm64 over mov32
+ * directly, so save couple of bytes by just doing
+ * 'mov %eax, imm32' instead.
+ */
+ emit_mov_imm32(&prog, false, dst_reg, imm32_lo);
+ } else {
+ /* movabsq %rax, imm64 */
+ EMIT2(add_1mod(0x48, dst_reg), add_1reg(0xB8, dst_reg));
+ EMIT(imm32_lo, 4);
+ EMIT(imm32_hi, 4);
+ }
+
+ *pprog = prog;
+}
+
+static void emit_mov_reg(u8 **pprog, bool is64, u32 dst_reg, u32 src_reg)
+{
+ u8 *prog = *pprog;
+ int cnt = 0;
+
+ if (is64) {
+ /* mov dst, src */
+ EMIT_mov(dst_reg, src_reg);
+ } else {
+ /* mov32 dst, src */
+ if (is_ereg(dst_reg) || is_ereg(src_reg))
+ EMIT1(add_2mod(0x40, dst_reg, src_reg));
+ EMIT2(0x89, add_2reg(0xC0, dst_reg, src_reg));
+ }
+
+ *pprog = prog;
+}
+
static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
int oldproglen, struct jit_context *ctx)
{
@@ -369,7 +457,8 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
int proglen = 0;
u8 *prog = temp;
- emit_prologue(&prog, bpf_prog->aux->stack_depth);
+ emit_prologue(&prog, bpf_prog->aux->stack_depth,
+ bpf_prog_was_classic(bpf_prog));
if (seen_ld_abs)
emit_load_skb_data_hlen(&prog);
@@ -378,7 +467,7 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
const s32 imm32 = insn->imm;
u32 dst_reg = insn->dst_reg;
u32 src_reg = insn->src_reg;
- u8 b1 = 0, b2 = 0, b3 = 0;
+ u8 b2 = 0, b3 = 0;
s64 jmp_offset;
u8 jmp_cond;
bool reload_skb_data;
@@ -414,16 +503,11 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
EMIT2(b2, add_2reg(0xC0, dst_reg, src_reg));
break;
- /* mov dst, src */
case BPF_ALU64 | BPF_MOV | BPF_X:
- EMIT_mov(dst_reg, src_reg);
- break;
-
- /* mov32 dst, src */
case BPF_ALU | BPF_MOV | BPF_X:
- if (is_ereg(dst_reg) || is_ereg(src_reg))
- EMIT1(add_2mod(0x40, dst_reg, src_reg));
- EMIT2(0x89, add_2reg(0xC0, dst_reg, src_reg));
+ emit_mov_reg(&prog,
+ BPF_CLASS(insn->code) == BPF_ALU64,
+ dst_reg, src_reg);
break;
/* neg dst */
@@ -486,58 +570,13 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
break;
case BPF_ALU64 | BPF_MOV | BPF_K:
- /* optimization: if imm32 is positive,
- * use 'mov eax, imm32' (which zero-extends imm32)
- * to save 2 bytes
- */
- if (imm32 < 0) {
- /* 'mov rax, imm32' sign extends imm32 */
- b1 = add_1mod(0x48, dst_reg);
- b2 = 0xC7;
- b3 = 0xC0;
- EMIT3_off32(b1, b2, add_1reg(b3, dst_reg), imm32);
- break;
- }
-
case BPF_ALU | BPF_MOV | BPF_K:
- /* optimization: if imm32 is zero, use 'xor <dst>,<dst>'
- * to save 3 bytes.
- */
- if (imm32 == 0) {
- if (is_ereg(dst_reg))
- EMIT1(add_2mod(0x40, dst_reg, dst_reg));
- b2 = 0x31; /* xor */
- b3 = 0xC0;
- EMIT2(b2, add_2reg(b3, dst_reg, dst_reg));
- break;
- }
-
- /* mov %eax, imm32 */
- if (is_ereg(dst_reg))
- EMIT1(add_1mod(0x40, dst_reg));
- EMIT1_off32(add_1reg(0xB8, dst_reg), imm32);
+ emit_mov_imm32(&prog, BPF_CLASS(insn->code) == BPF_ALU64,
+ dst_reg, imm32);
break;
case BPF_LD | BPF_IMM | BPF_DW:
- /* optimization: if imm64 is zero, use 'xor <dst>,<dst>'
- * to save 7 bytes.
- */
- if (insn[0].imm == 0 && insn[1].imm == 0) {
- b1 = add_2mod(0x48, dst_reg, dst_reg);
- b2 = 0x31; /* xor */
- b3 = 0xC0;
- EMIT3(b1, b2, add_2reg(b3, dst_reg, dst_reg));
-
- insn++;
- i++;
- break;
- }
-
- /* movabsq %rax, imm64 */
- EMIT2(add_1mod(0x48, dst_reg), add_1reg(0xB8, dst_reg));
- EMIT(insn[0].imm, 4);
- EMIT(insn[1].imm, 4);
-
+ emit_mov_imm64(&prog, dst_reg, insn[1].imm, insn[0].imm);
insn++;
i++;
break;
@@ -594,36 +633,38 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
case BPF_ALU | BPF_MUL | BPF_X:
case BPF_ALU64 | BPF_MUL | BPF_K:
case BPF_ALU64 | BPF_MUL | BPF_X:
- EMIT1(0x50); /* push rax */
- EMIT1(0x52); /* push rdx */
+ {
+ bool is64 = BPF_CLASS(insn->code) == BPF_ALU64;
+
+ if (dst_reg != BPF_REG_0)
+ EMIT1(0x50); /* push rax */
+ if (dst_reg != BPF_REG_3)
+ EMIT1(0x52); /* push rdx */
/* mov r11, dst_reg */
EMIT_mov(AUX_REG, dst_reg);
if (BPF_SRC(insn->code) == BPF_X)
- /* mov rax, src_reg */
- EMIT_mov(BPF_REG_0, src_reg);
+ emit_mov_reg(&prog, is64, BPF_REG_0, src_reg);
else
- /* mov rax, imm32 */
- EMIT3_off32(0x48, 0xC7, 0xC0, imm32);
+ emit_mov_imm32(&prog, is64, BPF_REG_0, imm32);
- if (BPF_CLASS(insn->code) == BPF_ALU64)
+ if (is64)
EMIT1(add_1mod(0x48, AUX_REG));
else if (is_ereg(AUX_REG))
EMIT1(add_1mod(0x40, AUX_REG));
/* mul(q) r11 */
EMIT2(0xF7, add_1reg(0xE0, AUX_REG));
- /* mov r11, rax */
- EMIT_mov(AUX_REG, BPF_REG_0);
-
- EMIT1(0x5A); /* pop rdx */
- EMIT1(0x58); /* pop rax */
-
- /* mov dst_reg, r11 */
- EMIT_mov(dst_reg, AUX_REG);
+ if (dst_reg != BPF_REG_3)
+ EMIT1(0x5A); /* pop rdx */
+ if (dst_reg != BPF_REG_0) {
+ /* mov dst_reg, rax */
+ EMIT_mov(dst_reg, BPF_REG_0);
+ EMIT1(0x58); /* pop rax */
+ }
break;
-
+ }
/* shifts */
case BPF_ALU | BPF_LSH | BPF_K:
case BPF_ALU | BPF_RSH | BPF_K:
@@ -641,7 +682,11 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
case BPF_RSH: b3 = 0xE8; break;
case BPF_ARSH: b3 = 0xF8; break;
}
- EMIT3(0xC1, add_1reg(b3, dst_reg), imm32);
+
+ if (imm32 == 1)
+ EMIT2(0xD1, add_1reg(b3, dst_reg));
+ else
+ EMIT3(0xC1, add_1reg(b3, dst_reg), imm32);
break;
case BPF_ALU | BPF_LSH | BPF_X: