diff --git a/pipe/command.go b/pipe/command.go index 908516b..eb2a847 100644 --- a/pipe/command.go +++ b/pipe/command.go @@ -81,8 +81,8 @@ func (s *commandStage) Process() *os.Process { func (s *commandStage) Requirements() StageRequirements { return StageRequirements{ - StdinNeedsFile: true, - StdoutNeedsFile: true, + Stdin: StreamPreferFile, + Stdout: StreamPreferFile, } } diff --git a/pipe/pipe_matching_test.go b/pipe/pipe_matching_test.go index badb990..89c5ae7 100644 --- a/pipe/pipe_matching_test.go +++ b/pipe/pipe_matching_test.go @@ -49,14 +49,10 @@ func writeCloser() io.WriteCloser { } func newPipeSniffingStage( - stdinNeedsFile bool, stdinExpectation ioExpectation, - stdoutNeedsFile bool, stdoutExpectation ioExpectation, + req pipe.StageRequirements, stdinExpectation, stdoutExpectation ioExpectation, ) *pipeSniffingStage { return &pipeSniffingStage{ - requirements: pipe.StageRequirements{ - StdinNeedsFile: stdinNeedsFile, - StdoutNeedsFile: stdoutNeedsFile, - }, + requirements: req, expect: pipeExpectations{ stdin: stdinExpectation, stdout: stdoutExpectation, @@ -68,8 +64,11 @@ func newPipeSniffingFunc( stdinExpectation, stdoutExpectation ioExpectation, ) *pipeSniffingStage { return newPipeSniffingStage( - false, stdinExpectation, - false, stdoutExpectation, + pipe.StageRequirements{ + Stdin: pipe.StreamAcceptAny, + Stdout: pipe.StreamAcceptAny, + }, + stdinExpectation, stdoutExpectation, ) } @@ -77,8 +76,11 @@ func newPipeSniffingCmd( stdinExpectation, stdoutExpectation ioExpectation, ) *pipeSniffingStage { return newPipeSniffingStage( - true, stdinExpectation, - true, stdoutExpectation, + pipe.StageRequirements{ + Stdin: pipe.StreamPreferFile, + Stdout: pipe.StreamPreferFile, + }, + stdinExpectation, stdoutExpectation, ) } @@ -325,16 +327,25 @@ func TestPipeTypes(t *testing.T) { opts: []pipe.Option{}, stages: []pipe.Stage{ newPipeSniffingStage( - false, expectNil, - false, expectOther, + pipe.StageRequirements{ + Stdin: pipe.StreamAcceptAny, + Stdout: pipe.StreamAcceptAny, + }, + expectNil, expectOther, ), newPipeSniffingStage( - false, expectOther, - true, expectFile, + pipe.StageRequirements{ + Stdin: pipe.StreamAcceptAny, + Stdout: pipe.StreamPreferFile, + }, + expectOther, expectFile, ), newPipeSniffingStage( - false, expectFile, - false, expectNil, + pipe.StageRequirements{ + Stdin: pipe.StreamAcceptAny, + Stdout: pipe.StreamAcceptAny, + }, + expectFile, expectNil, ), }, }, @@ -343,16 +354,25 @@ func TestPipeTypes(t *testing.T) { opts: []pipe.Option{}, stages: []pipe.Stage{ newPipeSniffingStage( - false, expectNil, - false, expectFile, + pipe.StageRequirements{ + Stdin: pipe.StreamAcceptAny, + Stdout: pipe.StreamAcceptAny, + }, + expectNil, expectFile, ), newPipeSniffingStage( - true, expectFile, - false, expectOther, + pipe.StageRequirements{ + Stdin: pipe.StreamPreferFile, + Stdout: pipe.StreamAcceptAny, + }, + expectFile, expectOther, ), newPipeSniffingStage( - false, expectOther, - false, expectNil, + pipe.StageRequirements{ + Stdin: pipe.StreamAcceptAny, + Stdout: pipe.StreamAcceptAny, + }, + expectOther, expectNil, ), }, }, diff --git a/pipe/pipeline.go b/pipe/pipeline.go index 085fdd3..44fedd1 100644 --- a/pipe/pipeline.go +++ b/pipe/pipeline.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "io" - "os" "sync/atomic" ) @@ -218,50 +217,6 @@ func (p *Pipeline) AddWithIgnoredError(em ErrorMatcher, stages ...Stage) { } } -type stageStarter struct { - requirements StageRequirements - stdin *InputStream - stdout *OutputStream -} - -func (requirement StreamRequirement) validate() error { - switch requirement { - case StreamOptional, StreamForbidden: - return nil - default: - return fmt.Errorf("invalid stream requirement %d", requirement) - } -} - -func (requirements StageRequirements) validate(s Stage, stdinConnected, stdoutConnected bool) error { - if err := requirements.Stdin.validate(); err != nil { - return fmt.Errorf("stdin: %w", err) - } - if err := requirements.Stdout.validate(); err != nil { - return fmt.Errorf("stdout: %w", err) - } - if requirements.Stdin == StreamForbidden && stdinConnected { - return fmt.Errorf("stage %q forbids stdin, but stdin is connected", s.Name()) - } - if requirements.Stdout == StreamForbidden && stdoutConnected { - return fmt.Errorf("stage %q forbids stdout, but stdout is connected", s.Name()) - } - return nil -} - -func (p *Pipeline) abortBeforeStart(s Stage, err error) error { - _ = p.stdout.Close() - p.cancel() - p.eventHandler(&Event{ - Command: s.Name(), - Msg: "failed to start pipeline stage", - Err: err, - }) - return fmt.Errorf( - "starting pipeline stage %q: %w", s.Name(), err, - ) -} - func (p *Pipeline) stageOptions() StageOptions { return StageOptions{Env: p.env, PanicHandler: p.panicHandler} } @@ -309,50 +264,68 @@ func (p *Pipeline) Start(ctx context.Context) error { // We need to decide how to start the stages, especially what // pipes to use to connect adjacent stages (`os.Pipe()` vs. // `io.Pipe()`) based on the two stages' requirements. - stageStarters := make([]stageStarter, len(p.stages)) + stageJoiners := make([]stageJoiner, len(p.stages)+1) + + // Arrange for the input of the 0th stage to come from `p.stdin`: + stageJoiners[0].nextStdin = p.stdin + + // Arrange for the output of the last stage to go to `p.stdout`: + stageJoiners[len(p.stages)].prevStdout = p.stdout + + // closePipes closes all of the streams that are currently stored + // in the joiners. This should be called if startup fails. As we + // call `Stage.Start()` and pass that method streams, we clear + // them from the corresponding joiners to avoid closing them + // twice. + closePipes := func() { + for _, sj := range stageJoiners { + _ = sj.closePipe() + } + } - // Collect information about each stage's type and requirements: + // Store the stages in the joiners, and verify that the stages' + // requirements are well-formed: for i, s := range p.stages { - stageStarters[i].requirements = s.Requirements() - - err := stageStarters[i].requirements.validate( - s, - i > 0 || p.stdin != nil, - i < len(p.stages)-1 || p.stdout != nil, - ) - if err != nil { - return p.abortBeforeStart(s, err) + // Make sure that the stage's requirements are well-formed: + requirements := s.Requirements() + if err := requirements.Stdin.Validate(); err != nil { + return fmt.Errorf("stdin: %w", err) + } + if err := requirements.Stdout.Validate(); err != nil { + return fmt.Errorf("stdout: %w", err) } - } - if p.stdin != nil { - // Arrange for the input of the 0th stage to come from - // `p.stdin`: - stageStarters[0].stdin = p.stdin + stageJoiners[i].nextStage = s + stageJoiners[i].nextStageReq = requirements + stageJoiners[i+1].prevStage = s + stageJoiners[i+1].prevStageReq = requirements } - if p.stdout != nil { - // Arrange for the output of the last stage to go to - // `p.stdout`: - stageStarters[len(p.stages)-1].stdout = p.stdout + // Check that each of the stages' requirements are satisfiable: + for i := range stageJoiners { + if err := stageJoiners[i].validate(); err != nil { + closePipes() + return err + } } - // Clean up any processes and pipes that have been created. `i` is the - // index of the stage that failed to start. If the stage already received - // its streams, it owns any closing stream. - abort := func(i int, err error, closeFailedStageStdin bool) error { - // If the failing stage never received its stdin, close the pipe that - // the previous stage was writing to. That should cause it to exit - // even if it's not minding its context. - if closeFailedStageStdin { - _ = stageStarters[i].stdin.Close() + // Create the "inner" pipes (i.e, all but the first and last + // `stageJoiners`): + for i := 1; i < len(stageJoiners)-1; i++ { + if err := stageJoiners[i].createPipe(); err != nil { + closePipes() + return err } + } - // If stdout was supplied with WithStdoutCloser but the final stage - // was never started, then the pipeline still owns that closer. - if i < len(p.stages)-1 { - _ = p.stdout.Close() - } + // We're about to start up the stages, one by one. If something + // goes wrong during that process, this function should be called + // to kill any stages that have already been started and to close + // any pipes that have not yet been passed to a stage. `i` is the + // index of the stage that failed to start. If the stage already + // received its streams, it is responsible for closing them. + abort := func(i int, err error) error { + closePipes() // Kill and wait for any stages that have been started // already to finish: @@ -370,51 +343,20 @@ func (p *Pipeline) Start(ctx context.Context) error { ) } - // Loop over all but the last stage, starting them. By the time we - // get to a stage, its stdin will have already been determined, - // but we still need to figure out its stdout and set the stdin - // that will be used for the subsequent stage. - for i, s := range p.stages[:len(p.stages)-1] { - ss := &stageStarters[i] - nextSS := &stageStarters[i+1] - - // We need to generate a pipe pair for this stage to use - // to communicate with its successor: - if ss.requirements.StdoutNeedsFile || nextSS.requirements.StdinNeedsFile { - // Use an OS-level pipe for the communication: - nextStdin, stdout, err := os.Pipe() - if err != nil { - return abort(i, err, true) - } - nextSS.stdin = ClosingInput(nextStdin) - ss.stdout = ClosingOutput(stdout) - } else { - nextStdin, stdout := io.Pipe() - nextSS.stdin = ClosingInput(nextStdin) - ss.stdout = ClosingOutput(stdout) - } - if err := s.Start( - ctx, p.stageOptions(), - ss.stdin, ss.stdout, - ); err != nil { - _ = nextSS.stdin.Close() - return abort(i, err, false) - } - } + // Loop over all of the stages, starting them in order. + for i, s := range p.stages { + prevSJ := &stageJoiners[i] + nextSJ := &stageJoiners[i+1] - // The last stage needs special handling, because its stdout - // doesn't need to flow into another stage (it's already set in - // `ss.stdout` if it's needed). - { - i := len(p.stages) - 1 - s := p.stages[i] - ss := &stageStarters[i] + err := s.Start(ctx, p.stageOptions(), prevSJ.nextStdin, nextSJ.prevStdout) - if err := s.Start( - ctx, p.stageOptions(), - ss.stdin, ss.stdout, - ); err != nil { - return abort(i, err, false) + // Even if that stage failed to start, we are no longer + // responsible for closing its streams: + prevSJ.nextStdin = nil + nextSJ.prevStdout = nil + + if err != nil { + return abort(i, err) } } diff --git a/pipe/stage.go b/pipe/stage.go index 9a35254..0a086ba 100644 --- a/pipe/stage.go +++ b/pipe/stage.go @@ -155,25 +155,11 @@ type StageOptions struct { // StagePanicHandler is a function that handles panics in a pipeline's stages. type StagePanicHandler func(p any) error -type StreamRequirement int - -const ( - // StreamOptional means the stream may be connected or nil. - StreamOptional StreamRequirement = iota - - // StreamForbidden means the stream must be nil. - StreamForbidden -) - -// StageRequirements describes what a Stage needs from the streams connected to -// its stdin and stdout. The zero value is correct for stages that are happy -// with arbitrary io.Reader/io.Writer streams, such as Function stages. +// StageRequirements describes what a Stage needs from the streams +// connected to its stdin and stdout. The zero value is correct for +// stages that are happy with arbitrary io.Reader/io.Writer streams, +// such as Function stages. type StageRequirements struct { Stdin StreamRequirement Stdout StreamRequirement - - // {Stdin,Stdout}NeedsFile indicate that, if stdio is connected, the - // stage requires it to be backed by an *os.File (a real file descriptor) - StdinNeedsFile bool - StdoutNeedsFile bool } diff --git a/pipe/stage_joiner.go b/pipe/stage_joiner.go new file mode 100644 index 0000000..03569cb --- /dev/null +++ b/pipe/stage_joiner.go @@ -0,0 +1,133 @@ +package pipe + +import ( + "errors" + "fmt" + "io" + "os" +) + +// stageJoiner is a helper type that helps join two adjacent stages +// together. stageJoiners[i] tells how to connect stage `i-1` to stage +// `i`. From the point of view of stages, `stageJoiners[i].nextStdin` +// and `stageJoiners[i+1].prevStdout` are the input and output +// streams, respectively, of `stage[i]`. The first and last elements +// of `stageJoiners` manage `p.stdin` and `p.stdout`, respectively. +// Schematically, the data flows through like this: +// +// p.stdin == stageJoiners[0].nextStdin → +// stage[0] → +// stageJoiners[1].prevStdout → stageJoiners[1].nextStdin → +// stage[1] → +// stageJoiners[2].prevStdout → stageJoiners[2].nextStdin → +// stage[2] → +// ... → +// stageJoiners[i].prevStdout → stageJoiners[i].nextStdin → +// stage[i] → +// stageJoiners[i+1].prevStdout → stageJoiners[i+1].nextStdin → +// ... → +// stageJoiners[len(stages)-1].prevStdout → stageJoiners[len(stages)-1].nextStdin → +// stage[len(stages)-1] → +// stageJoiners[len(stages)].prevStdout == p.stdout +// +// In pseudo-Shell notation, the stages are run like this: +// +// stage[0] stageJoiners[1].prevStdout +// stage[1] stageJoiners[2].prevStdout +// stage[2] stageJoiners[3].prevStdout +// ... +// stage[i] stageJoiners[i].prevStdout +// ... +// stage[len(stages)-1] p.stdout +type stageJoiner struct { + // prevStage holds the stage that needs to write to the pipe. + prevStage Stage + + // prevStageReq caches `prevStage.Requirements()` so that it + // doesn't have to be recomputed. It is the zero value if + // `prevStage` is nil. + prevStageReq StageRequirements + + // prevStdout will be used as the stdout of `prevStage`. It is + // usually the "write" end of the `(nextStdin, prevStdout)` pipe + // pair, with the connected pipe ends in the same `stageJoiner` + // instance. + prevStdout *OutputStream + + // nextStage holds the stage that needs to read from the pipe. + nextStage Stage + + // nextStageReq caches `nextStage.Requirements()` so that it + // doesn't have to be recomputed. It is the zero value if + // `nextStage` is nil. + nextStageReq StageRequirements + + // nextStdin will be used as the stdin of `nextStage`. It is + // usually the "read" end of the `(nextStdin, prevStdout)` pipe + // pair. + nextStdin *InputStream +} + +// needFilePipe returns `true` if the pipe that joins the two adjacent +// stages should be an `os.Pipe()` rather than an `io.Pipe()`. +func (sj *stageJoiner) needFilePipe() bool { + return sj.prevStageReq.Stdout == StreamPreferFile || + sj.nextStageReq.Stdin == StreamPreferFile +} + +func (sj *stageJoiner) createPipe() error { + var r io.ReadCloser + var w io.WriteCloser + if sj.needFilePipe() { + var err error + r, w, err = os.Pipe() + if err != nil { + return fmt.Errorf("creating os.Pipe: %w", err) + } + } else { + r, w = io.Pipe() + } + + sj.prevStdout = ClosingOutput(w) + sj.nextStdin = ClosingInput(r) + + return nil +} + +// closePipe closes both ends of the pipe that was allocated by +// `createPipe()`. This should only be called if the corresponding +// stage's `Start()` method was never called (otherwise the stage is +// responsible for closing its stdin and stdout). +func (sj *stageJoiner) closePipe() error { + return errors.Join( + sj.prevStdout.Close(), + sj.nextStdin.Close(), + ) +} + +// validate verifies that the adjacent stages' stream requirements are +// satisfiable, in particular that a stage that forbids its stdin or +// stdout is not connected to anything. +func (sj *stageJoiner) validate() error { + // `prevStage`'s stdout is connected if there is a `nextStage` to + // consume it (in which case an inner pipe will be created) or if + // a stream (`p.stdout`) has already been stored in `prevStdout`. + if sj.prevStage != nil && sj.prevStageReq.Stdout == StreamForbidden && + (sj.nextStage != nil || sj.prevStdout != nil) { + return fmt.Errorf( + "stage %q forbids stdout, but stdout is connected", sj.prevStage.Name(), + ) + } + + // `nextStage`'s stdin is connected if there is a `prevStage` to + // produce it (in which case an inner pipe will be created) or if + // a stream (`p.stdin`) has already been stored in `nextStdin`. + if sj.nextStage != nil && sj.nextStageReq.Stdin == StreamForbidden && + (sj.prevStage != nil || sj.nextStdin != nil) { + return fmt.Errorf( + "stage %q forbids stdin, but stdin is connected", sj.nextStage.Name(), + ) + } + + return nil +} diff --git a/pipe/stream_requirement.go b/pipe/stream_requirement.go new file mode 100644 index 0000000..ddff829 --- /dev/null +++ b/pipe/stream_requirement.go @@ -0,0 +1,38 @@ +package pipe + +import "fmt" + +// StreamRequirement describes a `Stage`'s requirement for its stdin +// or stdout, namely whether it can be anything, whether it should +// preferably be an `*os.File`, or whether it must be `nil`. The zero +// value `StreamAcceptAny` is a valid value that indicates that the +// stage has no particular requirements or preferences for its +// stdin/stdout, such as a typical `Function` stage. +type StreamRequirement int + +const ( + // StreamAcceptAny indicates that the stage hasn't declared what + // kind of stream it requires, maybe even `nil`. + StreamAcceptAny StreamRequirement = iota + + // StreamPreferFile indicates that the stage prefers the + // corresponding stream to be backed by an `*os.File` (a real file + // descriptor), but it can work with any io.Reader/io.Writer. + StreamPreferFile + + // StreamForbidden indicates that the stage requires the + // corresponding stream to be nil. It won't read/write the stream + // or close it. + StreamForbidden +) + +// Validate checks that `req` has a valid value and returns an error +// otherwise. +func (requirement StreamRequirement) Validate() error { + switch requirement { + case StreamAcceptAny, StreamPreferFile, StreamForbidden: + return nil + default: + return fmt.Errorf("invalid stream requirement %d", requirement) + } +}