diff --git a/Cargo.toml b/Cargo.toml index 19c3f851a..d8345a38a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,7 +19,7 @@ edition = "2024" version = "0.49.0" license = "Apache-2.0 WITH LLVM-exception OR Apache-2.0 OR MIT" repository = "https://github.com/bytecodealliance/wit-bindgen" -rust-version = "1.85.0" +rust-version = "1.87.0" [workspace.dependencies] anyhow = "1.0.72" diff --git a/crates/go/src/lib.rs b/crates/go/src/lib.rs index 469632187..97266ad05 100644 --- a/crates/go/src/lib.rs +++ b/crates/go/src/lib.rs @@ -420,6 +420,7 @@ impl Go { false, imported_type, ); + generator.collect_lifters = true; let lift_result = abi::lift_from_memory(resolve, &mut generator, "src".to_string(), &ty); @@ -432,6 +433,21 @@ impl Go { "value".to_string(), &ty, ); + + let lifter_count = generator.lifter_count; + let (prefix, suffix) = if lifter_count > 0 { + ( + format!("lifters := make([]func(), 0, {lifter_count})\n"), + "\nreturn func() { + for _, lifter := range lifters { + lifter() + } +}", + ) + } else { + (String::new(), "\nreturn func() {}") + }; + let lower = mem::take(&mut generator.src); data.extend(InterfaceData::from_generator_and_code( generator, @@ -448,8 +464,12 @@ impl Go { ), format!("wasm_{kind}_lift_{snake}"), format!( - "func wasm_{kind}_lower_{snake}(pinner *runtime.Pinner, value {payload}, dst unsafe.Pointer) {{ - {lower} + "func wasm_{kind}_lower_{snake}( + pinner *runtime.Pinner, + value {payload}, + dst unsafe.Pointer, +) func() {{ + {prefix}{lower}{suffix} }} " ), @@ -1014,15 +1034,33 @@ impl Go { .collect::>() .join(", "); - let lift = if let Some(result) = func.result { + let lift = if let Some(ty) = func.result { let result = abi::lift_from_memory( resolve, &mut generator, IMPORT_RETURN_AREA.to_string(), - &result, + &ty, ); let code = mem::take(&mut generator.src); - format!("{code}\nreturn {result}") + if let Type::Id(ty) = ty + && let TypeDefKind::Tuple(tuple) = &resolve.types[ty].kind + { + let count = tuple.types.len(); + let tuple = generator.locals.tmp("tuple"); + + let results = (0..count) + .map(|index| format!("{tuple}.F{index}")) + .collect::>() + .join(", "); + + format!( + "{code} +{tuple} := {result} +return {results}" + ) + } else { + format!("{code}\nreturn {result}") + } } else { String::new() }; @@ -1403,6 +1441,8 @@ struct FunctionGenerator<'a> { need_unsafe: bool, need_pinner: bool, need_math: bool, + collect_lifters: bool, + lifter_count: u32, return_area_size: ArchitectureSize, return_area_align: Alignment, imports: BTreeSet, @@ -1441,6 +1481,8 @@ impl<'a> FunctionGenerator<'a> { need_unsafe: false, need_pinner: false, need_math: false, + collect_lifters: false, + lifter_count: 0, return_area_size: ArchitectureSize::default(), return_area_align: Alignment::default(), imports: BTreeSet::new(), @@ -1744,13 +1786,18 @@ for index := 0; index < int({length}); index++ {{ && let TypeDefKind::Tuple(tuple) = &resolve.types[ty].kind { let count = tuple.types.len(); + let tuple = self.locals.tmp("tuple"); let results = (0..count) - .map(|index| format!("({result}).F{index}")) + .map(|index| format!("{tuple}.F{index}")) .collect::>() .join(", "); - uwriteln!(self.src, "return {results}"); + uwriteln!( + self.src, + "{tuple} := {result} +return {results}" + ); } else { uwriteln!(self.src, "return {result}"); } @@ -2275,7 +2322,25 @@ default: | Instruction::HandleLower { handle: Handle::Own(_), .. - } => results.push(format!("({}).TakeHandle()", operands[0])), + } => { + let op = &operands[0]; + if self.collect_lifters { + self.lifter_count += 1; + let resource = self.locals.tmp("resource"); + let handle = self.locals.tmp("handle"); + uwriteln!( + self.src, + "{resource} := {op} +{handle} := {resource}.TakeHandle() +lifters = append(lifters, func() {{ + {resource}.SetHandle({handle}) +}})" + ); + results.push(handle) + } else { + results.push(format!("({op}).TakeHandle()")) + } + } Instruction::HandleLower { handle: Handle::Borrow(_), .. @@ -2469,6 +2534,10 @@ func (self *{camel}) TakeHandle() int32 {{ return self.handle.Take() }} +func (self *{camel}) SetHandle(handle int32) {{ + self.handle.Set(handle) +}} + func (self *{camel}) Handle() int32 {{ return self.handle.Use() }} @@ -2525,6 +2594,12 @@ func (self *{camel}) TakeHandle() int32 {{ return self.handle }} +func (self *{camel}) SetHandle(handle int32) {{ + if self.handle != handle {{ + panic("invalid handle") + }} +}} + func (self *{camel}) Drop() {{ handle := self.handle if self.handle != 0 {{ @@ -3011,12 +3086,12 @@ fn func_declaration(resolve: &Resolve, func: &Function) -> (String, bool) { } fn maybe_gofmt<'a>(format: Format, code: &'a [u8]) -> Cow<'a, [u8]> { - return thread::scope(|s| { + thread::scope(|s| { if let Format::True = format && let Ok((reader, mut writer)) = io::pipe() { s.spawn(move || { - _ = writer.write_all(&code); + _ = writer.write_all(code); }); if let Ok(output) = Command::new("gofmt").stdin(reader).output() @@ -3027,5 +3102,5 @@ fn maybe_gofmt<'a>(format: Format, code: &'a [u8]) -> Cow<'a, [u8]> { } Cow::Borrowed(code) - }); + }) } diff --git a/crates/go/src/wit_future.go b/crates/go/src/wit_future.go index d5b3ed9a0..720548c6b 100644 --- a/crates/go/src/wit_future.go +++ b/crates/go/src/wit_future.go @@ -17,7 +17,7 @@ type FutureVtable[T any] struct { DropReadable func(handle int32) DropWritable func(handle int32) Lift func(src unsafe.Pointer) T - Lower func(pinner *runtime.Pinner, value T, dst unsafe.Pointer) + Lower func(pinner *runtime.Pinner, value T, dst unsafe.Pointer) func() } type FutureReader[T any] struct { @@ -63,6 +63,10 @@ func (self *FutureReader[T]) TakeHandle() int32 { return self.handle.Take() } +func (self *FutureReader[T]) SetHandle(handle int32) { + self.handle.Set(handle) +} + func MakeFutureReader[T any](vtable *FutureVtable[T], handleValue int32) *FutureReader[T] { handle := wit_runtime.MakeHandle(handleValue) value := &FutureReader[T]{vtable, handle} @@ -87,24 +91,26 @@ func (self *FutureWriter[T]) Write(item T) bool { pinner := runtime.Pinner{} defer pinner.Unpin() + var lifter func() var buffer unsafe.Pointer if self.vtable.Lower == nil { buffer = unsafe.Pointer(unsafe.SliceData([]T{item})) pinner.Pin(buffer) } else { buffer = wit_runtime.Allocate(&pinner, uintptr(self.vtable.Size), uintptr(self.vtable.Align)) - self.vtable.Lower(&pinner, item, buffer) + lifter = self.vtable.Lower(&pinner, item, buffer) } code, _ := wit_async.FutureOrStreamWait(self.vtable.Write(handle, buffer), handle) - // TODO: restore handles to any unwritten resources, streams, or futures - switch code { case wit_async.RETURN_CODE_COMPLETED: return true case wit_async.RETURN_CODE_DROPPED: + if lifter != nil { + lifter() + } return false default: diff --git a/crates/go/src/wit_runtime.go b/crates/go/src/wit_runtime.go index ebfc6cc11..5313ff1f8 100644 --- a/crates/go/src/wit_runtime.go +++ b/crates/go/src/wit_runtime.go @@ -26,6 +26,16 @@ func (self *Handle) Take() int32 { return value } +func (self *Handle) Set(value int32) { + if value == 0 { + panic("nil handle") + } + if self.value != 0 { + panic("handle already set") + } + self.value = value +} + func (self *Handle) TakeOrNil() int32 { value := self.value self.value = 0 diff --git a/crates/go/src/wit_stream.go b/crates/go/src/wit_stream.go index 168343ee3..e2f58d225 100644 --- a/crates/go/src/wit_stream.go +++ b/crates/go/src/wit_stream.go @@ -17,7 +17,7 @@ type StreamVtable[T any] struct { DropReadable func(handle int32) DropWritable func(handle int32) Lift func(src unsafe.Pointer) T - Lower func(pinner *runtime.Pinner, value T, dst unsafe.Pointer) + Lower func(pinner *runtime.Pinner, value T, dst unsafe.Pointer) func() } type StreamReader[T any] struct { @@ -78,6 +78,10 @@ func (self *StreamReader[T]) TakeHandle() int32 { return self.handle.Take() } +func (self *StreamReader[T]) SetHandle(handle int32) { + self.handle.Set(handle) +} + func MakeStreamReader[T any](vtable *StreamVtable[T], handleValue int32) *StreamReader[T] { handle := wit_runtime.MakeHandle(handleValue) value := &StreamReader[T]{vtable, handle, false} @@ -112,24 +116,33 @@ func (self *StreamWriter[T]) Write(items []T) uint32 { writeCount := uint32(len(items)) + var lifters []func() var buffer unsafe.Pointer if self.vtable.Lower == nil { buffer = unsafe.Pointer(unsafe.SliceData(items)) pinner.Pin(buffer) } else { + lifters = make([]func(), 0, writeCount) buffer = wit_runtime.Allocate( &pinner, uintptr(self.vtable.Size*writeCount), uintptr(self.vtable.Align), ) for index, item := range items { - self.vtable.Lower(&pinner, item, unsafe.Add(buffer, index*int(self.vtable.Size))) + lifters = append( + lifters, + self.vtable.Lower(&pinner, item, unsafe.Add(buffer, index*int(self.vtable.Size))), + ) } } code, count := wit_async.FutureOrStreamWait(self.vtable.Write(handle, buffer, writeCount), handle) - // TODO: restore handles to any unwritten resources, streams, or futures + if lifters != nil && count < writeCount { + for _, lifter := range lifters[count:] { + lifter() + } + } if code == wit_async.RETURN_CODE_DROPPED { self.readerDropped = true diff --git a/tests/runtime-async/async/incomplete-writes/leaf.go b/tests/runtime-async/async/incomplete-writes/leaf.go new file mode 100644 index 000000000..995d165b7 --- /dev/null +++ b/tests/runtime-async/async/incomplete-writes/leaf.go @@ -0,0 +1,19 @@ +package export_my_test_leaf_interface + +import "runtime" + +type LeafThing struct { + pinner runtime.Pinner + handle int32 + value string +} + +func (self *LeafThing) Get() string { + return self.value +} + +func (self *LeafThing) OnDrop() {} + +func MakeLeafThing(value string) *LeafThing { + return &LeafThing{runtime.Pinner{}, 0, value} +} diff --git a/tests/runtime-async/async/incomplete-writes/runner.go b/tests/runtime-async/async/incomplete-writes/runner.go new file mode 100644 index 000000000..fbacd41ea --- /dev/null +++ b/tests/runtime-async/async/incomplete-writes/runner.go @@ -0,0 +1,151 @@ +package export_wit_world + +import ( + "fmt" + leaf "wit_component/my_test_leaf_interface" + test "wit_component/my_test_test_interface" +) + +func Run() { + { + tx, rx := test.MakeStreamTestThing() + defer tx.Drop() + defer rx.Drop() + + stream := test.ShortReadsTest(rx) + defer stream.Drop() + + // Write the things all at once. The callee will read them only + // one at a time, forcing us to re-take ownership of any + // unwritten items between writes. + tx.WriteAll([]*test.TestThing{ + test.MakeTestThing("a"), + test.MakeTestThing("b"), + test.MakeTestThing("c"), + }) + tx.Drop() + + things := []*test.TestThing{} + for !stream.WriterDropped() { + // Read just one item at a time, forcing the writer to + // re-take ownership of any unwritten items between + // writes. + buffer := make([]*test.TestThing, 1) + count := stream.Read(buffer) + if count == 1 { + things = append(things, buffer[0]) + } + } + + assertEqual(things[0].Get(), "a") + assertEqual(things[1].Get(), "b") + assertEqual(things[2].Get(), "c") + } + + { + tx, rx := test.MakeStreamMyTestLeafInterfaceLeafThing() + defer tx.Drop() + defer rx.Drop() + + stream := test.ShortReadsLeaf(rx) + defer stream.Drop() + + // Write the things all at once. The callee will read them only + // one at a time, forcing us to re-take ownership of any + // unwritten items between writes. + tx.WriteAll([]*leaf.LeafThing{ + leaf.MakeLeafThing("a"), + leaf.MakeLeafThing("b"), + leaf.MakeLeafThing("c"), + }) + tx.Drop() + + things := []*leaf.LeafThing{} + for !stream.WriterDropped() { + // Read just one item at a time, forcing the writer to + // re-take ownership of any unwritten items between + // writes. + buffer := make([]*leaf.LeafThing, 1) + count := stream.Read(buffer) + if count == 1 { + things = append(things, buffer[0]) + } + } + + assertEqual(things[0].Get(), "a") + assertEqual(things[1].Get(), "b") + assertEqual(things[2].Get(), "c") + } + + { + tx1, rx1 := test.MakeFutureTestThing() + tx2, rx2 := test.MakeFutureTestThing() + f1, f2 := test.DroppedReaderTest(rx1, rx2) + + { + // Write a thing to the first future, the read end of + // which the callee will drop without reading from, + // forcing us to re-take ownership. + thing := test.MakeTestThing("a") + assert(!tx1.Write(thing)) + + // Write it again to the second future. This time, the + // callee will read it. + assert(tx2.Write(thing)) + } + + { + // Drop the first future without reading from it. This + // will force the callee to re-take ownership of the + // thing it tried to write. + f1.Drop() + + // Read from the second future and assert it matches + // what we wrote above. + thing := f2.Read() + assertEqual(thing.Get(), "a") + } + } + + { + tx1, rx1 := test.MakeFutureMyTestLeafInterfaceLeafThing() + tx2, rx2 := test.MakeFutureMyTestLeafInterfaceLeafThing() + f1, f2 := test.DroppedReaderLeaf(rx1, rx2) + + { + // Write a thing to the first future, the read end of + // which the callee will drop without reading from, + // forcing us to re-take ownership. + thing := leaf.MakeLeafThing("a") + assert(!tx1.Write(thing)) + + // Write it again to the second future. This time, the + // callee will read it. + assert(tx2.Write(thing)) + } + + { + // Drop the first future without reading from it. This + // will force the callee to re-take ownership of the + // thing it tried to write. + f1.Drop() + + // Read from the second future and assert it matches + // what we wrote above. + thing := f2.Read() + assertEqual(thing.Get(), "a") + } + } +} + +func assertEqual[T comparable](a T, b T) { + if a != b { + panic(fmt.Sprintf("%v not equal to %v", a, b)) + } +} + +func assert(v bool) { + if !v { + panic("assertion failed") + } +} diff --git a/tests/runtime-async/async/incomplete-writes/test.go b/tests/runtime-async/async/incomplete-writes/test.go new file mode 100644 index 000000000..fa75b8521 --- /dev/null +++ b/tests/runtime-async/async/incomplete-writes/test.go @@ -0,0 +1,135 @@ +package export_my_test_test_interface + +import ( + "runtime" + . "wit_component/my_test_test_interface" + . "wit_component/wit_types" +) + +type TestThing struct { + pinner runtime.Pinner + handle int32 + value string +} + +func (self *TestThing) Get() string { + return self.value +} + +func (self *TestThing) OnDrop() {} + +func MakeTestThing(value string) *TestThing { + return &TestThing{runtime.Pinner{}, 0, value} +} + +func ShortReadsTest(stream *StreamReader[*TestThing]) *StreamReader[*TestThing] { + tx, rx := MakeStreamTestThing() + + go func() { + defer stream.Drop() + defer tx.Drop() + + things := []*TestThing{} + for !stream.WriterDropped() { + // Read just one item at a time, forcing the writer to + // re-take ownership of any unwritten items between + // writes. + buffer := make([]*TestThing, 1) + count := stream.Read(buffer) + if count == 1 { + things = append(things, buffer[0]) + } + } + + // Write the things all at once. The caller will read them only + // one at a time, forcing us to re-take ownership of any + // unwritten items between writes. + tx.WriteAll(things) + }() + + return rx +} + +func ShortReadsLeaf(stream *StreamReader[*LeafThing]) *StreamReader[*LeafThing] { + tx, rx := MakeStreamMyTestLeafInterfaceLeafThing() + + go func() { + defer stream.Drop() + defer tx.Drop() + + things := []*LeafThing{} + for !stream.WriterDropped() { + // Read just one item at a time, forcing the writer to + // re-take ownership of any unwritten items between + // writes. + buffer := make([]*LeafThing, 1) + count := stream.Read(buffer) + if count == 1 { + things = append(things, buffer[0]) + } + } + + // Write the things all at once. The caller will read them only + // one at a time, forcing us to re-take ownership of any + // unwritten items between writes. + tx.WriteAll(things) + }() + + return rx +} + +func DroppedReaderTest(f1, f2 *FutureReader[*TestThing]) (*FutureReader[*TestThing], *FutureReader[*TestThing]) { + tx1, rx1 := MakeFutureTestThing() + tx2, rx2 := MakeFutureTestThing() + + go func() { + // Drop the first future without reading from it. This will + // force the callee to re-take ownership of the thing it tried + // to write. + f1.Drop() + + thing := f2.Read() + + // Write the thing to the first future, the read end of which + // the calle4 will drop without reading from, forcing us to + // re-take ownership. + assert(!tx1.Write(thing)) + + // Write it again to the second future. This time, the caller + // will read it. + assert(tx2.Write(thing)) + }() + + return rx1, rx2 +} + +func DroppedReaderLeaf(f1, f2 *FutureReader[*LeafThing]) (*FutureReader[*LeafThing], *FutureReader[*LeafThing]) { + tx1, rx1 := MakeFutureMyTestLeafInterfaceLeafThing() + tx2, rx2 := MakeFutureMyTestLeafInterfaceLeafThing() + + go func() { + // Drop the first future without reading from it. This will + // force the callee to re-take ownership of the thing it tried + // to write. + f1.Drop() + + thing := f2.Read() + + // Write the thing to the first future, the read end of which + // the calle4 will drop without reading from, forcing us to + // re-take ownership. + assert(!tx1.Write(thing)) + + // Write it again to the second future. This time, the caller + // will read it. + assert(tx2.Write(thing)) + }() + + return rx1, rx2 +} + +func assert(v bool) { + if !v { + panic("assertion failed") + } +} diff --git a/tests/runtime-async/async/incomplete-writes/test.wit b/tests/runtime-async/async/incomplete-writes/test.wit new file mode 100644 index 000000000..6e3604133 --- /dev/null +++ b/tests/runtime-async/async/incomplete-writes/test.wit @@ -0,0 +1,38 @@ +//@ dependencies = ['test', 'leaf'] + +package my:test; + +interface leaf-interface { + resource leaf-thing { + constructor(s: string); + get: func() -> string; + } +} + +interface test-interface { + use leaf-interface.{leaf-thing}; + + resource test-thing { + constructor(s: string); + get: func() -> string; + } + + short-reads-test: async func(s: stream) -> stream; + short-reads-leaf: async func(s: stream) -> stream; + + dropped-reader-test: async func(f1: future, f2: future) -> tuple, future>; + dropped-reader-leaf: async func(f1: future, f2: future) -> tuple, future>; +} + +world leaf { + export leaf-interface; +} + +world test { + export test-interface; +} + +world runner { + import test-interface; + export run: async func(); +}