Skip to content

Commit 4d34588

Browse files
anakryikoAlexei Starovoitov
authored andcommitted
bpf: unify 32-bit and 64-bit is_branch_taken logic
Combine 32-bit and 64-bit is_branch_taken logic for SCALAR_VALUE registers. It makes it easier to see parallels between two domains (32-bit and 64-bit), and makes subsequent refactoring more straightforward. No functional changes. Acked-by: Eduard Zingerman <[email protected]> Signed-off-by: Andrii Nakryiko <[email protected]> Link: https://lore.kernel.org/r/[email protected] Signed-off-by: Alexei Starovoitov <[email protected]>
1 parent b74c2a8 commit 4d34588

File tree

1 file changed

+59
-141
lines changed

1 file changed

+59
-141
lines changed

kernel/bpf/verifier.c

Lines changed: 59 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -14204,166 +14204,86 @@ static u64 reg_const_value(struct bpf_reg_state *reg, bool subreg32)
1420414204
/*
1420514205
* <reg1> <op> <reg2>, currently assuming reg2 is a constant
1420614206
*/
14207-
static int is_branch32_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg2, u8 opcode)
14207+
static int is_scalar_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg2,
14208+
u8 opcode, bool is_jmp32)
1420814209
{
14209-
struct tnum subreg = tnum_subreg(reg1->var_off);
14210-
u32 val = (u32)tnum_subreg(reg2->var_off).value;
14211-
s32 sval = (s32)val;
14210+
struct tnum t1 = is_jmp32 ? tnum_subreg(reg1->var_off) : reg1->var_off;
14211+
u64 umin1 = is_jmp32 ? (u64)reg1->u32_min_value : reg1->umin_value;
14212+
u64 umax1 = is_jmp32 ? (u64)reg1->u32_max_value : reg1->umax_value;
14213+
s64 smin1 = is_jmp32 ? (s64)reg1->s32_min_value : reg1->smin_value;
14214+
s64 smax1 = is_jmp32 ? (s64)reg1->s32_max_value : reg1->smax_value;
14215+
u64 uval = is_jmp32 ? (u32)tnum_subreg(reg2->var_off).value : reg2->var_off.value;
14216+
s64 sval = is_jmp32 ? (s32)uval : (s64)uval;
1421214217

1421314218
switch (opcode) {
1421414219
case BPF_JEQ:
14215-
if (tnum_is_const(subreg))
14216-
return !!tnum_equals_const(subreg, val);
14217-
else if (val < reg1->u32_min_value || val > reg1->u32_max_value)
14220+
if (tnum_is_const(t1))
14221+
return !!tnum_equals_const(t1, uval);
14222+
else if (uval < umin1 || uval > umax1)
1421814223
return 0;
14219-
else if (sval < reg1->s32_min_value || sval > reg1->s32_max_value)
14224+
else if (sval < smin1 || sval > smax1)
1422014225
return 0;
1422114226
break;
1422214227
case BPF_JNE:
14223-
if (tnum_is_const(subreg))
14224-
return !tnum_equals_const(subreg, val);
14225-
else if (val < reg1->u32_min_value || val > reg1->u32_max_value)
14228+
if (tnum_is_const(t1))
14229+
return !tnum_equals_const(t1, uval);
14230+
else if (uval < umin1 || uval > umax1)
1422614231
return 1;
14227-
else if (sval < reg1->s32_min_value || sval > reg1->s32_max_value)
14232+
else if (sval < smin1 || sval > smax1)
1422814233
return 1;
1422914234
break;
1423014235
case BPF_JSET:
14231-
if ((~subreg.mask & subreg.value) & val)
14236+
if ((~t1.mask & t1.value) & uval)
1423214237
return 1;
14233-
if (!((subreg.mask | subreg.value) & val))
14238+
if (!((t1.mask | t1.value) & uval))
1423414239
return 0;
1423514240
break;
1423614241
case BPF_JGT:
14237-
if (reg1->u32_min_value > val)
14242+
if (umin1 > uval )
1423814243
return 1;
14239-
else if (reg1->u32_max_value <= val)
14244+
else if (umax1 <= uval)
1424014245
return 0;
1424114246
break;
1424214247
case BPF_JSGT:
14243-
if (reg1->s32_min_value > sval)
14248+
if (smin1 > sval)
1424414249
return 1;
14245-
else if (reg1->s32_max_value <= sval)
14250+
else if (smax1 <= sval)
1424614251
return 0;
1424714252
break;
1424814253
case BPF_JLT:
14249-
if (reg1->u32_max_value < val)
14254+
if (umax1 < uval)
1425014255
return 1;
14251-
else if (reg1->u32_min_value >= val)
14256+
else if (umin1 >= uval)
1425214257
return 0;
1425314258
break;
1425414259
case BPF_JSLT:
14255-
if (reg1->s32_max_value < sval)
14260+
if (smax1 < sval)
1425614261
return 1;
14257-
else if (reg1->s32_min_value >= sval)
14262+
else if (smin1 >= sval)
1425814263
return 0;
1425914264
break;
1426014265
case BPF_JGE:
14261-
if (reg1->u32_min_value >= val)
14266+
if (umin1 >= uval)
1426214267
return 1;
14263-
else if (reg1->u32_max_value < val)
14268+
else if (umax1 < uval)
1426414269
return 0;
1426514270
break;
1426614271
case BPF_JSGE:
14267-
if (reg1->s32_min_value >= sval)
14272+
if (smin1 >= sval)
1426814273
return 1;
14269-
else if (reg1->s32_max_value < sval)
14274+
else if (smax1 < sval)
1427014275
return 0;
1427114276
break;
1427214277
case BPF_JLE:
14273-
if (reg1->u32_max_value <= val)
14278+
if (umax1 <= uval)
1427414279
return 1;
14275-
else if (reg1->u32_min_value > val)
14280+
else if (umin1 > uval)
1427614281
return 0;
1427714282
break;
1427814283
case BPF_JSLE:
14279-
if (reg1->s32_max_value <= sval)
14284+
if (smax1 <= sval)
1428014285
return 1;
14281-
else if (reg1->s32_min_value > sval)
14282-
return 0;
14283-
break;
14284-
}
14285-
14286-
return -1;
14287-
}
14288-
14289-
14290-
/*
14291-
* <reg1> <op> <reg2>, currently assuming reg2 is a constant
14292-
*/
14293-
static int is_branch64_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg2, u8 opcode)
14294-
{
14295-
u64 val = reg2->var_off.value;
14296-
s64 sval = (s64)val;
14297-
14298-
switch (opcode) {
14299-
case BPF_JEQ:
14300-
if (tnum_is_const(reg1->var_off))
14301-
return !!tnum_equals_const(reg1->var_off, val);
14302-
else if (val < reg1->umin_value || val > reg1->umax_value)
14303-
return 0;
14304-
else if (sval < reg1->smin_value || sval > reg1->smax_value)
14305-
return 0;
14306-
break;
14307-
case BPF_JNE:
14308-
if (tnum_is_const(reg1->var_off))
14309-
return !tnum_equals_const(reg1->var_off, val);
14310-
else if (val < reg1->umin_value || val > reg1->umax_value)
14311-
return 1;
14312-
else if (sval < reg1->smin_value || sval > reg1->smax_value)
14313-
return 1;
14314-
break;
14315-
case BPF_JSET:
14316-
if ((~reg1->var_off.mask & reg1->var_off.value) & val)
14317-
return 1;
14318-
if (!((reg1->var_off.mask | reg1->var_off.value) & val))
14319-
return 0;
14320-
break;
14321-
case BPF_JGT:
14322-
if (reg1->umin_value > val)
14323-
return 1;
14324-
else if (reg1->umax_value <= val)
14325-
return 0;
14326-
break;
14327-
case BPF_JSGT:
14328-
if (reg1->smin_value > sval)
14329-
return 1;
14330-
else if (reg1->smax_value <= sval)
14331-
return 0;
14332-
break;
14333-
case BPF_JLT:
14334-
if (reg1->umax_value < val)
14335-
return 1;
14336-
else if (reg1->umin_value >= val)
14337-
return 0;
14338-
break;
14339-
case BPF_JSLT:
14340-
if (reg1->smax_value < sval)
14341-
return 1;
14342-
else if (reg1->smin_value >= sval)
14343-
return 0;
14344-
break;
14345-
case BPF_JGE:
14346-
if (reg1->umin_value >= val)
14347-
return 1;
14348-
else if (reg1->umax_value < val)
14349-
return 0;
14350-
break;
14351-
case BPF_JSGE:
14352-
if (reg1->smin_value >= sval)
14353-
return 1;
14354-
else if (reg1->smax_value < sval)
14355-
return 0;
14356-
break;
14357-
case BPF_JLE:
14358-
if (reg1->umax_value <= val)
14359-
return 1;
14360-
else if (reg1->umin_value > val)
14361-
return 0;
14362-
break;
14363-
case BPF_JSLE:
14364-
if (reg1->smax_value <= sval)
14365-
return 1;
14366-
else if (reg1->smin_value > sval)
14286+
else if (smin1 > sval)
1436714287
return 0;
1436814288
break;
1436914289
}
@@ -14477,9 +14397,7 @@ static int is_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg
1447714397
}
1447814398
}
1447914399

14480-
if (is_jmp32)
14481-
return is_branch32_taken(reg1, reg2, opcode);
14482-
return is_branch64_taken(reg1, reg2, opcode);
14400+
return is_scalar_branch_taken(reg1, reg2, opcode, is_jmp32);
1448314401
}
1448414402

1448514403
/* Adjusts the register min/max values in the case that the dst_reg is the
@@ -14489,15 +14407,15 @@ static int is_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg
1448914407
*/
1449014408
static void reg_set_min_max(struct bpf_reg_state *true_reg,
1449114409
struct bpf_reg_state *false_reg,
14492-
u64 val, u32 val32,
14410+
u64 uval, u32 uval32,
1449314411
u8 opcode, bool is_jmp32)
1449414412
{
1449514413
struct tnum false_32off = tnum_subreg(false_reg->var_off);
1449614414
struct tnum false_64off = false_reg->var_off;
1449714415
struct tnum true_32off = tnum_subreg(true_reg->var_off);
1449814416
struct tnum true_64off = true_reg->var_off;
14499-
s64 sval = (s64)val;
14500-
s32 sval32 = (s32)val32;
14417+
s64 sval = (s64)uval;
14418+
s32 sval32 = (s32)uval32;
1450114419

1450214420
/* If the dst_reg is a pointer, we can't learn anything about its
1450314421
* variable offset from the compare (unless src_reg were a pointer into
@@ -14520,49 +14438,49 @@ static void reg_set_min_max(struct bpf_reg_state *true_reg,
1452014438
*/
1452114439
case BPF_JEQ:
1452214440
if (is_jmp32) {
14523-
__mark_reg32_known(true_reg, val32);
14441+
__mark_reg32_known(true_reg, uval32);
1452414442
true_32off = tnum_subreg(true_reg->var_off);
1452514443
} else {
14526-
___mark_reg_known(true_reg, val);
14444+
___mark_reg_known(true_reg, uval);
1452714445
true_64off = true_reg->var_off;
1452814446
}
1452914447
break;
1453014448
case BPF_JNE:
1453114449
if (is_jmp32) {
14532-
__mark_reg32_known(false_reg, val32);
14450+
__mark_reg32_known(false_reg, uval32);
1453314451
false_32off = tnum_subreg(false_reg->var_off);
1453414452
} else {
14535-
___mark_reg_known(false_reg, val);
14453+
___mark_reg_known(false_reg, uval);
1453614454
false_64off = false_reg->var_off;
1453714455
}
1453814456
break;
1453914457
case BPF_JSET:
1454014458
if (is_jmp32) {
14541-
false_32off = tnum_and(false_32off, tnum_const(~val32));
14542-
if (is_power_of_2(val32))
14459+
false_32off = tnum_and(false_32off, tnum_const(~uval32));
14460+
if (is_power_of_2(uval32))
1454314461
true_32off = tnum_or(true_32off,
14544-
tnum_const(val32));
14462+
tnum_const(uval32));
1454514463
} else {
14546-
false_64off = tnum_and(false_64off, tnum_const(~val));
14547-
if (is_power_of_2(val))
14464+
false_64off = tnum_and(false_64off, tnum_const(~uval));
14465+
if (is_power_of_2(uval))
1454814466
true_64off = tnum_or(true_64off,
14549-
tnum_const(val));
14467+
tnum_const(uval));
1455014468
}
1455114469
break;
1455214470
case BPF_JGE:
1455314471
case BPF_JGT:
1455414472
{
1455514473
if (is_jmp32) {
14556-
u32 false_umax = opcode == BPF_JGT ? val32 : val32 - 1;
14557-
u32 true_umin = opcode == BPF_JGT ? val32 + 1 : val32;
14474+
u32 false_umax = opcode == BPF_JGT ? uval32 : uval32 - 1;
14475+
u32 true_umin = opcode == BPF_JGT ? uval32 + 1 : uval32;
1455814476

1455914477
false_reg->u32_max_value = min(false_reg->u32_max_value,
1456014478
false_umax);
1456114479
true_reg->u32_min_value = max(true_reg->u32_min_value,
1456214480
true_umin);
1456314481
} else {
14564-
u64 false_umax = opcode == BPF_JGT ? val : val - 1;
14565-
u64 true_umin = opcode == BPF_JGT ? val + 1 : val;
14482+
u64 false_umax = opcode == BPF_JGT ? uval : uval - 1;
14483+
u64 true_umin = opcode == BPF_JGT ? uval + 1 : uval;
1456614484

1456714485
false_reg->umax_value = min(false_reg->umax_value, false_umax);
1456814486
true_reg->umin_value = max(true_reg->umin_value, true_umin);
@@ -14591,16 +14509,16 @@ static void reg_set_min_max(struct bpf_reg_state *true_reg,
1459114509
case BPF_JLT:
1459214510
{
1459314511
if (is_jmp32) {
14594-
u32 false_umin = opcode == BPF_JLT ? val32 : val32 + 1;
14595-
u32 true_umax = opcode == BPF_JLT ? val32 - 1 : val32;
14512+
u32 false_umin = opcode == BPF_JLT ? uval32 : uval32 + 1;
14513+
u32 true_umax = opcode == BPF_JLT ? uval32 - 1 : uval32;
1459614514

1459714515
false_reg->u32_min_value = max(false_reg->u32_min_value,
1459814516
false_umin);
1459914517
true_reg->u32_max_value = min(true_reg->u32_max_value,
1460014518
true_umax);
1460114519
} else {
14602-
u64 false_umin = opcode == BPF_JLT ? val : val + 1;
14603-
u64 true_umax = opcode == BPF_JLT ? val - 1 : val;
14520+
u64 false_umin = opcode == BPF_JLT ? uval : uval + 1;
14521+
u64 true_umax = opcode == BPF_JLT ? uval - 1 : uval;
1460414522

1460514523
false_reg->umin_value = max(false_reg->umin_value, false_umin);
1460614524
true_reg->umax_value = min(true_reg->umax_value, true_umax);
@@ -14649,15 +14567,15 @@ static void reg_set_min_max(struct bpf_reg_state *true_reg,
1464914567
*/
1465014568
static void reg_set_min_max_inv(struct bpf_reg_state *true_reg,
1465114569
struct bpf_reg_state *false_reg,
14652-
u64 val, u32 val32,
14570+
u64 uval, u32 uval32,
1465314571
u8 opcode, bool is_jmp32)
1465414572
{
1465514573
opcode = flip_opcode(opcode);
1465614574
/* This uses zero as "not present in table"; luckily the zero opcode,
1465714575
* BPF_JA, can't get here.
1465814576
*/
1465914577
if (opcode)
14660-
reg_set_min_max(true_reg, false_reg, val, val32, opcode, is_jmp32);
14578+
reg_set_min_max(true_reg, false_reg, uval, uval32, opcode, is_jmp32);
1466114579
}
1466214580

1466314581
/* Regs are known to be equal, so intersect their min/max/var_off */

0 commit comments

Comments
 (0)