Coverage for ci/cbindgen.py: 96%

184 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-20 16:25 +0000

1# type: ignore 

2 

3# This is a small script to parse the header files from wasmtime and generate 

4# appropriate function definitions in Python for each exported function. This 

5# also reflects types into Python with `ctypes`. While there's at least one 

6# other generate that does this already it seemed to not quite fit our purposes 

7# with lots of extra an unnecessary boilerplate. 

8 

9from pycparser import c_ast, parse_file 

10 

11 

12class Visitor(c_ast.NodeVisitor): 

13 def __init__(self): 

14 self.ret = '' 

15 self.ret += '# flake8: noqa\n' 

16 self.ret += '#\n' 

17 self.ret += '# This is a procedurally generated file, DO NOT EDIT\n' 

18 self.ret += '# instead edit `./ci/cbindgen.py` at the root of the repo\n' 

19 self.ret += '\n' 

20 self.ret += 'from ctypes import *\n' 

21 self.ret += 'import ctypes\n' 

22 self.ret += 'from typing import Any\n' 

23 self.ret += 'from enum import Enum, auto\n' 

24 self.ret += 'from ._ffi import dll, wasm_val_t, wasm_ref_t\n' 

25 

26 # Skip all function definitions, we don't bind those 

27 def visit_FuncDef(self, node): 

28 pass 

29 

30 def visit_Struct(self, node): 

31 if not node.name or not node.name.startswith('was'): 

32 return 

33 

34 # This is hand-generated since it has an anonymous union in it 

35 if node.name == 'wasm_val_t' or node.name == 'wasm_ref_t': 

36 return 

37 

38 self.ret += "\n" 

39 self.ret += "class {}(Structure):\n".format(node.name) 

40 if node.decls: 

41 self.ret += " _fields_ = [\n" 

42 for decl in node.decls: 

43 self.ret += " (\"{}\", {}),\n".format(decl.name, type_name(decl.type)) 

44 self.ret += " ]\n" 

45 for decl in node.decls: 

46 self.ret += " {}: {}\n".format(decl.name, type_name(decl.type, typing=True)) 

47 else: 

48 self.ret += " pass\n" 

49 

50 def visit_Union(self, node): 

51 if not node.name or not node.name.startswith('was'): 

52 return 

53 

54 self.ret += "\n" 

55 self.ret += "class {}(Union):\n".format(node.name) 

56 if node.decls: 

57 self.ret += " _fields_ = [\n" 

58 for decl in node.decls: 

59 self.ret += " (\"{}\", {}),\n".format(name(decl.name), type_name(decl.type)) 

60 self.ret += " ]\n" 

61 for decl in node.decls: 

62 self.ret += " {}: {}".format(name(decl.name), type_name(decl.type, typing=True)) 

63 if decl.name == 'v128': 

64 self.ret += ' # type: ignore' 

65 self.ret += "\n" 

66 else: 

67 self.ret += " pass\n" 

68 

69 def visit_Enum(self, node): 

70 if not node.name or not node.name.startswith('was'): 

71 return 

72 

73 self.ret += "\n" 

74 self.ret += "class {}(Enum):\n".format(node.name) 

75 for enumerator in node.values.enumerators: 

76 if enumerator.value: 

77 self.ret += " {} = {}\n".format(enumerator.name, enumerator.value.value) 

78 else: 

79 self.ret += " {} = auto()\n".format(enumerator.name) 

80 

81 

82 def visit_Typedef(self, node): 

83 if not node.name or not node.name.startswith('was'): 

84 return 

85 

86 # Given anonymous structs in typedefs names by default. 

87 if isinstance(node.type, c_ast.TypeDecl): 

88 if isinstance(node.type.type, c_ast.Struct): 

89 if node.type.type.name is None: 

90 if node.name.endswith('_t'): 

91 node.type.type.name = node.name[:-2] 

92 

93 self.visit(node.type) 

94 tyname = type_name(node.type) 

95 if tyname != node.name: 

96 self.ret += "\n" 

97 if isinstance(node.type, c_ast.ArrayDecl): 

98 self.ret += "{} = {} * {}\n".format(node.name, type_name(node.type.type), node.type.dim.value) 

99 else: 

100 self.ret += "{} = {}\n".format(node.name, type_name(node.type)) 

101 

102 def visit_FuncDecl(self, node): 

103 if isinstance(node.type, c_ast.TypeDecl): 

104 ptr = False 

105 ty = node.type 

106 elif isinstance(node.type, c_ast.PtrDecl): 

107 ptr = True 

108 ty = node.type.type 

109 name = ty.declname 

110 # This is probably a type, skip it 

111 if name.endswith('_t'): 

112 return 

113 # Skip anything not related to wasi or wasm 

114 if not name.startswith('was'): 

115 return 

116 

117 # This function forward-declares `wasm_instance_t` which doesn't work 

118 # with this binding generator, but for now this isn't used anyway so 

119 # just skip it. 

120 if name == 'wasm_frame_instance': 

121 return 

122 

123 ret = ty.type 

124 

125 argpairs = [] 

126 argtypes = [] 

127 argnames = [] 

128 if node.args: 

129 for i, param in enumerate(node.args.params): 

130 argname = param.name 

131 if not argname or argname == "import" or argname == "global": 

132 argname = "arg{}".format(i) 

133 tyname = type_name(param.type) 

134 if i == 0 and tyname == "None": 

135 continue 

136 argpairs.append("{}: Any".format(argname)) 

137 argnames.append(argname) 

138 argtypes.append(tyname) 

139 

140 # It seems like this is the actual return value of the function, not a 

141 # pointer. Special-case this so the type-checking agrees with runtime. 

142 if type_name(ret, ptr) == 'c_void_p': 

143 retty = 'int' 

144 else: 

145 retty = type_name(node.type, ptr, typing=True) 

146 

147 self.ret += "\n" 

148 self.ret += "_{0} = dll.{0}\n".format(name) 

149 self.ret += "_{}.restype = {}\n".format(name, type_name(ret, ptr)) 

150 self.ret += "_{}.argtypes = [{}]\n".format(name, ', '.join(argtypes)) 

151 self.ret += "def {}({}) -> {}:\n".format(name, ', '.join(argpairs), retty) 

152 self.ret += " return _{}({}) # type: ignore\n".format(name, ', '.join(argnames)) 

153 

154 

155def name(name): 

156 if name == 'global': 

157 return 'global_' 

158 return name 

159 

160 

161def type_name(ty, ptr=False, typing=False): 

162 while isinstance(ty, c_ast.TypeDecl): 

163 ty = ty.type 

164 

165 if ptr: 

166 if typing: 

167 return "ctypes._Pointer" 

168 if isinstance(ty, c_ast.IdentifierType) and ty.names[0] == "void": 

169 return "c_void_p" 

170 elif not isinstance(ty, c_ast.FuncDecl): 

171 return "POINTER({})".format(type_name(ty, False, typing)) 

172 

173 if isinstance(ty, c_ast.IdentifierType): 

174 assert(len(ty.names) == 1) 

175 if ty.names[0] == "void": 

176 return "None" 

177 elif ty.names[0] == "_Bool": 

178 return "bool" if typing else "c_bool" 

179 elif ty.names[0] == "byte_t": 

180 return "c_ubyte" 

181 elif ty.names[0] == "uint8_t": 

182 return "c_uint8" 

183 elif ty.names[0] == "int32_t": 

184 return "int" if typing else "c_int32" 

185 elif ty.names[0] == "uint32_t": 

186 return "int" if typing else "c_uint32" 

187 elif ty.names[0] == "uint64_t": 

188 return "int" if typing else "c_uint64" 

189 elif ty.names[0] == "int64_t": 

190 return "int" if typing else "c_int64" 

191 elif ty.names[0] == "float32_t": 

192 return "float" if typing else "c_float" 

193 elif ty.names[0] == "float64_t": 

194 return "float" if typing else "c_double" 

195 elif ty.names[0] == "size_t": 

196 return "int" if typing else "c_size_t" 

197 elif ty.names[0] == "char": 

198 return "c_char" 

199 elif ty.names[0] == "int": 

200 return "int" if typing else "c_int" 

201 # ctypes values can't stand as typedefs, so just use the pointer type here 

202 elif typing and 'callback_t' in ty.names[0]: 

203 return "ctypes._Pointer" 

204 elif typing and ('size' in ty.names[0] or 'pages' in ty.names[0]): 

205 return "int" 

206 return ty.names[0] 

207 elif isinstance(ty, c_ast.Struct): 

208 return ty.name 

209 elif isinstance(ty, c_ast.Union): 

210 return ty.name 

211 elif isinstance(ty, c_ast.Enum): 

212 return ty.name 

213 elif isinstance(ty, c_ast.FuncDecl): 

214 tys = [] 

215 # TODO: apparently errors are thrown if we faithfully represent the 

216 # pointer type here, seems odd? 

217 if isinstance(ty.type, c_ast.PtrDecl): 

218 tys.append("c_size_t") 

219 else: 

220 tys.append(type_name(ty.type)) 

221 if ty.args.params: 

222 for param in ty.args.params: 

223 tys.append(type_name(param.type)) 

224 return "CFUNCTYPE({})".format(', '.join(tys)) 

225 elif isinstance(ty, c_ast.PtrDecl) or isinstance(ty, c_ast.ArrayDecl): 

226 return type_name(ty.type, True, typing) 

227 else: 

228 raise RuntimeError("unknown {}".format(ty)) 

229 

230 

231ast = parse_file( 

232 './wasmtime/include/wasmtime.h', 

233 use_cpp=True, 

234 cpp_path='gcc', 

235 cpp_args=[ 

236 '-E', 

237 '-I./wasmtime/include', 

238 '-D__attribute__(x)=', 

239 '-D__asm__(x)=', 

240 '-D__asm(x)=', 

241 '-D__volatile__(x)=', 

242 '-D_Static_assert(x, y)=', 

243 '-Dstatic_assert(x, y)=', 

244 '-D__restrict=', 

245 '-D__restrict__=', 

246 '-D__extension__=', 

247 '-D__inline__=', 

248 '-D__signed=', 

249 '-D__builtin_va_list=int', 

250 ] 

251) 

252 

253v = Visitor() 

254v.visit(ast) 

255 

256if __name__ == "__main__": 

257 with open("wasmtime/_bindings.py", "w") as f: 

258 f.write(v.ret) 

259else: 

260 with open("wasmtime/_bindings.py", "r") as f: 

261 contents = f.read() 

262 if contents != v.ret: 

263 raise RuntimeError("bindings need an update, run this script")