Skip to content

Implement OpTypeMatrix #738

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Aug 30, 2021
40 changes: 40 additions & 0 deletions crates/rustc_codegen_spirv/src/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -843,5 +843,45 @@ fn trans_intrinsic_type<'tcx>(
Err(ErrorReported)
}
}
IntrinsicType::Matrix => {
let span = def_id_for_spirv_type_adt(ty)
.map(|did| cx.tcx.def_span(did))
.expect("#[spirv(matrix)] must be added to a type which has DefId");

let field_types = (0..ty.fields.count())
.map(|i| trans_type_impl(cx, span, ty.field(cx, i), false))
.collect::<Vec<_>>();
if field_types.len() < 2 {
cx.tcx
.sess
.span_err(span, "#[spirv(matrix)] type must have at least two fields");
return Err(ErrorReported);
}
let elem_type = field_types[0];
if !field_types.iter().all(|&ty| ty == elem_type) {
cx.tcx.sess.span_err(
span,
"#[spirv(matrix)] type fields must all be the same type",
);
return Err(ErrorReported);
}
match cx.lookup_type(elem_type) {
SpirvType::Vector { .. } => (),
ty => {
cx.tcx
.sess
.struct_span_err(span, "#[spirv(matrix)] type fields must all be vectors")
.note(&format!("field type is {}", ty.debug(elem_type, cx)))
.emit();
return Err(ErrorReported);
}
}

Ok(SpirvType::Matrix {
element: elem_type,
count: field_types.len() as u32,
}
.def(span, cx))
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't have to (and shouldn't) do this digging into field types for this, and should use the standard type translation tools on the fields instead. Also, this is missing quite a bit of validation (e.g. if fields are different types, or there are no fields - IIRC your code ICEs if there's no fields) - there should be tests for each of the three kinds of error. Something like this should work instead:

IntrinsicType::Matrix => {
    let field_types = (0..ty.fields.count())
        .map(|i| trans_type_impl(cx, span, ty.field(cx, i), false))
        .collect::<Vec<_>>();
    if field_types.is_empty() {
        cx.tcx
            .sess
            .err("#[spirv(matrix)] type must have at least one field");
        return Err(ErrorReported);
    }
    let elem_type = field_types[0];
    if !field_types.iter().all(|&ty| ty == elem_type) {
        cx.tcx
            .sess
            .err("#[spirv(matrix)] type fields must all be the same type");
        return Err(ErrorReported);
    }
    match cx.lookup_type(elem_type) {
        SpirvType::Vector { .. } => (),
        ty => {
            cx.tcx
                .sess
                .struct_err("#[spirv(matrix)] type fields must all be vectors")
                .note(&format!("field type is {}", ty.debug(elem_type, cx)))
                .emit();
            return Err(ErrorReported);
        }
    }

    Ok(SpirvType::Matrix {
        element: elem_type,
        count: field_types.len() as u32,
    }
    .def(span, cx))
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.
I updated the code and change Matrix length validation since OpTypeMatrix requires a length of at least 2.
https://www.khronos.org/registry/SPIR-V/specs/unified1/SPIRV.html#OpTypeMatrix

}
}
1 change: 1 addition & 0 deletions crates/rustc_codegen_spirv/src/attr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ pub enum IntrinsicType {
SampledImage,
RayQueryKhr,
RuntimeArray,
Matrix,
}

// NOTE(eddyb) when adding new `#[spirv(...)]` attributes, the tests found inside
Expand Down
17 changes: 11 additions & 6 deletions crates/rustc_codegen_spirv/src/builder/builder_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
)),
},
SpirvType::Adt { .. } => self.fatal("memset on structs not implemented yet"),
SpirvType::Vector { element, count } => {
SpirvType::Vector { element, count } | SpirvType::Matrix { element, count } => {
let elem_pat = self.memset_const_pattern(&self.lookup_type(element), fill_byte);
self.constant_composite(
ty.clone().def(self.span(), self),
Expand Down Expand Up @@ -277,7 +277,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
)
.unwrap()
}
SpirvType::Vector { element, count } => {
SpirvType::Vector { element, count } | SpirvType::Matrix { element, count } => {
let elem_pat = self.memset_dynamic_pattern(&self.lookup_type(element), fill_var);
self.emit()
.composite_construct(
Expand Down Expand Up @@ -426,7 +426,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
}
SpirvType::Vector { element, .. }
| SpirvType::Array { element, .. }
| SpirvType::RuntimeArray { element } => {
| SpirvType::RuntimeArray { element }
| SpirvType::Matrix { element, .. } => {
ty = element;
ty_kind = self.lookup_type(ty);

Expand Down Expand Up @@ -1080,7 +1081,8 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
} => field_types[idx as usize],
SpirvType::Array { element, .. }
| SpirvType::RuntimeArray { element, .. }
| SpirvType::Vector { element, .. } => element,
| SpirvType::Vector { element, .. }
| SpirvType::Matrix { element, .. } => element,
SpirvType::InterfaceBlock { inner_type } => {
assert_eq!(idx, 0);
inner_type
Expand All @@ -1107,7 +1109,8 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
SpirvType::Adt { field_offsets, .. } => field_offsets[idx as usize],
SpirvType::Array { element, .. }
| SpirvType::RuntimeArray { element, .. }
| SpirvType::Vector { element, .. } => {
| SpirvType::Vector { element, .. }
| SpirvType::Matrix { element, .. } => {
self.lookup_type(element).sizeof(self).unwrap() * idx
}
_ => unreachable!(),
Expand Down Expand Up @@ -1843,7 +1846,9 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
fn extract_value(&mut self, agg_val: Self::Value, idx: u64) -> Self::Value {
let result_type = match self.lookup_type(agg_val.ty) {
SpirvType::Adt { field_types, .. } => field_types[idx as usize],
SpirvType::Array { element, .. } | SpirvType::Vector { element, .. } => element,
SpirvType::Array { element, .. }
| SpirvType::Vector { element, .. }
| SpirvType::Matrix { element, .. } => element,
other => self.fatal(&format!(
"extract_value not implemented on type {:?}",
other
Expand Down
5 changes: 5 additions & 0 deletions crates/rustc_codegen_spirv/src/builder/spirv_asm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,11 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
count: inst.operands[1].unwrap_literal_int32(),
}
.def(self.span(), self),
Op::TypeMatrix => SpirvType::Matrix {
element: inst.operands[0].unwrap_id_ref(),
count: inst.operands[1].unwrap_literal_int32(),
}
.def(self.span(), self),
Op::TypeArray => {
self.err("OpTypeArray in asm! is not supported yet");
return;
Expand Down
15 changes: 15 additions & 0 deletions crates/rustc_codegen_spirv/src/codegen_cx/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,21 @@ impl<'tcx> CodegenCx<'tcx> {
*offset = final_offset;
result
}
SpirvType::Matrix { element, count } => {
let total_size = ty_concrete
.sizeof(self)
.expect("create_const_alloc: Matrices must be sized");
let final_offset = *offset + total_size;
let values = (0..count).map(|_| {
self.create_const_alloc2(alloc, offset, element)
.def_cx(self)
});
let result = self.constant_composite(ty, values);
assert!(*offset <= final_offset);
// Matrices sometimes have padding at the end (e.g. Mat4x3), skip over it.
*offset = final_offset;
result
}
SpirvType::RuntimeArray { element } => {
let mut values = Vec::new();
while offset.bytes_usize() != alloc.len() {
Expand Down
2 changes: 1 addition & 1 deletion crates/rustc_codegen_spirv/src/codegen_cx/type_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ impl<'tcx> BaseTypeMethods<'tcx> for CodegenCx<'tcx> {
TypeKind::Struct
}
SpirvType::Vector { .. } => TypeKind::Vector,
SpirvType::Array { .. } | SpirvType::RuntimeArray { .. } => TypeKind::Array,
SpirvType::Array { .. } | SpirvType::RuntimeArray { .. } | SpirvType::Matrix { .. } => TypeKind::Array,
SpirvType::Pointer { .. } => TypeKind::Pointer,
SpirvType::Function { .. } => TypeKind::Function,
// HACK(eddyb) this is probably the closest `TypeKind` (which is still
Expand Down
21 changes: 17 additions & 4 deletions crates/rustc_codegen_spirv/src/spirv_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ pub enum SpirvType {
/// Note: vector count is literal.
count: u32,
},
Matrix {
element: Word,
/// Note: matrix count is literal.
count: u32,
},
Array {
element: Word,
/// Note: array count is ref to constant.
Expand Down Expand Up @@ -174,6 +179,7 @@ impl SpirvType {
result
}
Self::Vector { element, count } => cx.emit_global().type_vector_id(id, element, count),
Self::Matrix { element, count } => cx.emit_global().type_matrix_id(id, element, count),
Self::Array { element, count } => {
// ArrayStride decoration wants in *bytes*
let element_size = cx
Expand Down Expand Up @@ -347,6 +353,7 @@ impl SpirvType {
Self::Vector { element, count } => {
cx.lookup_type(element).sizeof(cx)? * count.next_power_of_two() as u64
}
Self::Matrix { element, count } => cx.lookup_type(element).sizeof(cx)? * count as u64,
Self::Array { element, count } => {
cx.lookup_type(element).sizeof(cx)? * cx.builder.lookup_const_u64(count).unwrap()
}
Expand Down Expand Up @@ -377,9 +384,9 @@ impl SpirvType {
.bytes(),
)
.expect("alignof: Vectors must have power-of-2 size"),
Self::Array { element, .. } | Self::RuntimeArray { element } => {
cx.lookup_type(element).alignof(cx)
}
Self::Array { element, .. }
| Self::RuntimeArray { element }
| Self::Matrix { element, .. } => cx.lookup_type(element).alignof(cx),
Self::Pointer { .. } => cx.tcx.data_layout.pointer_align.abi,
Self::Image { .. }
| Self::AccelerationStructureKhr
Expand Down Expand Up @@ -455,6 +462,12 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_> {
.field("element", &self.cx.debug_type(element))
.field("count", &count)
.finish(),
SpirvType::Matrix { element, count } => f
.debug_struct("Matrix")
.field("id", &self.id)
.field("element", &self.cx.debug_type(element))
.field("count", &count)
.finish(),
SpirvType::Array { element, count } => f
.debug_struct("Array")
.field("id", &self.id)
Expand Down Expand Up @@ -612,7 +625,7 @@ impl SpirvTypePrinter<'_, '_> {
}
f.write_str(" }")
}
SpirvType::Vector { element, count } => {
SpirvType::Vector { element, count } | SpirvType::Matrix { element, count } => {
ty(self.cx, stack, f, element)?;
write!(f, "x{}", count)
}
Expand Down
4 changes: 4 additions & 0 deletions crates/rustc_codegen_spirv/src/symbols.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,10 @@ impl Symbols {
"runtime_array",
SpirvAttribute::IntrinsicType(IntrinsicType::RuntimeArray),
),
(
"matrix",
SpirvAttribute::IntrinsicType(IntrinsicType::Matrix),
),
("unroll_loops", SpirvAttribute::UnrollLoops),
]
.iter()
Expand Down
12 changes: 12 additions & 0 deletions tests/ui/spirv-attr/invalid-matrix-type-empty.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Tests that matrix type inference fails correctly, for empty struct
// build-fail

use spirv_std as _;

#[spirv(matrix)]
pub struct _EmptyStruct {}

#[spirv(fragment)]
pub fn _entry() {
let _empty_struct = _EmptyStruct {};
}
8 changes: 8 additions & 0 deletions tests/ui/spirv-attr/invalid-matrix-type-empty.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
error: #[spirv(matrix)] type must have at least two fields
--> $DIR/invalid-matrix-type-empty.rs:7:1
|
7 | pub struct _EmptyStruct {}
| ^^^^^^^^^^^^^^^^^^^^^^^^^^

error: aborting due to previous error

25 changes: 25 additions & 0 deletions tests/ui/spirv-attr/invalid-matrix-type.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Tests that matrix type inference fails correctly
// build-fail

use spirv_std as _;

#[spirv(matrix)]
pub struct _FewerFields {
_v: glam::Vec3,
}

#[spirv(matrix)]
pub struct _NotVectorField {
_x: f32,
_y: f32,
_z: f32,
}

#[spirv(matrix)]
pub struct _DifferentType {
_x: glam::Vec3,
_y: glam::Vec2,
}

#[spirv(fragment)]
pub fn _entry(_arg1: _FewerFields, _arg2: _NotVectorField, _arg3: _DifferentType) {}
31 changes: 31 additions & 0 deletions tests/ui/spirv-attr/invalid-matrix-type.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
error: #[spirv(matrix)] type must have at least two fields
--> $DIR/invalid-matrix-type.rs:7:1
|
7 | / pub struct _FewerFields {
8 | | _v: glam::Vec3,
9 | | }
| |_^

error: #[spirv(matrix)] type fields must all be vectors
--> $DIR/invalid-matrix-type.rs:12:1
|
12 | / pub struct _NotVectorField {
13 | | _x: f32,
14 | | _y: f32,
15 | | _z: f32,
16 | | }
| |_^
|
= note: field type is f32

error: #[spirv(matrix)] type fields must all be the same type
--> $DIR/invalid-matrix-type.rs:19:1
|
19 | / pub struct _DifferentType {
20 | | _x: glam::Vec3,
21 | | _y: glam::Vec2,
22 | | }
| |_^

error: aborting due to 3 previous errors

53 changes: 53 additions & 0 deletions tests/ui/spirv-attr/matrix-type.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// build-pass
// compile-flags: -Ctarget-feature=+RayTracingKHR,+ext:SPV_KHR_ray_tracing

use spirv_std as _;

#[derive(Clone, Copy)]
#[spirv(matrix)]
pub struct Affine3 {
pub x: glam::Vec3,
pub y: glam::Vec3,
pub z: glam::Vec3,
pub w: glam::Vec3,
}

impl Affine3 {
pub const ZERO: Self = Self {
x: glam::Vec3::ZERO,
y: glam::Vec3::ZERO,
z: glam::Vec3::ZERO,
w: glam::Vec3::ZERO,
};

pub const IDENTITY: Self = Self {
x: glam::Vec3::X,
y: glam::Vec3::Y,
z: glam::Vec3::Z,
w: glam::Vec3::ZERO,
};
}

impl Default for Affine3 {
#[inline]
fn default() -> Self {
Self::IDENTITY
}
}

#[spirv(closest_hit)]
pub fn main_attrs(
#[spirv(object_to_world)] _object_to_world: Affine3,
#[spirv(world_to_object)] _world_to_object: Affine3,
) {
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this isn't actually testing much due to the majority of it being dead code. (Also, nit, the name Affine3 is a little misleading - it's not an affine transformation - but isn't super important for a test, haha)

I would like to see field accesses/etc. tested as well, I'm nervous about just crossing our fingers and hoping that matricies behave exactly like structs in all ways and no instructions need to be modified to handle matricies specially.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, I chose Affine3 because glam::f32::Affine3A has the same layout https://docs.rs/glam/0.18.0/glam/f32/struct.Affine3A.html (Sry, I don't know about math 😨).

I feel adding more tests for field operations in tests/ui/spirv-attr/* is not appropriate because it may should contain only about attr tests.
Is it OK? or is there any better place?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh! whoops, misread it as 4 Vec4s, not 4 Vec3s, 4 Vec3s definitely is an affine transform, haha, sorry

Yeah, not totally sure about where to put tests, anywhere is probably fine

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added tests to matrix-type.rs I think this is enough but are there suggestions for more cases?


#[spirv(fragment)]
pub fn main_default(out: &mut Affine3) {
*out = Affine3::default();
}

#[spirv(fragment)]
pub fn main_add(affine3: Affine3, out: &mut glam::Vec3) {
*out = affine3.x + affine3.y + affine3.z + affine3.w;
}
12 changes: 12 additions & 0 deletions tests/ui/spirv-attr/multiple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,21 @@ use spirv_std as _;
#[spirv(sampler, sampler)]
struct _SameIntrinsicType {}

#[spirv(matrix, matrix)]
struct _SameIntrinsicMatrixType {
x: glam::Vec3,
y: glam::Vec3,
}

#[spirv(sampler, generic_image_type)]
struct _DiffIntrinsicType {}

#[spirv(sampler, matrix)]
struct _SamplerAndMatrix {
x: glam::Vec3,
y: glam::Vec3,
}

#[spirv(block, block)]
struct _Block {}

Expand Down
Loading