Skip to content

Commit b0676cb

Browse files
authored
Implement OpTypeMatrix (#738)
* Implement OpTypeMatrix * clippy * Use cached Symbol * Implement #[spirv(matrix(ty, m, n))] instead of Matrix trait * Update #[spirv(matrix(..))] - #[spirv(matrix(ty, m, n))] Specify all of type, rows, columns. - #[spirv(matrix(ty, m))] Specify all of type, rows. Infer columns. - #[spirv(matrix(ty))] Specify all of type. Infer others. - #[spirv(matrix)] Infer all. * Drop #[spirv(matrix(..))] (with arguments) * Fix IntrinsicType::Matrix type construction * Update matrix-type.rs * Update tests/ui/spirv-attr/multiple.rs to test Matrix * Fix tests/ui/spirv-attr/matrix-type.rs * Add failing tests for #[spirv(matrix) * Update error messages for #[spirv(matrix)]
1 parent 1e3881b commit b0676cb

File tree

15 files changed

+322
-74
lines changed

15 files changed

+322
-74
lines changed

crates/rustc_codegen_spirv/src/abi.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -843,5 +843,45 @@ fn trans_intrinsic_type<'tcx>(
843843
Err(ErrorReported)
844844
}
845845
}
846+
IntrinsicType::Matrix => {
847+
let span = def_id_for_spirv_type_adt(ty)
848+
.map(|did| cx.tcx.def_span(did))
849+
.expect("#[spirv(matrix)] must be added to a type which has DefId");
850+
851+
let field_types = (0..ty.fields.count())
852+
.map(|i| trans_type_impl(cx, span, ty.field(cx, i), false))
853+
.collect::<Vec<_>>();
854+
if field_types.len() < 2 {
855+
cx.tcx
856+
.sess
857+
.span_err(span, "#[spirv(matrix)] type must have at least two fields");
858+
return Err(ErrorReported);
859+
}
860+
let elem_type = field_types[0];
861+
if !field_types.iter().all(|&ty| ty == elem_type) {
862+
cx.tcx.sess.span_err(
863+
span,
864+
"#[spirv(matrix)] type fields must all be the same type",
865+
);
866+
return Err(ErrorReported);
867+
}
868+
match cx.lookup_type(elem_type) {
869+
SpirvType::Vector { .. } => (),
870+
ty => {
871+
cx.tcx
872+
.sess
873+
.struct_span_err(span, "#[spirv(matrix)] type fields must all be vectors")
874+
.note(&format!("field type is {}", ty.debug(elem_type, cx)))
875+
.emit();
876+
return Err(ErrorReported);
877+
}
878+
}
879+
880+
Ok(SpirvType::Matrix {
881+
element: elem_type,
882+
count: field_types.len() as u32,
883+
}
884+
.def(span, cx))
885+
}
846886
}
847887
}

crates/rustc_codegen_spirv/src/attr.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ pub enum IntrinsicType {
6565
SampledImage,
6666
RayQueryKhr,
6767
RuntimeArray,
68+
Matrix,
6869
}
6970

7071
// NOTE(eddyb) when adding new `#[spirv(...)]` attributes, the tests found inside

crates/rustc_codegen_spirv/src/builder/builder_methods.rs

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
210210
)),
211211
},
212212
SpirvType::Adt { .. } => self.fatal("memset on structs not implemented yet"),
213-
SpirvType::Vector { element, count } => {
213+
SpirvType::Vector { element, count } | SpirvType::Matrix { element, count } => {
214214
let elem_pat = self.memset_const_pattern(&self.lookup_type(element), fill_byte);
215215
self.constant_composite(
216216
ty.clone().def(self.span(), self),
@@ -277,7 +277,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
277277
)
278278
.unwrap()
279279
}
280-
SpirvType::Vector { element, count } => {
280+
SpirvType::Vector { element, count } | SpirvType::Matrix { element, count } => {
281281
let elem_pat = self.memset_dynamic_pattern(&self.lookup_type(element), fill_var);
282282
self.emit()
283283
.composite_construct(
@@ -426,7 +426,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
426426
}
427427
SpirvType::Vector { element, .. }
428428
| SpirvType::Array { element, .. }
429-
| SpirvType::RuntimeArray { element } => {
429+
| SpirvType::RuntimeArray { element }
430+
| SpirvType::Matrix { element, .. } => {
430431
ty = element;
431432
ty_kind = self.lookup_type(ty);
432433

@@ -1080,7 +1081,8 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
10801081
} => field_types[idx as usize],
10811082
SpirvType::Array { element, .. }
10821083
| SpirvType::RuntimeArray { element, .. }
1083-
| SpirvType::Vector { element, .. } => element,
1084+
| SpirvType::Vector { element, .. }
1085+
| SpirvType::Matrix { element, .. } => element,
10841086
SpirvType::InterfaceBlock { inner_type } => {
10851087
assert_eq!(idx, 0);
10861088
inner_type
@@ -1107,7 +1109,8 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
11071109
SpirvType::Adt { field_offsets, .. } => field_offsets[idx as usize],
11081110
SpirvType::Array { element, .. }
11091111
| SpirvType::RuntimeArray { element, .. }
1110-
| SpirvType::Vector { element, .. } => {
1112+
| SpirvType::Vector { element, .. }
1113+
| SpirvType::Matrix { element, .. } => {
11111114
self.lookup_type(element).sizeof(self).unwrap() * idx
11121115
}
11131116
_ => unreachable!(),
@@ -1843,7 +1846,9 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
18431846
fn extract_value(&mut self, agg_val: Self::Value, idx: u64) -> Self::Value {
18441847
let result_type = match self.lookup_type(agg_val.ty) {
18451848
SpirvType::Adt { field_types, .. } => field_types[idx as usize],
1846-
SpirvType::Array { element, .. } | SpirvType::Vector { element, .. } => element,
1849+
SpirvType::Array { element, .. }
1850+
| SpirvType::Vector { element, .. }
1851+
| SpirvType::Matrix { element, .. } => element,
18471852
other => self.fatal(&format!(
18481853
"extract_value not implemented on type {:?}",
18491854
other

crates/rustc_codegen_spirv/src/builder/spirv_asm.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,11 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
303303
count: inst.operands[1].unwrap_literal_int32(),
304304
}
305305
.def(self.span(), self),
306+
Op::TypeMatrix => SpirvType::Matrix {
307+
element: inst.operands[0].unwrap_id_ref(),
308+
count: inst.operands[1].unwrap_literal_int32(),
309+
}
310+
.def(self.span(), self),
306311
Op::TypeArray => {
307312
self.err("OpTypeArray in asm! is not supported yet");
308313
return;

crates/rustc_codegen_spirv/src/codegen_cx/constant.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,21 @@ impl<'tcx> CodegenCx<'tcx> {
482482
*offset = final_offset;
483483
result
484484
}
485+
SpirvType::Matrix { element, count } => {
486+
let total_size = ty_concrete
487+
.sizeof(self)
488+
.expect("create_const_alloc: Matrices must be sized");
489+
let final_offset = *offset + total_size;
490+
let values = (0..count).map(|_| {
491+
self.create_const_alloc2(alloc, offset, element)
492+
.def_cx(self)
493+
});
494+
let result = self.constant_composite(ty, values);
495+
assert!(*offset <= final_offset);
496+
// Matrices sometimes have padding at the end (e.g. Mat4x3), skip over it.
497+
*offset = final_offset;
498+
result
499+
}
485500
SpirvType::RuntimeArray { element } => {
486501
let mut values = Vec::new();
487502
while offset.bytes_usize() != alloc.len() {

crates/rustc_codegen_spirv/src/codegen_cx/type_.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ impl<'tcx> BaseTypeMethods<'tcx> for CodegenCx<'tcx> {
174174
TypeKind::Struct
175175
}
176176
SpirvType::Vector { .. } => TypeKind::Vector,
177-
SpirvType::Array { .. } | SpirvType::RuntimeArray { .. } => TypeKind::Array,
177+
SpirvType::Array { .. } | SpirvType::RuntimeArray { .. } | SpirvType::Matrix { .. } => TypeKind::Array,
178178
SpirvType::Pointer { .. } => TypeKind::Pointer,
179179
SpirvType::Function { .. } => TypeKind::Function,
180180
// HACK(eddyb) this is probably the closest `TypeKind` (which is still

crates/rustc_codegen_spirv/src/spirv_type.rs

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ pub enum SpirvType {
4747
/// Note: vector count is literal.
4848
count: u32,
4949
},
50+
Matrix {
51+
element: Word,
52+
/// Note: matrix count is literal.
53+
count: u32,
54+
},
5055
Array {
5156
element: Word,
5257
/// Note: array count is ref to constant.
@@ -174,6 +179,7 @@ impl SpirvType {
174179
result
175180
}
176181
Self::Vector { element, count } => cx.emit_global().type_vector_id(id, element, count),
182+
Self::Matrix { element, count } => cx.emit_global().type_matrix_id(id, element, count),
177183
Self::Array { element, count } => {
178184
// ArrayStride decoration wants in *bytes*
179185
let element_size = cx
@@ -347,6 +353,7 @@ impl SpirvType {
347353
Self::Vector { element, count } => {
348354
cx.lookup_type(element).sizeof(cx)? * count.next_power_of_two() as u64
349355
}
356+
Self::Matrix { element, count } => cx.lookup_type(element).sizeof(cx)? * count as u64,
350357
Self::Array { element, count } => {
351358
cx.lookup_type(element).sizeof(cx)? * cx.builder.lookup_const_u64(count).unwrap()
352359
}
@@ -377,9 +384,9 @@ impl SpirvType {
377384
.bytes(),
378385
)
379386
.expect("alignof: Vectors must have power-of-2 size"),
380-
Self::Array { element, .. } | Self::RuntimeArray { element } => {
381-
cx.lookup_type(element).alignof(cx)
382-
}
387+
Self::Array { element, .. }
388+
| Self::RuntimeArray { element }
389+
| Self::Matrix { element, .. } => cx.lookup_type(element).alignof(cx),
383390
Self::Pointer { .. } => cx.tcx.data_layout.pointer_align.abi,
384391
Self::Image { .. }
385392
| Self::AccelerationStructureKhr
@@ -455,6 +462,12 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_> {
455462
.field("element", &self.cx.debug_type(element))
456463
.field("count", &count)
457464
.finish(),
465+
SpirvType::Matrix { element, count } => f
466+
.debug_struct("Matrix")
467+
.field("id", &self.id)
468+
.field("element", &self.cx.debug_type(element))
469+
.field("count", &count)
470+
.finish(),
458471
SpirvType::Array { element, count } => f
459472
.debug_struct("Array")
460473
.field("id", &self.id)
@@ -612,7 +625,7 @@ impl SpirvTypePrinter<'_, '_> {
612625
}
613626
f.write_str(" }")
614627
}
615-
SpirvType::Vector { element, count } => {
628+
SpirvType::Vector { element, count } | SpirvType::Matrix { element, count } => {
616629
ty(self.cx, stack, f, element)?;
617630
write!(f, "x{}", count)
618631
}

crates/rustc_codegen_spirv/src/symbols.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,10 @@ impl Symbols {
334334
"runtime_array",
335335
SpirvAttribute::IntrinsicType(IntrinsicType::RuntimeArray),
336336
),
337+
(
338+
"matrix",
339+
SpirvAttribute::IntrinsicType(IntrinsicType::Matrix),
340+
),
337341
("unroll_loops", SpirvAttribute::UnrollLoops),
338342
]
339343
.iter()
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// Tests that matrix type inference fails correctly, for empty struct
2+
// build-fail
3+
4+
use spirv_std as _;
5+
6+
#[spirv(matrix)]
7+
pub struct _EmptyStruct {}
8+
9+
#[spirv(fragment)]
10+
pub fn _entry() {
11+
let _empty_struct = _EmptyStruct {};
12+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
error: #[spirv(matrix)] type must have at least two fields
2+
--> $DIR/invalid-matrix-type-empty.rs:7:1
3+
|
4+
7 | pub struct _EmptyStruct {}
5+
| ^^^^^^^^^^^^^^^^^^^^^^^^^^
6+
7+
error: aborting due to previous error
8+
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// Tests that matrix type inference fails correctly
2+
// build-fail
3+
4+
use spirv_std as _;
5+
6+
#[spirv(matrix)]
7+
pub struct _FewerFields {
8+
_v: glam::Vec3,
9+
}
10+
11+
#[spirv(matrix)]
12+
pub struct _NotVectorField {
13+
_x: f32,
14+
_y: f32,
15+
_z: f32,
16+
}
17+
18+
#[spirv(matrix)]
19+
pub struct _DifferentType {
20+
_x: glam::Vec3,
21+
_y: glam::Vec2,
22+
}
23+
24+
#[spirv(fragment)]
25+
pub fn _entry(_arg1: _FewerFields, _arg2: _NotVectorField, _arg3: _DifferentType) {}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
error: #[spirv(matrix)] type must have at least two fields
2+
--> $DIR/invalid-matrix-type.rs:7:1
3+
|
4+
7 | / pub struct _FewerFields {
5+
8 | | _v: glam::Vec3,
6+
9 | | }
7+
| |_^
8+
9+
error: #[spirv(matrix)] type fields must all be vectors
10+
--> $DIR/invalid-matrix-type.rs:12:1
11+
|
12+
12 | / pub struct _NotVectorField {
13+
13 | | _x: f32,
14+
14 | | _y: f32,
15+
15 | | _z: f32,
16+
16 | | }
17+
| |_^
18+
|
19+
= note: field type is f32
20+
21+
error: #[spirv(matrix)] type fields must all be the same type
22+
--> $DIR/invalid-matrix-type.rs:19:1
23+
|
24+
19 | / pub struct _DifferentType {
25+
20 | | _x: glam::Vec3,
26+
21 | | _y: glam::Vec2,
27+
22 | | }
28+
| |_^
29+
30+
error: aborting due to 3 previous errors
31+

tests/ui/spirv-attr/matrix-type.rs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// build-pass
2+
// compile-flags: -Ctarget-feature=+RayTracingKHR,+ext:SPV_KHR_ray_tracing
3+
4+
use spirv_std as _;
5+
6+
#[derive(Clone, Copy)]
7+
#[spirv(matrix)]
8+
pub struct Affine3 {
9+
pub x: glam::Vec3,
10+
pub y: glam::Vec3,
11+
pub z: glam::Vec3,
12+
pub w: glam::Vec3,
13+
}
14+
15+
impl Affine3 {
16+
pub const ZERO: Self = Self {
17+
x: glam::Vec3::ZERO,
18+
y: glam::Vec3::ZERO,
19+
z: glam::Vec3::ZERO,
20+
w: glam::Vec3::ZERO,
21+
};
22+
23+
pub const IDENTITY: Self = Self {
24+
x: glam::Vec3::X,
25+
y: glam::Vec3::Y,
26+
z: glam::Vec3::Z,
27+
w: glam::Vec3::ZERO,
28+
};
29+
}
30+
31+
impl Default for Affine3 {
32+
#[inline]
33+
fn default() -> Self {
34+
Self::IDENTITY
35+
}
36+
}
37+
38+
#[spirv(closest_hit)]
39+
pub fn main_attrs(
40+
#[spirv(object_to_world)] _object_to_world: Affine3,
41+
#[spirv(world_to_object)] _world_to_object: Affine3,
42+
) {
43+
}
44+
45+
#[spirv(fragment)]
46+
pub fn main_default(out: &mut Affine3) {
47+
*out = Affine3::default();
48+
}
49+
50+
#[spirv(fragment)]
51+
pub fn main_add(affine3: Affine3, out: &mut glam::Vec3) {
52+
*out = affine3.x + affine3.y + affine3.z + affine3.w;
53+
}

tests/ui/spirv-attr/multiple.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,21 @@ use spirv_std as _;
88
#[spirv(sampler, sampler)]
99
struct _SameIntrinsicType {}
1010

11+
#[spirv(matrix, matrix)]
12+
struct _SameIntrinsicMatrixType {
13+
x: glam::Vec3,
14+
y: glam::Vec3,
15+
}
16+
1117
#[spirv(sampler, generic_image_type)]
1218
struct _DiffIntrinsicType {}
1319

20+
#[spirv(sampler, matrix)]
21+
struct _SamplerAndMatrix {
22+
x: glam::Vec3,
23+
y: glam::Vec3,
24+
}
25+
1426
#[spirv(block, block)]
1527
struct _Block {}
1628

0 commit comments

Comments
 (0)