Skip to content

Commit c5b5956

Browse files
committed
Address feedback
Signed-off-by: James Sturtevant <jsturtevant@gmail.com>
1 parent 34f68b7 commit c5b5956

2 files changed

Lines changed: 50 additions & 25 deletions

File tree

src/hyperlight_wasm/src/sandbox/proto_wasm_sandbox.rs

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ See the License for the specific language governing permissions and
1414
limitations under the License.
1515
*/
1616

17+
use std::collections::HashMap;
18+
1719
use hyperlight_common::flatbuffer_wrappers::function_types::{ParameterType, ReturnType};
1820
use hyperlight_common::flatbuffer_wrappers::host_function_definition::HostFunctionDefinition;
1921
use hyperlight_common::flatbuffer_wrappers::host_function_details::HostFunctionDetails;
@@ -34,8 +36,7 @@ use crate::build_info::BuildInfo;
3436
/// With that `WasmSandbox` you can load a Wasm module through the `load_module` method and get a `LoadedWasmSandbox` which can then execute functions defined in the Wasm module.
3537
pub struct ProtoWasmSandbox {
3638
pub(super) inner: Option<UninitializedSandbox>,
37-
/// Tracks registered host function definitions for pushing to the guest at load time
38-
host_function_definitions: Vec<HostFunctionDefinition>,
39+
host_function_definitions: HashMap<String, HostFunctionDefinition>,
3940
}
4041

4142
impl Registerable for ProtoWasmSandbox {
@@ -44,17 +45,15 @@ impl Registerable for ProtoWasmSandbox {
4445
name: &str,
4546
hf: impl Into<HostFunction<Output, Args>>,
4647
) -> Result<()> {
47-
// Track the host function definition for pushing to guest at load time
48-
self.host_function_definitions.push(HostFunctionDefinition {
49-
function_name: name.to_string(),
50-
parameter_types: Some(Args::TYPE.to_vec()),
51-
return_type: Output::TYPE,
52-
});
53-
5448
self.inner
5549
.as_mut()
5650
.ok_or(new_error!("inner sandbox was none"))
57-
.and_then(|sb| sb.register(name, hf))
51+
.and_then(|sb| sb.register(name, hf))?;
52+
53+
// Track the host function definition for pushing to guest at load time.
54+
// matching hyperlight-core's FunctionRegistry behavior.
55+
self.track_host_function_definition(name, Args::TYPE, Output::TYPE);
56+
Ok(())
5857
}
5958
}
6059

@@ -79,11 +78,15 @@ impl ProtoWasmSandbox {
7978
metrics::counter!(METRIC_TOTAL_PROTO_WASM_SANDBOXES).increment(1);
8079

8180
// HostPrint is always registered by UninitializedSandbox, so include it by default
82-
let host_function_definitions = vec![HostFunctionDefinition {
83-
function_name: "HostPrint".to_string(),
84-
parameter_types: Some(vec![ParameterType::String]),
85-
return_type: ReturnType::Int,
86-
}];
81+
let mut host_function_definitions = HashMap::new();
82+
host_function_definitions.insert(
83+
"HostPrint".to_string(),
84+
HostFunctionDefinition {
85+
function_name: "HostPrint".to_string(),
86+
parameter_types: Some(vec![ParameterType::String]),
87+
return_type: ReturnType::Int,
88+
},
89+
);
8790

8891
Ok(Self {
8992
inner: Some(inner),
@@ -100,7 +103,11 @@ impl ProtoWasmSandbox {
100103
pub fn load_runtime(mut self) -> Result<WasmSandbox> {
101104
// Serialize host function definitions to push to the guest during InitWasmRuntime
102105
let host_function_definitions = HostFunctionDetails {
103-
host_functions: Some(std::mem::take(&mut self.host_function_definitions)),
106+
host_functions: Some(
107+
std::mem::take(&mut self.host_function_definitions)
108+
.into_values()
109+
.collect(),
110+
),
104111
};
105112

106113
let host_function_definitions_bytes: Vec<u8> = (&host_function_definitions)
@@ -132,17 +139,29 @@ impl ProtoWasmSandbox {
132139
name: impl AsRef<str>,
133140
host_func: impl Into<HostFunction<Output, Args>>,
134141
) -> Result<()> {
135-
// Track the host function definition for pushing to guest at load time
136-
self.host_function_definitions.push(HostFunctionDefinition {
137-
function_name: name.as_ref().to_string(),
138-
parameter_types: Some(Args::TYPE.to_vec()),
139-
return_type: Output::TYPE,
140-
});
141-
142142
self.inner
143143
.as_mut()
144144
.ok_or(new_error!("inner sandbox was none"))?
145-
.register(name, host_func)
145+
.register(&name, host_func)?;
146+
147+
self.track_host_function_definition(name.as_ref(), Args::TYPE, Output::TYPE);
148+
Ok(())
149+
}
150+
151+
fn track_host_function_definition(
152+
&mut self,
153+
name: &str,
154+
parameter_types: &[ParameterType],
155+
return_type: ReturnType,
156+
) {
157+
self.host_function_definitions.insert(
158+
name.to_string(),
159+
HostFunctionDefinition {
160+
function_name: name.to_string(),
161+
parameter_types: Some(parameter_types.to_vec()),
162+
return_type,
163+
},
164+
);
146165
}
147166

148167
/// Register the given host printing function `print_func` with `self`.

src/wasm_runtime/src/module.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,12 +122,18 @@ fn init_wasm_runtime(function_call: &FunctionCall) -> Result<Vec<u8>> {
122122

123123
let bytes = match params.first() {
124124
Some(ParameterValue::VecBytes(ref b)) => b,
125-
_ => {
125+
Some(_) => {
126126
return Err(HyperlightGuestError::new(
127127
ErrorCode::GuestFunctionParameterTypeMismatch,
128128
"InitWasmRuntime: first parameter must be VecBytes".to_string(),
129129
))
130130
}
131+
None => {
132+
return Err(HyperlightGuestError::new(
133+
ErrorCode::GuestFunctionParameterTypeMismatch,
134+
"InitWasmRuntime: expected 1 parameter, got 0".to_string(),
135+
))
136+
}
131137
};
132138

133139
let hfd: hostfuncs::HostFunctionDetails = bytes.as_slice().try_into().map_err(|e| {

0 commit comments

Comments
 (0)