diff --git a/crates/guest-rust/rt/src/async_support/stream_support.rs b/crates/guest-rust/rt/src/async_support/stream_support.rs index ec2ed596f..6e7ee7049 100644 --- a/crates/guest-rust/rt/src/async_support/stream_support.rs +++ b/crates/guest-rust/rt/src/async_support/stream_support.rs @@ -28,11 +28,11 @@ fn ceiling(x: usize, y: usize) -> usize { #[doc(hidden)] pub struct StreamVtable { - pub write: fn(future: u32, values: &[T]) -> Pin>>>, + pub write: fn(future: u32, values: &[T]) -> Pin> + '_>>, pub read: fn( future: u32, values: &mut [MaybeUninit], - ) -> Pin>>>, + ) -> Pin> + '_>>, pub cancel_write: fn(future: u32), pub cancel_read: fn(future: u32), pub close_writable: fn(future: u32), diff --git a/crates/rust/src/bindgen.rs b/crates/rust/src/bindgen.rs index 693dc2dd9..b06ea9304 100644 --- a/crates/rust/src/bindgen.rs +++ b/crates/rust/src/bindgen.rs @@ -475,7 +475,7 @@ impl Bindgen for FunctionBindgen<'_, '_> { .as_ref() .map(|ty| { self.gen - .full_type_name_owned(ty, Identifier::StreamOrFuturePayload) + .type_name_owned_with_id(ty, Identifier::StreamOrFuturePayload) }) .unwrap_or_else(|| "()".into()); let ordinal = self.gen.gen.future_payloads.get_index_of(&name).unwrap(); @@ -496,7 +496,7 @@ impl Bindgen for FunctionBindgen<'_, '_> { let op = &operands[0]; let name = self .gen - .full_type_name_owned(payload, Identifier::StreamOrFuturePayload); + .type_name_owned_with_id(payload, Identifier::StreamOrFuturePayload); let ordinal = self.gen.gen.stream_payloads.get_index_of(&name).unwrap(); let path = self.gen.path_to_root(); results.push(format!( diff --git a/crates/rust/src/interface.rs b/crates/rust/src/interface.rs index 192dd80ac..93e1c6063 100644 --- a/crates/rust/src/interface.rs +++ b/crates/rust/src/interface.rs @@ -483,6 +483,8 @@ macro_rules! {macro_name} {{ } fn generate_payloads(&mut self, prefix: &str, func: &Function, interface: Option<&WorldKey>) { + let old_identifier = mem::replace(&mut self.identifier, Identifier::StreamOrFuturePayload); + for (index, ty) in func .find_futures_and_streams(self.resolve) .into_iter() @@ -500,7 +502,7 @@ macro_rules! {macro_name} {{ match &self.resolve.types[ty].kind { TypeDefKind::Future(payload_type) => { let name = if let Some(payload_type) = payload_type { - self.full_type_name_owned(payload_type, Identifier::StreamOrFuturePayload) + self.type_name_owned(payload_type) } else { "()".into() }; @@ -533,7 +535,7 @@ macro_rules! {macro_name} {{ (String::new(), "let value = ();\n".into()) }; - let box_ = format!("super::super::{}", self.path_to_box()); + let box_ = self.path_to_box(); let code = format!( r#" #[doc(hidden)] @@ -545,7 +547,7 @@ pub mod vtable{ordinal} {{ }} #[cfg(target_arch = "wasm32")] - {{ + {box_}::pin(async move {{ #[repr(align({align}))] struct Buffer([::core::mem::MaybeUninit::; {size}]); let mut buffer = Buffer([::core::mem::MaybeUninit::uninit(); {size}]); @@ -558,10 +560,8 @@ pub mod vtable{ordinal} {{ fn wit_import(_: u32, _: *mut u8) -> u32; }} - {box_}::pin(async move {{ - unsafe {{ {async_support}::await_future_result(wit_import, future, address).await }} - }}) - }} + unsafe {{ {async_support}::await_future_result(wit_import, future, address).await }} + }}) }} fn read(future: u32) -> ::core::pin::Pin<{box_}>>> {{ @@ -571,7 +571,7 @@ pub mod vtable{ordinal} {{ }} #[cfg(target_arch = "wasm32")] - {{ + {box_}::pin(async move {{ struct Buffer([::core::mem::MaybeUninit::; {size}]); let mut buffer = Buffer([::core::mem::MaybeUninit::uninit(); {size}]); let address = buffer.0.as_mut_ptr() as *mut u8; @@ -582,15 +582,13 @@ pub mod vtable{ordinal} {{ fn wit_import(_: u32, _: *mut u8) -> u32; }} - {box_}::pin(async move {{ - if unsafe {{ {async_support}::await_future_result(wit_import, future, address).await }} {{ - {lift} - Some(value) - }} else {{ - None - }} - }}) - }} + if unsafe {{ {async_support}::await_future_result(wit_import, future, address).await }} {{ + {lift} + Some(value) + }} else {{ + None + }} + }}) }} fn cancel_write(writer: u32) {{ @@ -691,8 +689,7 @@ pub mod vtable{ordinal} {{ } } TypeDefKind::Stream(payload_type) => { - let name = - self.full_type_name_owned(payload_type, Identifier::StreamOrFuturePayload); + let name = self.type_name_owned(payload_type); if !self.gen.stream_payloads.contains_key(&name) { let ordinal = self.gen.stream_payloads.len(); @@ -747,19 +744,19 @@ for (index, dst) in values.iter_mut().take(count).enumerate() {{ (address.clone(), lower, address, lift) }; - let box_ = format!("super::super::{}", self.path_to_box()); + let box_ = self.path_to_box(); let code = format!( r#" #[doc(hidden)] pub mod vtable{ordinal} {{ - fn write(stream: u32, values: &[{name}]) -> ::core::pin::Pin<{box_}>>> {{ + fn write(stream: u32, values: &[{name}]) -> ::core::pin::Pin<{box_}> + '_>> {{ #[cfg(not(target_arch = "wasm32"))] {{ unreachable!(); }} #[cfg(target_arch = "wasm32")] - {{ + {box_}::pin(async move {{ {lower_address} {lower} @@ -769,27 +766,25 @@ pub mod vtable{ordinal} {{ fn wit_import(_: u32, _: *mut u8, _: u32) -> u32; }} - {box_}::pin(async move {{ - unsafe {{ - {async_support}::await_stream_result( - wit_import, - stream, - address, - u32::try_from(values.len()).unwrap() - ).await - }} - }}) - }} + unsafe {{ + {async_support}::await_stream_result( + wit_import, + stream, + address, + u32::try_from(values.len()).unwrap() + ).await + }} + }}) }} - fn read(stream: u32, values: &mut [::core::mem::MaybeUninit::<{name}>]) -> ::core::pin::Pin<{box_}>>> {{ + fn read(stream: u32, values: &mut [::core::mem::MaybeUninit::<{name}>]) -> ::core::pin::Pin<{box_}> + '_>> {{ #[cfg(not(target_arch = "wasm32"))] {{ unreachable!(); }} #[cfg(target_arch = "wasm32")] - {{ + {box_}::pin(async move {{ {lift_address} #[link(wasm_import_module = "{module}")] @@ -798,22 +793,20 @@ pub mod vtable{ordinal} {{ fn wit_import(_: u32, _: *mut u8, _: u32) -> u32; }} - {box_}::pin(async move {{ - let count = unsafe {{ - {async_support}::await_stream_result( - wit_import, - stream, - address, - u32::try_from(values.len()).unwrap() - ).await - }}; - #[allow(unused)] - if let Some(count) = count {{ - {lift} - }} - count - }}) - }} + let count = unsafe {{ + {async_support}::await_stream_result( + wit_import, + stream, + address, + u32::try_from(values.len()).unwrap() + ).await + }}; + #[allow(unused)] + if let Some(count) = count {{ + {lift} + }} + count + }}) }} fn cancel_write(writer: u32) {{ @@ -916,6 +909,8 @@ pub mod vtable{ordinal} {{ _ => unreachable!(), } } + + self.identifier = old_identifier; } fn generate_guest_import(&mut self, func: &Function, interface: Option<&WorldKey>) { @@ -1699,25 +1694,24 @@ pub mod vtable{ordinal} {{ } } - pub(crate) fn full_type_name_owned(&mut self, ty: &Type, id: Identifier<'i>) -> String { - self.full_type_name( + pub(crate) fn type_name_owned_with_id(&mut self, ty: &Type, id: Identifier<'i>) -> String { + let old_identifier = mem::replace(&mut self.identifier, id); + let name = self.type_name_owned(ty); + self.identifier = old_identifier; + name + } + + fn type_name_owned(&mut self, ty: &Type) -> String { + self.type_name( ty, TypeMode { lifetime: None, lists_borrowed: false, style: TypeOwnershipStyle::Owned, }, - id, ) } - fn full_type_name(&mut self, ty: &Type, mode: TypeMode, id: Identifier<'i>) -> String { - let old_identifier = mem::replace(&mut self.identifier, id); - let name = self.type_name(ty, mode); - self.identifier = old_identifier; - name - } - fn type_name(&mut self, ty: &Type, mode: TypeMode) -> String { let old = mem::take(&mut self.src); self.print_ty(ty, mode); diff --git a/tests/codegen/streams.wit b/tests/codegen/streams.wit index fd00239b7..7ed696ed8 100644 --- a/tests/codegen/streams.wit +++ b/tests/codegen/streams.wit @@ -1,5 +1,19 @@ package foo:foo; +interface transmit { + variant control { + read-stream(string), + read-future(string), + write-stream(string), + write-future(string), + } + + exchange: func(control: stream, + caller-stream: stream, + caller-future1: future, + caller-future2: future) -> tuple, future, future>; +} + interface streams { stream-u8-param: func(x: stream); stream-u16-param: func(x: stream); @@ -82,4 +96,5 @@ interface streams { world the-streams { import streams; export streams; + export transmit; }