From 6615253348a2b2a16957cf761dee5c4c9c3e942f Mon Sep 17 00:00:00 2001 From: Harishankar Vishwanathan Date: Wed, 17 Dec 2025 14:32:41 -0500 Subject: [PATCH] bpf/verifier: pruning branches based on no intersection between domains When verifying branches that compare two registers, if we can determine that the possible value ranges of the two registers do not intersect based on their unsigned and signed min/max values and tnums, we can conclude that certain branch outcomes are impossible. This patch adds checks for intersection between the unsigned ranges/signed ranges and tnums of registers. If no intersection exists, we can prune the impossible branch. Signed-off-by: Harishankar Vishwanathan --- include/linux/tnum.h | 3 +- kernel/bpf/tnum.c | 39 ++++++++++ kernel/bpf/verifier.c | 169 ++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 203 insertions(+), 8 deletions(-) diff --git a/include/linux/tnum.h b/include/linux/tnum.h index c52b862dad45b..e5d87853892fa 100644 --- a/include/linux/tnum.h +++ b/include/linux/tnum.h @@ -125,5 +125,6 @@ static inline bool tnum_subreg_is_const(struct tnum a) { return !(tnum_subreg(a)).mask; } - +/* Returns smallest member of t > z */ +u64 tnum_step(struct tnum t, u64 z); #endif /* _LINUX_TNUM_H */ diff --git a/kernel/bpf/tnum.c b/kernel/bpf/tnum.c index f8e70e9c3998d..cca2b5675a1d4 100644 --- a/kernel/bpf/tnum.c +++ b/kernel/bpf/tnum.c @@ -253,3 +253,42 @@ struct tnum tnum_const_subreg(struct tnum a, u32 value) { return tnum_with_subreg(a, tnum_const(value)); } + +u64 tnum_step(struct tnum t, u64 z) +{ + u64 tmax, j, p, q, r, s, v, u, w, res; + u8 k; + + tmax = t.value | t.mask; + + /* if z >= largest member of t, return largest member of t */ + if (z >= tmax) + return tmax; + + /* keep t's known bits, and match all unknown bits to z */ + j = t.value | (z & t.mask); + + if (j > z) { + p = ~z & t.value & ~t.mask; + k = fls64(p); /* k is the most-significant 0-to-1 flip */ + q = U64_MAX << k; + r = q & z; /* positions > k matched to z */ + s = ~q & t.value; /* positions <= k matched to t.value */ + v = r | s; + res = v; + } else { + p = z & ~t.value & ~t.mask; + k = fls64(p); /* k is the most-significant 1-to-0 flip */ + q = U64_MAX << k; + r = q & t.mask & z; /* unknown positions > k, matched to z */ + s = q & ~t.mask; /* known positions > k, set to 1 */ + v = r | s; + /* add 1 to unknown positions > k to make value greater than z */ + u = v + (1ULL << k); + /* extract bits in unknown positions > k from u, rest from t.value */ + w = u & (t.mask | t.value); + res = w; + } + return res; +} + diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c index d6b8a77fbe3bf..3d983fa49836c 100644 --- a/kernel/bpf/verifier.c +++ b/kernel/bpf/verifier.c @@ -16059,11 +16059,165 @@ static void find_good_pkt_pointers(struct bpf_verifier_state *vstate, })); } +static bool intersection_u64_s64(u64 umin, u64 umax, s64 smin, s64 smax) +{ + if ((u64)smin <= (u64)smax) + return !(((u64)smax < umin) || (umax < (u64)smin)); + else + return !(((u64)smin > umax) && ((u64)smax < umin)); +} + +static bool intersection_u64_tnum(u64 umin, u64 umax, struct tnum t) +{ + u64 tmin = t.value; + u64 tmax = t.value | t.mask; + + return !((tmin > umax) || (tmax < umin) || + ((t.value != (umin & ~t.mask)) && (tnum_step(t, umin) > umax))); +} + +static bool intersection_s64_tnum(s64 smin, s64 smax, struct tnum t) +{ + if ((u64)smin <= (u64)smax) + return intersection_u64_tnum((u64)smin, (u64)smax, t); + + return (intersection_u64_tnum((u64)smin, U64_MAX, t) || + intersection_u64_tnum(0, (u64)smax, t)); +} + +static bool intersection_u32_s32(u32 u32_min, u32 u32_max, s32 s32_min, s32 s32_max) +{ + if ((u32)s32_min <= (u32)s32_max) + return !(((u32)s32_max < u32_min) || (u32_max < (u32)s32_min)); + else + return !(((u32)s32_min > u32_max) && ((u32)s32_max < u32_min)); +} + +static bool intersection_u32_tnum(u32 u32_min, u32 u32_max, struct tnum t) +{ + struct tnum t32 = tnum_subreg(t); + u32 t32_min = t32.value; + u32 t32_max = t32.value | t32.mask; + + return !((t32_min > u32_max) || + (t32_max < u32_min) || + ((t32.value != (u32_min & ~t32.mask)) && + (tnum_step(t32, u32_min) > u32_max))); +} + +static bool intersection_s32_tnum(s32 s32_min, s32 s32_max, struct tnum t) +{ + if ((u32)s32_min <= (u32)s32_max) + return intersection_u32_tnum((u32)s32_min, (u32)s32_max, t); + + return (intersection_u32_tnum((u32)s32_min, U32_MAX, t) || + intersection_u32_tnum(0, (u32)s32_max, t)); +} + +static bool check_intersection_all(struct bpf_verifier_env *env, + struct bpf_reg_state *reg, + const char *ctx) +{ + + bool intersection_exists; + + if (reg->umin_value > reg->umax_value || + reg->smin_value > reg->smax_value || + reg->u32_min_value > reg->u32_max_value || + reg->s32_min_value > reg->s32_max_value) { + intersection_exists = false; + } else if ((reg->var_off.value & reg->var_off.mask) != 0) { + intersection_exists = false; + } else if (!intersection_u64_s64(reg->umin_value, reg->umax_value, + reg->smin_value, reg->smax_value)) { + intersection_exists = false; + } else if (!intersection_u64_tnum(reg->umin_value, reg->umax_value, reg->var_off)) { + intersection_exists = false; + } else if (!intersection_s64_tnum(reg->smin_value, reg->smax_value, reg->var_off)) { + intersection_exists = false; + } else if (!intersection_u32_s32(reg->u32_min_value, reg->u32_max_value, + reg->s32_min_value, reg->s32_max_value)) { + intersection_exists = false; + } else if (!intersection_u32_tnum(reg->u32_min_value, reg->u32_max_value, reg->var_off)) { + intersection_exists = false; + } else if (!intersection_s32_tnum(reg->s32_min_value, reg->s32_max_value, reg->var_off)) { + intersection_exists = false; + } else { + intersection_exists = true; + } + + if (!intersection_exists) { + return false; + } else + return true; + +} + +static void regs_refine_cond_op(struct bpf_reg_state *reg1, + struct bpf_reg_state *reg2, + u8 opcode, bool is_jmp32); +static u8 rev_opcode(u8 opcode); + +static int simulate_both_branches_taken(struct bpf_verifier_env *env, + struct bpf_reg_state *false_reg1, + struct bpf_reg_state *false_reg2, + u8 opcode, bool is_jmp32) +{ + + struct bpf_reg_state false_reg1_c, false_reg2_c, true_reg1, true_reg2; + bool t1, t2, f1, f2; + + memcpy(&false_reg1_c, false_reg1, sizeof(struct bpf_reg_state)); + memcpy(&false_reg2_c, false_reg2, sizeof(struct bpf_reg_state)); + memcpy(&true_reg1, false_reg1, sizeof(struct bpf_reg_state)); + memcpy(&true_reg2, false_reg2, sizeof(struct bpf_reg_state)); + + /* fallthrough (FALSE) branch */ + check_intersection_all(env, &false_reg1_c, "BR_false_reg1_c"); + check_intersection_all(env, &false_reg2_c, "BR_false_reg2_c"); + regs_refine_cond_op(&false_reg1_c, &false_reg2_c, rev_opcode(opcode), is_jmp32); + check_intersection_all(env, &false_reg1_c, "BS_false_reg1_c"); + check_intersection_all(env, &false_reg2_c, "BS_false_reg2_c"); + reg_bounds_sync(&false_reg1_c); + reg_bounds_sync(&false_reg2_c); + f1 = check_intersection_all(env, &false_reg1_c, "AS_false_reg1_c"); + f2 = check_intersection_all(env, &false_reg2_c, "AS_false_reg2_c"); + + /* jump (TRUE) branch */ + check_intersection_all(env, &true_reg1, "BR_true_reg1"); + check_intersection_all(env, &true_reg2, "BR_true_reg2"); + regs_refine_cond_op(&true_reg1, &true_reg2, opcode, is_jmp32); + check_intersection_all(env, &true_reg1, "BS_true_reg1"); + check_intersection_all(env, &true_reg2, "BS_true_reg2"); + reg_bounds_sync(&true_reg1); + reg_bounds_sync(&true_reg2); + t1 = check_intersection_all(env, &true_reg1, "AS_true_reg1"); + t2 = check_intersection_all(env, &true_reg2, "AS_true_reg2"); + + if (!f1 || !f2) { + /* If there is no intersection among *any pair* of abstract values in + * either reg_states in the FALSE branch (i.e. false_reg1, false_reg2), + * the FALSE branch must be dead. Only TRUE branch will be taken. + */ + return 1; + } else if (!t1 || !t2) { + /* If there is no intersection among *any pair* of abstract values in + * either reg_states in the TRUE branch (i.e. true_reg1, true_reg2), + * the TRUE branch must be dead. Only FALSE branch will be taken. + */ + return 0; + } + + return -1; +} + /* * , currently assuming reg2 is a constant */ -static int is_scalar_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg2, - u8 opcode, bool is_jmp32) +static int is_scalar_branch_taken(struct bpf_verifier_env *env, + struct bpf_reg_state *reg1, + struct bpf_reg_state *reg2, + u8 opcode, bool is_jmp32) { struct tnum t1 = is_jmp32 ? tnum_subreg(reg1->var_off) : reg1->var_off; struct tnum t2 = is_jmp32 ? tnum_subreg(reg2->var_off) : reg2->var_off; @@ -16215,7 +16369,7 @@ static int is_scalar_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_sta break; } - return -1; + return simulate_both_branches_taken(env, reg1, reg2, opcode, is_jmp32); } static int flip_opcode(u32 opcode) @@ -16286,8 +16440,9 @@ static int is_pkt_ptr_branch_taken(struct bpf_reg_state *dst_reg, * -1 - unknown. Example: "if (reg1 < 5)" is unknown when register value * range [0,10] */ -static int is_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg2, - u8 opcode, bool is_jmp32) +static int is_branch_taken(struct bpf_verifier_env *env, + struct bpf_reg_state *reg1, struct bpf_reg_state *reg2, + u8 opcode, bool is_jmp32) { if (reg_is_pkt_pointer_any(reg1) && reg_is_pkt_pointer_any(reg2) && !is_jmp32) return is_pkt_ptr_branch_taken(reg1, reg2, opcode); @@ -16325,7 +16480,7 @@ static int is_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg } /* now deal with two scalars, but not necessarily constants */ - return is_scalar_branch_taken(reg1, reg2, opcode, is_jmp32); + return is_scalar_branch_taken(env, reg1, reg2, opcode, is_jmp32); } /* Opcode that corresponds to a *false* branch condition. @@ -16933,7 +17088,7 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env, } is_jmp32 = BPF_CLASS(insn->code) == BPF_JMP32; - pred = is_branch_taken(dst_reg, src_reg, opcode, is_jmp32); + pred = is_branch_taken(env, dst_reg, src_reg, opcode, is_jmp32); if (pred >= 0) { /* If we get here with a dst_reg pointer type it is because * above is_branch_taken() special cased the 0 comparison.