Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/linux/tnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,5 +125,6 @@ static inline bool tnum_subreg_is_const(struct tnum a)
{
return !(tnum_subreg(a)).mask;
}

/* Returns smallest member of t > z */
u64 tnum_step(struct tnum t, u64 z);
#endif /* _LINUX_TNUM_H */
39 changes: 39 additions & 0 deletions kernel/bpf/tnum.c
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,42 @@ struct tnum tnum_const_subreg(struct tnum a, u32 value)
{
return tnum_with_subreg(a, tnum_const(value));
}

u64 tnum_step(struct tnum t, u64 z)
{
u64 tmax, j, p, q, r, s, v, u, w, res;
u8 k;

tmax = t.value | t.mask;

/* if z >= largest member of t, return largest member of t */
if (z >= tmax)
return tmax;

/* keep t's known bits, and match all unknown bits to z */
j = t.value | (z & t.mask);

if (j > z) {
p = ~z & t.value & ~t.mask;
k = fls64(p); /* k is the most-significant 0-to-1 flip */
q = U64_MAX << k;
r = q & z; /* positions > k matched to z */
s = ~q & t.value; /* positions <= k matched to t.value */
v = r | s;
res = v;
} else {
p = z & ~t.value & ~t.mask;
k = fls64(p); /* k is the most-significant 1-to-0 flip */
q = U64_MAX << k;
r = q & t.mask & z; /* unknown positions > k, matched to z */
s = q & ~t.mask; /* known positions > k, set to 1 */
v = r | s;
/* add 1 to unknown positions > k to make value greater than z */
u = v + (1ULL << k);
/* extract bits in unknown positions > k from u, rest from t.value */
w = u & (t.mask | t.value);
res = w;
}
return res;
}

169 changes: 162 additions & 7 deletions kernel/bpf/verifier.c
Original file line number Diff line number Diff line change
Expand Up @@ -16059,11 +16059,165 @@ static void find_good_pkt_pointers(struct bpf_verifier_state *vstate,
}));
}

static bool intersection_u64_s64(u64 umin, u64 umax, s64 smin, s64 smax)
{
if ((u64)smin <= (u64)smax)
return !(((u64)smax < umin) || (umax < (u64)smin));
else
return !(((u64)smin > umax) && ((u64)smax < umin));
}

static bool intersection_u64_tnum(u64 umin, u64 umax, struct tnum t)
{
u64 tmin = t.value;
u64 tmax = t.value | t.mask;

return !((tmin > umax) || (tmax < umin) ||
((t.value != (umin & ~t.mask)) && (tnum_step(t, umin) > umax)));
}

static bool intersection_s64_tnum(s64 smin, s64 smax, struct tnum t)
{
if ((u64)smin <= (u64)smax)
return intersection_u64_tnum((u64)smin, (u64)smax, t);

return (intersection_u64_tnum((u64)smin, U64_MAX, t) ||
intersection_u64_tnum(0, (u64)smax, t));
}

static bool intersection_u32_s32(u32 u32_min, u32 u32_max, s32 s32_min, s32 s32_max)
{
if ((u32)s32_min <= (u32)s32_max)
return !(((u32)s32_max < u32_min) || (u32_max < (u32)s32_min));
else
return !(((u32)s32_min > u32_max) && ((u32)s32_max < u32_min));
}

static bool intersection_u32_tnum(u32 u32_min, u32 u32_max, struct tnum t)
{
struct tnum t32 = tnum_subreg(t);
u32 t32_min = t32.value;
u32 t32_max = t32.value | t32.mask;

return !((t32_min > u32_max) ||
(t32_max < u32_min) ||
((t32.value != (u32_min & ~t32.mask)) &&
(tnum_step(t32, u32_min) > u32_max)));
}

static bool intersection_s32_tnum(s32 s32_min, s32 s32_max, struct tnum t)
{
if ((u32)s32_min <= (u32)s32_max)
return intersection_u32_tnum((u32)s32_min, (u32)s32_max, t);

return (intersection_u32_tnum((u32)s32_min, U32_MAX, t) ||
intersection_u32_tnum(0, (u32)s32_max, t));
}

static bool check_intersection_all(struct bpf_verifier_env *env,
struct bpf_reg_state *reg,
const char *ctx)
{

bool intersection_exists;

if (reg->umin_value > reg->umax_value ||
reg->smin_value > reg->smax_value ||
reg->u32_min_value > reg->u32_max_value ||
reg->s32_min_value > reg->s32_max_value) {
intersection_exists = false;
} else if ((reg->var_off.value & reg->var_off.mask) != 0) {
intersection_exists = false;
} else if (!intersection_u64_s64(reg->umin_value, reg->umax_value,
reg->smin_value, reg->smax_value)) {
intersection_exists = false;
} else if (!intersection_u64_tnum(reg->umin_value, reg->umax_value, reg->var_off)) {
intersection_exists = false;
} else if (!intersection_s64_tnum(reg->smin_value, reg->smax_value, reg->var_off)) {
intersection_exists = false;
} else if (!intersection_u32_s32(reg->u32_min_value, reg->u32_max_value,
reg->s32_min_value, reg->s32_max_value)) {
intersection_exists = false;
} else if (!intersection_u32_tnum(reg->u32_min_value, reg->u32_max_value, reg->var_off)) {
intersection_exists = false;
} else if (!intersection_s32_tnum(reg->s32_min_value, reg->s32_max_value, reg->var_off)) {
intersection_exists = false;
} else {
intersection_exists = true;
}

if (!intersection_exists) {
return false;
} else
return true;

}

static void regs_refine_cond_op(struct bpf_reg_state *reg1,
struct bpf_reg_state *reg2,
u8 opcode, bool is_jmp32);
static u8 rev_opcode(u8 opcode);

static int simulate_both_branches_taken(struct bpf_verifier_env *env,
struct bpf_reg_state *false_reg1,
struct bpf_reg_state *false_reg2,
u8 opcode, bool is_jmp32)
{

struct bpf_reg_state false_reg1_c, false_reg2_c, true_reg1, true_reg2;
bool t1, t2, f1, f2;

memcpy(&false_reg1_c, false_reg1, sizeof(struct bpf_reg_state));
memcpy(&false_reg2_c, false_reg2, sizeof(struct bpf_reg_state));
memcpy(&true_reg1, false_reg1, sizeof(struct bpf_reg_state));
memcpy(&true_reg2, false_reg2, sizeof(struct bpf_reg_state));

/* fallthrough (FALSE) branch */
check_intersection_all(env, &false_reg1_c, "BR_false_reg1_c");
check_intersection_all(env, &false_reg2_c, "BR_false_reg2_c");
regs_refine_cond_op(&false_reg1_c, &false_reg2_c, rev_opcode(opcode), is_jmp32);
check_intersection_all(env, &false_reg1_c, "BS_false_reg1_c");
check_intersection_all(env, &false_reg2_c, "BS_false_reg2_c");
reg_bounds_sync(&false_reg1_c);
reg_bounds_sync(&false_reg2_c);
f1 = check_intersection_all(env, &false_reg1_c, "AS_false_reg1_c");
f2 = check_intersection_all(env, &false_reg2_c, "AS_false_reg2_c");

/* jump (TRUE) branch */
check_intersection_all(env, &true_reg1, "BR_true_reg1");
check_intersection_all(env, &true_reg2, "BR_true_reg2");
regs_refine_cond_op(&true_reg1, &true_reg2, opcode, is_jmp32);
check_intersection_all(env, &true_reg1, "BS_true_reg1");
check_intersection_all(env, &true_reg2, "BS_true_reg2");
reg_bounds_sync(&true_reg1);
reg_bounds_sync(&true_reg2);
t1 = check_intersection_all(env, &true_reg1, "AS_true_reg1");
t2 = check_intersection_all(env, &true_reg2, "AS_true_reg2");

if (!f1 || !f2) {
/* If there is no intersection among *any pair* of abstract values in
* either reg_states in the FALSE branch (i.e. false_reg1, false_reg2),
* the FALSE branch must be dead. Only TRUE branch will be taken.
*/
return 1;
} else if (!t1 || !t2) {
/* If there is no intersection among *any pair* of abstract values in
* either reg_states in the TRUE branch (i.e. true_reg1, true_reg2),
* the TRUE branch must be dead. Only FALSE branch will be taken.
*/
return 0;
}

return -1;
}

/*
* <reg1> <op> <reg2>, currently assuming reg2 is a constant
*/
static int is_scalar_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg2,
u8 opcode, bool is_jmp32)
static int is_scalar_branch_taken(struct bpf_verifier_env *env,
struct bpf_reg_state *reg1,
struct bpf_reg_state *reg2,
u8 opcode, bool is_jmp32)
{
struct tnum t1 = is_jmp32 ? tnum_subreg(reg1->var_off) : reg1->var_off;
struct tnum t2 = is_jmp32 ? tnum_subreg(reg2->var_off) : reg2->var_off;
Expand Down Expand Up @@ -16215,7 +16369,7 @@ static int is_scalar_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_sta
break;
}

return -1;
return simulate_both_branches_taken(env, reg1, reg2, opcode, is_jmp32);
}

static int flip_opcode(u32 opcode)
Expand Down Expand Up @@ -16286,8 +16440,9 @@ static int is_pkt_ptr_branch_taken(struct bpf_reg_state *dst_reg,
* -1 - unknown. Example: "if (reg1 < 5)" is unknown when register value
* range [0,10]
*/
static int is_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg2,
u8 opcode, bool is_jmp32)
static int is_branch_taken(struct bpf_verifier_env *env,
struct bpf_reg_state *reg1, struct bpf_reg_state *reg2,
u8 opcode, bool is_jmp32)
{
if (reg_is_pkt_pointer_any(reg1) && reg_is_pkt_pointer_any(reg2) && !is_jmp32)
return is_pkt_ptr_branch_taken(reg1, reg2, opcode);
Expand Down Expand Up @@ -16325,7 +16480,7 @@ static int is_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg
}

/* now deal with two scalars, but not necessarily constants */
return is_scalar_branch_taken(reg1, reg2, opcode, is_jmp32);
return is_scalar_branch_taken(env, reg1, reg2, opcode, is_jmp32);
}

/* Opcode that corresponds to a *false* branch condition.
Expand Down Expand Up @@ -16933,7 +17088,7 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env,
}

is_jmp32 = BPF_CLASS(insn->code) == BPF_JMP32;
pred = is_branch_taken(dst_reg, src_reg, opcode, is_jmp32);
pred = is_branch_taken(env, dst_reg, src_reg, opcode, is_jmp32);
if (pred >= 0) {
/* If we get here with a dst_reg pointer type it is because
* above is_branch_taken() special cased the 0 comparison.
Expand Down
Loading