wasm.go 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. //go:build !nowasm && cgo && ((linux && amd64) || (linux && arm64) || (darwin && amd64) || (darwin && arm64) || (windows && amd64))
  2. // The above build constraint is based of the cgo directives in this file:
  3. // https://github.com/bytecodealliance/wasmtime-go/blob/main/ffi.go
  4. package wasm
  5. import (
  6. "context"
  7. "crypto/sha256"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "log/slog"
  12. "net/http"
  13. "os"
  14. "path/filepath"
  15. "runtime"
  16. "runtime/trace"
  17. "strings"
  18. wasmtime "github.com/bytecodealliance/wasmtime-go/v14"
  19. "golang.org/x/sync/singleflight"
  20. "google.golang.org/grpc"
  21. "google.golang.org/grpc/codes"
  22. "google.golang.org/grpc/status"
  23. "google.golang.org/protobuf/proto"
  24. "google.golang.org/protobuf/reflect/protoreflect"
  25. "github.com/sqlc-dev/sqlc/internal/cache"
  26. "github.com/sqlc-dev/sqlc/internal/info"
  27. "github.com/sqlc-dev/sqlc/internal/plugin"
  28. )
  29. // This version must be updated whenever the wasmtime-go dependency is updated
  30. const wasmtimeVersion = `v14.0.0`
  31. func cacheDir() (string, error) {
  32. cache := os.Getenv("SQLCCACHE")
  33. if cache != "" {
  34. return cache, nil
  35. }
  36. cacheHome := os.Getenv("XDG_CACHE_HOME")
  37. if cacheHome == "" {
  38. home, err := os.UserHomeDir()
  39. if err != nil {
  40. return "", err
  41. }
  42. cacheHome = filepath.Join(home, ".cache")
  43. }
  44. return filepath.Join(cacheHome, "sqlc"), nil
  45. }
  46. var flight singleflight.Group
  47. // Verify the provided sha256 is valid.
  48. func (r *Runner) getChecksum(ctx context.Context) (string, error) {
  49. if r.SHA256 != "" {
  50. return r.SHA256, nil
  51. }
  52. // TODO: Add a log line here about something
  53. _, sum, err := r.fetch(ctx, r.URL)
  54. if err != nil {
  55. return "", err
  56. }
  57. slog.Warn("fetching WASM binary to calculate sha256. Set this value in sqlc.yaml to prevent unneeded work", "sha256", sum)
  58. return sum, nil
  59. }
  60. func (r *Runner) loadModule(ctx context.Context, engine *wasmtime.Engine) (*wasmtime.Module, error) {
  61. expected, err := r.getChecksum(ctx)
  62. if err != nil {
  63. return nil, err
  64. }
  65. value, err, _ := flight.Do(expected, func() (interface{}, error) {
  66. return r.loadSerializedModule(ctx, engine, expected)
  67. })
  68. if err != nil {
  69. return nil, err
  70. }
  71. data, ok := value.([]byte)
  72. if !ok {
  73. return nil, fmt.Errorf("returned value was not a byte slice")
  74. }
  75. return wasmtime.NewModuleDeserialize(engine, data)
  76. }
  77. func (r *Runner) loadSerializedModule(ctx context.Context, engine *wasmtime.Engine, expectedSha string) ([]byte, error) {
  78. cacheDir, err := cache.PluginsDir()
  79. if err != nil {
  80. return nil, err
  81. }
  82. pluginDir := filepath.Join(cacheDir, expectedSha)
  83. modName := fmt.Sprintf("plugin_%s_%s_%s.module", runtime.GOOS, runtime.GOARCH, wasmtimeVersion)
  84. modPath := filepath.Join(pluginDir, modName)
  85. _, staterr := os.Stat(modPath)
  86. if staterr == nil {
  87. data, err := os.ReadFile(modPath)
  88. if err != nil {
  89. return nil, err
  90. }
  91. return data, nil
  92. }
  93. wmod, err := r.loadWASM(ctx, cacheDir, expectedSha)
  94. if err != nil {
  95. return nil, err
  96. }
  97. moduRegion := trace.StartRegion(ctx, "wasmtime.NewModule")
  98. module, err := wasmtime.NewModule(engine, wmod)
  99. moduRegion.End()
  100. if err != nil {
  101. return nil, fmt.Errorf("define wasi: %w", err)
  102. }
  103. err = os.Mkdir(pluginDir, 0755)
  104. if err != nil && !os.IsExist(err) {
  105. return nil, fmt.Errorf("mkdirall: %w", err)
  106. }
  107. out, err := module.Serialize()
  108. if err != nil {
  109. return nil, fmt.Errorf("serialize: %w", err)
  110. }
  111. if err := os.WriteFile(modPath, out, 0444); err != nil {
  112. return nil, fmt.Errorf("cache wasm: %w", err)
  113. }
  114. return out, nil
  115. }
  116. func (r *Runner) fetch(ctx context.Context, uri string) ([]byte, string, error) {
  117. var body io.ReadCloser
  118. switch {
  119. case strings.HasPrefix(uri, "file://"):
  120. file, err := os.Open(strings.TrimPrefix(uri, "file://"))
  121. if err != nil {
  122. return nil, "", fmt.Errorf("os.Open: %s %w", uri, err)
  123. }
  124. body = file
  125. case strings.HasPrefix(uri, "https://"):
  126. req, err := http.NewRequestWithContext(ctx, "GET", uri, nil)
  127. if err != nil {
  128. return nil, "", fmt.Errorf("http.Get: %s %w", uri, err)
  129. }
  130. req.Header.Set("User-Agent", fmt.Sprintf("sqlc/%s Go/%s (%s %s)", info.Version, runtime.Version(), runtime.GOOS, runtime.GOARCH))
  131. resp, err := http.DefaultClient.Do(req)
  132. if err != nil {
  133. return nil, "", fmt.Errorf("http.Get: %s %w", r.URL, err)
  134. }
  135. body = resp.Body
  136. default:
  137. return nil, "", fmt.Errorf("unknown scheme: %s", r.URL)
  138. }
  139. defer body.Close()
  140. wmod, err := io.ReadAll(body)
  141. if err != nil {
  142. return nil, "", fmt.Errorf("readall: %w", err)
  143. }
  144. sum := sha256.Sum256(wmod)
  145. actual := fmt.Sprintf("%x", sum)
  146. return wmod, actual, nil
  147. }
  148. func (r *Runner) loadWASM(ctx context.Context, cache string, expected string) ([]byte, error) {
  149. pluginDir := filepath.Join(cache, expected)
  150. pluginPath := filepath.Join(pluginDir, "plugin.wasm")
  151. _, staterr := os.Stat(pluginPath)
  152. uri := r.URL
  153. if staterr == nil {
  154. uri = "file://" + pluginPath
  155. }
  156. wmod, actual, err := r.fetch(ctx, uri)
  157. if err != nil {
  158. return nil, err
  159. }
  160. if expected != actual {
  161. return nil, fmt.Errorf("invalid checksum: expected %s, got %s", expected, actual)
  162. }
  163. if staterr != nil {
  164. err := os.Mkdir(pluginDir, 0755)
  165. if err != nil && !os.IsExist(err) {
  166. return nil, fmt.Errorf("mkdirall: %w", err)
  167. }
  168. if err := os.WriteFile(pluginPath, wmod, 0444); err != nil {
  169. return nil, fmt.Errorf("cache wasm: %w", err)
  170. }
  171. }
  172. return wmod, nil
  173. }
  174. // removePGCatalog removes the pg_catalog schema from the request. There is a
  175. // mysterious (reason unknown) bug with wasm plugins when a large amount of
  176. // tables (like there are in the catalog) are sent.
  177. // @see https://github.com/sqlc-dev/sqlc/pull/1748
  178. func removePGCatalog(req *plugin.GenerateRequest) {
  179. if req.Catalog == nil || req.Catalog.Schemas == nil {
  180. return
  181. }
  182. filtered := make([]*plugin.Schema, 0, len(req.Catalog.Schemas))
  183. for _, schema := range req.Catalog.Schemas {
  184. if schema.Name == "pg_catalog" || schema.Name == "information_schema" {
  185. continue
  186. }
  187. filtered = append(filtered, schema)
  188. }
  189. req.Catalog.Schemas = filtered
  190. }
  191. func (r *Runner) Invoke(ctx context.Context, method string, args any, reply any, opts ...grpc.CallOption) error {
  192. req, ok := args.(protoreflect.ProtoMessage)
  193. if !ok {
  194. return status.Error(codes.InvalidArgument, "args isn't a protoreflect.ProtoMessage")
  195. }
  196. // Remove the pg_catalog schema. Its sheer size causes unknown issues with wasm plugins
  197. genReq, ok := req.(*plugin.GenerateRequest)
  198. if ok {
  199. removePGCatalog(genReq)
  200. req = genReq
  201. }
  202. stdinBlob, err := proto.Marshal(req)
  203. if err != nil {
  204. return fmt.Errorf("failed to encode codegen request: %w", err)
  205. }
  206. engine := wasmtime.NewEngine()
  207. module, err := r.loadModule(ctx, engine)
  208. if err != nil {
  209. return fmt.Errorf("loadModule: %w", err)
  210. }
  211. linker := wasmtime.NewLinker(engine)
  212. if err := linker.DefineWasi(); err != nil {
  213. return err
  214. }
  215. dir, err := os.MkdirTemp(os.Getenv("SQLCTMPDIR"), "out")
  216. if err != nil {
  217. return fmt.Errorf("temp dir: %w", err)
  218. }
  219. defer os.RemoveAll(dir)
  220. stdinPath := filepath.Join(dir, "stdin")
  221. stderrPath := filepath.Join(dir, "stderr")
  222. stdoutPath := filepath.Join(dir, "stdout")
  223. if err := os.WriteFile(stdinPath, stdinBlob, 0755); err != nil {
  224. return fmt.Errorf("write file: %w", err)
  225. }
  226. // Configure WASI imports to write stdout into a file.
  227. wasiConfig := wasmtime.NewWasiConfig()
  228. wasiConfig.SetArgv([]string{"plugin.wasm", method})
  229. wasiConfig.SetStdinFile(stdinPath)
  230. wasiConfig.SetStdoutFile(stdoutPath)
  231. wasiConfig.SetStderrFile(stderrPath)
  232. keys := []string{"SQLC_VERSION"}
  233. vals := []string{info.Version}
  234. for _, key := range r.Env {
  235. keys = append(keys, key)
  236. vals = append(vals, os.Getenv(key))
  237. }
  238. wasiConfig.SetEnv(keys, vals)
  239. store := wasmtime.NewStore(engine)
  240. store.SetWasi(wasiConfig)
  241. linkRegion := trace.StartRegion(ctx, "linker.DefineModule")
  242. err = linker.DefineModule(store, "", module)
  243. linkRegion.End()
  244. if err != nil {
  245. return fmt.Errorf("define wasi: %w", err)
  246. }
  247. // Run the function
  248. fn, err := linker.GetDefault(store, "")
  249. if err != nil {
  250. return fmt.Errorf("wasi: get default: %w", err)
  251. }
  252. callRegion := trace.StartRegion(ctx, "call _start")
  253. _, err = fn.Call(store)
  254. callRegion.End()
  255. if cerr := checkError(err, stderrPath); cerr != nil {
  256. return cerr
  257. }
  258. // Print WASM stdout
  259. stdoutBlob, err := os.ReadFile(stdoutPath)
  260. if err != nil {
  261. return fmt.Errorf("read file: %w", err)
  262. }
  263. resp, ok := reply.(protoreflect.ProtoMessage)
  264. if !ok {
  265. return fmt.Errorf("reply isn't a GenerateResponse")
  266. }
  267. if err := proto.Unmarshal(stdoutBlob, resp); err != nil {
  268. return err
  269. }
  270. return nil
  271. }
  272. func (r *Runner) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
  273. return nil, status.Error(codes.Unimplemented, "")
  274. }
  275. func checkError(err error, stderrPath string) error {
  276. if err == nil {
  277. return err
  278. }
  279. var wtError *wasmtime.Error
  280. if errors.As(err, &wtError) {
  281. if code, ok := wtError.ExitStatus(); ok {
  282. if code == 0 {
  283. return nil
  284. }
  285. }
  286. }
  287. // Print WASM stdout
  288. stderrBlob, rferr := os.ReadFile(stderrPath)
  289. if rferr == nil && len(stderrBlob) > 0 {
  290. return errors.New(string(stderrBlob))
  291. }
  292. return fmt.Errorf("call: %w", err)
  293. }