From 04368bf71442446e659163f39d0f0284c2d82be5 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Wed, 24 Jun 2026 01:21:05 +0200 Subject: [PATCH 1/2] Fix typetree generation for arguments, e.g. slices --- compiler/rustc_ast/src/expand/typetree.rs | 1 + compiler/rustc_codegen_llvm/src/builder.rs | 2 +- .../src/builder/autodiff.rs | 3 +- compiler/rustc_codegen_llvm/src/intrinsic.rs | 1 + .../rustc_codegen_llvm/src/llvm/enzyme_ffi.rs | 5 + compiler/rustc_codegen_llvm/src/llvm/mod.rs | 15 +++ compiler/rustc_codegen_llvm/src/typetree.rs | 122 ++++++++++++------ compiler/rustc_middle/src/ty/typetree.rs | 72 ++++++----- 8 files changed, 149 insertions(+), 72 deletions(-) diff --git a/compiler/rustc_ast/src/expand/typetree.rs b/compiler/rustc_ast/src/expand/typetree.rs index 9619c80904426..1d099f475b5f6 100644 --- a/compiler/rustc_ast/src/expand/typetree.rs +++ b/compiler/rustc_ast/src/expand/typetree.rs @@ -28,6 +28,7 @@ pub enum Kind { Anything, Integer, Pointer, + RustSlice, Half, Float, Double, diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs index afb6985d21a95..6f5443a7c3c7a 100644 --- a/compiler/rustc_codegen_llvm/src/builder.rs +++ b/compiler/rustc_codegen_llvm/src/builder.rs @@ -1175,7 +1175,7 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { // vs. copying a struct with mixed types requires different derivative handling. // The TypeTree tells Enzyme exactly what memory layout to expect. if let Some(tt) = tt { - crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, memcpy, tt); + crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, self.cx().tcx, memcpy, tt); } } diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index ee17468ec0c03..8bcf54a28da41 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -293,6 +293,7 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>( // cover some assumptions of enzyme/autodiff, which could lead to UB otherwise. pub(crate) fn generate_enzyme_call<'ll, 'tcx>( builder: &mut Builder<'_, 'll, 'tcx>, + tcx: TyCtxt<'tcx>, cx: &SimpleCx<'ll>, fn_to_diff: &'ll Value, outer_name: &str, @@ -379,7 +380,7 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>( ); if !fnc_tree.args.is_empty() || !fnc_tree.ret.0.is_empty() { - crate::typetree::add_tt(cx.llmod, cx.llcx, fn_to_diff, fnc_tree); + crate::typetree::add_tt(cx.llmod, cx.llcx, tcx, fn_to_diff, fnc_tree); } let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None); diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index 1caa95f369360..8ffa6dcd02c7d 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -1793,6 +1793,7 @@ fn codegen_autodiff<'ll, 'tcx>( // Build body generate_enzyme_call( bx, + tcx, bx.cx, fn_to_diff, &diff_symbol, diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs index 195e050a9b651..c95e2e97a5c12 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs @@ -292,6 +292,11 @@ pub(crate) mod Enzyme_AD { unsafe { (self.EnzymeTypeTreeToString)(tree) } } + pub(crate) fn tree_to_cstr(&self, tree: *mut EnzymeTypeTree) -> &std::ffi::CStr { + let c_str = self.tree_to_string(tree); + unsafe { std::ffi::CStr::from_ptr(c_str) } + } + pub(crate) fn tree_to_string_free(&self, ch: *const c_char) { unsafe { (self.EnzymeTypeTreeToStringFree)(ch) } } diff --git a/compiler/rustc_codegen_llvm/src/llvm/mod.rs b/compiler/rustc_codegen_llvm/src/llvm/mod.rs index 2ec19b1795b5a..5c5127c691990 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/mod.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/mod.rs @@ -76,6 +76,21 @@ pub(crate) fn CreateAttrStringValue<'ll>( ) } } +pub(crate) fn CreateAttrStringValueFromCStr<'ll>( + llcx: &'ll Context, + attr: &std::ffi::CStr, + value: &std::ffi::CStr, +) -> &'ll Attribute { + unsafe { + LLVMCreateStringAttribute( + llcx, + (*attr).as_ptr(), + (*attr).to_bytes().len() as c_uint, + (*value).as_ptr(), + (*value).to_bytes().len() as c_uint, + ) + } +} pub(crate) fn CreateAttrString<'ll>(llcx: &'ll Context, attr: &str) -> &'ll Attribute { unsafe { diff --git a/compiler/rustc_codegen_llvm/src/typetree.rs b/compiler/rustc_codegen_llvm/src/typetree.rs index 4f433f273c8cc..1e3c38d3bf79a 100644 --- a/compiler/rustc_codegen_llvm/src/typetree.rs +++ b/compiler/rustc_codegen_llvm/src/typetree.rs @@ -1,6 +1,8 @@ -use std::ffi::{CString, c_char, c_uint}; +use std::ffi::{CString, c_char}; -use rustc_ast::expand::typetree::{FncTree, TypeTree as RustTypeTree}; +use rustc_ast::expand::typetree::{FncTree, Kind, TypeTree as RustTypeTree}; +use rustc_middle::bug; +use rustc_middle::ty::TyCtxt; use crate::attributes; use crate::llvm::{self, EnzymeWrapper, Value}; @@ -9,27 +11,38 @@ fn to_enzyme_typetree( rust_typetree: RustTypeTree, _data_layout: &str, llcx: &llvm::Context, -) -> llvm::TypeTree { +) -> (llvm::TypeTree, Vec) { let mut enzyme_tt = llvm::TypeTree::new(); - process_typetree_recursive(&mut enzyme_tt, &rust_typetree, &[], llcx); - enzyme_tt + let extra_ints = process_typetree_recursive(&mut enzyme_tt, &rust_typetree, &[], llcx); + + let mut int_vec = vec![]; + for _ in 0..extra_ints { + let mut int_tt = llvm::TypeTree::new(); + int_tt.insert(&[0], llvm::CConcreteType::DT_Integer, llcx); + int_vec.push(int_tt); + } + + (enzyme_tt, int_vec) } + fn process_typetree_recursive( enzyme_tt: &mut llvm::TypeTree, rust_typetree: &RustTypeTree, parent_indices: &[i64], llcx: &llvm::Context, -) { +) -> u32 { + let mut extra_ints = 0; for rust_type in &rust_typetree.0 { let concrete_type = match rust_type.kind { - rustc_ast::expand::typetree::Kind::Anything => llvm::CConcreteType::DT_Anything, - rustc_ast::expand::typetree::Kind::Integer => llvm::CConcreteType::DT_Integer, - rustc_ast::expand::typetree::Kind::Pointer => llvm::CConcreteType::DT_Pointer, - rustc_ast::expand::typetree::Kind::Half => llvm::CConcreteType::DT_Half, - rustc_ast::expand::typetree::Kind::Float => llvm::CConcreteType::DT_Float, - rustc_ast::expand::typetree::Kind::Double => llvm::CConcreteType::DT_Double, - rustc_ast::expand::typetree::Kind::F128 => llvm::CConcreteType::DT_FP128, - rustc_ast::expand::typetree::Kind::Unknown => llvm::CConcreteType::DT_Unknown, + Kind::Anything => llvm::CConcreteType::DT_Anything, + Kind::Integer => llvm::CConcreteType::DT_Integer, + Kind::Pointer => llvm::CConcreteType::DT_Pointer, + Kind::RustSlice => llvm::CConcreteType::DT_Pointer, + Kind::Half => llvm::CConcreteType::DT_Half, + Kind::Float => llvm::CConcreteType::DT_Float, + Kind::Double => llvm::CConcreteType::DT_Double, + Kind::F128 => llvm::CConcreteType::DT_FP128, + Kind::Unknown => llvm::CConcreteType::DT_Unknown, }; let mut indices = parent_indices.to_vec(); @@ -43,18 +56,25 @@ fn process_typetree_recursive( enzyme_tt.insert(&indices, concrete_type, llcx); - if rust_type.kind == rustc_ast::expand::typetree::Kind::Pointer + if matches!(rust_type.kind, Kind::RustSlice) { + // We lower slices to `ptr,int`, so add the int here. + extra_ints += 1; + } + + if matches!(rust_type.kind, Kind::Pointer | Kind::RustSlice) && !rust_type.child.0.is_empty() { process_typetree_recursive(enzyme_tt, &rust_type.child, &indices, llcx); } } + extra_ints } #[cfg_attr(not(feature = "llvm_enzyme"), allow(unused))] -pub(crate) fn add_tt<'ll>( +pub(crate) fn add_tt<'tcx, 'll>( llmod: &'ll llvm::Module, llcx: &'ll llvm::Context, + tcx: TyCtxt<'tcx>, fn_def: &'ll Value, tt: FncTree, ) { @@ -77,39 +97,59 @@ pub(crate) fn add_tt<'ll>( let attr_name = "enzyme_type"; let c_attr_name = CString::new(attr_name).unwrap(); + let mut offset = 0; for (i, input) in inputs.iter().enumerate() { - unsafe { - let enzyme_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx); + let (enzyme_tt, extra_ints) = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx); + + // This scope is just a visual reminder that we *must* drop the enzyme_wrapper before + // we drop any typetrees (mainly enzyme_tt and extra_ints). Drop calls can not accept + // arguments like an enzyme_wrapper, so the typetree drop impl has to call get_instance + // on the static enzyme instance, which is behind a Mutex. Therefore we'd deadlock if we + // hold the enzyme_wrapper while dropping the typetrees. + { let enzyme_wrapper = EnzymeWrapper::get_instance(); - let c_str = enzyme_wrapper.tree_to_string(enzyme_tt.inner); - let c_str = std::ffi::CStr::from_ptr(c_str); - - let attr = llvm::LLVMCreateStringAttribute( - llcx, - c_attr_name.as_ptr(), - c_attr_name.as_bytes().len() as c_uint, - c_str.as_ptr(), - c_str.to_bytes().len() as c_uint, - ); + let c_str = enzyme_wrapper.tree_to_cstr(enzyme_tt.inner); - attributes::apply_to_llfn(fn_def, llvm::AttributePlace::Argument(i as u32), &[attr]); + let attr = llvm::CreateAttrStringValueFromCStr(llcx, &c_attr_name, &c_str); + attributes::apply_to_llfn( + fn_def, + llvm::AttributePlace::Argument(i as u32 + offset), + &[attr], + ); enzyme_wrapper.tree_to_string_free(c_str.as_ptr()); + for v in &extra_ints { + offset += 1; + let c_str = enzyme_wrapper.tree_to_cstr(v.inner); + let int_attr = llvm::CreateAttrStringValueFromCStr(llcx, &c_attr_name, &c_str); + attributes::apply_to_llfn( + fn_def, + llvm::AttributePlace::Argument(i as u32 + offset), + &[int_attr], + ); + enzyme_wrapper.tree_to_string_free(c_str.as_ptr()); + } } } + // We will only fail this if Rust types got lowered to LLVM in a way that we didn't predict. + // Error, so we can learn from our mistakes. + let expected = offset as usize + inputs.len(); + let actual = llvm::count_params(fn_def) as usize; + if expected != actual { + tcx.dcx().warn(format!( + "autodiff type-tree failure. We expected {expected} LLVM argument(s), \ + but the generated LLVM function has {actual} parameter(s)" + )); + } - unsafe { - let enzyme_tt = to_enzyme_typetree(ret_tt, llvm_data_layout, llcx); + let (enzyme_tt, extra_ints) = to_enzyme_typetree(ret_tt, llvm_data_layout, llcx); + { let enzyme_wrapper = EnzymeWrapper::get_instance(); - let c_str = enzyme_wrapper.tree_to_string(enzyme_tt.inner); - let c_str = std::ffi::CStr::from_ptr(c_str); - - let ret_attr = llvm::LLVMCreateStringAttribute( - llcx, - c_attr_name.as_ptr(), - c_attr_name.as_bytes().len() as c_uint, - c_str.as_ptr(), - c_str.to_bytes().len() as c_uint, - ); + if !extra_ints.is_empty() { + bug!("A return type should not have extra integers. Implementation bug!"); + } + let c_str = enzyme_wrapper.tree_to_cstr(enzyme_tt.inner); + + let ret_attr = llvm::CreateAttrStringValueFromCStr(llcx, &c_attr_name, &c_str); attributes::apply_to_llfn(fn_def, llvm::AttributePlace::ReturnValue, &[ret_attr]); enzyme_wrapper.tree_to_string_free(c_str.as_ptr()); diff --git a/compiler/rustc_middle/src/ty/typetree.rs b/compiler/rustc_middle/src/ty/typetree.rs index 9e941bdb849ec..f541301afccca 100644 --- a/compiler/rustc_middle/src/ty/typetree.rs +++ b/compiler/rustc_middle/src/ty/typetree.rs @@ -32,7 +32,8 @@ pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>) -> FncTree { // Create TypeTree for return type let ret = typetree_from_ty(tcx, sig.output()); - FncTree { args, ret } + let f = FncTree { args, ret }; + f } /// Generate a TypeTree for a specific type. @@ -64,31 +65,29 @@ fn typetree_from_ty_impl_inner<'tcx>( } visited.push(ty); - if ty.is_scalar() { - let (kind, size) = if ty.is_integral() || ty.is_char() || ty.is_bool() { - (Kind::Integer, ty.primitive_size(tcx).bytes_usize()) - } else if ty.is_floating_point() { - match ty { - x if x == tcx.types.f16 => (Kind::Half, 2), - x if x == tcx.types.f32 => (Kind::Float, 4), - x if x == tcx.types.f64 => (Kind::Double, 8), - x if x == tcx.types.f128 => (Kind::F128, 16), - _ => (Kind::Integer, 0), - } - } else { - (Kind::Integer, 0) - }; - - // Use offset 0 for scalars that are direct targets of references (like &f64) - // Use offset -1 for scalars used directly (like function return types) - let offset = if is_reference_target && !ty.is_array() { 0 } else { -1 }; - return TypeTree(vec![Type { offset, size, kind, child: TypeTree::new() }]); + if ty.is_slice() { + bug!("incorrect autodiff typetree handling for slice: {}", ty); } if ty.is_ref() || ty.is_raw_ptr() || ty.is_box() { let Some(inner_ty) = ty.builtin_deref(true) else { - return TypeTree::new(); + bug!("incorrect autodiff typetree handling for type: {}", ty); }; + // slices are represented as `&'{erased} mut [f32]` + // This reads as a reference to a slice of f32. + // So we'd end up with ptr->RustSlice->f32 without this extra handling + if inner_ty.is_slice() { + if let ty::Slice(element_ty) = inner_ty.kind() { + let element_tree = + typetree_from_ty_impl_inner(tcx, *element_ty, depth + 1, visited, false); + return TypeTree(vec![Type { + offset: -1, + size: tcx.data_layout.pointer_size().bytes_usize(), + kind: Kind::RustSlice, + child: element_tree, + }]); + } + } let child = typetree_from_ty_impl_inner(tcx, inner_ty, depth + 1, visited, true); return TypeTree(vec![Type { @@ -121,14 +120,6 @@ fn typetree_from_ty_impl_inner<'tcx>( } } - if ty.is_slice() { - if let ty::Slice(element_ty) = ty.kind() { - let element_tree = - typetree_from_ty_impl_inner(tcx, *element_ty, depth + 1, visited, false); - return element_tree; - } - } - if let ty::Tuple(tuple_types) = ty.kind() { if tuple_types.is_empty() { return TypeTree::new(); @@ -204,5 +195,28 @@ fn typetree_from_ty_impl_inner<'tcx>( } } + if ty.is_scalar() { + let (kind, size) = if ty.is_integral() || ty.is_char() || ty.is_bool() { + (Kind::Integer, ty.primitive_size(tcx).bytes_usize()) + } else if ty.is_floating_point() { + match ty { + x if x == tcx.types.f16 => (Kind::Half, 2), + x if x == tcx.types.f32 => (Kind::Float, 4), + x if x == tcx.types.f64 => (Kind::Double, 8), + x if x == tcx.types.f128 => (Kind::F128, 16), + _ => bug!("Unexpected floating point type: {:?}", ty), + } + } else { + // is_scalar also accepts things like FnDef or FnPtr, for which we don't know how to + // generate a TypeTree, so return nothing. + return TypeTree::new(); + }; + + // Use offset 0 for scalars that are direct targets of references (like &f64) + // Use offset -1 for scalars used directly (like function return types) or slices. + let offset = if is_reference_target && !ty.is_array() { 0 } else { -1 }; + return TypeTree(vec![Type { offset, size, kind, child: TypeTree::new() }]); + } + TypeTree::new() } From 64d9d50ed8ffdf94915b97bcc06111b6fe15ee60 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Thu, 25 Jun 2026 02:07:00 +0200 Subject: [PATCH 2/2] Fix typetree generation for memcpy to move TA failure in testcase from memcpy to a later extractvalue --- compiler/rustc_ast/src/expand/typetree.rs | 3 + .../rustc_codegen_llvm/src/llvm/enzyme_ffi.rs | 1 + compiler/rustc_codegen_llvm/src/typetree.rs | 83 ++++++++++++++----- .../rustc_codegen_ssa/src/traits/builder.rs | 21 ++++- .../rustc_llvm/llvm-wrapper/RustWrapper.cpp | 17 ++++ compiler/rustc_middle/src/ty/typetree.rs | 6 ++ 6 files changed, 107 insertions(+), 24 deletions(-) diff --git a/compiler/rustc_ast/src/expand/typetree.rs b/compiler/rustc_ast/src/expand/typetree.rs index 1d099f475b5f6..84576c67729a3 100644 --- a/compiler/rustc_ast/src/expand/typetree.rs +++ b/compiler/rustc_ast/src/expand/typetree.rs @@ -58,6 +58,9 @@ impl TypeTree { } Self(ints) } + pub fn add_indirection(self) -> Self { + Self(vec![Type { offset: 0, size: 1, kind: Kind::Pointer, child: self }]) + } } #[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, StableHash)] diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs index c95e2e97a5c12..d125760a5b9aa 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs @@ -66,6 +66,7 @@ unsafe extern "C" { NameLen: libc::size_t, ) -> Option<&Value>; + pub(crate) fn LLVMRustIsIntrinsicCall(V: &Value) -> bool; } unsafe extern "C" { diff --git a/compiler/rustc_codegen_llvm/src/typetree.rs b/compiler/rustc_codegen_llvm/src/typetree.rs index 1e3c38d3bf79a..ca35abe556b3b 100644 --- a/compiler/rustc_codegen_llvm/src/typetree.rs +++ b/compiler/rustc_codegen_llvm/src/typetree.rs @@ -8,7 +8,7 @@ use crate::attributes; use crate::llvm::{self, EnzymeWrapper, Value}; fn to_enzyme_typetree( - rust_typetree: RustTypeTree, + rust_typetree: &RustTypeTree, _data_layout: &str, llcx: &llvm::Context, ) -> (llvm::TypeTree, Vec) { @@ -86,6 +86,13 @@ pub(crate) fn add_tt<'tcx, 'll>( #[cfg(not(feature = "llvm_enzyme"))] return; + if !tcx.sess.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::Enable) { + return; + } + if tcx.sess.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::NoTT) { + return; + } + let inputs = tt.args; let ret_tt: RustTypeTree = tt.ret; @@ -99,7 +106,7 @@ pub(crate) fn add_tt<'tcx, 'll>( let mut offset = 0; for (i, input) in inputs.iter().enumerate() { - let (enzyme_tt, extra_ints) = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx); + let (enzyme_tt, extra_ints) = to_enzyme_typetree(&input, llvm_data_layout, llcx); // This scope is just a visual reminder that we *must* drop the enzyme_wrapper before // we drop any typetrees (mainly enzyme_tt and extra_ints). Drop calls can not accept @@ -111,38 +118,62 @@ pub(crate) fn add_tt<'tcx, 'll>( let c_str = enzyme_wrapper.tree_to_cstr(enzyme_tt.inner); let attr = llvm::CreateAttrStringValueFromCStr(llcx, &c_attr_name, &c_str); - attributes::apply_to_llfn( - fn_def, - llvm::AttributePlace::Argument(i as u32 + offset), - &[attr], - ); + dbg!(&fn_def); + if unsafe { llvm::LLVMRustIsIntrinsicCall(fn_def) } { + dbg!("callsite"); + attributes::apply_to_callsite( + fn_def, + llvm::AttributePlace::Argument(i as u32 + offset), + &[attr], + ); + } else { + dbg!("llfn"); + attributes::apply_to_llfn( + fn_def, + llvm::AttributePlace::Argument(i as u32 + offset), + &[attr], + ); + } enzyme_wrapper.tree_to_string_free(c_str.as_ptr()); for v in &extra_ints { offset += 1; let c_str = enzyme_wrapper.tree_to_cstr(v.inner); let int_attr = llvm::CreateAttrStringValueFromCStr(llcx, &c_attr_name, &c_str); - attributes::apply_to_llfn( - fn_def, - llvm::AttributePlace::Argument(i as u32 + offset), - &[int_attr], - ); + dbg!(&fn_def); + if unsafe { llvm::LLVMRustIsIntrinsicCall(fn_def) } { + dbg!("callsite"); + attributes::apply_to_callsite( + fn_def, + llvm::AttributePlace::Argument(i as u32 + offset), + &[int_attr], + ); + } else { + dbg!("llfn"); + attributes::apply_to_llfn( + fn_def, + llvm::AttributePlace::Argument(i as u32 + offset), + &[int_attr], + ); + } enzyme_wrapper.tree_to_string_free(c_str.as_ptr()); } } } // We will only fail this if Rust types got lowered to LLVM in a way that we didn't predict. // Error, so we can learn from our mistakes. - let expected = offset as usize + inputs.len(); - let actual = llvm::count_params(fn_def) as usize; - if expected != actual { - tcx.dcx().warn(format!( - "autodiff type-tree failure. We expected {expected} LLVM argument(s), \ - but the generated LLVM function has {actual} parameter(s)" - )); + if unsafe { !llvm::LLVMRustIsIntrinsicCall(fn_def) } { + let expected = offset as usize + inputs.len(); + let actual = llvm::count_params(fn_def) as usize; + if expected != actual { + tcx.dcx().warn(format!( + "autodiff type-tree failure. We expected {expected} LLVM argument(s), \ + but the generated LLVM function has {actual} parameter(s)" + )); + } } - let (enzyme_tt, extra_ints) = to_enzyme_typetree(ret_tt, llvm_data_layout, llcx); - { + let (enzyme_tt, extra_ints) = to_enzyme_typetree(&ret_tt, llvm_data_layout, llcx); + if ret_tt != RustTypeTree::new() { let enzyme_wrapper = EnzymeWrapper::get_instance(); if !extra_ints.is_empty() { bug!("A return type should not have extra integers. Implementation bug!"); @@ -151,7 +182,15 @@ pub(crate) fn add_tt<'tcx, 'll>( let ret_attr = llvm::CreateAttrStringValueFromCStr(llcx, &c_attr_name, &c_str); - attributes::apply_to_llfn(fn_def, llvm::AttributePlace::ReturnValue, &[ret_attr]); + if unsafe { llvm::LLVMRustIsIntrinsicCall(fn_def) } { + attributes::apply_to_callsite( + fn_def, + llvm::AttributePlace::ReturnValue, + &[ret_attr], + ); + } else { + attributes::apply_to_llfn(fn_def, llvm::AttributePlace::ReturnValue, &[ret_attr]); + } enzyme_wrapper.tree_to_string_free(c_str.as_ptr()); } } diff --git a/compiler/rustc_codegen_ssa/src/traits/builder.rs b/compiler/rustc_codegen_ssa/src/traits/builder.rs index d68549c6871f4..eb83d39ec9ed8 100644 --- a/compiler/rustc_codegen_ssa/src/traits/builder.rs +++ b/compiler/rustc_codegen_ssa/src/traits/builder.rs @@ -2,10 +2,12 @@ use std::assert_matches; use std::ops::Deref; use rustc_abi::{Align, Scalar, Size, WrappingRange}; +use rustc_ast::expand::typetree::{TypeTree, FncTree}; use rustc_hir::attrs::AttributeKind; use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrs; use rustc_middle::mir; use rustc_middle::ty::layout::{FnAbiOf, LayoutOf, TyAndLayout}; +use rustc_middle::ty::typetree::typetree_from_ty; use rustc_middle::ty::{AtomicOrdering, Instance, Ty}; use rustc_session::config::OptLevel; use rustc_span::Span; @@ -456,7 +458,7 @@ pub trait BuilderMethods<'a, 'tcx>: src_align: Align, size: Self::Value, flags: MemFlags, - tt: Option, + tt: Option, ); fn memmove( &mut self, @@ -466,6 +468,7 @@ pub trait BuilderMethods<'a, 'tcx>: src_align: Align, size: Self::Value, flags: MemFlags, + //tt: Option, ); fn memset( &mut self, @@ -474,6 +477,7 @@ pub trait BuilderMethods<'a, 'tcx>: size: Self::Value, align: Align, flags: MemFlags, + //tt: Option, ); // Produce a value from calling the `vscale` intrinsic (containing the `vscale` multiplier that @@ -517,6 +521,19 @@ pub trait BuilderMethods<'a, 'tcx>: let temp = self.load_operand(src.with_type(layout)); temp.val.store_with_flags(self, dst.with_type(layout), flags); } else if !layout.is_zst() { + let tt = typetree_from_ty(self.tcx(), layout.ty); + // We seem to pass all values to memcpy with one more indirection. + let tt = tt.add_indirection(); + dbg!(&tt); + use rustc_middle::ty::print::with_no_trimmed_paths; + + with_no_trimmed_paths!({ + eprintln!("memcpy ty = {:?}", layout.ty); + }); + let fnc_tree = FncTree { + args: vec![tt.clone(), tt], + ret: TypeTree::new(), + }; let bytes = self.const_usize(layout.size.bytes()); let bytes = if layout.peel_transparent_wrappers(self).ty.is_scalable_vector() { let vscale = self.vscale(self.type_i64()); @@ -524,7 +541,7 @@ pub trait BuilderMethods<'a, 'tcx>: } else { bytes }; - self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags, None); + self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags, Some(fnc_tree)); } } diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index 8b063af187a58..df50a0863e384 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -160,6 +160,23 @@ extern "C" void LLVMRustPrintStatisticsJSON(RustStringRef OutBuf) { llvm::PrintStatisticsJSON(OS); } +extern "C" bool LLVMRustIsIntrinsicCall(LLVMValueRef V) { + llvm::Value *Val = llvm::unwrap(V); +llvm:errs() << "LLVMRustIsIntrinsicCall: " << *Val << "\n"; + + if (auto *CB = llvm::dyn_cast(Val)) { + if (auto *Callee = CB->getCalledFunction()) + return Callee->isIntrinsic(); + + return false; + } + + if (auto *F = llvm::dyn_cast(Val)) + return F->isIntrinsic(); + + return false; +} + // Some of the functions here rely on LLVM modules that may not always be // available. As such, we only try to build it in the first place, if // llvm.offload is enabled. diff --git a/compiler/rustc_middle/src/ty/typetree.rs b/compiler/rustc_middle/src/ty/typetree.rs index f541301afccca..7e3f5b4ab3389 100644 --- a/compiler/rustc_middle/src/ty/typetree.rs +++ b/compiler/rustc_middle/src/ty/typetree.rs @@ -39,6 +39,12 @@ pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>) -> FncTree { /// Generate a TypeTree for a specific type. /// Mainly a convenience wrapper around the actual implementation. pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { + if !tcx.sess.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::Enable) { + return TypeTree::new(); + } + if tcx.sess.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::NoTT) { + return TypeTree::new(); + } let mut visited = Vec::new(); typetree_from_ty_impl_inner(tcx, ty, 0, &mut visited, false) }