Skip to content

Commit 52445ce

Browse files
committed
[WIP] add #[spirv(typed_buffer)] for explicit SpirvType::InterfaceBlocks.
1 parent 9cba610 commit 52445ce

File tree

10 files changed

+251
-87
lines changed

10 files changed

+251
-87
lines changed

crates/rustc_codegen_spirv/src/abi.rs

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -935,18 +935,48 @@ fn trans_intrinsic_type<'tcx>(
935935
.err("#[spirv(runtime_array)] type must have size 4"));
936936
}
937937

938-
// We use a generic to indicate the underlying element type.
939-
// The spirv type of it will be generated by querying the type of the first generic.
938+
// We use a generic param to indicate the underlying element type.
939+
// The SPIR-V element type will be generated from the first generic param.
940940
if let Some(elem_ty) = args.types().next() {
941-
let element = cx.layout_of(elem_ty).spirv_type(span, cx);
942-
Ok(SpirvType::RuntimeArray { element }.def(span, cx))
941+
Ok(SpirvType::RuntimeArray {
942+
element: cx.layout_of(elem_ty).spirv_type(span, cx),
943+
}
944+
.def(span, cx))
943945
} else {
944946
Err(cx
945947
.tcx
946948
.dcx()
947949
.err("#[spirv(runtime_array)] type must have a generic element type"))
948950
}
949951
}
952+
IntrinsicType::TypedBuffer => {
953+
if ty.size != Size::from_bytes(4) {
954+
return Err(cx
955+
.tcx
956+
.sess
957+
.dcx()
958+
.err("#[spirv(typed_buffer)] type must have size 4"));
959+
}
960+
961+
// We use a generic param to indicate the underlying data type.
962+
// The SPIR-V data type will be generated from the first generic param.
963+
if let Some(data_ty) = args.types().next() {
964+
// HACK(eddyb) this should be a *pointer* to an "interface block",
965+
// but SPIR-V screwed up and used no explicit indirection for the
966+
// descriptor indexing case, and instead made a `RuntimeArray` of
967+
// `InterfaceBlock`s be an "array of typed buffer resources".
968+
Ok(SpirvType::InterfaceBlock {
969+
inner_type: cx.layout_of(data_ty).spirv_type(span, cx),
970+
}
971+
.def(span, cx))
972+
} else {
973+
Err(cx
974+
.tcx
975+
.sess
976+
.dcx()
977+
.err("#[spirv(typed_buffer)] type must have a generic data type"))
978+
}
979+
}
950980
IntrinsicType::Matrix => {
951981
let span = def_id_for_spirv_type_adt(ty)
952982
.map(|did| cx.tcx.def_span(did))

crates/rustc_codegen_spirv/src/attr.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ pub enum IntrinsicType {
6565
SampledImage,
6666
RayQueryKhr,
6767
RuntimeArray,
68+
TypedBuffer,
6869
Matrix,
6970
}
7071

crates/rustc_codegen_spirv/src/builder/spirv_asm.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,11 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
699699
};
700700
ty = match cx.lookup_type(ty) {
701701
SpirvType::Array { element, .. }
702-
| SpirvType::RuntimeArray { element } => element,
702+
| SpirvType::RuntimeArray { element }
703+
// HACK(eddyb) this is pretty bad because it's not
704+
// checking that the index is an `OpConstant 0`, but
705+
// there's no other valid choice anyway.
706+
| SpirvType::InterfaceBlock { inner_type: element } => element,
703707

704708
SpirvType::Adt { field_types, .. } => *index_to_usize()
705709
.and_then(|i| field_types.get(i))

crates/rustc_codegen_spirv/src/codegen_cx/entry.rs

Lines changed: 90 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -494,89 +494,107 @@ impl<'tcx> CodegenCx<'tcx> {
494494
.dcx()
495495
.span_fatal(hir_param.ty_span, "pair type not supported yet")
496496
}
497+
// FIXME(eddyb) should this talk about "typed buffers" instead of "interface blocks"?
498+
// FIXME(eddyb) should we talk about "descriptor indexing" or
499+
// actually use more reasonable terms like "resource arrays"?
500+
let needs_interface_block_and_supports_descriptor_indexing = matches!(
501+
storage_class,
502+
Ok(StorageClass::Uniform | StorageClass::StorageBuffer)
503+
);
504+
let needs_interface_block = needs_interface_block_and_supports_descriptor_indexing
505+
|| storage_class == Ok(StorageClass::PushConstant);
506+
// NOTE(eddyb) `#[spirv(typed_buffer)]` adds `SpirvType::InterfaceBlock`s
507+
// which must bypass the automated ones (i.e. the user is taking control).
508+
let has_explicit_interface_block = needs_interface_block_and_supports_descriptor_indexing
509+
&& {
510+
// Peel off arrays first (used for "descriptor indexing").
511+
let outermost_or_array_element = match self.lookup_type(value_spirv_type) {
512+
SpirvType::Array { element, .. } | SpirvType::RuntimeArray { element } => {
513+
element
514+
}
515+
_ => value_spirv_type,
516+
};
517+
matches!(
518+
self.lookup_type(outermost_or_array_element),
519+
SpirvType::InterfaceBlock { .. }
520+
)
521+
};
497522
let var_ptr_spirv_type;
498-
let (value_ptr, value_len) = match storage_class {
499-
Ok(
500-
StorageClass::PushConstant | StorageClass::Uniform | StorageClass::StorageBuffer,
501-
) => {
502-
let var_spirv_type = SpirvType::InterfaceBlock {
503-
inner_type: value_spirv_type,
504-
}
505-
.def(hir_param.span, self);
506-
var_ptr_spirv_type = self.type_ptr_to(var_spirv_type);
523+
let (value_ptr, value_len) = if needs_interface_block && !has_explicit_interface_block {
524+
let var_spirv_type = SpirvType::InterfaceBlock {
525+
inner_type: value_spirv_type,
526+
}
527+
.def(hir_param.span, self);
528+
var_ptr_spirv_type = self.type_ptr_to(var_spirv_type);
507529

508-
let value_ptr = bx.struct_gep(
509-
var_spirv_type,
510-
var_id.unwrap().with_type(var_ptr_spirv_type),
511-
0,
512-
);
530+
let value_ptr = bx.struct_gep(
531+
var_spirv_type,
532+
var_id.unwrap().with_type(var_ptr_spirv_type),
533+
0,
534+
);
513535

514-
let value_len = if is_unsized_with_len {
515-
match self.lookup_type(value_spirv_type) {
516-
SpirvType::RuntimeArray { .. } => {}
517-
_ => {
518-
self.tcx.dcx().span_err(
519-
hir_param.ty_span,
520-
"only plain slices are supported as unsized types",
521-
);
522-
}
536+
let value_len = if is_unsized_with_len {
537+
match self.lookup_type(value_spirv_type) {
538+
SpirvType::RuntimeArray { .. } => {}
539+
_ => {
540+
self.tcx.dcx().span_err(
541+
hir_param.ty_span,
542+
"only plain slices are supported as unsized types",
543+
);
523544
}
545+
}
524546

525-
// FIXME(eddyb) shouldn't this be `usize`?
526-
let len_spirv_type = self.type_isize();
527-
let len = bx
528-
.emit()
529-
.array_length(len_spirv_type, None, var_id.unwrap(), 0)
530-
.unwrap();
531-
532-
Some(len.with_type(len_spirv_type))
533-
} else {
534-
if is_unsized {
535-
// It's OK to use a RuntimeArray<u32> and not have a length parameter, but
536-
// it's just nicer ergonomics to use a slice.
537-
self.tcx
538-
.dcx()
539-
.span_warn(hir_param.ty_span, "use &[T] instead of &RuntimeArray<T>");
540-
}
541-
None
542-
};
547+
// FIXME(eddyb) shouldn't this be `usize`?
548+
let len_spirv_type = self.type_isize();
549+
let len = bx
550+
.emit()
551+
.array_length(len_spirv_type, None, var_id.unwrap(), 0)
552+
.unwrap();
543553

544-
(Ok(value_ptr), value_len)
545-
}
546-
Ok(StorageClass::UniformConstant) => {
547-
var_ptr_spirv_type = self.type_ptr_to(value_spirv_type);
554+
Some(len.with_type(len_spirv_type))
555+
} else {
556+
if is_unsized {
557+
// It's OK to use a RuntimeArray<u32> and not have a length parameter, but
558+
// it's just nicer ergonomics to use a slice.
559+
self.tcx
560+
.dcx()
561+
.span_warn(hir_param.ty_span, "use &[T] instead of &RuntimeArray<T>");
562+
}
563+
None
564+
};
548565

566+
(Ok(value_ptr), value_len)
567+
} else {
568+
var_ptr_spirv_type = self.type_ptr_to(value_spirv_type);
569+
570+
// FIXME(eddyb) should we talk about "descriptor indexing" or
571+
// actually use more reasonable terms like "resource arrays"?
572+
let unsized_is_descriptor_indexing =
573+
needs_interface_block_and_supports_descriptor_indexing
574+
|| storage_class == Ok(StorageClass::UniformConstant);
575+
if unsized_is_descriptor_indexing {
549576
match self.lookup_type(value_spirv_type) {
550577
SpirvType::RuntimeArray { .. } => {
551578
if is_unsized_with_len {
552579
self.tcx.dcx().span_err(
553580
hir_param.ty_span,
554-
"uniform_constant must use &RuntimeArray<T>, not &[T]",
581+
"descriptor indexing must use &RuntimeArray<T>, not &[T]",
555582
);
556583
}
557584
}
558585
_ => {
559586
if is_unsized {
560587
self.tcx.dcx().span_err(
561588
hir_param.ty_span,
562-
"only plain slices are supported as unsized types",
589+
"only RuntimeArray is supported, not other unsized types",
563590
);
564591
}
565592
}
566593
}
567-
568-
let value_len = if is_pair {
569-
// We've already emitted an error, fill in a placeholder value
570-
Some(bx.undef(self.type_isize()))
571-
} else {
572-
None
573-
};
574-
575-
(Ok(var_id.unwrap().with_type(var_ptr_spirv_type)), value_len)
576-
}
577-
_ => {
578-
var_ptr_spirv_type = self.type_ptr_to(value_spirv_type);
579-
594+
} else {
595+
// FIXME(eddyb) determine, based on the type, what kind of type
596+
// this is, to narrow it further to e.g. "buffer in a non-buffer
597+
// storage class" or "storage class expects fixed data sizes".
580598
if is_unsized {
581599
self.tcx.dcx().span_fatal(
582600
hir_param.ty_span,
@@ -589,12 +607,19 @@ impl<'tcx> CodegenCx<'tcx> {
589607
),
590608
);
591609
}
592-
593-
(
594-
var_id.map(|var_id| var_id.with_type(var_ptr_spirv_type)),
595-
None,
596-
)
597610
}
611+
612+
let value_len = if is_pair {
613+
// We've already emitted an error, fill in a placeholder value
614+
Some(bx.undef(self.type_isize()))
615+
} else {
616+
None
617+
};
618+
619+
(
620+
var_id.map(|var_id| var_id.with_type(var_ptr_spirv_type)),
621+
value_len,
622+
)
598623
};
599624

600625
// Compute call argument(s) to match what the Rust entry `fn` expects,

crates/rustc_codegen_spirv/src/spirv_type.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -342,9 +342,8 @@ impl SpirvType<'_> {
342342
| Self::AccelerationStructureKhr
343343
| Self::RayQueryKhr
344344
| Self::Sampler
345-
| Self::SampledImage { .. } => Size::from_bytes(4),
346-
347-
Self::InterfaceBlock { inner_type } => cx.lookup_type(inner_type).sizeof(cx)?,
345+
| Self::SampledImage { .. }
346+
| Self::InterfaceBlock { .. } => Size::from_bytes(4),
348347
};
349348
Some(result)
350349
}
@@ -372,9 +371,8 @@ impl SpirvType<'_> {
372371
| Self::AccelerationStructureKhr
373372
| Self::RayQueryKhr
374373
| Self::Sampler
375-
| Self::SampledImage { .. } => Align::from_bytes(4).unwrap(),
376-
377-
Self::InterfaceBlock { inner_type } => cx.lookup_type(inner_type).alignof(cx),
374+
| Self::SampledImage { .. }
375+
| Self::InterfaceBlock { .. } => Align::from_bytes(4).unwrap(),
378376
}
379377
}
380378

crates/rustc_codegen_spirv/src/symbols.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,10 @@ impl Symbols {
342342
"runtime_array",
343343
SpirvAttribute::IntrinsicType(IntrinsicType::RuntimeArray),
344344
),
345+
(
346+
"typed_buffer",
347+
SpirvAttribute::IntrinsicType(IntrinsicType::TypedBuffer),
348+
),
345349
(
346350
"matrix",
347351
SpirvAttribute::IntrinsicType(IntrinsicType::Matrix),

crates/spirv-std/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,13 +107,15 @@ mod runtime_array;
107107
mod sampler;
108108
pub mod scalar;
109109
pub(crate) mod sealed;
110+
mod typed_buffer;
110111
pub mod vector;
111112

112113
pub use self::sampler::Sampler;
113114
pub use crate::macros::Image;
114115
pub use byte_addressable_buffer::ByteAddressableBuffer;
115116
pub use num_traits;
116117
pub use runtime_array::*;
118+
pub use typed_buffer::*;
117119

118120
pub use glam;
119121

0 commit comments

Comments
 (0)