Coverage for ci/cbindgen.py: 96%

211 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-12-01 19:40 +0000

1# mypy: ignore-errors 

2 

3# This is a small script to parse the header files from wasmtime and generate 

4# appropriate function definitions in Python for each exported function. This 

5# also reflects types into Python with `ctypes`. While there's at least one 

6# other generate that does this already it seemed to not quite fit our purposes 

7# with lots of extra an unnecessary boilerplate. 

8 

9from pycparser import c_ast, parse_file 

10import sys 

11 

12class Visitor(c_ast.NodeVisitor): 

13 def __init__(self): 

14 self.ret = '' 

15 self.ret += '# flake8: noqa\n' 

16 self.ret += '#\n' 

17 self.ret += '# This is a procedurally generated file, DO NOT EDIT\n' 

18 self.ret += '# instead edit `./ci/cbindgen.py` at the root of the repo\n' 

19 self.ret += '\n' 

20 self.ret += 'import ctypes\n' 

21 self.ret += 'from typing import Any\n' 

22 self.ret += 'from enum import Enum, auto\n' 

23 self.ret += 'from ._ffi import dll, wasm_val_t, wasm_ref_t\n' 

24 self.forward_declared = {} 

25 

26 # Skip all function definitions, we don't bind those 

27 def visit_FuncDef(self, node): 

28 pass 

29 

30 def visit_Struct(self, node): 

31 if not node.name or not node.name.startswith('was'): 

32 return 

33 

34 # This is hand-generated since it has an anonymous union in it 

35 if node.name == 'wasm_val_t' or node.name == 'wasm_ref_t': 

36 return 

37 

38 self.ret += "\n" 

39 if not node.decls: 

40 self.forward_declared[node.name] = True 

41 self.ret += "class {}(ctypes.Structure):\n".format(node.name) 

42 self.ret += " pass\n" 

43 return 

44 

45 anon_decl = 0 

46 for decl in node.decls: 

47 if not decl.name: 

48 assert(isinstance(decl.type, c_ast.Struct)) 

49 decl.type.name = node.name + '_anon_' + str(anon_decl) 

50 self.visit_Struct(decl.type) 

51 anon_decl += 1 

52 decl.name = '_anon_' + str(anon_decl) 

53 

54 if node.name in self.forward_declared: 

55 self.ret += "{}._fields_ = [\n".format(node.name) 

56 else: 

57 self.ret += "class {}(ctypes.Structure):\n".format(node.name) 

58 self.ret += " _fields_ = [\n" 

59 

60 for decl in node.decls: 

61 self.ret += " (\"{}\", {}),\n".format(decl.name, type_name(decl.type)) 

62 self.ret += " ]\n" 

63 

64 if not node.name in self.forward_declared: 

65 for decl in node.decls: 

66 self.ret += " {}: {}\n".format(decl.name, type_name(decl.type, typing=True)) 

67 

68 def visit_Union(self, node): 

69 if not node.name or not node.name.startswith('was'): 

70 return 

71 assert(node.decls) 

72 

73 self.ret += "\n" 

74 self.ret += "class {}(ctypes.Union):\n".format(node.name) 

75 self.ret += " _fields_ = [\n" 

76 for decl in node.decls: 

77 self.ret += " (\"{}\", {}),\n".format(name(decl.name), type_name(decl.type)) 

78 self.ret += " ]\n" 

79 for decl in node.decls: 

80 self.ret += " {}: {}".format(name(decl.name), type_name(decl.type, typing=True)) 

81 if decl.name == 'v128': 

82 self.ret += ' # type: ignore' 

83 self.ret += "\n" 

84 

85 def visit_Enum(self, node): 

86 if not node.name or not node.name.startswith('was'): 

87 return 

88 

89 self.ret += "\n" 

90 self.ret += "class {}(Enum):\n".format(node.name) 

91 for enumerator in node.values.enumerators: 

92 if enumerator.value: 

93 self.ret += " {} = {}\n".format(enumerator.name, enumerator.value.value) 

94 else: 

95 self.ret += " {} = auto()\n".format(enumerator.name) 

96 

97 

98 def visit_Typedef(self, node): 

99 if not node.name or not node.name.startswith('was'): 

100 return 

101 

102 # Given anonymous structs in typedefs names by default. 

103 if isinstance(node.type, c_ast.TypeDecl): 

104 if isinstance(node.type.type, c_ast.Struct) or \ 

105 isinstance(node.type.type, c_ast.Union): 

106 if node.type.type.name is None: 

107 if node.name.endswith('_t'): 

108 node.type.type.name = node.name[:-2] 

109 

110 self.visit(node.type) 

111 tyname = type_name(node.type) 

112 if tyname != node.name: 

113 self.ret += "\n" 

114 if isinstance(node.type, c_ast.ArrayDecl): 

115 self.ret += "{} = {} * {}\n".format(node.name, type_name(node.type.type), node.type.dim.value) 

116 else: 

117 self.ret += "{} = {}\n".format(node.name, type_name(node.type)) 

118 

119 def visit_FuncDecl(self, node): 

120 if isinstance(node.type, c_ast.TypeDecl): 

121 ptr = False 

122 ty = node.type 

123 elif isinstance(node.type, c_ast.PtrDecl): 

124 ptr = True 

125 ty = node.type.type 

126 name = ty.declname 

127 # This is probably a type, skip it 

128 if name.endswith('_t'): 

129 return 

130 # Skip anything not related to wasi or wasm 

131 if not name.startswith('was'): 

132 return 

133 

134 # This function forward-declares `wasm_instance_t` which doesn't work 

135 # with this binding generator, but for now this isn't used anyway so 

136 # just skip it. 

137 if name == 'wasm_frame_instance': 

138 return 

139 

140 ret = ty.type 

141 

142 argpairs = [] 

143 argtypes = [] 

144 argnames = [] 

145 if node.args: 

146 for i, param in enumerate(node.args.params): 

147 argname = param.name 

148 if not argname or argname == "import" or argname == "global": 

149 argname = "arg{}".format(i) 

150 tyname = type_name(param.type) 

151 if i == 0 and tyname == "None": 

152 continue 

153 argpairs.append("{}: Any".format(argname)) 

154 argnames.append(argname) 

155 argtypes.append(tyname) 

156 

157 # It seems like this is the actual return value of the function, not a 

158 # pointer. Special-case this so the type-checking agrees with runtime. 

159 if type_name(ret, ptr) == 'c_void_p': 

160 retty = 'int' 

161 else: 

162 retty = type_name(node.type, ptr, typing=True) 

163 

164 self.ret += "\n" 

165 self.ret += "_{0} = dll.{0}\n".format(name) 

166 self.ret += "_{}.restype = {}\n".format(name, type_name(ret, ptr)) 

167 self.ret += "_{}.argtypes = [{}]\n".format(name, ', '.join(argtypes)) 

168 self.ret += "def {}({}) -> {}:\n".format(name, ', '.join(argpairs), retty) 

169 self.ret += " return _{}({}) # type: ignore\n".format(name, ', '.join(argnames)) 

170 

171 

172def name(name): 

173 if name == 'global': 

174 return 'global_' 

175 return name 

176 

177 

178def type_name(ty, ptr=False, typing=False): 

179 while isinstance(ty, c_ast.TypeDecl): 

180 ty = ty.type 

181 

182 if ptr: 

183 if typing: 

184 return "ctypes._Pointer" 

185 if isinstance(ty, c_ast.IdentifierType) and ty.names[0] == "void": 

186 return "ctypes.c_void_p" 

187 elif not isinstance(ty, c_ast.FuncDecl): 

188 return "ctypes.POINTER({})".format(type_name(ty, False, typing)) 

189 

190 if isinstance(ty, c_ast.IdentifierType): 

191 if ty.names == ['unsigned', 'char']: 

192 return "int" if typing else "ctypes.c_ubyte" 

193 assert(len(ty.names) == 1) 

194 

195 if ty.names[0] == "void": 

196 return "None" 

197 elif ty.names[0] == "_Bool": 

198 return "bool" if typing else "ctypes.c_bool" 

199 elif ty.names[0] == "byte_t": 

200 return "ctypes.c_ubyte" 

201 elif ty.names[0] == "int8_t": 

202 return "ctypes.c_int8" 

203 elif ty.names[0] == "uint8_t": 

204 return "ctypes.c_uint8" 

205 elif ty.names[0] == "int16_t": 

206 return "ctypes.c_int16" 

207 elif ty.names[0] == "uint16_t": 

208 return "ctypes.c_uint16" 

209 elif ty.names[0] == "int32_t": 

210 return "int" if typing else "ctypes.c_int32" 

211 elif ty.names[0] == "uint32_t": 

212 return "int" if typing else "ctypes.c_uint32" 

213 elif ty.names[0] == "uint64_t": 

214 return "int" if typing else "ctypes.c_uint64" 

215 elif ty.names[0] == "int64_t": 

216 return "int" if typing else "ctypes.c_int64" 

217 elif ty.names[0] == "float32_t": 

218 return "float" if typing else "ctypes.c_float" 

219 elif ty.names[0] == "float64_t": 

220 return "float" if typing else "ctypes.c_double" 

221 elif ty.names[0] == "size_t": 

222 return "int" if typing else "ctypes.c_size_t" 

223 elif ty.names[0] == "ptrdiff_t": 

224 return "int" if typing else "ctypes.c_ssize_t" 

225 elif ty.names[0] == "char": 

226 return "ctypes.c_char" 

227 elif ty.names[0] == "int": 

228 return "int" if typing else "ctypes.c_int" 

229 # ctypes values can't stand as typedefs, so just use the pointer type here 

230 elif typing and 'callback_t' in ty.names[0]: 

231 return "ctypes._Pointer" 

232 elif typing and ('size' in ty.names[0] or 'pages' in ty.names[0]): 

233 return "int" 

234 return ty.names[0] 

235 elif isinstance(ty, c_ast.Struct): 

236 return ty.name 

237 elif isinstance(ty, c_ast.Union): 

238 return ty.name 

239 elif isinstance(ty, c_ast.Enum): 

240 return ty.name 

241 elif isinstance(ty, c_ast.FuncDecl): 

242 tys = [] 

243 # TODO: apparently errors are thrown if we faithfully represent the 

244 # pointer type here, seems odd? 

245 if isinstance(ty.type, c_ast.PtrDecl): 

246 tys.append("ctypes.c_size_t") 

247 else: 

248 tys.append(type_name(ty.type)) 

249 if ty.args and ty.args.params: 

250 for param in ty.args.params: 

251 tys.append(type_name(param.type)) 

252 return "ctypes.CFUNCTYPE({})".format(', '.join(tys)) 

253 elif isinstance(ty, c_ast.PtrDecl) or isinstance(ty, c_ast.ArrayDecl): 

254 return type_name(ty.type, True, typing) 

255 else: 

256 raise RuntimeError("unknown {}".format(ty)) 

257 

258 

259def run(): 

260 ast = parse_file( 

261 './wasmtime/include/wasmtime.h', 

262 use_cpp=True, 

263 cpp_path='gcc', 

264 cpp_args=[ 

265 '-E', 

266 '-I./wasmtime/include', 

267 '-D__attribute__(x)=', 

268 '-D__asm__(x)=', 

269 '-D__asm(x)=', 

270 '-D__volatile__(x)=', 

271 '-D_Static_assert(x, y)=', 

272 '-Dstatic_assert(x, y)=', 

273 '-D__restrict=', 

274 '-D__restrict__=', 

275 '-D__extension__=', 

276 '-D__inline__=', 

277 '-D__signed=', 

278 '-D__builtin_va_list=int', 

279 ] 

280 ) 

281 

282 v = Visitor() 

283 v.visit(ast) 

284 return v.ret 

285 

286if __name__ == "__main__": 

287 with open("wasmtime/_bindings.py", "w") as f: 

288 f.write(run()) 

289elif sys.platform == 'linux': 

290 with open("wasmtime/_bindings.py", "r") as f: 

291 contents = f.read() 

292 if contents != run(): 

293 raise RuntimeError("bindings need an update, run this script")