Coverage for wasmtime/_func.py: 96%

163 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-20 16:25 +0000

1from contextlib import contextmanager 

2from ctypes import POINTER, byref, CFUNCTYPE, c_void_p, cast 

3import ctypes 

4from wasmtime import Store, FuncType, Val, Trap, WasmtimeError 

5from . import _ffi as ffi 

6from ._extern import wrap_extern 

7from typing import Callable, Optional, Generic, TypeVar, List, Union, Tuple, cast as cast_type, Sequence, Any 

8from ._exportable import AsExtern 

9from ._store import Storelike 

10 

11 

12T = TypeVar('T') 

13FUNCTIONS: "Slab[Tuple]" 

14LAST_EXCEPTION: Optional[Exception] = None 

15 

16 

17class Func: 

18 _func: ffi.wasmtime_func_t 

19 

20 def __init__(self, store: Storelike, ty: FuncType, func: Callable, access_caller: bool = False): 

21 """ 

22 Creates a new func in `store` with the given `ty` which calls the closure 

23 given 

24 

25 The `func` is called with the parameters natively and they'll have native 

26 Python values rather than being wrapped in `Val`. If `access_caller` is 

27 set to `True` then the first argument given to `func` is an instance of 

28 type `Caller` below. 

29 """ 

30 

31 if not isinstance(store, Store): 

32 raise TypeError("expected a Store") 

33 if not isinstance(ty, FuncType): 

34 raise TypeError("expected a FuncType") 

35 idx = FUNCTIONS.allocate((func, ty.results, access_caller)) 

36 _func = ffi.wasmtime_func_t() 

37 ffi.wasmtime_func_new( 

38 store._context(), 

39 ty.ptr(), 

40 trampoline, 

41 idx, 

42 finalize, 

43 byref(_func)) 

44 self._func = _func 

45 

46 @classmethod 

47 def _from_raw(cls, func: ffi.wasmtime_func_t) -> "Func": 

48 ty: "Func" = cls.__new__(cls) 

49 ty._func = func 

50 return ty 

51 

52 def type(self, store: Storelike) -> FuncType: 

53 """ 

54 Gets the type of this func as a `FuncType` 

55 """ 

56 ptr = ffi.wasmtime_func_type(store._context(), byref(self._func)) 

57 return FuncType._from_ptr(ptr, None) 

58 

59 def __call__(self, store: Storelike, *params: Any) -> Any: 

60 """ 

61 Calls this function with the given parameters 

62 

63 Parameters can either be a `Val` or a native python value which can be 

64 converted to a `Val` of the corresponding correct type 

65 

66 Returns `None` if this func has 0 return types 

67 Returns a single value if the func has 1 return type 

68 Returns a list if the func has more than 1 return type 

69 

70 Note that you can also use the `__call__` method and invoke a `Func` as 

71 if it were a function directly. 

72 """ 

73 

74 ty = self.type(store) 

75 param_tys = ty.params 

76 if len(params) > len(param_tys): 

77 raise WasmtimeError("too many parameters provided: given %s, expected %s" % 

78 (len(params), len(param_tys))) 

79 if len(params) < len(param_tys): 

80 raise WasmtimeError("too few parameters provided: given %s, expected %s" % 

81 (len(params), len(param_tys))) 

82 

83 params_ptr = (ffi.wasmtime_val_t * len(params))() 

84 params_set = 0 

85 try: 

86 for val in params: 

87 params_ptr[params_set] = Val._convert_to_raw(store, param_tys[params_set], val) 

88 params_set += 1 

89 

90 result_tys = ty.results 

91 results_ptr = (ffi.wasmtime_val_t * len(result_tys))() 

92 

93 with enter_wasm(store) as trap: 

94 error = ffi.wasmtime_func_call( 

95 store._context(), 

96 byref(self._func), 

97 params_ptr, 

98 len(params), 

99 results_ptr, 

100 len(result_tys), 

101 trap) 

102 if error: 

103 raise WasmtimeError._from_ptr(error) 

104 finally: 

105 for i in range(0, params_set): 

106 ffi.wasmtime_val_unroot(store._context(), byref(params_ptr[i])) 

107 

108 results = [] 

109 for i in range(0, len(result_tys)): 

110 results.append(Val._from_raw(store, results_ptr[i]).value) 

111 if len(results) == 0: 

112 return None 

113 elif len(results) == 1: 

114 return results[0] 

115 else: 

116 return results 

117 

118 def _as_extern(self) -> ffi.wasmtime_extern_t: 

119 union = ffi.wasmtime_extern_union(func=self._func) 

120 return ffi.wasmtime_extern_t(ffi.WASMTIME_EXTERN_FUNC, union) 

121 

122 

123class Caller: 

124 __ptr: "Optional[ctypes._Pointer[ffi.wasmtime_caller_t]]" 

125 __context: "Optional[ctypes._Pointer[ffi.wasmtime_context_t]]" 

126 

127 def __init__(self, ptr: "ctypes._Pointer[ffi.wasmtime_caller_t]"): 

128 self.__ptr = ptr 

129 self.__context = ffi.wasmtime_caller_context(ptr) 

130 

131 def __getitem__(self, name: str) -> AsExtern: 

132 """ 

133 Looks up an export with `name` on the calling module. 

134 

135 If `name` isn't defined on the calling module, or if the caller has gone 

136 away for some reason, then this will raise a `KeyError`. For more 

137 information about when this could fail see the `get` method which 

138 returns `None` on failure. 

139 """ 

140 

141 ret = self.get(name) 

142 if ret is None: 

143 raise KeyError("failed to find export {}".format(name)) 

144 return ret 

145 

146 def get(self, name: str) -> Optional[AsExtern]: 

147 """ 

148 Looks up an export with `name` on the calling module. 

149 

150 May return `None` if the export isn't found, if it's not a memory (for 

151 now), or if the caller has gone away and this `Caller` object has 

152 persisted too long. 

153 """ 

154 

155 # First convert to a raw name so we can typecheck our argument 

156 name_bytes = name.encode('utf-8') 

157 name_buf = ffi.create_string_buffer(name_bytes) 

158 

159 # Next see if we've been invalidated 

160 if self.__ptr is None: 

161 return None 

162 

163 # And if we're not invalidated we can perform the actual lookup 

164 item = ffi.wasmtime_extern_t() 

165 ok = ffi.wasmtime_caller_export_get(self.__ptr, name_buf, len(name_bytes), byref(item)) 

166 if ok: 

167 return wrap_extern(item) 

168 else: 

169 return None 

170 

171 def _context(self) -> "ctypes._Pointer[ffi.wasmtime_context_t]": 

172 if self.__context is None: 

173 raise ValueError("caller is no longer valid") 

174 return self.__context 

175 

176 def _invalidate(self) -> None: 

177 self.__ptr = None 

178 self.__context = None 

179 

180 

181def extract_val(val: Val) -> Any: 

182 a = val.value 

183 if a is not None: 

184 return a 

185 return val 

186 

187 

188@ffi.wasmtime_func_callback_t 

189def trampoline(idx, caller, params, nparams, results, nresults): # type: ignore 

190 caller = Caller(caller) 

191 try: 

192 func, result_tys, access_caller = FUNCTIONS.get(idx or 0) 

193 pyparams = [] 

194 if access_caller: 

195 pyparams.append(caller) 

196 

197 for i in range(0, nparams): 

198 pyparams.append(Val._from_raw(caller, params[i], owned=False).value) 

199 pyresults = func(*pyparams) 

200 if nresults == 0: 

201 if pyresults is not None: 

202 raise WasmtimeError( 

203 "callback produced results when it shouldn't") 

204 elif nresults == 1: 

205 results[0] = Val._convert_to_raw(caller, result_tys[0], pyresults) 

206 else: 

207 if len(pyresults) != nresults: 

208 raise WasmtimeError("callback produced wrong number of results") 

209 for i, result in enumerate(pyresults): 

210 results[i] = Val._convert_to_raw(caller, result_tys[i], result) 

211 return 0 

212 except Exception as e: 

213 global LAST_EXCEPTION 

214 LAST_EXCEPTION = e 

215 trap = Trap("python exception")._consume() 

216 return cast(trap, c_void_p).value 

217 finally: 

218 caller._invalidate() 

219 

220 

221@CFUNCTYPE(None, c_void_p) 

222def finalize(idx): # type: ignore 

223 FUNCTIONS.deallocate(idx or 0) 

224 

225 

226class Slab(Generic[T]): 

227 list: List[Union[int, T]] 

228 next: int 

229 

230 def __init__(self) -> None: 

231 self.list = [] 

232 self.next = 0 

233 

234 def allocate(self, val: T) -> int: 

235 idx = self.next 

236 

237 if len(self.list) == idx: 

238 self.list.append(0) 

239 self.next += 1 

240 else: 

241 self.next = cast_type(int, self.list[idx]) 

242 

243 self.list[idx] = val 

244 return idx 

245 

246 def get(self, idx: int) -> T: 

247 return cast_type(T, self.list[idx]) 

248 

249 def deallocate(self, idx: int) -> None: 

250 self.list[idx] = self.next 

251 self.next = idx 

252 

253 

254FUNCTIONS = Slab() 

255 

256 

257@contextmanager 

258def enter_wasm(store: Storelike): # type: ignore 

259 try: 

260 trap = POINTER(ffi.wasm_trap_t)() 

261 yield byref(trap) 

262 if trap: 

263 trap_obj = Trap._from_ptr(trap) 

264 maybe_raise_last_exn() 

265 raise trap_obj 

266 except WasmtimeError: 

267 maybe_raise_last_exn() 

268 raise 

269 

270 

271def maybe_raise_last_exn() -> None: 

272 global LAST_EXCEPTION 

273 if LAST_EXCEPTION is None: 

274 return 

275 exn = LAST_EXCEPTION 

276 LAST_EXCEPTION = None 

277 raise exn