Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions compiler/rustc_ast/src/expand/typetree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pub enum Kind {
Anything,
Integer,
Pointer,
RustSlice,
Half,
Float,
Double,
Expand Down Expand Up @@ -57,6 +58,9 @@ impl TypeTree {
}
Self(ints)
}
pub fn add_indirection(self) -> Self {
Self(vec![Type { offset: 0, size: 1, kind: Kind::Pointer, child: self }])
}
}

#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, StableHash)]
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_codegen_llvm/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1175,7 +1175,7 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
// vs. copying a struct with mixed types requires different derivative handling.
// The TypeTree tells Enzyme exactly what memory layout to expect.
if let Some(tt) = tt {
crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, memcpy, tt);
crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, self.cx().tcx, memcpy, tt);
}
}

Expand Down
3 changes: 2 additions & 1 deletion compiler/rustc_codegen_llvm/src/builder/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
// cover some assumptions of enzyme/autodiff, which could lead to UB otherwise.
pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
builder: &mut Builder<'_, 'll, 'tcx>,
tcx: TyCtxt<'tcx>,
cx: &SimpleCx<'ll>,
fn_to_diff: &'ll Value,
outer_name: &str,
Expand Down Expand Up @@ -379,7 +380,7 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
);

if !fnc_tree.args.is_empty() || !fnc_tree.ret.0.is_empty() {
crate::typetree::add_tt(cx.llmod, cx.llcx, fn_to_diff, fnc_tree);
crate::typetree::add_tt(cx.llmod, cx.llcx, tcx, fn_to_diff, fnc_tree);
}

let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None);
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_codegen_llvm/src/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1793,6 +1793,7 @@ fn codegen_autodiff<'ll, 'tcx>(
// Build body
generate_enzyme_call(
bx,
tcx,
bx.cx,
fn_to_diff,
&diff_symbol,
Expand Down
6 changes: 6 additions & 0 deletions compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ unsafe extern "C" {
NameLen: libc::size_t,
) -> Option<&Value>;

pub(crate) fn LLVMRustIsIntrinsicCall(V: &Value) -> bool;
}

unsafe extern "C" {
Expand Down Expand Up @@ -292,6 +293,11 @@ pub(crate) mod Enzyme_AD {
unsafe { (self.EnzymeTypeTreeToString)(tree) }
}

pub(crate) fn tree_to_cstr(&self, tree: *mut EnzymeTypeTree) -> &std::ffi::CStr {
let c_str = self.tree_to_string(tree);
unsafe { std::ffi::CStr::from_ptr(c_str) }
}

pub(crate) fn tree_to_string_free(&self, ch: *const c_char) {
unsafe { (self.EnzymeTypeTreeToStringFree)(ch) }
}
Expand Down
15 changes: 15 additions & 0 deletions compiler/rustc_codegen_llvm/src/llvm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,21 @@ pub(crate) fn CreateAttrStringValue<'ll>(
)
}
}
pub(crate) fn CreateAttrStringValueFromCStr<'ll>(
llcx: &'ll Context,
attr: &std::ffi::CStr,
value: &std::ffi::CStr,
) -> &'ll Attribute {
unsafe {
LLVMCreateStringAttribute(
llcx,
(*attr).as_ptr(),
(*attr).to_bytes().len() as c_uint,
(*value).as_ptr(),
(*value).to_bytes().len() as c_uint,
)
}
}

pub(crate) fn CreateAttrString<'ll>(llcx: &'ll Context, attr: &str) -> &'ll Attribute {
unsafe {
Expand Down
169 changes: 124 additions & 45 deletions compiler/rustc_codegen_llvm/src/typetree.rs
Original file line number Diff line number Diff line change
@@ -1,35 +1,48 @@
use std::ffi::{CString, c_char, c_uint};
use std::ffi::{CString, c_char};

use rustc_ast::expand::typetree::{FncTree, TypeTree as RustTypeTree};
use rustc_ast::expand::typetree::{FncTree, Kind, TypeTree as RustTypeTree};
use rustc_middle::bug;
use rustc_middle::ty::TyCtxt;

use crate::attributes;
use crate::llvm::{self, EnzymeWrapper, Value};

fn to_enzyme_typetree(
rust_typetree: RustTypeTree,
rust_typetree: &RustTypeTree,
_data_layout: &str,
llcx: &llvm::Context,
) -> llvm::TypeTree {
) -> (llvm::TypeTree, Vec<llvm::TypeTree>) {
let mut enzyme_tt = llvm::TypeTree::new();
process_typetree_recursive(&mut enzyme_tt, &rust_typetree, &[], llcx);
enzyme_tt
let extra_ints = process_typetree_recursive(&mut enzyme_tt, &rust_typetree, &[], llcx);

let mut int_vec = vec![];
for _ in 0..extra_ints {
let mut int_tt = llvm::TypeTree::new();
int_tt.insert(&[0], llvm::CConcreteType::DT_Integer, llcx);
int_vec.push(int_tt);
}

(enzyme_tt, int_vec)
}

fn process_typetree_recursive(
enzyme_tt: &mut llvm::TypeTree,
rust_typetree: &RustTypeTree,
parent_indices: &[i64],
llcx: &llvm::Context,
) {
) -> u32 {
let mut extra_ints = 0;
for rust_type in &rust_typetree.0 {
let concrete_type = match rust_type.kind {
rustc_ast::expand::typetree::Kind::Anything => llvm::CConcreteType::DT_Anything,
rustc_ast::expand::typetree::Kind::Integer => llvm::CConcreteType::DT_Integer,
rustc_ast::expand::typetree::Kind::Pointer => llvm::CConcreteType::DT_Pointer,
rustc_ast::expand::typetree::Kind::Half => llvm::CConcreteType::DT_Half,
rustc_ast::expand::typetree::Kind::Float => llvm::CConcreteType::DT_Float,
rustc_ast::expand::typetree::Kind::Double => llvm::CConcreteType::DT_Double,
rustc_ast::expand::typetree::Kind::F128 => llvm::CConcreteType::DT_FP128,
rustc_ast::expand::typetree::Kind::Unknown => llvm::CConcreteType::DT_Unknown,
Kind::Anything => llvm::CConcreteType::DT_Anything,
Kind::Integer => llvm::CConcreteType::DT_Integer,
Kind::Pointer => llvm::CConcreteType::DT_Pointer,
Kind::RustSlice => llvm::CConcreteType::DT_Pointer,
Kind::Half => llvm::CConcreteType::DT_Half,
Kind::Float => llvm::CConcreteType::DT_Float,
Kind::Double => llvm::CConcreteType::DT_Double,
Kind::F128 => llvm::CConcreteType::DT_FP128,
Kind::Unknown => llvm::CConcreteType::DT_Unknown,
};

let mut indices = parent_indices.to_vec();
Expand All @@ -43,18 +56,25 @@ fn process_typetree_recursive(

enzyme_tt.insert(&indices, concrete_type, llcx);

if rust_type.kind == rustc_ast::expand::typetree::Kind::Pointer
if matches!(rust_type.kind, Kind::RustSlice) {
// We lower slices to `ptr,int`, so add the int here.
extra_ints += 1;
}

if matches!(rust_type.kind, Kind::Pointer | Kind::RustSlice)
&& !rust_type.child.0.is_empty()
{
process_typetree_recursive(enzyme_tt, &rust_type.child, &indices, llcx);
}
}
extra_ints
}

#[cfg_attr(not(feature = "llvm_enzyme"), allow(unused))]
pub(crate) fn add_tt<'ll>(
pub(crate) fn add_tt<'tcx, 'll>(
llmod: &'ll llvm::Module,
llcx: &'ll llvm::Context,
tcx: TyCtxt<'tcx>,
fn_def: &'ll Value,
tt: FncTree,
) {
Expand All @@ -66,6 +86,13 @@ pub(crate) fn add_tt<'ll>(
#[cfg(not(feature = "llvm_enzyme"))]
return;

if !tcx.sess.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::Enable) {
return;
}
if tcx.sess.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::NoTT) {
return;
}

let inputs = tt.args;
let ret_tt: RustTypeTree = tt.ret;

Expand All @@ -77,41 +104,93 @@ pub(crate) fn add_tt<'ll>(
let attr_name = "enzyme_type";
let c_attr_name = CString::new(attr_name).unwrap();

let mut offset = 0;
for (i, input) in inputs.iter().enumerate() {
unsafe {
let enzyme_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx);
let enzyme_wrapper = EnzymeWrapper::get_instance();
let c_str = enzyme_wrapper.tree_to_string(enzyme_tt.inner);
let c_str = std::ffi::CStr::from_ptr(c_str);

let attr = llvm::LLVMCreateStringAttribute(
llcx,
c_attr_name.as_ptr(),
c_attr_name.as_bytes().len() as c_uint,
c_str.as_ptr(),
c_str.to_bytes().len() as c_uint,
);
let (enzyme_tt, extra_ints) = to_enzyme_typetree(&input, llvm_data_layout, llcx);

attributes::apply_to_llfn(fn_def, llvm::AttributePlace::Argument(i as u32), &[attr]);
// This scope is just a visual reminder that we *must* drop the enzyme_wrapper before
// we drop any typetrees (mainly enzyme_tt and extra_ints). Drop calls can not accept
// arguments like an enzyme_wrapper, so the typetree drop impl has to call get_instance
// on the static enzyme instance, which is behind a Mutex. Therefore we'd deadlock if we
// hold the enzyme_wrapper while dropping the typetrees.
{
let enzyme_wrapper = EnzymeWrapper::get_instance();
let c_str = enzyme_wrapper.tree_to_cstr(enzyme_tt.inner);

let attr = llvm::CreateAttrStringValueFromCStr(llcx, &c_attr_name, &c_str);
dbg!(&fn_def);
if unsafe { llvm::LLVMRustIsIntrinsicCall(fn_def) } {
dbg!("callsite");
attributes::apply_to_callsite(
fn_def,
llvm::AttributePlace::Argument(i as u32 + offset),
&[attr],
);
} else {
dbg!("llfn");
attributes::apply_to_llfn(
fn_def,
llvm::AttributePlace::Argument(i as u32 + offset),
&[attr],
);
}
enzyme_wrapper.tree_to_string_free(c_str.as_ptr());
for v in &extra_ints {
offset += 1;
let c_str = enzyme_wrapper.tree_to_cstr(v.inner);
let int_attr = llvm::CreateAttrStringValueFromCStr(llcx, &c_attr_name, &c_str);
dbg!(&fn_def);
if unsafe { llvm::LLVMRustIsIntrinsicCall(fn_def) } {
dbg!("callsite");
attributes::apply_to_callsite(
fn_def,
llvm::AttributePlace::Argument(i as u32 + offset),
&[int_attr],
);
} else {
dbg!("llfn");
attributes::apply_to_llfn(
fn_def,
llvm::AttributePlace::Argument(i as u32 + offset),
&[int_attr],
);
}
enzyme_wrapper.tree_to_string_free(c_str.as_ptr());
}
}
}
// We will only fail this if Rust types got lowered to LLVM in a way that we didn't predict.
// Error, so we can learn from our mistakes.
if unsafe { !llvm::LLVMRustIsIntrinsicCall(fn_def) } {
let expected = offset as usize + inputs.len();
let actual = llvm::count_params(fn_def) as usize;
if expected != actual {
tcx.dcx().warn(format!(
"autodiff type-tree failure. We expected {expected} LLVM argument(s), \
but the generated LLVM function has {actual} parameter(s)"
));
}
}

unsafe {
let enzyme_tt = to_enzyme_typetree(ret_tt, llvm_data_layout, llcx);
let (enzyme_tt, extra_ints) = to_enzyme_typetree(&ret_tt, llvm_data_layout, llcx);
if ret_tt != RustTypeTree::new() {
let enzyme_wrapper = EnzymeWrapper::get_instance();
let c_str = enzyme_wrapper.tree_to_string(enzyme_tt.inner);
let c_str = std::ffi::CStr::from_ptr(c_str);

let ret_attr = llvm::LLVMCreateStringAttribute(
llcx,
c_attr_name.as_ptr(),
c_attr_name.as_bytes().len() as c_uint,
c_str.as_ptr(),
c_str.to_bytes().len() as c_uint,
);

attributes::apply_to_llfn(fn_def, llvm::AttributePlace::ReturnValue, &[ret_attr]);
if !extra_ints.is_empty() {
bug!("A return type should not have extra integers. Implementation bug!");
}
let c_str = enzyme_wrapper.tree_to_cstr(enzyme_tt.inner);

let ret_attr = llvm::CreateAttrStringValueFromCStr(llcx, &c_attr_name, &c_str);

if unsafe { llvm::LLVMRustIsIntrinsicCall(fn_def) } {
attributes::apply_to_callsite(
fn_def,
llvm::AttributePlace::ReturnValue,
&[ret_attr],
);
} else {
attributes::apply_to_llfn(fn_def, llvm::AttributePlace::ReturnValue, &[ret_attr]);
}
enzyme_wrapper.tree_to_string_free(c_str.as_ptr());
}
}
21 changes: 19 additions & 2 deletions compiler/rustc_codegen_ssa/src/traits/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ use std::assert_matches;
use std::ops::Deref;

use rustc_abi::{Align, Scalar, Size, WrappingRange};
use rustc_ast::expand::typetree::{TypeTree, FncTree};
use rustc_hir::attrs::AttributeKind;
use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrs;
use rustc_middle::mir;
use rustc_middle::ty::layout::{FnAbiOf, LayoutOf, TyAndLayout};
use rustc_middle::ty::typetree::typetree_from_ty;
use rustc_middle::ty::{AtomicOrdering, Instance, Ty};
use rustc_session::config::OptLevel;
use rustc_span::Span;
Expand Down Expand Up @@ -456,7 +458,7 @@ pub trait BuilderMethods<'a, 'tcx>:
src_align: Align,
size: Self::Value,
flags: MemFlags,
tt: Option<rustc_ast::expand::typetree::FncTree>,
tt: Option<FncTree>,
);
fn memmove(
&mut self,
Expand All @@ -466,6 +468,7 @@ pub trait BuilderMethods<'a, 'tcx>:
src_align: Align,
size: Self::Value,
flags: MemFlags,
//tt: Option<FncTree>,
);
fn memset(
&mut self,
Expand All @@ -474,6 +477,7 @@ pub trait BuilderMethods<'a, 'tcx>:
size: Self::Value,
align: Align,
flags: MemFlags,
//tt: Option<FncTree>,
);

// Produce a value from calling the `vscale` intrinsic (containing the `vscale` multiplier that
Expand Down Expand Up @@ -517,14 +521,27 @@ pub trait BuilderMethods<'a, 'tcx>:
let temp = self.load_operand(src.with_type(layout));
temp.val.store_with_flags(self, dst.with_type(layout), flags);
} else if !layout.is_zst() {
let tt = typetree_from_ty(self.tcx(), layout.ty);
// We seem to pass all values to memcpy with one more indirection.
let tt = tt.add_indirection();
dbg!(&tt);
use rustc_middle::ty::print::with_no_trimmed_paths;

with_no_trimmed_paths!({
eprintln!("memcpy ty = {:?}", layout.ty);
});
let fnc_tree = FncTree {
args: vec![tt.clone(), tt],
ret: TypeTree::new(),
};
let bytes = self.const_usize(layout.size.bytes());
let bytes = if layout.peel_transparent_wrappers(self).ty.is_scalable_vector() {
let vscale = self.vscale(self.type_i64());
self.mul(vscale, bytes)
} else {
bytes
};
self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags, None);
self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags, Some(fnc_tree));
}
}

Expand Down
Loading
Loading