Skip to content

Commit 64718ab

Browse files
committed
working dupvonly for fwd mode
1 parent 42cc672 commit 64718ab

File tree

4 files changed

+33
-19
lines changed

4 files changed

+33
-19
lines changed

compiler/rustc_ast/src/expand/autodiff_attrs.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ pub enum DiffActivity {
5656
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
5757
/// with it. Drop the code which updates the original input/output for maximum performance.
5858
DualOnly,
59+
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
60+
/// with it. Drop the code which updates the original input/output for maximum performance.
61+
/// It expects the shadow argument to be `width` times larger than the original input/output.
62+
DualvOnly,
5963
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
6064
Duplicated,
6165
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
@@ -139,6 +143,7 @@ pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
139143
activity == DiffActivity::Dual
140144
|| activity == DiffActivity::Dualv
141145
|| activity == DiffActivity::DualOnly
146+
|| activity == DiffActivity::DualvOnly
142147
|| activity == DiffActivity::Const
143148
}
144149
DiffMode::Reverse => {
@@ -161,7 +166,7 @@ pub fn valid_ty_for_activity(ty: &P<Ty>, activity: DiffActivity) -> bool {
161166
if matches!(activity, Const) {
162167
return true;
163168
}
164-
if matches!(activity, Dual | DualOnly | Dualv) {
169+
if matches!(activity, Dual | DualOnly | Dualv | DualvOnly) {
165170
return true;
166171
}
167172
// FIXME(ZuseZ4) We should make this more robust to also
@@ -178,7 +183,7 @@ pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
178183
DiffMode::Error => false,
179184
DiffMode::Source => false,
180185
DiffMode::Forward => {
181-
matches!(activity, Dual | DualOnly | Dualv | Const)
186+
matches!(activity, Dual | DualOnly | Dualv | DualvOnly | Const)
182187
}
183188
DiffMode::Reverse => {
184189
matches!(activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const)
@@ -196,6 +201,7 @@ impl Display for DiffActivity {
196201
DiffActivity::Dual => write!(f, "Dual"),
197202
DiffActivity::Dualv => write!(f, "Dualv"),
198203
DiffActivity::DualOnly => write!(f, "DualOnly"),
204+
DiffActivity::DualvOnly => write!(f, "DualvOnly"),
199205
DiffActivity::Duplicated => write!(f, "Duplicated"),
200206
DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"),
201207
DiffActivity::FakeActivitySize => write!(f, "FakeActivitySize"),
@@ -228,6 +234,7 @@ impl FromStr for DiffActivity {
228234
"Dual" => Ok(DiffActivity::Dual),
229235
"Dualv" => Ok(DiffActivity::Dualv),
230236
"DualOnly" => Ok(DiffActivity::DualOnly),
237+
"DualvOnly" => Ok(DiffActivity::DualvOnly),
231238
"Duplicated" => Ok(DiffActivity::Duplicated),
232239
"DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly),
233240
_ => Err(()),

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -799,12 +799,18 @@ mod llvm_enzyme {
799799
d_inputs.push(shadow_arg.clone());
800800
}
801801
}
802-
DiffActivity::Dual | DiffActivity::DualOnly | DiffActivity::Dualv => {
803-
let iterations = if matches!(activity, DiffActivity::Dualv) {
804-
1
805-
} else {
806-
x.width
807-
};
802+
DiffActivity::Dual
803+
| DiffActivity::DualOnly
804+
| DiffActivity::Dualv
805+
| DiffActivity::DualvOnly => {
806+
// the *v variants get lowered to enzyme_dupv and enzyme_dupnoneedv, which cause
807+
// Enzyme to not expect N arguments, but one argument (which is instead larger).
808+
let iterations =
809+
if matches!(activity, DiffActivity::Dualv | DiffActivity::DualvOnly) {
810+
1
811+
} else {
812+
x.width
813+
};
808814
for i in 0..iterations {
809815
let mut shadow_arg = arg.clone();
810816
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
@@ -908,7 +914,7 @@ mod llvm_enzyme {
908914
let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
909915
d_decl.output = FnRetTy::Ty(ty);
910916
}
911-
if let DiffActivity::DualOnly = x.ret_activity {
917+
if matches!(x.ret_activity, DiffActivity::DualOnly | DiffActivity::DualvOnly) {
912918
// No need to change the return type,
913919
// we will just return the shadow in place of the primal return.
914920
// However, if we have a width > 1, then we don't return -> T, but -> [T; width]

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ fn has_sret(fnc: &Value) -> bool {
5151
// using iterators and peek()?
5252
fn match_args_from_caller_to_enzyme<'ll>(
5353
cx: &SimpleCx<'ll>,
54-
builder: &SBuilder<'ll,'ll>,
54+
builder: &SBuilder<'ll, 'll>,
5555
width: u32,
5656
args: &mut Vec<&'ll llvm::Value>,
5757
inputs: &[DiffActivity],
@@ -81,6 +81,7 @@ fn match_args_from_caller_to_enzyme<'ll>(
8181
let enzyme_dup = cx.create_metadata("enzyme_dup".to_string()).unwrap();
8282
let enzyme_dupv = cx.create_metadata("enzyme_dupv".to_string()).unwrap();
8383
let enzyme_dupnoneed = cx.create_metadata("enzyme_dupnoneed".to_string()).unwrap();
84+
let enzyme_dupnoneedv = cx.create_metadata("enzyme_dupnoneedv".to_string()).unwrap();
8485

8586
while activity_pos < inputs.len() {
8687
let diff_activity = inputs[activity_pos as usize];
@@ -94,6 +95,7 @@ fn match_args_from_caller_to_enzyme<'ll>(
9495
DiffActivity::Dual => (enzyme_dup, true),
9596
DiffActivity::Dualv => (enzyme_dupv, true),
9697
DiffActivity::DualOnly => (enzyme_dupnoneed, true),
98+
DiffActivity::DualvOnly => (enzyme_dupnoneedv, true),
9799
DiffActivity::Duplicated => (enzyme_dup, true),
98100
DiffActivity::DuplicatedOnly => (enzyme_dupnoneed, true),
99101
DiffActivity::FakeActivitySize => (enzyme_const, false),
@@ -106,10 +108,9 @@ fn match_args_from_caller_to_enzyme<'ll>(
106108
// T=f32 => 4 bytes
107109
// n_elems is the next integer.
108110
// Now we multiply `4 * next_outer_arg` to get the stride.
109-
//let mul = builder
110-
// .build_mul(cx.get_const_i64(4), next_outer_arg)
111-
// .unwrap();
112-
let mul = unsafe {llvm::LLVMBuildMul(builder.llbuilder, cx.get_const_i64(4), next_outer_arg, UNNAMED)};
111+
let mul = unsafe {
112+
llvm::LLVMBuildMul(builder.llbuilder, cx.get_const_i64(4), next_outer_arg, UNNAMED)
113+
};
113114
args.push(mul);
114115
}
115116
args.push(outer_arg);
@@ -140,11 +141,8 @@ fn match_args_from_caller_to_enzyme<'ll>(
140141
// int2 >= int1, which means the shadow vector is large enough to store the gradient.
141142
assert_eq!(cx.type_kind(next_outer_ty), TypeKind::Integer);
142143

143-
let iterations = if matches!(diff_activity, DiffActivity::Dualv) {
144-
1
145-
} else {
146-
width as usize
147-
};
144+
let iterations =
145+
if matches!(diff_activity, DiffActivity::Dualv) { 1 } else { width as usize };
148146

149147
for i in 0..iterations {
150148
let next_outer_arg2 = outer_args[outer_pos + 2 * (i + 1)];

compiler/rustc_monomorphize/src/partitioning/autodiff.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec
4040
new_activities.push(activity);
4141
new_positions.push(i + 1);
4242
}
43+
// Now we need to figure out the size of each slice element in memory.
44+
// Can we actually do that here?
45+
4346
continue;
4447
}
4548
}

0 commit comments

Comments
 (0)