Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions crates/guest-rust/rt/src/async_support/stream_support.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ fn ceiling(x: usize, y: usize) -> usize {

#[doc(hidden)]
pub struct StreamVtable<T> {
pub write: fn(future: u32, values: &[T]) -> Pin<Box<dyn Future<Output = Option<usize>>>>,
pub write: fn(future: u32, values: &[T]) -> Pin<Box<dyn Future<Output = Option<usize>> + '_>>,
pub read: fn(
future: u32,
values: &mut [MaybeUninit<T>],
) -> Pin<Box<dyn Future<Output = Option<usize>>>>,
) -> Pin<Box<dyn Future<Output = Option<usize>> + '_>>,
pub cancel_write: fn(future: u32),
pub cancel_read: fn(future: u32),
pub close_writable: fn(future: u32),
Expand Down
4 changes: 2 additions & 2 deletions crates/rust/src/bindgen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ impl Bindgen for FunctionBindgen<'_, '_> {
.as_ref()
.map(|ty| {
self.gen
.full_type_name_owned(ty, Identifier::StreamOrFuturePayload)
.type_name_owned_with_id(ty, Identifier::StreamOrFuturePayload)
})
.unwrap_or_else(|| "()".into());
let ordinal = self.gen.gen.future_payloads.get_index_of(&name).unwrap();
Expand All @@ -496,7 +496,7 @@ impl Bindgen for FunctionBindgen<'_, '_> {
let op = &operands[0];
let name = self
.gen
.full_type_name_owned(payload, Identifier::StreamOrFuturePayload);
.type_name_owned_with_id(payload, Identifier::StreamOrFuturePayload);
let ordinal = self.gen.gen.stream_payloads.get_index_of(&name).unwrap();
let path = self.gen.path_to_root();
results.push(format!(
Expand Down
116 changes: 55 additions & 61 deletions crates/rust/src/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,8 @@ macro_rules! {macro_name} {{
}

fn generate_payloads(&mut self, prefix: &str, func: &Function, interface: Option<&WorldKey>) {
let old_identifier = mem::replace(&mut self.identifier, Identifier::StreamOrFuturePayload);

for (index, ty) in func
.find_futures_and_streams(self.resolve)
.into_iter()
Expand All @@ -500,7 +502,7 @@ macro_rules! {macro_name} {{
match &self.resolve.types[ty].kind {
TypeDefKind::Future(payload_type) => {
let name = if let Some(payload_type) = payload_type {
self.full_type_name_owned(payload_type, Identifier::StreamOrFuturePayload)
self.type_name_owned(payload_type)
} else {
"()".into()
};
Expand Down Expand Up @@ -533,7 +535,7 @@ macro_rules! {macro_name} {{
(String::new(), "let value = ();\n".into())
};

let box_ = format!("super::super::{}", self.path_to_box());
let box_ = self.path_to_box();
let code = format!(
r#"
#[doc(hidden)]
Expand All @@ -545,7 +547,7 @@ pub mod vtable{ordinal} {{
}}

#[cfg(target_arch = "wasm32")]
{{
{box_}::pin(async move {{
#[repr(align({align}))]
struct Buffer([::core::mem::MaybeUninit::<u8>; {size}]);
let mut buffer = Buffer([::core::mem::MaybeUninit::uninit(); {size}]);
Expand All @@ -558,10 +560,8 @@ pub mod vtable{ordinal} {{
fn wit_import(_: u32, _: *mut u8) -> u32;
}}

{box_}::pin(async move {{
unsafe {{ {async_support}::await_future_result(wit_import, future, address).await }}
}})
}}
unsafe {{ {async_support}::await_future_result(wit_import, future, address).await }}
}})
}}

fn read(future: u32) -> ::core::pin::Pin<{box_}<dyn ::core::future::Future<Output = Option<{name}>>>> {{
Expand All @@ -571,7 +571,7 @@ pub mod vtable{ordinal} {{
}}

#[cfg(target_arch = "wasm32")]
{{
{box_}::pin(async move {{
struct Buffer([::core::mem::MaybeUninit::<u8>; {size}]);
let mut buffer = Buffer([::core::mem::MaybeUninit::uninit(); {size}]);
let address = buffer.0.as_mut_ptr() as *mut u8;
Expand All @@ -582,15 +582,13 @@ pub mod vtable{ordinal} {{
fn wit_import(_: u32, _: *mut u8) -> u32;
}}

{box_}::pin(async move {{
if unsafe {{ {async_support}::await_future_result(wit_import, future, address).await }} {{
{lift}
Some(value)
}} else {{
None
}}
}})
}}
if unsafe {{ {async_support}::await_future_result(wit_import, future, address).await }} {{
{lift}
Some(value)
}} else {{
None
}}
}})
}}

fn cancel_write(writer: u32) {{
Expand Down Expand Up @@ -691,8 +689,7 @@ pub mod vtable{ordinal} {{
}
}
TypeDefKind::Stream(payload_type) => {
let name =
self.full_type_name_owned(payload_type, Identifier::StreamOrFuturePayload);
let name = self.type_name_owned(payload_type);

if !self.gen.stream_payloads.contains_key(&name) {
let ordinal = self.gen.stream_payloads.len();
Expand Down Expand Up @@ -747,19 +744,19 @@ for (index, dst) in values.iter_mut().take(count).enumerate() {{
(address.clone(), lower, address, lift)
};

let box_ = format!("super::super::{}", self.path_to_box());
let box_ = self.path_to_box();
let code = format!(
r#"
#[doc(hidden)]
pub mod vtable{ordinal} {{
fn write(stream: u32, values: &[{name}]) -> ::core::pin::Pin<{box_}<dyn ::core::future::Future<Output = Option<usize>>>> {{
fn write(stream: u32, values: &[{name}]) -> ::core::pin::Pin<{box_}<dyn ::core::future::Future<Output = Option<usize>> + '_>> {{
#[cfg(not(target_arch = "wasm32"))]
{{
unreachable!();
}}

#[cfg(target_arch = "wasm32")]
{{
{box_}::pin(async move {{
{lower_address}
{lower}

Expand All @@ -769,27 +766,25 @@ pub mod vtable{ordinal} {{
fn wit_import(_: u32, _: *mut u8, _: u32) -> u32;
}}

{box_}::pin(async move {{
unsafe {{
{async_support}::await_stream_result(
wit_import,
stream,
address,
u32::try_from(values.len()).unwrap()
).await
}}
}})
}}
unsafe {{
{async_support}::await_stream_result(
wit_import,
stream,
address,
u32::try_from(values.len()).unwrap()
).await
}}
}})
}}

fn read(stream: u32, values: &mut [::core::mem::MaybeUninit::<{name}>]) -> ::core::pin::Pin<{box_}<dyn ::core::future::Future<Output = Option<usize>>>> {{
fn read(stream: u32, values: &mut [::core::mem::MaybeUninit::<{name}>]) -> ::core::pin::Pin<{box_}<dyn ::core::future::Future<Output = Option<usize>> + '_>> {{
#[cfg(not(target_arch = "wasm32"))]
{{
unreachable!();
}}

#[cfg(target_arch = "wasm32")]
{{
{box_}::pin(async move {{
{lift_address}

#[link(wasm_import_module = "{module}")]
Expand All @@ -798,22 +793,20 @@ pub mod vtable{ordinal} {{
fn wit_import(_: u32, _: *mut u8, _: u32) -> u32;
}}

{box_}::pin(async move {{
let count = unsafe {{
{async_support}::await_stream_result(
wit_import,
stream,
address,
u32::try_from(values.len()).unwrap()
).await
}};
#[allow(unused)]
if let Some(count) = count {{
{lift}
}}
count
}})
}}
let count = unsafe {{
{async_support}::await_stream_result(
wit_import,
stream,
address,
u32::try_from(values.len()).unwrap()
).await
}};
#[allow(unused)]
if let Some(count) = count {{
{lift}
}}
count
}})
}}

fn cancel_write(writer: u32) {{
Expand Down Expand Up @@ -916,6 +909,8 @@ pub mod vtable{ordinal} {{
_ => unreachable!(),
}
}

self.identifier = old_identifier;
}

fn generate_guest_import(&mut self, func: &Function, interface: Option<&WorldKey>) {
Expand Down Expand Up @@ -1699,25 +1694,24 @@ pub mod vtable{ordinal} {{
}
}

pub(crate) fn full_type_name_owned(&mut self, ty: &Type, id: Identifier<'i>) -> String {
self.full_type_name(
pub(crate) fn type_name_owned_with_id(&mut self, ty: &Type, id: Identifier<'i>) -> String {
let old_identifier = mem::replace(&mut self.identifier, id);
let name = self.type_name_owned(ty);
self.identifier = old_identifier;
name
}

fn type_name_owned(&mut self, ty: &Type) -> String {
self.type_name(
ty,
TypeMode {
lifetime: None,
lists_borrowed: false,
style: TypeOwnershipStyle::Owned,
},
id,
)
}

fn full_type_name(&mut self, ty: &Type, mode: TypeMode, id: Identifier<'i>) -> String {
let old_identifier = mem::replace(&mut self.identifier, id);
let name = self.type_name(ty, mode);
self.identifier = old_identifier;
name
}

fn type_name(&mut self, ty: &Type, mode: TypeMode) -> String {
let old = mem::take(&mut self.src);
self.print_ty(ty, mode);
Expand Down
15 changes: 15 additions & 0 deletions tests/codegen/streams.wit
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
package foo:foo;

interface transmit {
variant control {
read-stream(string),
read-future(string),
write-stream(string),
write-future(string),
}

exchange: func(control: stream<control>,
caller-stream: stream<string>,
caller-future1: future<string>,
caller-future2: future<string>) -> tuple<stream<string>, future<string>, future<string>>;
}

interface streams {
stream-u8-param: func(x: stream<u8>);
stream-u16-param: func(x: stream<u16>);
Expand Down Expand Up @@ -82,4 +96,5 @@ interface streams {
world the-streams {
import streams;
export streams;
export transmit;
}
Loading