package ast import ( "fmt" "os" "sync" "github.com/dominikbraun/graph" "github.com/dominikbraun/graph/draw" "golang.org/x/sync/errgroup" ) type TaskfileGraph struct { sync.Mutex graph.Graph[string, *TaskfileVertex] } // A TaskfileVertex is a vertex on the Taskfile DAG. type TaskfileVertex struct { URI string Taskfile *Taskfile } func taskfileHash(vertex *TaskfileVertex) string { return vertex.URI } func NewTaskfileGraph() *TaskfileGraph { return &TaskfileGraph{ sync.Mutex{}, graph.New(taskfileHash, graph.Directed(), graph.PreventCycles(), graph.Rooted(), ), } } func (tfg *TaskfileGraph) Visualize(filename string) error { f, err := os.Create(filename) if err != nil { return err } defer f.Close() return draw.DOT(tfg.Graph, f) } func (tfg *TaskfileGraph) Merge() (*Taskfile, error) { hashes, err := graph.TopologicalSort(tfg.Graph) if err != nil { return nil, err } predecessorMap, err := tfg.PredecessorMap() if err != nil { return nil, err } // Loop over each vertex in reverse topological order except for the root vertex. // This gives us a loop over every included Taskfile in an order which is safe to merge. for i := len(hashes) - 1; i > 0; i-- { hash := hashes[i] // Get the included vertex includedVertex, err := tfg.Vertex(hash) if err != nil { return nil, err } // Create an error group to wait for all the included Taskfiles to be merged with all its parents var g errgroup.Group // Loop over edge that leads to a vertex that includes the current vertex for _, edge := range predecessorMap[hash] { // Start a goroutine to process each included Taskfile g.Go(func() error { // Get the base vertex vertex, err := tfg.Vertex(edge.Source) if err != nil { return err } // Get the merge options includes, ok := edge.Properties.Data.([]*Include) if !ok { return fmt.Errorf("task: Failed to get merge options") } // Merge the included Taskfiles into the parent Taskfile for _, include := range includes { if err := vertex.Taskfile.Merge( includedVertex.Taskfile, include, ); err != nil { return err } } return nil }) if err := g.Wait(); err != nil { return nil, err } } // Wait for all the go routines to finish if err := g.Wait(); err != nil { return nil, err } } // Get the root vertex rootVertex, err := tfg.Vertex(hashes[0]) if err != nil { return nil, err } _ = rootVertex.Taskfile.Tasks.Range(func(name string, task *Task) error { if task == nil { task = &Task{} rootVertex.Taskfile.Tasks.Set(name, task) } task.Task = name return nil }) return rootVertex.Taskfile, nil }