Skip to content

safety check for slices! #93

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions compiler/rustc_ast/src/expand/autodiff_attrs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ pub enum DiffActivity {
DualOnly,
Duplicated,
DuplicatedOnly,
FakeActivitySize
}

impl Display for DiffActivity {
Expand All @@ -122,6 +123,7 @@ impl Display for DiffActivity {
DiffActivity::DualOnly => write!(f, "DualOnly"),
DiffActivity::Duplicated => write!(f, "Duplicated"),
DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"),
DiffActivity::FakeActivitySize => write!(f, "FakeActivitySize"),
}
}
}
Expand Down
132 changes: 125 additions & 7 deletions compiler/rustc_codegen_llvm/src/back/write.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#![allow(unused_imports)]
#![allow(unused_variables)]
use crate::llvm::LLVMGetFirstBasicBlock;
use crate::llvm::LLVMBuildCondBr;
use crate::llvm::LLVMBuildICmp;
use crate::llvm::LLVMBuildRetVoid;
use crate::llvm::LLVMRustEraseInstBefore;
use crate::llvm::LLVMRustHasDbgMetadata;
Expand Down Expand Up @@ -47,6 +49,7 @@ use crate::typetree::to_enzyme_typetree;
use crate::DiffTypeTree;
use crate::LlvmCodegenBackend;
use crate::ModuleLlvm;
use llvm::IntPredicate;
use llvm::LLVMRustDISetInstMetadata;
use llvm::{
LLVMRustLLVMHasZlibCompressionForDebugSymbols, LLVMRustLLVMHasZstdCompressionForDebugSymbols, LLVMGetNextBasicBlock,
Expand Down Expand Up @@ -691,7 +694,8 @@ fn get_params(fnc: &Value) -> Vec<&Value> {


unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool,
llmod: &'a llvm::Module, llcx: &llvm::Context) {
llmod: &'a llvm::Module, llcx: &llvm::Context, size_positions: &[usize]) {
dbg!("size_positions: {:?}", size_positions);
// first, remove all calls from fnc
let bb = LLVMGetFirstBasicBlock(tgt);
let br = LLVMRustGetTerminator(bb);
Expand All @@ -707,6 +711,11 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool,
let inner_args: Vec<&Value> = get_params(src);
let mut call_args: Vec<&Value> = vec![];

let mut safety_vals = vec![];
let builder = LLVMCreateBuilderInContext(llcx);
let last_inst = LLVMRustGetLastInstruction(bb).unwrap();
LLVMPositionBuilderAtEnd(builder, bb);

if inner_param_num == outer_param_num {
call_args = outer_args;
} else {
Expand Down Expand Up @@ -745,21 +754,71 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool,

outer_pos += 3;
inner_pos += 2;

// Now we assert if int1 <= int2
let res = LLVMBuildICmp(
builder,
IntPredicate::IntULE as u32,
outer_arg,
next2_outer_arg,
"safety_check".as_ptr() as *const c_char);
safety_vals.push(res);
}
}
}


if inner_param_num as usize != call_args.len() {
panic!("Args len shouldn't differ. Please report this. {} : {}", inner_param_num, call_args.len());
}

// Now add the safety checks.
if !safety_vals.is_empty() {
dbg!("Adding safety checks");
// first we create one bb per check and two more for the fail and success case.
let fail_bb = LLVMAppendBasicBlockInContext(llcx, tgt, "ad_safety_fail".as_ptr() as *const c_char);
let success_bb = LLVMAppendBasicBlockInContext(llcx, tgt, "ad_safety_success".as_ptr() as *const c_char);
let mut err_bb = vec![];
for i in 0..safety_vals.len() {
let name: String = format!("ad_safety_err_{}", i);
err_bb.push(LLVMAppendBasicBlockInContext(llcx, tgt, name.as_ptr() as *const c_char));
}
for (i, &val) in safety_vals.iter().enumerate() {
LLVMBuildCondBr(builder, val, err_bb[i], fail_bb);
LLVMPositionBuilderAtEnd(builder, err_bb[i]);
}
LLVMBuildCondBr(builder, safety_vals.last().unwrap(), success_bb, fail_bb);
LLVMPositionBuilderAtEnd(builder, fail_bb);



let mut arg_vec = vec![add_panic_msg_to_global(llmod, llcx)];
let name1 = "_ZN4core9panicking14panic_explicit17h8607a79b2acfb83bE";
let name2 = "_RN4core9panicking14panic_explicit17h8607a79b2acfb83bE";
let cname1 = CString::new(name1).unwrap();
let cname2 = CString::new(name2).unwrap();

let fnc1 = llvm::LLVMGetNamedFunction(llmod, cname1.as_ptr() as *const c_char);
let call;
if fnc1.is_none() {
let fnc2 = llvm::LLVMGetNamedFunction(llmod, cname1.as_ptr() as *const c_char);
assert!(fnc2.is_some());
let fnc2 = fnc2.unwrap();
let ty = LLVMRustGetFunctionType(fnc2);
// now call with msg
call = LLVMBuildCall2(builder, ty, fnc2, arg_vec.as_mut_ptr(), arg_vec.len(), name2.as_ptr() as *const c_char);
} else {
let fnc1 = fnc1.unwrap();
let ty = LLVMRustGetFunctionType(fnc1);
call = LLVMBuildCall2(builder, ty, fnc1, arg_vec.as_mut_ptr(), arg_vec.len(), name1.as_ptr() as *const c_char);
}
llvm::LLVMSetTailCall(call, 1);
llvm::LLVMBuildUnreachable(builder);
LLVMPositionBuilderAtEnd(builder, success_bb);
}

let inner_fnc_name = llvm::get_value_name(src);
let c_inner_fnc_name = CString::new(inner_fnc_name).unwrap();

let builder = LLVMCreateBuilderInContext(llcx);
let last_inst = LLVMRustGetLastInstruction(bb).unwrap();
LLVMPositionBuilderAtEnd(builder, bb);
let mut struct_ret = LLVMBuildCall2(
builder,
f_ty,
Expand Down Expand Up @@ -788,6 +847,7 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool,

// Now clean up placeholder code.
LLVMRustEraseInstBefore(bb, last_inst);
//dbg!(&tgt);

let f_return_type = LLVMGetReturnType(LLVMGlobalGetValueType(src));
let void_type = LLVMVoidTypeInContext(llcx);
Expand All @@ -810,6 +870,61 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool,
LLVMVerifyFunction(tgt, llvm::LLVMVerifierFailureAction::LLVMAbortProcessAction);
}

unsafe fn add_panic_msg_to_global<'a>(llmod: &'a llvm::Module, llcx: &'a llvm::Context) -> &'a llvm::Value {
use llvm::*;

// Convert the message to a CString
let msg = "autodiff safety check failed!";
let cmsg = CString::new(msg).unwrap();

let msg_global_name = "ad_safety_msg".to_string();
let cmsg_global_name = CString::new(msg_global_name).unwrap();

// Get the length of the message
let msg_len = msg.len();

// Create the array type
let i8_array_type = LLVMRustArrayType(LLVMInt8TypeInContext(llcx), msg_len as u64);

// Create the string constant
let string_const_val = LLVMConstStringInContext(llcx, cmsg.as_ptr() as *const i8, msg_len as u32, 0);

// Create the array initializer
let mut array_elems: Vec<_> = Vec::with_capacity(msg_len);
for i in 0..msg_len {
let char_value = LLVMConstInt(LLVMInt8TypeInContext(llcx), cmsg.as_bytes()[i] as u64, 0);
array_elems.push(char_value);
}
let array_initializer = LLVMConstArray(LLVMInt8TypeInContext(llcx), array_elems.as_mut_ptr(), msg_len as u32);

// Create the struct type
let global_type = LLVMStructTypeInContext(llcx, [i8_array_type].as_mut_ptr(), 1, 0);

// Create the struct initializer
let struct_initializer = LLVMConstStructInContext(llcx, [array_initializer].as_mut_ptr(), 1, 0);

// Add the global variable to the module
let global_var = LLVMAddGlobal(llmod, global_type, cmsg_global_name.as_ptr() as *const i8);
LLVMRustSetLinkage(global_var, Linkage::PrivateLinkage);
LLVMSetInitializer(global_var, struct_initializer);

//let msg_global_name = "ad_safety_msg".to_string();
//let cmsg_global_name = CString::new(msg_global_name).unwrap();
//let msg = "autodiff safety check failed!";
//let cmsg = CString::new(msg).unwrap();
//let msg_len = msg.len();
//let i8_array_type = llvm::LLVMRustArrayType(llvm::LLVMInt8TypeInContext(llcx), msg_len as u64);
//let global_type = llvm::LLVMStructTypeInContext(llcx, [i8_array_type].as_mut_ptr(), 1, 0);
//let string_const_val = llvm::LLVMConstStringInContext(llcx, cmsg.as_ptr() as *const c_char, msg_len as u32, 0);
//let initializer = llvm::LLVMConstStructInContext(llcx, [string_const_val].as_mut_ptr(), 1, 0);
//let global = llvm::LLVMAddGlobal(llmod, global_type, cmsg_global_name.as_ptr() as *const c_char);
//llvm::LLVMRustSetLinkage(global, llvm::Linkage::PrivateLinkage);
//llvm::LLVMSetInitializer(global, initializer);
//llvm::LLVMSetUnnamedAddress(global, llvm::UnnamedAddr::Global);

global_var
}

// As unsafe as it can be.
#[allow(unused_variables)]
#[allow(unused)]
Expand Down Expand Up @@ -895,7 +1010,7 @@ pub(crate) unsafe fn enzyme_ad(
llvm::set_print(true);
}

let mut res: &Value = match item.attrs.mode {
let mut tmp = match item.attrs.mode {
DiffMode::Forward => enzyme_rust_forward_diff(
logic_ref,
type_analysis,
Expand All @@ -916,11 +1031,14 @@ pub(crate) unsafe fn enzyme_ad(
),
_ => unreachable!(),
};
let mut res: &Value = tmp.0;
let size_positions: Vec<usize> = tmp.1;

let f_return_type = LLVMGetReturnType(LLVMGlobalGetValueType(res));

let void_type = LLVMVoidTypeInContext(llcx);
let rev_mode = item.attrs.mode == DiffMode::Reverse;
create_call(target_fnc, res, rev_mode, llmod, llcx);
create_call(target_fnc, res, rev_mode, llmod, llcx, &size_positions);
// TODO: implement drop for wrapper type?
FreeTypeAnalysis(type_analysis);

Expand Down
29 changes: 22 additions & 7 deletions compiler/rustc_codegen_llvm/src/llvm/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,7 @@ pub(crate) unsafe fn enzyme_rust_forward_diff(
ret_diffactivity: DiffActivity,
input_tts: Vec<TypeTree>,
output_tt: TypeTree,
) -> &Value {
) -> (&Value, Vec<usize>) {
let ret_activity = cdiffe_from(ret_diffactivity);
assert!(ret_activity != CDIFFE_TYPE::DFT_OUT_DIFF);
let mut input_activity: Vec<CDIFFE_TYPE> = vec![];
Expand Down Expand Up @@ -893,7 +893,7 @@ pub(crate) unsafe fn enzyme_rust_forward_diff(
KnownValues: known_values.as_mut_ptr(),
};

EnzymeCreateForwardDiff(
let res = EnzymeCreateForwardDiff(
logic_ref, // Logic
std::ptr::null(),
std::ptr::null(),
Expand All @@ -911,18 +911,19 @@ pub(crate) unsafe fn enzyme_rust_forward_diff(
args_uncacheable.as_ptr(),
args_uncacheable.len(), // uncacheable arguments
std::ptr::null_mut(), // write augmented function to this
)
);
(res, vec![])
}

pub(crate) unsafe fn enzyme_rust_reverse_diff(
logic_ref: EnzymeLogicRef,
type_analysis: EnzymeTypeAnalysisRef,
fnc: &Value,
input_activity: Vec<DiffActivity>,
rust_input_activity: Vec<DiffActivity>,
ret_activity: DiffActivity,
input_tts: Vec<TypeTree>,
output_tt: TypeTree,
) -> &Value {
) -> (&Value, Vec<usize>) {
let (primary_ret, ret_activity) = match ret_activity {
DiffActivity::Const => (true, CDIFFE_TYPE::DFT_CONSTANT),
DiffActivity::Active => (true, CDIFFE_TYPE::DFT_OUT_DIFF),
Expand All @@ -935,7 +936,16 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
// https://github.com/EnzymeAD/Enzyme.jl/blob/a511e4e6979d6161699f5c9919d49801c0764a09/src/compiler.jl#L3092
let diff_ret = false;

let input_activity: Vec<CDIFFE_TYPE> = input_activity.iter().map(|&x| cdiffe_from(x)).collect();
let mut primal_sizes = vec![];
let mut input_activity: Vec<CDIFFE_TYPE> = vec![];
for (i, &x) in rust_input_activity.iter().enumerate() {
if is_size(x) {
primal_sizes.push(i);
input_activity.push(CDIFFE_TYPE::DFT_CONSTANT);
continue;
}
input_activity.push(cdiffe_from(x));
}

let mut args_tree = input_tts.iter().map(|x| x.inner).collect::<Vec<_>>();

Expand Down Expand Up @@ -988,7 +998,7 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff(
std::ptr::null_mut(), // write augmented function to this
0,
);
res
(res, primal_sizes)
}

extern "C" {
Expand Down Expand Up @@ -2810,9 +2820,14 @@ pub mod Shared_AD {
DiffActivity::DualOnly => CDIFFE_TYPE::DFT_DUP_NONEED,
DiffActivity::Duplicated => CDIFFE_TYPE::DFT_DUP_ARG,
DiffActivity::DuplicatedOnly => CDIFFE_TYPE::DFT_DUP_NONEED,
DiffActivity::FakeActivitySize => panic!("Implementation error"),
};
}

pub fn is_size(act: DiffActivity) -> bool {
return act == DiffActivity::FakeActivitySize;
}

#[repr(u32)]
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
pub enum CDerivativeMode {
Expand Down
11 changes: 10 additions & 1 deletion compiler/rustc_middle/src/ty/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2762,7 +2762,16 @@ pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec<Diff
if !da.is_empty() {
// We are looking at a slice. The length of that slice will become an
// extra integer on llvm level. Integers are always const.
da.insert(i + 1 + offset, DiffActivity::Const);
// However, if the slice get's duplicated, we want to know to later check the
// size. So we mark the new size argument as FakeActivitySize.
let activity = match da[i] {
DiffActivity::DualOnly | DiffActivity::Dual |
DiffActivity::DuplicatedOnly | DiffActivity::Duplicated
=> DiffActivity::FakeActivitySize,
DiffActivity::Const => DiffActivity::Const,
_ => panic!("unexpected activity for ptr/ref"),
};
da.insert(i + 1 + offset, activity);
offset += 1;
}
trace!("ABI MATCHING!");
Expand Down
Loading