diff --git a/Makefile b/Makefile index 96f7927..daedcb2 100644 --- a/Makefile +++ b/Makefile @@ -8,11 +8,12 @@ CFLAGS = -std=c99 -Wall -Wextra -pedantic -Werror -O2 \ -D_FORTIFY_SOURCE=2 -fstack-protector-strong -fPIE -fcf-protection \ -Isrc -Isrc/fe -Isrc/ir -Isrc/amdgpu LDFLAGS = -pie +LIBS = -lm # Linux/ELF only: -Wl,-z,relro,-z,now -Wl,-z,noexecstack SOURCES = src/main.c \ src/fe/preproc.c src/fe/lexer.c src/fe/parser.c src/fe/sema.c \ - src/ir/bir.c src/ir/bir_print.c src/ir/bir_lower.c src/ir/bir_mem2reg.c src/ir/bir_dce.c \ + src/ir/bir.c src/ir/bir_print.c src/ir/bir_lower.c src/ir/bir_mem2reg.c src/ir/bir_cfold.c src/ir/bir_dce.c \ src/amdgpu/isel.c src/amdgpu/emit.c src/amdgpu/encode.c src/amdgpu/enc_tab.c \ src/tensix_isel.c src/tensix_emit.c src/tensix_coarsen.c src/tensix_datamov.c OBJECTS = $(SOURCES:.c=.o) @@ -21,7 +22,7 @@ TARGET = barracuda all: $(TARGET) $(TARGET): $(OBJECTS) - $(CC) $(CFLAGS) $(LDFLAGS) -o $@ $^ + $(CC) $(CFLAGS) $(LDFLAGS) -o $@ $^ $(LIBS) %.o: %.c $(CC) $(CFLAGS) -c $< -o $@ @@ -31,9 +32,10 @@ TCFLAGS = -std=c99 -D_POSIX_C_SOURCE=200809L -Wall -Wextra -O0 -g \ -Isrc -Isrc/fe -Isrc/ir -Isrc/amdgpu TSRC = tests/tmain.c tests/tsmoke.c tests/tcomp.c tests/tenc.c \ tests/ttabs.c tests/ttypes.c tests/terrs.c tests/tphase.c \ - tests/tdce.c + tests/tdce.c \ + tests/tcfold.c TOBJS = $(TSRC:.c=.o) -COBJS = src/ir/bir.o src/ir/bir_print.o src/ir/bir_lower.o src/ir/bir_mem2reg.o src/ir/bir_dce.o \ +COBJS = src/ir/bir.o src/ir/bir_print.o src/ir/bir_lower.o src/ir/bir_mem2reg.o src/ir/bir_cfold.o src/ir/bir_dce.o \ src/amdgpu/encode.o src/amdgpu/enc_tab.o src/amdgpu/isel.o src/amdgpu/emit.o \ src/fe/lexer.o src/fe/parser.o src/fe/preproc.o src/fe/sema.o @@ -41,7 +43,7 @@ test: $(TARGET) trunner ./trunner --all trunner: $(TOBJS) $(COBJS) - $(CC) $(TCFLAGS) -o $@ $^ + $(CC) $(TCFLAGS) -o $@ $^ $(LIBS) tests/%.o: tests/%.c $(CC) $(TCFLAGS) -c $< -o $@ diff --git a/src/ir/bir_cfold.c b/src/ir/bir_cfold.c new file mode 100644 index 0000000..bd6d39d --- /dev/null +++ b/src/ir/bir_cfold.c @@ -0,0 +1,405 @@ +#include "bir_cfold.h" +#include + +/* + * bir_cfold: constant folding. + * + * Single forward pass per function. SSA guarantees defs dominate + * uses, so when we reach an instruction all its operands have + * already been resolved through val_rewrite[]. After folding, + * DCE cleans up the dead instructions. + */ + +/* ---- Helpers ---- */ + +/* Is inline operand j a block reference (not a value reference)? */ +static int is_inline_block_ref(uint16_t op, uint8_t j) +{ + switch (op) { + case BIR_BR: return j == 0; + case BIR_BR_COND: return j == 1 || j == 2; + case BIR_SWITCH: return j == 1; + case BIR_PHI: return j % 2 == 0; + default: return 0; + } +} + +/* Is extra operand j a block reference? */ +static int is_extra_block_ref(uint16_t op, uint32_t j) +{ + if (op == BIR_PHI) return j % 2 == 0; + if (op == BIR_SWITCH) return j == 1 || (j >= 3 && j % 2 == 1); + return 0; +} + +/* Rewrite a single operand through val_rewrite[]. */ +static uint32_t rewrite_val(uint32_t ref, const uint32_t *val_rewrite) +{ + if (ref == BIR_VAL_NONE || BIR_VAL_IS_CONST(ref)) + return ref; + uint32_t rw = val_rewrite[BIR_VAL_INDEX(ref)]; + return rw != BIR_VAL_NONE ? rw : ref; +} + +/* Get the integer bit width from a type index. Returns 0 if not integer. */ +static int int_width(const bir_module_t *M, uint32_t tidx) +{ + if (tidx >= M->num_types) return 0; + if (M->types[tidx].kind != BIR_TYPE_INT) return 0; + return (int)M->types[tidx].width; +} + +/* Get the float bit width from a type index. Returns 0 if not float. */ +static int float_width(const bir_module_t *M, uint32_t tidx) +{ + if (tidx >= M->num_types) return 0; + if (M->types[tidx].kind != BIR_TYPE_FLOAT) return 0; + return (int)M->types[tidx].width; +} + +/* Mask an integer to w bits (unsigned). */ +static int64_t mask_to_width(int64_t val, int w) +{ + if (w >= 64) return val; + return val & (int64_t)((1ULL << w) - 1); +} + +/* Sign-extend an integer from w bits to int64_t. */ +static int64_t sign_extend(int64_t val, int w) +{ + if (w >= 64) return val; + val = mask_to_width(val, w); + int64_t sign_bit = (int64_t)(1ULL << (w - 1)); + return (val ^ sign_bit) - sign_bit; +} + +/* ---- Working State ---- */ + +typedef struct { + bir_module_t *M; + uint32_t val_rewrite[BIR_MAX_INSTS]; +} cf_t; + +static cf_t G; + +/* ---- Fold One Instruction ---- */ + +/* + * Try to fold instruction at absolute index ii. + * If foldable, sets S->val_rewrite[ii] and returns 1. + * Otherwise returns 0. + */ +static int try_fold(cf_t *S, uint32_t ii) +{ + bir_module_t *M = S->M; + bir_inst_t *I = &M->insts[ii]; + uint16_t op = I->op; + + /* SELECT with constant condition — propagate chosen operand */ + if (op == BIR_SELECT && I->num_operands >= 3) { + uint32_t cond = I->operands[0]; + if (!BIR_VAL_IS_CONST(cond)) return 0; + uint32_t ci = BIR_VAL_INDEX(cond); + if (ci >= M->num_consts) return 0; + const bir_const_t *c = &M->consts[ci]; + if (c->kind != BIR_CONST_INT) return 0; + S->val_rewrite[ii] = c->d.ival ? I->operands[1] : I->operands[2]; + return 1; + } + + /* From here, require all value operands to be constants */ + if (I->num_operands == 0 || I->num_operands == BIR_OPERANDS_OVERFLOW) + return 0; + + for (uint8_t j = 0; j < I->num_operands && j < BIR_OPERANDS_INLINE; j++) { + if (is_inline_block_ref(op, j)) continue; + uint32_t v = I->operands[j]; + if (v == BIR_VAL_NONE) continue; + if (!BIR_VAL_IS_CONST(v)) return 0; + } + + /* Collect constant operands */ + const bir_const_t *c0 = NULL, *c1 = NULL; + + if (I->num_operands >= 1 && !is_inline_block_ref(op, 0)) { + uint32_t idx = BIR_VAL_INDEX(I->operands[0]); + if (idx < M->num_consts) c0 = &M->consts[idx]; + } + if (I->num_operands >= 2 && !is_inline_block_ref(op, 1)) { + uint32_t idx = BIR_VAL_INDEX(I->operands[1]); + if (idx < M->num_consts) c1 = &M->consts[idx]; + } + + int rw = int_width(M, I->type); + int fw = float_width(M, I->type); + + /* Integer comparison (before binary ops — icmp result is i1, + * which would match the rw > 0 gate and hit default: return 0) */ + if (op == BIR_ICMP && c0 && c1 + && c0->kind == BIR_CONST_INT && c1->kind == BIR_CONST_INT) { + int64_t a = c0->d.ival, b = c1->d.ival; + int sw = int_width(M, c0->type); + if (sw == 0) sw = 32; + int64_t sa = sign_extend(a, sw), sb = sign_extend(b, sw); + uint64_t ua = (uint64_t)mask_to_width(a, sw); + uint64_t ub = (uint64_t)mask_to_width(b, sw); + int result; + + switch (I->subop) { + case BIR_ICMP_EQ: result = (ua == ub); break; + case BIR_ICMP_NE: result = (ua != ub); break; + case BIR_ICMP_SLT: result = (sa < sb); break; + case BIR_ICMP_SLE: result = (sa <= sb); break; + case BIR_ICMP_SGT: result = (sa > sb); break; + case BIR_ICMP_SGE: result = (sa >= sb); break; + case BIR_ICMP_ULT: result = (ua < ub); break; + case BIR_ICMP_ULE: result = (ua <= ub); break; + case BIR_ICMP_UGT: result = (ua > ub); break; + case BIR_ICMP_UGE: result = (ua >= ub); break; + default: return 0; + } + + uint32_t ci = bir_const_int(M, I->type, result); + S->val_rewrite[ii] = BIR_MAKE_CONST(ci); + return 1; + } + + /* Integer binary ops */ + if (c0 && c1 && c0->kind == BIR_CONST_INT && c1->kind == BIR_CONST_INT + && rw > 0) { + int64_t a = c0->d.ival, b = c1->d.ival; + int64_t r; + int sw = int_width(M, c0->type); + if (sw == 0) sw = rw; + + switch (op) { + case BIR_ADD: r = a + b; break; + case BIR_SUB: r = a - b; break; + case BIR_MUL: r = a * b; break; + case BIR_SDIV: + if (b == 0) return 0; + if (sign_extend(b, sw) == -1 && sign_extend(a, sw) == INT64_MIN) + return 0; + r = sign_extend(a, sw) / sign_extend(b, sw); + break; + case BIR_UDIV: + if (b == 0) return 0; + r = (int64_t)((uint64_t)mask_to_width(a, sw) + / (uint64_t)mask_to_width(b, sw)); + break; + case BIR_SREM: + if (b == 0) return 0; + if (sign_extend(b, sw) == -1 && sign_extend(a, sw) == INT64_MIN) + return 0; + r = sign_extend(a, sw) % sign_extend(b, sw); + break; + case BIR_UREM: + if (b == 0) return 0; + r = (int64_t)((uint64_t)mask_to_width(a, sw) + % (uint64_t)mask_to_width(b, sw)); + break; + case BIR_AND: r = a & b; break; + case BIR_OR: r = a | b; break; + case BIR_XOR: r = a ^ b; break; + case BIR_SHL: r = (int64_t)((uint64_t)a << (b & (sw - 1))); break; + case BIR_LSHR: + r = (int64_t)((uint64_t)mask_to_width(a, sw) + >> (b & (sw - 1))); + break; + case BIR_ASHR: + r = sign_extend(a, sw) >> (b & (sw - 1)); + break; + default: return 0; + } + + r = mask_to_width(r, rw); + uint32_t ci = bir_const_int(M, I->type, r); + S->val_rewrite[ii] = BIR_MAKE_CONST(ci); + return 1; + } + + /* Float comparison */ + if (op == BIR_FCMP && c0 && c1 + && c0->kind == BIR_CONST_FLOAT && c1->kind == BIR_CONST_FLOAT) { + double a = c0->d.fval, b = c1->d.fval; + int ord = (a == a) && (b == b); /* neither is NaN */ + int result; + + switch (I->subop) { + case BIR_FCMP_OEQ: result = ord && (a == b); break; + case BIR_FCMP_ONE: result = ord && (a != b); break; + case BIR_FCMP_OLT: result = ord && (a < b); break; + case BIR_FCMP_OLE: result = ord && (a <= b); break; + case BIR_FCMP_OGT: result = ord && (a > b); break; + case BIR_FCMP_OGE: result = ord && (a >= b); break; + case BIR_FCMP_UEQ: result = !ord || (a == b); break; + case BIR_FCMP_UNE: result = !ord || (a != b); break; + case BIR_FCMP_ULT: result = !ord || (a < b); break; + case BIR_FCMP_ULE: result = !ord || (a <= b); break; + case BIR_FCMP_UGT: result = !ord || (a > b); break; + case BIR_FCMP_UGE: result = !ord || (a >= b); break; + case BIR_FCMP_ORD: result = ord; break; + case BIR_FCMP_UNO: result = !ord; break; + default: return 0; + } + + uint32_t ci = bir_const_int(M, I->type, result); + S->val_rewrite[ii] = BIR_MAKE_CONST(ci); + return 1; + } + + /* Float binary ops */ + if (c0 && c1 && c0->kind == BIR_CONST_FLOAT + && c1->kind == BIR_CONST_FLOAT && fw > 0) { + double a = c0->d.fval, b = c1->d.fval; + double r; + + switch (op) { + case BIR_FADD: r = a + b; break; + case BIR_FSUB: r = a - b; break; + case BIR_FMUL: r = a * b; break; + case BIR_FDIV: r = a / b; break; + case BIR_FREM: r = fmod(a, b); break; + case BIR_FMAX: r = a > b ? a : b; break; + case BIR_FMIN: r = a < b ? a : b; break; + default: return 0; + } + + if (fw == 32) r = (double)(float)r; + uint32_t ci = bir_const_float(M, I->type, r); + S->val_rewrite[ii] = BIR_MAKE_CONST(ci); + return 1; + } + + /* Integer conversions (unary, one constant operand) */ + if (c0 && c0->kind == BIR_CONST_INT && rw > 0) { + int64_t a = c0->d.ival; + int64_t r; + int sw = int_width(M, c0->type); + if (sw == 0) sw = 32; + + switch (op) { + case BIR_TRUNC: r = mask_to_width(a, rw); break; + case BIR_ZEXT: r = mask_to_width(a, sw); break; + case BIR_SEXT: r = sign_extend(a, sw); break; + default: return 0; + } + + r = mask_to_width(r, rw); + uint32_t ci = bir_const_int(M, I->type, r); + S->val_rewrite[ii] = BIR_MAKE_CONST(ci); + return 1; + } + + /* Int-to-float conversions */ + if (c0 && c0->kind == BIR_CONST_INT && fw > 0) { + int64_t a = c0->d.ival; + int sw = int_width(M, c0->type); + if (sw == 0) sw = 32; + double r; + + switch (op) { + case BIR_SITOFP: r = (double)sign_extend(a, sw); break; + case BIR_UITOFP: r = (double)(uint64_t)mask_to_width(a, sw); break; + default: return 0; + } + + if (fw == 32) r = (double)(float)r; + uint32_t ci = bir_const_float(M, I->type, r); + S->val_rewrite[ii] = BIR_MAKE_CONST(ci); + return 1; + } + + /* Float-to-int conversions */ + if (c0 && c0->kind == BIR_CONST_FLOAT && rw > 0) { + double a = c0->d.fval; + int64_t r; + + switch (op) { + case BIR_FPTOSI: r = (int64_t)a; break; + case BIR_FPTOUI: r = (int64_t)(uint64_t)a; break; + default: return 0; + } + + r = mask_to_width(r, rw); + uint32_t ci = bir_const_int(M, I->type, r); + S->val_rewrite[ii] = BIR_MAKE_CONST(ci); + return 1; + } + + /* Float width conversions */ + if (c0 && c0->kind == BIR_CONST_FLOAT && fw > 0) { + double r = c0->d.fval; + + switch (op) { + case BIR_FPTRUNC: + if (fw == 32) r = (double)(float)r; + else return 0; /* skip folding for f16 */ + break; + case BIR_FPEXT: break; /* already double internally */ + default: return 0; + } + + uint32_t ci = bir_const_float(M, I->type, r); + S->val_rewrite[ii] = BIR_MAKE_CONST(ci); + return 1; + } + + return 0; +} + +/* ---- Per-Function Pass ---- */ + +static int cf_run_func(cf_t *S, uint32_t fi) +{ + bir_module_t *M = S->M; + const bir_func_t *F = &M->funcs[fi]; + if (F->num_blocks == 0 || F->total_insts == 0) return 0; + + uint32_t base = M->blocks[F->first_block].first_inst; + uint32_t end = base + F->total_insts; + int changes = 0; + + for (uint32_t i = base; i < end; i++) + S->val_rewrite[i] = BIR_VAL_NONE; + + for (uint32_t i = base; i < end; i++) { + bir_inst_t *I = &M->insts[i]; + + /* Rewrite operands through val_rewrite[] */ + if (I->num_operands == BIR_OPERANDS_OVERFLOW) { + uint32_t start = I->operands[0]; + uint32_t count = I->operands[1]; + for (uint32_t j = 0; j < count + && (start + j) < M->num_extra_ops; j++) { + if (is_extra_block_ref(I->op, j)) continue; + M->extra_operands[start + j] = + rewrite_val(M->extra_operands[start + j], + S->val_rewrite); + } + } else { + for (uint8_t j = 0; j < I->num_operands + && j < BIR_OPERANDS_INLINE; j++) { + if (is_inline_block_ref(I->op, j)) continue; + I->operands[j] = rewrite_val(I->operands[j], + S->val_rewrite); + } + } + + /* Try to fold */ + changes += try_fold(S, i); + } + + return changes; +} + +/* ---- Public API ---- */ + +int bir_cfold(bir_module_t *M) +{ + G.M = M; + int total = 0; + for (uint32_t fi = 0; fi < M->num_funcs; fi++) + total += cf_run_func(&G, fi); + return total; +} diff --git a/src/ir/bir_cfold.h b/src/ir/bir_cfold.h new file mode 100644 index 0000000..5826cc8 --- /dev/null +++ b/src/ir/bir_cfold.h @@ -0,0 +1,18 @@ +#ifndef BARRACUDA_BIR_CFOLD_H +#define BARRACUDA_BIR_CFOLD_H + +#include "bir.h" + +/* + * Constant folding. + * + * Runs after mem2reg, before DCE. Evaluates constant expressions + * at compile time (arithmetic, comparisons, conversions, select) + * and replaces them with constant values. DCE cleans up the + * now-dead instructions afterward. + * + * Returns the total number of instructions folded (>= 0). + */ +int bir_cfold(bir_module_t *M); + +#endif /* BARRACUDA_BIR_CFOLD_H */ diff --git a/src/main.c b/src/main.c index d9d2d0a..93e6442 100644 --- a/src/main.c +++ b/src/main.c @@ -4,6 +4,7 @@ #include "sema.h" #include "bir_lower.h" #include "bir_mem2reg.h" +#include "bir_cfold.h" #include "bir_dce.h" #include "amdgpu.h" #include "tensix.h" @@ -62,6 +63,7 @@ static void usage(const char *prog) " --parse Parse and dump AST\n" " --ir Lower to BIR and print IR\n" " --no-mem2reg Skip mem2reg optimization pass\n" + " --no-cfold Skip constant folding\n" " --no-dce Skip dead code elimination\n" " --sema Run semantic analysis and dump types\n" " --pp Preprocess only and print result\n" @@ -91,6 +93,7 @@ int main(int argc, char *argv[]) int mode_amdgpu_bin = 0; int mode_tensix = 0; int no_mem2reg = 0; + int no_cfold = 0; int no_dce = 0; int no_pp = 0; amd_target_t amd_target = AMD_TARGET_GFX1100; @@ -176,6 +179,8 @@ int main(int argc, char *argv[]) defines[num_defines++] = argv[i] + 2; } else if (strcmp(argv[i], "--no-mem2reg") == 0) no_mem2reg = 1; + else if (strcmp(argv[i], "--no-cfold") == 0) + no_cfold = 1; else if (strcmp(argv[i], "--no-dce") == 0) no_dce = 1; else if (strcmp(argv[i], "--help") == 0 || strcmp(argv[i], "-h") == 0) { @@ -330,6 +335,8 @@ int main(int argc, char *argv[]) if (lrc == BC_OK) { if (!no_mem2reg) bir_mem2reg(bir_module); + if (!no_cfold) + bir_cfold(bir_module); if (!no_dce) bir_dce(bir_module); diff --git a/tests/tcfold.c b/tests/tcfold.c new file mode 100644 index 0000000..2904070 --- /dev/null +++ b/tests/tcfold.c @@ -0,0 +1,134 @@ +/* tcfold.c -- Constant folding tests. + * Verify that constant folding evaluates constant expressions. */ + +#include "tharns.h" + +static char obuf[TH_BUFSZ]; + +/* ---- Helpers ---- */ + +static const char *strnstr_range(const char *start, const char *end, + const char *needle) +{ + size_t nlen = strlen(needle); + for (const char *p = start; p + nlen <= end; p++) { + if (memcmp(p, needle, nlen) == 0) return p; + } + return NULL; +} + +/* ---- cf: integer arithmetic folded ---- */ + +static void cf_int_arith(void) +{ + int rc = th_run(BC_BIN " --ir tests/test_cf.cu", obuf, TH_BUFSZ); + CHEQ(rc, 0); + const char *fn = strstr(obuf, "cf_int_arith"); + CHECK(fn != NULL); + const char *body = strchr(fn, '\n'); + CHECK(body != NULL); + const char *fn_end = strstr(body, "\n}"); + CHECK(fn_end != NULL); + /* The constant add (3+4) must be folded — add with param survives */ + CHECK(strnstr_range(body, fn_end, "= add") != NULL); + /* The constant 7 (= 3+4) should appear as an operand */ + CHECK(strnstr_range(body, fn_end, " 7") != NULL); + PASS(); +} +TH_REG("cf", cf_int_arith) + +/* ---- cf: chained constants fold ---- */ + +static void cf_chain(void) +{ + int rc = th_run(BC_BIN " --ir tests/test_cf.cu", obuf, TH_BUFSZ); + CHEQ(rc, 0); + const char *fn = strstr(obuf, "cf_chain"); + CHECK(fn != NULL); + const char *body = strchr(fn, '\n'); + CHECK(body != NULL); + const char *fn_end = strstr(body, "\n}"); + CHECK(fn_end != NULL); + /* mul (5*4=20) must be folded away */ + CHECK(strnstr_range(body, fn_end, "= mul") == NULL); + /* The constant 20 should appear as an operand */ + CHECK(strnstr_range(body, fn_end, " 20") != NULL); + PASS(); +} +TH_REG("cf", cf_chain) + +/* ---- cf: icmp + select with constant condition ---- */ + +static void cf_icmp_select(void) +{ + int rc = th_run(BC_BIN " --ir tests/test_cf.cu", obuf, TH_BUFSZ); + CHEQ(rc, 0); + const char *fn = strstr(obuf, "cf_icmp_select"); + CHECK(fn != NULL); + const char *body = strchr(fn, '\n'); + CHECK(body != NULL); + const char *fn_end = strstr(body, "\n}"); + CHECK(fn_end != NULL); + /* icmp and select must be folded away */ + CHECK(strnstr_range(body, fn_end, "= icmp") == NULL); + CHECK(strnstr_range(body, fn_end, "= select") == NULL); + PASS(); +} +TH_REG("cf", cf_icmp_select) + +/* ---- cf: integer division by zero not folded ---- */ + +static void cf_divzero(void) +{ + int rc = th_run(BC_BIN " --ir tests/test_cf.cu", obuf, TH_BUFSZ); + CHEQ(rc, 0); + const char *fn = strstr(obuf, "cf_divzero"); + CHECK(fn != NULL); + const char *body = strchr(fn, '\n'); + CHECK(body != NULL); + const char *fn_end = strstr(body, "\n}"); + CHECK(fn_end != NULL); + /* sdiv by zero must survive — folding it would be UB */ + CHECK(strnstr_range(body, fn_end, "= sdiv") != NULL); + PASS(); +} +TH_REG("cf", cf_divzero) + +/* ---- cf: integer/float conversions folded ---- */ + +static void cf_conv(void) +{ + int rc = th_run(BC_BIN " --ir tests/test_cf.cu", obuf, TH_BUFSZ); + CHEQ(rc, 0); + const char *fn = strstr(obuf, "cf_conv"); + CHECK(fn != NULL); + const char *body = strchr(fn, '\n'); + CHECK(body != NULL); + const char *fn_end = strstr(body, "\n}"); + CHECK(fn_end != NULL); + /* sitofp and fptosi must be folded away */ + CHECK(strnstr_range(body, fn_end, "= sitofp") == NULL); + CHECK(strnstr_range(body, fn_end, "= fptosi") == NULL); + PASS(); +} +TH_REG("cf", cf_conv) + +/* ---- cf: float arithmetic folded ---- */ + +static void cf_float(void) +{ + int rc = th_run(BC_BIN " --ir tests/test_cf.cu", obuf, TH_BUFSZ); + CHEQ(rc, 0); + const char *fn = strstr(obuf, "cf_float"); + CHECK(fn != NULL); + const char *body = strchr(fn, '\n'); + CHECK(body != NULL); + const char *fn_end = strstr(body, "\n}"); + CHECK(fn_end != NULL); + /* fadd must be folded away — constant 4 (=1.5+2.5) stored directly */ + CHECK(strnstr_range(body, fn_end, "= fadd") == NULL); + /* The folded value 4 should appear as an f32 operand */ + CHECK(strnstr_range(body, fn_end, "f32 4") != NULL); + PASS(); +} +TH_REG("cf", cf_float) diff --git a/tests/test_cf.cu b/tests/test_cf.cu new file mode 100644 index 0000000..d9d1b11 --- /dev/null +++ b/tests/test_cf.cu @@ -0,0 +1,40 @@ +/* test_cf.cu — Constant folding test cases. + * + * Each kernel targets a specific constant folding edge case. + * The test harness compiles with --ir and checks which + * instructions survived. */ + +/* Integer arithmetic with constant operands */ +__global__ void cf_int_arith(int *out, int x) { + out[0] = x + (3 + 4); + out[1] = x + (10 - 3); +} + +/* Chained folding: a=2+3, b=a*4, c=b+x */ +__global__ void cf_chain(int *out, int x) { + int a = 2 + 3; + int b = a * 4; + out[0] = b + x; +} + +/* Integer comparison + select folds away */ +__global__ void cf_icmp_select(int *out, int x) { + out[0] = (2 < 5) ? x : 0; +} + +/* Integer division by zero must not fold (undefined behavior) */ +__global__ void cf_divzero(int *out) { + out[0] = 1 / 0; +} + +/* Integer/float conversions with constant operands */ +__global__ void cf_conv(int *iout, float *fout) { + fout[0] = (float)42; /* sitofp */ + iout[0] = (int)3.14f; /* fptosi */ +} + +/* Float arithmetic with constant operands */ +__global__ void cf_float(float *out) { + out[0] = 1.5f + 2.5f; +} +