Skip to content

Commit 964dc8e

Browse files
authored
go: restore handles on incomplete writes (#1470)
We now keep track of any resource/stream/future handles we've lowered while writing to a stream or future and restore any that were unwritten so they can be used (e.g. possibly written) again. While I was working on this, clippy pointed out that `std::io::pipe` (which I started using in a previous commit) was added in Rust 1.87, so I've bumped the MSRV to match. Also, the test case I added to cover this revealed another bug, which I've fixed here: we weren't generating valid code for async functions which return tuples. Fixes #1458 Signed-off-by: Joel Dice <[email protected]>
1 parent c86323d commit 964dc8e

File tree

9 files changed

+466
-19
lines changed

9 files changed

+466
-19
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ edition = "2024"
1919
version = "0.49.0"
2020
license = "Apache-2.0 WITH LLVM-exception OR Apache-2.0 OR MIT"
2121
repository = "https://github.com/bytecodealliance/wit-bindgen"
22-
rust-version = "1.85.0"
22+
rust-version = "1.87.0"
2323

2424
[workspace.dependencies]
2525
anyhow = "1.0.72"

crates/go/src/lib.rs

Lines changed: 86 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,7 @@ impl Go {
420420
false,
421421
imported_type,
422422
);
423+
generator.collect_lifters = true;
423424

424425
let lift_result =
425426
abi::lift_from_memory(resolve, &mut generator, "src".to_string(), &ty);
@@ -432,6 +433,21 @@ impl Go {
432433
"value".to_string(),
433434
&ty,
434435
);
436+
437+
let lifter_count = generator.lifter_count;
438+
let (prefix, suffix) = if lifter_count > 0 {
439+
(
440+
format!("lifters := make([]func(), 0, {lifter_count})\n"),
441+
"\nreturn func() {
442+
for _, lifter := range lifters {
443+
lifter()
444+
}
445+
}",
446+
)
447+
} else {
448+
(String::new(), "\nreturn func() {}")
449+
};
450+
435451
let lower = mem::take(&mut generator.src);
436452
data.extend(InterfaceData::from_generator_and_code(
437453
generator,
@@ -448,8 +464,12 @@ impl Go {
448464
),
449465
format!("wasm_{kind}_lift_{snake}"),
450466
format!(
451-
"func wasm_{kind}_lower_{snake}(pinner *runtime.Pinner, value {payload}, dst unsafe.Pointer) {{
452-
{lower}
467+
"func wasm_{kind}_lower_{snake}(
468+
pinner *runtime.Pinner,
469+
value {payload},
470+
dst unsafe.Pointer,
471+
) func() {{
472+
{prefix}{lower}{suffix}
453473
}}
454474
"
455475
),
@@ -1014,15 +1034,33 @@ impl Go {
10141034
.collect::<Vec<_>>()
10151035
.join(", ");
10161036

1017-
let lift = if let Some(result) = func.result {
1037+
let lift = if let Some(ty) = func.result {
10181038
let result = abi::lift_from_memory(
10191039
resolve,
10201040
&mut generator,
10211041
IMPORT_RETURN_AREA.to_string(),
1022-
&result,
1042+
&ty,
10231043
);
10241044
let code = mem::take(&mut generator.src);
1025-
format!("{code}\nreturn {result}")
1045+
if let Type::Id(ty) = ty
1046+
&& let TypeDefKind::Tuple(tuple) = &resolve.types[ty].kind
1047+
{
1048+
let count = tuple.types.len();
1049+
let tuple = generator.locals.tmp("tuple");
1050+
1051+
let results = (0..count)
1052+
.map(|index| format!("{tuple}.F{index}"))
1053+
.collect::<Vec<_>>()
1054+
.join(", ");
1055+
1056+
format!(
1057+
"{code}
1058+
{tuple} := {result}
1059+
return {results}"
1060+
)
1061+
} else {
1062+
format!("{code}\nreturn {result}")
1063+
}
10261064
} else {
10271065
String::new()
10281066
};
@@ -1403,6 +1441,8 @@ struct FunctionGenerator<'a> {
14031441
need_unsafe: bool,
14041442
need_pinner: bool,
14051443
need_math: bool,
1444+
collect_lifters: bool,
1445+
lifter_count: u32,
14061446
return_area_size: ArchitectureSize,
14071447
return_area_align: Alignment,
14081448
imports: BTreeSet<String>,
@@ -1441,6 +1481,8 @@ impl<'a> FunctionGenerator<'a> {
14411481
need_unsafe: false,
14421482
need_pinner: false,
14431483
need_math: false,
1484+
collect_lifters: false,
1485+
lifter_count: 0,
14441486
return_area_size: ArchitectureSize::default(),
14451487
return_area_align: Alignment::default(),
14461488
imports: BTreeSet::new(),
@@ -1744,13 +1786,18 @@ for index := 0; index < int({length}); index++ {{
17441786
&& let TypeDefKind::Tuple(tuple) = &resolve.types[ty].kind
17451787
{
17461788
let count = tuple.types.len();
1789+
let tuple = self.locals.tmp("tuple");
17471790

17481791
let results = (0..count)
1749-
.map(|index| format!("({result}).F{index}"))
1792+
.map(|index| format!("{tuple}.F{index}"))
17501793
.collect::<Vec<_>>()
17511794
.join(", ");
17521795

1753-
uwriteln!(self.src, "return {results}");
1796+
uwriteln!(
1797+
self.src,
1798+
"{tuple} := {result}
1799+
return {results}"
1800+
);
17541801
} else {
17551802
uwriteln!(self.src, "return {result}");
17561803
}
@@ -2275,7 +2322,25 @@ default:
22752322
| Instruction::HandleLower {
22762323
handle: Handle::Own(_),
22772324
..
2278-
} => results.push(format!("({}).TakeHandle()", operands[0])),
2325+
} => {
2326+
let op = &operands[0];
2327+
if self.collect_lifters {
2328+
self.lifter_count += 1;
2329+
let resource = self.locals.tmp("resource");
2330+
let handle = self.locals.tmp("handle");
2331+
uwriteln!(
2332+
self.src,
2333+
"{resource} := {op}
2334+
{handle} := {resource}.TakeHandle()
2335+
lifters = append(lifters, func() {{
2336+
{resource}.SetHandle({handle})
2337+
}})"
2338+
);
2339+
results.push(handle)
2340+
} else {
2341+
results.push(format!("({op}).TakeHandle()"))
2342+
}
2343+
}
22792344
Instruction::HandleLower {
22802345
handle: Handle::Borrow(_),
22812346
..
@@ -2469,6 +2534,10 @@ func (self *{camel}) TakeHandle() int32 {{
24692534
return self.handle.Take()
24702535
}}
24712536
2537+
func (self *{camel}) SetHandle(handle int32) {{
2538+
self.handle.Set(handle)
2539+
}}
2540+
24722541
func (self *{camel}) Handle() int32 {{
24732542
return self.handle.Use()
24742543
}}
@@ -2525,6 +2594,12 @@ func (self *{camel}) TakeHandle() int32 {{
25252594
return self.handle
25262595
}}
25272596
2597+
func (self *{camel}) SetHandle(handle int32) {{
2598+
if self.handle != handle {{
2599+
panic("invalid handle")
2600+
}}
2601+
}}
2602+
25282603
func (self *{camel}) Drop() {{
25292604
handle := self.handle
25302605
if self.handle != 0 {{
@@ -3011,12 +3086,12 @@ fn func_declaration(resolve: &Resolve, func: &Function) -> (String, bool) {
30113086
}
30123087

30133088
fn maybe_gofmt<'a>(format: Format, code: &'a [u8]) -> Cow<'a, [u8]> {
3014-
return thread::scope(|s| {
3089+
thread::scope(|s| {
30153090
if let Format::True = format
30163091
&& let Ok((reader, mut writer)) = io::pipe()
30173092
{
30183093
s.spawn(move || {
3019-
_ = writer.write_all(&code);
3094+
_ = writer.write_all(code);
30203095
});
30213096

30223097
if let Ok(output) = Command::new("gofmt").stdin(reader).output()
@@ -3027,5 +3102,5 @@ fn maybe_gofmt<'a>(format: Format, code: &'a [u8]) -> Cow<'a, [u8]> {
30273102
}
30283103

30293104
Cow::Borrowed(code)
3030-
});
3105+
})
30313106
}

crates/go/src/wit_future.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ type FutureVtable[T any] struct {
1717
DropReadable func(handle int32)
1818
DropWritable func(handle int32)
1919
Lift func(src unsafe.Pointer) T
20-
Lower func(pinner *runtime.Pinner, value T, dst unsafe.Pointer)
20+
Lower func(pinner *runtime.Pinner, value T, dst unsafe.Pointer) func()
2121
}
2222

2323
type FutureReader[T any] struct {
@@ -63,6 +63,10 @@ func (self *FutureReader[T]) TakeHandle() int32 {
6363
return self.handle.Take()
6464
}
6565

66+
func (self *FutureReader[T]) SetHandle(handle int32) {
67+
self.handle.Set(handle)
68+
}
69+
6670
func MakeFutureReader[T any](vtable *FutureVtable[T], handleValue int32) *FutureReader[T] {
6771
handle := wit_runtime.MakeHandle(handleValue)
6872
value := &FutureReader[T]{vtable, handle}
@@ -87,24 +91,26 @@ func (self *FutureWriter[T]) Write(item T) bool {
8791
pinner := runtime.Pinner{}
8892
defer pinner.Unpin()
8993

94+
var lifter func()
9095
var buffer unsafe.Pointer
9196
if self.vtable.Lower == nil {
9297
buffer = unsafe.Pointer(unsafe.SliceData([]T{item}))
9398
pinner.Pin(buffer)
9499
} else {
95100
buffer = wit_runtime.Allocate(&pinner, uintptr(self.vtable.Size), uintptr(self.vtable.Align))
96-
self.vtable.Lower(&pinner, item, buffer)
101+
lifter = self.vtable.Lower(&pinner, item, buffer)
97102
}
98103

99104
code, _ := wit_async.FutureOrStreamWait(self.vtable.Write(handle, buffer), handle)
100105

101-
// TODO: restore handles to any unwritten resources, streams, or futures
102-
103106
switch code {
104107
case wit_async.RETURN_CODE_COMPLETED:
105108
return true
106109

107110
case wit_async.RETURN_CODE_DROPPED:
111+
if lifter != nil {
112+
lifter()
113+
}
108114
return false
109115

110116
default:

crates/go/src/wit_runtime.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,16 @@ func (self *Handle) Take() int32 {
2626
return value
2727
}
2828

29+
func (self *Handle) Set(value int32) {
30+
if value == 0 {
31+
panic("nil handle")
32+
}
33+
if self.value != 0 {
34+
panic("handle already set")
35+
}
36+
self.value = value
37+
}
38+
2939
func (self *Handle) TakeOrNil() int32 {
3040
value := self.value
3141
self.value = 0

crates/go/src/wit_stream.go

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ type StreamVtable[T any] struct {
1717
DropReadable func(handle int32)
1818
DropWritable func(handle int32)
1919
Lift func(src unsafe.Pointer) T
20-
Lower func(pinner *runtime.Pinner, value T, dst unsafe.Pointer)
20+
Lower func(pinner *runtime.Pinner, value T, dst unsafe.Pointer) func()
2121
}
2222

2323
type StreamReader[T any] struct {
@@ -78,6 +78,10 @@ func (self *StreamReader[T]) TakeHandle() int32 {
7878
return self.handle.Take()
7979
}
8080

81+
func (self *StreamReader[T]) SetHandle(handle int32) {
82+
self.handle.Set(handle)
83+
}
84+
8185
func MakeStreamReader[T any](vtable *StreamVtable[T], handleValue int32) *StreamReader[T] {
8286
handle := wit_runtime.MakeHandle(handleValue)
8387
value := &StreamReader[T]{vtable, handle, false}
@@ -112,24 +116,33 @@ func (self *StreamWriter[T]) Write(items []T) uint32 {
112116

113117
writeCount := uint32(len(items))
114118

119+
var lifters []func()
115120
var buffer unsafe.Pointer
116121
if self.vtable.Lower == nil {
117122
buffer = unsafe.Pointer(unsafe.SliceData(items))
118123
pinner.Pin(buffer)
119124
} else {
125+
lifters = make([]func(), 0, writeCount)
120126
buffer = wit_runtime.Allocate(
121127
&pinner,
122128
uintptr(self.vtable.Size*writeCount),
123129
uintptr(self.vtable.Align),
124130
)
125131
for index, item := range items {
126-
self.vtable.Lower(&pinner, item, unsafe.Add(buffer, index*int(self.vtable.Size)))
132+
lifters = append(
133+
lifters,
134+
self.vtable.Lower(&pinner, item, unsafe.Add(buffer, index*int(self.vtable.Size))),
135+
)
127136
}
128137
}
129138

130139
code, count := wit_async.FutureOrStreamWait(self.vtable.Write(handle, buffer, writeCount), handle)
131140

132-
// TODO: restore handles to any unwritten resources, streams, or futures
141+
if lifters != nil && count < writeCount {
142+
for _, lifter := range lifters[count:] {
143+
lifter()
144+
}
145+
}
133146

134147
if code == wit_async.RETURN_CODE_DROPPED {
135148
self.readerDropped = true
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package export_my_test_leaf_interface
2+
3+
import "runtime"
4+
5+
type LeafThing struct {
6+
pinner runtime.Pinner
7+
handle int32
8+
value string
9+
}
10+
11+
func (self *LeafThing) Get() string {
12+
return self.value
13+
}
14+
15+
func (self *LeafThing) OnDrop() {}
16+
17+
func MakeLeafThing(value string) *LeafThing {
18+
return &LeafThing{runtime.Pinner{}, 0, value}
19+
}

0 commit comments

Comments
 (0)