123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349 |
- //go:build !nowasm && cgo && ((linux && amd64) || (linux && arm64) || (darwin && amd64) || (darwin && arm64) || (windows && amd64))
- // The above build constraint is based of the cgo directives in this file:
- // https://github.com/bytecodealliance/wasmtime-go/blob/main/ffi.go
- package wasm
- import (
- "context"
- "crypto/sha256"
- "errors"
- "fmt"
- "io"
- "log/slog"
- "net/http"
- "os"
- "path/filepath"
- "runtime"
- "runtime/trace"
- "strings"
- wasmtime "github.com/bytecodealliance/wasmtime-go/v14"
- "golang.org/x/sync/singleflight"
- "google.golang.org/grpc"
- "google.golang.org/grpc/codes"
- "google.golang.org/grpc/status"
- "google.golang.org/protobuf/proto"
- "google.golang.org/protobuf/reflect/protoreflect"
- "github.com/sqlc-dev/sqlc/internal/cache"
- "github.com/sqlc-dev/sqlc/internal/info"
- "github.com/sqlc-dev/sqlc/internal/plugin"
- )
- // This version must be updated whenever the wasmtime-go dependency is updated
- const wasmtimeVersion = `v14.0.0`
- func cacheDir() (string, error) {
- cache := os.Getenv("SQLCCACHE")
- if cache != "" {
- return cache, nil
- }
- cacheHome := os.Getenv("XDG_CACHE_HOME")
- if cacheHome == "" {
- home, err := os.UserHomeDir()
- if err != nil {
- return "", err
- }
- cacheHome = filepath.Join(home, ".cache")
- }
- return filepath.Join(cacheHome, "sqlc"), nil
- }
- var flight singleflight.Group
- // Verify the provided sha256 is valid.
- func (r *Runner) getChecksum(ctx context.Context) (string, error) {
- if r.SHA256 != "" {
- return r.SHA256, nil
- }
- // TODO: Add a log line here about something
- _, sum, err := r.fetch(ctx, r.URL)
- if err != nil {
- return "", err
- }
- slog.Warn("fetching WASM binary to calculate sha256. Set this value in sqlc.yaml to prevent unneeded work", "sha256", sum)
- return sum, nil
- }
- func (r *Runner) loadModule(ctx context.Context, engine *wasmtime.Engine) (*wasmtime.Module, error) {
- expected, err := r.getChecksum(ctx)
- if err != nil {
- return nil, err
- }
- value, err, _ := flight.Do(expected, func() (interface{}, error) {
- return r.loadSerializedModule(ctx, engine, expected)
- })
- if err != nil {
- return nil, err
- }
- data, ok := value.([]byte)
- if !ok {
- return nil, fmt.Errorf("returned value was not a byte slice")
- }
- return wasmtime.NewModuleDeserialize(engine, data)
- }
- func (r *Runner) loadSerializedModule(ctx context.Context, engine *wasmtime.Engine, expectedSha string) ([]byte, error) {
- cacheDir, err := cache.PluginsDir()
- if err != nil {
- return nil, err
- }
- pluginDir := filepath.Join(cacheDir, expectedSha)
- modName := fmt.Sprintf("plugin_%s_%s_%s.module", runtime.GOOS, runtime.GOARCH, wasmtimeVersion)
- modPath := filepath.Join(pluginDir, modName)
- _, staterr := os.Stat(modPath)
- if staterr == nil {
- data, err := os.ReadFile(modPath)
- if err != nil {
- return nil, err
- }
- return data, nil
- }
- wmod, err := r.loadWASM(ctx, cacheDir, expectedSha)
- if err != nil {
- return nil, err
- }
- moduRegion := trace.StartRegion(ctx, "wasmtime.NewModule")
- module, err := wasmtime.NewModule(engine, wmod)
- moduRegion.End()
- if err != nil {
- return nil, fmt.Errorf("define wasi: %w", err)
- }
- err = os.Mkdir(pluginDir, 0755)
- if err != nil && !os.IsExist(err) {
- return nil, fmt.Errorf("mkdirall: %w", err)
- }
- out, err := module.Serialize()
- if err != nil {
- return nil, fmt.Errorf("serialize: %w", err)
- }
- if err := os.WriteFile(modPath, out, 0444); err != nil {
- return nil, fmt.Errorf("cache wasm: %w", err)
- }
- return out, nil
- }
- func (r *Runner) fetch(ctx context.Context, uri string) ([]byte, string, error) {
- var body io.ReadCloser
- switch {
- case strings.HasPrefix(uri, "file://"):
- file, err := os.Open(strings.TrimPrefix(uri, "file://"))
- if err != nil {
- return nil, "", fmt.Errorf("os.Open: %s %w", uri, err)
- }
- body = file
- case strings.HasPrefix(uri, "https://"):
- req, err := http.NewRequestWithContext(ctx, "GET", uri, nil)
- if err != nil {
- return nil, "", fmt.Errorf("http.Get: %s %w", uri, err)
- }
- req.Header.Set("User-Agent", fmt.Sprintf("sqlc/%s Go/%s (%s %s)", info.Version, runtime.Version(), runtime.GOOS, runtime.GOARCH))
- resp, err := http.DefaultClient.Do(req)
- if err != nil {
- return nil, "", fmt.Errorf("http.Get: %s %w", r.URL, err)
- }
- body = resp.Body
- default:
- return nil, "", fmt.Errorf("unknown scheme: %s", r.URL)
- }
- defer body.Close()
- wmod, err := io.ReadAll(body)
- if err != nil {
- return nil, "", fmt.Errorf("readall: %w", err)
- }
- sum := sha256.Sum256(wmod)
- actual := fmt.Sprintf("%x", sum)
- return wmod, actual, nil
- }
- func (r *Runner) loadWASM(ctx context.Context, cache string, expected string) ([]byte, error) {
- pluginDir := filepath.Join(cache, expected)
- pluginPath := filepath.Join(pluginDir, "plugin.wasm")
- _, staterr := os.Stat(pluginPath)
- uri := r.URL
- if staterr == nil {
- uri = "file://" + pluginPath
- }
- wmod, actual, err := r.fetch(ctx, uri)
- if err != nil {
- return nil, err
- }
- if expected != actual {
- return nil, fmt.Errorf("invalid checksum: expected %s, got %s", expected, actual)
- }
- if staterr != nil {
- err := os.Mkdir(pluginDir, 0755)
- if err != nil && !os.IsExist(err) {
- return nil, fmt.Errorf("mkdirall: %w", err)
- }
- if err := os.WriteFile(pluginPath, wmod, 0444); err != nil {
- return nil, fmt.Errorf("cache wasm: %w", err)
- }
- }
- return wmod, nil
- }
- // removePGCatalog removes the pg_catalog schema from the request. There is a
- // mysterious (reason unknown) bug with wasm plugins when a large amount of
- // tables (like there are in the catalog) are sent.
- // @see https://github.com/sqlc-dev/sqlc/pull/1748
- func removePGCatalog(req *plugin.GenerateRequest) {
- if req.Catalog == nil || req.Catalog.Schemas == nil {
- return
- }
- filtered := make([]*plugin.Schema, 0, len(req.Catalog.Schemas))
- for _, schema := range req.Catalog.Schemas {
- if schema.Name == "pg_catalog" || schema.Name == "information_schema" {
- continue
- }
- filtered = append(filtered, schema)
- }
- req.Catalog.Schemas = filtered
- }
- func (r *Runner) Invoke(ctx context.Context, method string, args any, reply any, opts ...grpc.CallOption) error {
- req, ok := args.(protoreflect.ProtoMessage)
- if !ok {
- return status.Error(codes.InvalidArgument, "args isn't a protoreflect.ProtoMessage")
- }
- // Remove the pg_catalog schema. Its sheer size causes unknown issues with wasm plugins
- genReq, ok := req.(*plugin.GenerateRequest)
- if ok {
- removePGCatalog(genReq)
- req = genReq
- }
- stdinBlob, err := proto.Marshal(req)
- if err != nil {
- return fmt.Errorf("failed to encode codegen request: %w", err)
- }
- engine := wasmtime.NewEngine()
- module, err := r.loadModule(ctx, engine)
- if err != nil {
- return fmt.Errorf("loadModule: %w", err)
- }
- linker := wasmtime.NewLinker(engine)
- if err := linker.DefineWasi(); err != nil {
- return err
- }
- dir, err := os.MkdirTemp(os.Getenv("SQLCTMPDIR"), "out")
- if err != nil {
- return fmt.Errorf("temp dir: %w", err)
- }
- defer os.RemoveAll(dir)
- stdinPath := filepath.Join(dir, "stdin")
- stderrPath := filepath.Join(dir, "stderr")
- stdoutPath := filepath.Join(dir, "stdout")
- if err := os.WriteFile(stdinPath, stdinBlob, 0755); err != nil {
- return fmt.Errorf("write file: %w", err)
- }
- // Configure WASI imports to write stdout into a file.
- wasiConfig := wasmtime.NewWasiConfig()
- wasiConfig.SetArgv([]string{"plugin.wasm", method})
- wasiConfig.SetStdinFile(stdinPath)
- wasiConfig.SetStdoutFile(stdoutPath)
- wasiConfig.SetStderrFile(stderrPath)
- keys := []string{"SQLC_VERSION"}
- vals := []string{info.Version}
- for _, key := range r.Env {
- keys = append(keys, key)
- vals = append(vals, os.Getenv(key))
- }
- wasiConfig.SetEnv(keys, vals)
- store := wasmtime.NewStore(engine)
- store.SetWasi(wasiConfig)
- linkRegion := trace.StartRegion(ctx, "linker.DefineModule")
- err = linker.DefineModule(store, "", module)
- linkRegion.End()
- if err != nil {
- return fmt.Errorf("define wasi: %w", err)
- }
- // Run the function
- fn, err := linker.GetDefault(store, "")
- if err != nil {
- return fmt.Errorf("wasi: get default: %w", err)
- }
- callRegion := trace.StartRegion(ctx, "call _start")
- _, err = fn.Call(store)
- callRegion.End()
- if cerr := checkError(err, stderrPath); cerr != nil {
- return cerr
- }
- // Print WASM stdout
- stdoutBlob, err := os.ReadFile(stdoutPath)
- if err != nil {
- return fmt.Errorf("read file: %w", err)
- }
- resp, ok := reply.(protoreflect.ProtoMessage)
- if !ok {
- return fmt.Errorf("reply isn't a GenerateResponse")
- }
- if err := proto.Unmarshal(stdoutBlob, resp); err != nil {
- return err
- }
- return nil
- }
- func (r *Runner) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
- return nil, status.Error(codes.Unimplemented, "")
- }
- func checkError(err error, stderrPath string) error {
- if err == nil {
- return err
- }
- var wtError *wasmtime.Error
- if errors.As(err, &wtError) {
- if code, ok := wtError.ExitStatus(); ok {
- if code == 0 {
- return nil
- }
- }
- }
- // Print WASM stdout
- stderrBlob, rferr := os.ReadFile(stderrPath)
- if rferr == nil && len(stderrBlob) > 0 {
- return errors.New(string(stderrBlob))
- }
- return fmt.Errorf("call: %w", err)
- }
|