Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 54 additions & 30 deletions crates/intrinsic-test/src/arm/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,48 +6,72 @@ impl TypeDefinition for ArmType {
fn c_type(&self) -> String {
let prefix = self.kind.c_prefix();

if let Some(bit_len) = self.bit_len {
match (self.simd_len, self.vec_len) {
(None, None) => format!("{prefix}{bit_len}_t"),
(Some(SimdLen::Fixed(simd)), None) => format!("{prefix}{bit_len}x{simd}_t"),
(Some(SimdLen::Fixed(simd)), Some(vec)) => {
format!("{prefix}{bit_len}x{simd}x{vec}_t")
}
(Some(SimdLen::Scalable), None) => format!("sv{prefix}{bit_len}_t"),
(Some(SimdLen::Scalable), Some(vec)) => {
format!("sv{prefix}{bit_len}x{vec}_t")
}
(None, Some(_)) => todo!("{self:#?}"), // Likely an invalid case
match (self.bit_len, self.simd_len, self.vec_len) {
// e.g. `bool`
(Some(_), None, None) if matches!(self.kind, TypeKind::Bool) => {
format!("{prefix}")
}
} else {
todo!("{self:#?}")
// e.g. `float32_t`, `int64_t`
(Some(bit_len), None, None) => format!("{prefix}{bit_len}_t"),
// e.g. `float32x2_t`, `int64x2_t`
(Some(bit_len), Some(SimdLen::Fixed(simd)), None) => {
format!("{prefix}{bit_len}x{simd}_t")
}
// e.g. `float32x2x3_t`, `int64x2x3_t`
(Some(bit_len), Some(SimdLen::Fixed(simd)), Some(vec)) => {
format!("{prefix}{bit_len}x{simd}x{vec}_t")
}
// e.g. `svbool_t`
(Some(_), Some(SimdLen::Scalable), None) if matches!(self.kind, TypeKind::Bool) => {
format!("sv{prefix}_t")
}
// e.g. `svfloat32_t`, `svint64_t`
(Some(bit_len), Some(SimdLen::Scalable), None) => format!("sv{prefix}{bit_len}_t"),
// e.g. `svfloat32x3_t`, `svint64x3_t`
(Some(bit_len), Some(SimdLen::Scalable), Some(vec)) => {
format!("sv{prefix}{bit_len}x{vec}_t")
}
_ => todo!("{self:#?}"),
}
}

fn rust_type(&self) -> String {
let rust_prefix = self.kind.rust_prefix();
let c_prefix = self.kind.c_prefix();

if let Some(bit_len) = self.bit_len {
match (self.simd_len, self.vec_len) {
(None, None) => format!("{rust_prefix}{bit_len}"),
(Some(SimdLen::Fixed(simd)), None) => format!("{c_prefix}{bit_len}x{simd}_t"),
(Some(SimdLen::Fixed(simd)), Some(vec)) => {
format!("{c_prefix}{bit_len}x{simd}x{vec}_t")
}
(Some(SimdLen::Scalable), None) => format!("sv{c_prefix}{bit_len}_t"),
(Some(SimdLen::Scalable), Some(vec)) => {
format!("sv{c_prefix}{bit_len}x{vec}_t")
}
(None, Some(_)) => todo!("{self:#?}"), // Likely an invalid case
match (self.bit_len, self.simd_len, self.vec_len) {
// e.g. `svpattern`
(None, _, _) => format!("{rust_prefix}"),
// e.g. `bool`
(Some(_), None, None) if matches!(self.kind, TypeKind::Bool) => {
format!("{rust_prefix}")
}
} else {
todo!("{self:#?}")
// e.g. `i32`
(Some(bit_len), None, None) => format!("{rust_prefix}{bit_len}"),
// e.g. `int32x2_t`
(Some(bit_len), Some(SimdLen::Fixed(simd)), None) => {
format!("{c_prefix}{bit_len}x{simd}_t")
}
// e.g. `int32x2x3_t`
(Some(bit_len), Some(SimdLen::Fixed(simd)), Some(vec)) => {
format!("{c_prefix}{bit_len}x{simd}x{vec}_t")
}
// e.g. `svbool_t`
(Some(_), Some(SimdLen::Scalable), None) if matches!(self.kind, TypeKind::Bool) => {
format!("sv{c_prefix}_t")
}
// e.g. `svint32_t`
(Some(bit_len), Some(SimdLen::Scalable), None) => format!("sv{c_prefix}{bit_len}_t"),
// e.g. `svint32x3_t`
(Some(bit_len), Some(SimdLen::Scalable), Some(vec)) => {
format!("sv{c_prefix}{bit_len}x{vec}_t")
}
(Some(_), None, Some(_)) => todo!("{self:#?}"),
}
}

/// Determines the load function for this type.
fn get_load_function(&self) -> String {
fn load_function(&self) -> String {
if let IntrinsicType {
kind: k,
bit_len: Some(bl),
Expand All @@ -71,7 +95,7 @@ impl TypeDefinition for ArmType {
len = vec_len.unwrap_or(1),
)
} else {
todo!("get_load_function IntrinsicType: {self:#?}")
todo!("load_function IntrinsicType: {self:#?}")
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/intrinsic-test/src/common/argument.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ where
"let {name} = {load}({vals_name}.as_ptr().add((i+{idx}) % {PASSES}) as _);",
name = arg.generate_name(),
vals_name = test_values_array_name(&arg.ty),
load = arg.ty.get_load_function(),
load = arg.ty.load_function(),
)
} else {
format!(
Expand Down
27 changes: 2 additions & 25 deletions crates/intrinsic-test/src/common/gen_rust.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use itertools::Itertools;
use super::intrinsic_helpers::TypeDefinition;
use crate::common::cli::{CcArgStyle, ProcessedCli};
use crate::common::intrinsic::Intrinsic;
use crate::common::intrinsic_helpers::TypeKind;
use crate::common::values::{test_values_array_name, test_values_array_static};
use crate::common::{PASSES, SupportedArchitecture};

Expand Down Expand Up @@ -195,29 +194,6 @@ fn generate_rust_test_loop<A: SupportedArchitecture>(
writeln!(w, " ];")?;
}

let (cast_prefix, cast_suffix) = if intrinsic.results.is_simd() {
(
format!(
"std::mem::transmute::<_, [{}; {}]>(",
intrinsic.results.rust_scalar_type().replace("f", "NanEqF"),
intrinsic.results.num_lanes() * intrinsic.results.num_vectors()
),
")",
)
} else if intrinsic.results.kind == TypeKind::Float {
(
match intrinsic.results.inner_size() {
16 => format!("NanEqF16("),
32 => format!("NanEqF32("),
64 => format!("NanEqF64("),
_ => unimplemented!(),
},
")",
)
} else {
("".to_string(), "")
};

write!(
w,
r#"
Expand All @@ -231,14 +207,15 @@ for (id, rust, c) in specializations {{
c(__c_return_value.as_mut_ptr(){c_args});
let __c_return_value = __c_return_value.assume_init();

assert_eq!({cast_prefix}__rust_return_value{cast_suffix}, {cast_prefix}__c_return_value{cast_suffix}, "{{id}}");
{comparison}
}}
}}
}}
"#,
loaded_args = intrinsic.arguments.load_values_rust(),
rust_args = intrinsic.arguments.as_call_param_rust(),
c_args = intrinsic.arguments.as_c_call_param_rust(),
comparison = intrinsic.results.comparison_function(),
)
}

Expand Down
77 changes: 63 additions & 14 deletions crates/intrinsic-test/src/common/intrinsic_helpers.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::cmp;
use std::fmt;
use std::ops::Deref;
use std::ops::DerefMut;
use std::str::FromStr;

#[derive(Debug, PartialEq, Copy, Clone)]
Expand Down Expand Up @@ -90,9 +90,13 @@ impl TypeKind {
}
}

/// Returns the Rust prefix for this type kind i.e. `i`, `u`, or `f`.
/// Returns the Rust prefix for this type kind (i.e. `i` for `i16`, or `u` for `u16`). For type
/// kinds without any bit length at the end (e.g. `bool`), returns the whole type name.
pub fn rust_prefix(&self) -> &str {
match self {
Self::Bool => "bool",
Self::SvPattern => "svpattern",
Self::SvPrefetchOp => "svprfop",
Self::BFloat => "bf",
Self::Float => "f",
Self::Int(Sign::Signed) => "i",
Expand All @@ -101,7 +105,7 @@ impl TypeKind {
Self::Char(Sign::Unsigned) => "u",
Self::Char(Sign::Signed) => "i",
Self::Mask => "u",
_ => unreachable!("Unused type kind: {self:#?}"),
_ => unreachable!("type kind without Rust prefix: {self:#?}"),
}
}
}
Expand Down Expand Up @@ -195,9 +199,19 @@ impl IntrinsicType {
}
}

pub trait TypeDefinition: Clone + Deref<Target = IntrinsicType> {
pub trait TypeDefinition: Clone + DerefMut<Target = IntrinsicType> {
/// Determines the load function for this type.
fn get_load_function(&self) -> String;
fn load_function(&self) -> String;

/// Determines the comparison function for this type.
fn comparison_function(&self) -> String {
match self.simd_len {
Some(SimdLen::Scalable) => unimplemented!("architecture-specific"),
Some(SimdLen::Fixed(_)) | None => {
default_fixed_vector_comparison(self, self.num_lanes())
}
}
}

/// Gets a string containing the typename for this type in C.
fn c_type(&self) -> String;
Expand All @@ -208,14 +222,49 @@ pub trait TypeDefinition: Clone + Deref<Target = IntrinsicType> {
/// Gets a string containing the name of the scalar type corresponding to this type if it is a
/// vector.
fn rust_scalar_type(&self) -> String {
if self.is_simd() {
format!(
"{prefix}{bits}",
prefix = self.kind().rust_prefix(),
bits = self.inner_size()
)
} else {
self.rust_type()
}
let mut ty = self.clone();
ty.simd_len = None;
ty.vec_len = None;
ty.rust_type()
}
}

/// Returns the default comparison between results of an intrinsic - casting the vectors to arrays
/// and using `assert_eq` - using `NanEqF*` where required for floats.
pub(crate) fn default_fixed_vector_comparison<Ty: TypeDefinition>(
ty: &Ty,
num_lanes: u32,
) -> String {
let (cast_prefix, cast_suffix) = if ty.is_simd() {
(
format!(
"std::mem::transmute::<_, [{}; {}]>(",
ty.rust_scalar_type().replace("f", "NanEqF"),
num_lanes * ty.num_vectors()
),
")",
)
} else if ty.kind == TypeKind::Float {
(
match ty.inner_size() {
16 => format!("NanEqF16("),
32 => format!("NanEqF32("),
64 => format!("NanEqF64("),
_ => unimplemented!(),
},
")",
)
} else {
("".to_string(), "")
};

format!(
r#"
assert_eq!(
{cast_prefix}__rust_return_value{cast_suffix},
{cast_prefix}__c_return_value{cast_suffix},
"{{id}}"
);
"#,
)
}
2 changes: 1 addition & 1 deletion crates/intrinsic-test/src/x86/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ impl TypeDefinition for X86IntrinsicType {
}

/// Determines the load function for this type.
fn get_load_function(&self) -> String {
fn load_function(&self) -> String {
let type_value = self.param.type_data.clone();
if type_value.len() == 0 {
unimplemented!("the value for key 'type' is not present!");
Expand Down