Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmd/shelldoc/cmd/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ executes them and compares their output with the content of the code block.`,
func init() {
runCmd.Flags().StringVarP(&runContext.ShellName, "shell", "s", "", "The shell to invoke (default: $SHELL)")
runCmd.Flags().BoolVarP(&runContext.FailureStops, "fail", "f", false, "Stop on the first failure")
runCmd.Flags().BoolVarP(&runContext.MergeStderr, "merge-stderr", "m", false, "Merge stderr into stdout (2>&1) instead of capturing separately")
runCmd.Flags().StringVarP(&runContext.XMLOutputFile, "xml", "x", "", "Write results to the specified output file in JUnitXML format")
runCmd.Flags().BoolVarP(&runContext.ReplaceDots, "replace-dots-in-xml-classname", "d", true, "When using filenames as classnames, replace dots with a unicode circle")
runCmd.Flags().BoolVarP(&runContext.DryRun, "dry-run", "n", false, "Preview commands without executing them")
Expand Down
2 changes: 2 additions & 0 deletions pkg/junitxml/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ type JUnitTestCase struct {
SkipMessage *JUnitSkipMessage `xml:"skipped,omitempty"`
Failure *JUnitFailure `xml:"failure,omitempty"`
Error *JUnitError `xml:"error,omitempty"`
SystemOut string `xml:"system-out,omitempty"`
SystemErr string `xml:"system-err,omitempty"`
}

// JUnitSkipMessage contains the reason why a testcase was skipped.
Expand Down
1 change: 1 addition & 0 deletions pkg/run/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ type Context struct {
Verbose bool
FailureStops bool
XMLOutputFile string
MergeStderr bool
ReplaceDots bool
DryRun bool
Timeout time.Duration
Expand Down
6 changes: 5 additions & 1 deletion pkg/run/interactions.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func (runCtx *Context) performInteractions(ctx context.Context, inputfile string
return nil, err
}
// start a background shell, it will run until the function ends
currentShell, err := shell.StartShell(shellpath)
currentShell, err := shell.StartShell(shellpath, runCtx.MergeStderr)
if err != nil {
return nil, fmt.Errorf("unable to start shell: %v", err)
}
Expand Down Expand Up @@ -106,6 +106,10 @@ func (runCtx *Context) performInteractions(ctx context.Context, inputfile string
}
testcase, err := runCtx.performTestCase(ctx, interaction, &currentShell)
testcase.Classname = inputfile // testcase is always returned, even if err is not nil
if len(runCtx.XMLOutputFile) > 0 {
testcase.SystemOut = strings.Join(interaction.Output, "\n")
testcase.SystemErr = strings.Join(interaction.ErrorOutput, "\n")
}
if runCtx.ReplaceDots {
testcase.Classname = strings.ReplaceAll(inputfile, ".", "●")
}
Expand Down
73 changes: 52 additions & 21 deletions pkg/shell/shell.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ var ErrCancelled = errors.New("command execution cancelled")

// Shell represents the shell process that runs in the background and executes the commands.
type Shell struct {
cmd *exec.Cmd
stdin io.WriteCloser
stdout io.ReadCloser
cmd *exec.Cmd
stdin io.WriteCloser
stdout io.ReadCloser
mergeStderr bool
}

// DetectShell returns the path to the selected shell or the content of $SHELL
Expand All @@ -48,8 +49,10 @@ func DetectShell(selected string) (string, error) {
return selected, nil
}

// StartShell starts a shell as a background process
func StartShell(shell string) (Shell, error) {
// StartShell starts a shell as a background process.
// When mergeStderr is true, stderr from each command is redirected into stdout (2>&1).
// When false, stderr is captured separately via a temp file and returned alongside stdout.
func StartShell(shell string, mergeStderr bool) (Shell, error) {
cmd := exec.Command(shell)
stdin, err := cmd.StdinPipe()
if err != nil {
Expand All @@ -63,37 +66,55 @@ func StartShell(shell string) (Shell, error) {
if err != nil {
return Shell{}, fmt.Errorf("Unable to start shell %s: %v", shell, err)
}
return Shell{cmd, stdin, stdout}, nil
return Shell{cmd, stdin, stdout, mergeStderr}, nil
}

// commandResult holds the result of a command execution
type commandResult struct {
output []string
stdout []string
stderr []string
rc int
err error
}

// ExecuteCommand runs a command in the shell and returns its output and exit code.
// ExecuteCommand runs a command in the shell and returns its stdout, stderr, exit code, and any error.
// The context can be used to cancel execution (e.g., on SIGINT).
// The timeout parameter specifies a per-command timeout (0 means no timeout).
func (shell *Shell) ExecuteCommand(ctx context.Context, command string, timeout time.Duration) ([]string, int, error) {
// When shell.mergeStderr is true, stderr is redirected into stdout via 2>&1 and the returned stderr slice is nil.
// When false, stderr is captured to a temp file and returned separately.
func (shell *Shell) ExecuteCommand(ctx context.Context, command string, timeout time.Duration) ([]string, []string, int, error) {
const (
beginMarker = ">>>>>>>>>>SHELLDOC_MARKER>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>"
endMarker = "<<<<<<<<<<SHELLDOC_MARKER"
)
instruction := fmt.Sprintf("%s", strings.TrimSpace(command))
io.WriteString(shell.stdin, fmt.Sprintf("echo \"%s\"\n", beginMarker))
io.WriteString(shell.stdin, fmt.Sprintf("%s; echo \"%s $?\"\n", instruction, endMarker))

// Build the instruction, routing stderr as configured.
trimmed := strings.TrimSpace(command)
var stderrFile string
var instruction string
if shell.mergeStderr {
instruction = fmt.Sprintf("{ %s; } 2>&1; echo \"%s $?\"\n", trimmed, endMarker)
} else {
f, err := os.CreateTemp("", "shelldoc_stderr_*")
if err == nil {
stderrFile = f.Name()
f.Close()
}
instruction = fmt.Sprintf("{ %s; } 2>%s; echo \"%s $?\"\n", trimmed, stderrFile, endMarker)
}

beginEx := fmt.Sprintf("^%s$", beginMarker)
beginRx := regexp.MustCompile(beginEx)
endEx := fmt.Sprintf("^%s (.+)$", endMarker)
endRx := regexp.MustCompile(endEx)

io.WriteString(shell.stdin, fmt.Sprintf("echo \"%s\"\n", beginMarker))
io.WriteString(shell.stdin, instruction)

// Run the scanner in a goroutine to support timeout and cancellation
resultCh := make(chan commandResult, 1)
go func() {
var output []string
var stdout []string
var rc int
beginFound := false
scanner := bufio.NewScanner(shell.stdout)
Expand All @@ -110,35 +131,45 @@ func (shell *Shell) ExecuteCommand(ctx context.Context, command string, timeout
if len(match) > 1 {
value, err := strconv.Atoi(match[1])
if err != nil {
resultCh <- commandResult{nil, -1, fmt.Errorf("unable to read exit code for shell command: %v", err)}
resultCh <- commandResult{nil, nil, -1, fmt.Errorf("unable to read exit code for shell command: %v", err)}
return
}
rc = value
break
}
output = append(output, line)
stdout = append(stdout, line)
}
// Read stderr from temp file if capturing separately
var stderr []string
if stderrFile != "" {
if data, err := os.ReadFile(stderrFile); err == nil {
os.Remove(stderrFile)
if len(data) > 0 {
stderr = strings.Split(strings.TrimRight(string(data), "\n"), "\n")
}
}
}
resultCh <- commandResult{output, rc, nil}
resultCh <- commandResult{stdout, stderr, rc, nil}
}()

// Wait for result, timeout, or context cancellation
if timeout > 0 {
select {
case result := <-resultCh:
return result.output, result.rc, result.err
return result.stdout, result.stderr, result.rc, result.err
case <-time.After(timeout):
return nil, -1, ErrTimeout
return nil, nil, -1, ErrTimeout
case <-ctx.Done():
return nil, -1, ErrCancelled
return nil, nil, -1, ErrCancelled
}
}

// No timeout specified, wait for result or context cancellation
select {
case result := <-resultCh:
return result.output, result.rc, result.err
return result.stdout, result.stderr, result.rc, result.err
case <-ctx.Done():
return nil, -1, ErrCancelled
return nil, nil, -1, ErrCancelled
}
}

Expand Down
80 changes: 61 additions & 19 deletions pkg/shell/shell_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,43 +22,43 @@ func TestMain(m *testing.M) {
}
func TestShellLifeCycle(t *testing.T) {
// The most basic test, start a shell and exit it again
shell, err := StartShell(shellpath)
shell, err := StartShell(shellpath, false)
require.NoError(t, err, "Starting a shell should work")
require.NoError(t, shell.Exit(), "Exiting ad running shell should work")
}

func TestShellLifeCycleRepeated(t *testing.T) {
// Can the program start and stop a shell repeatedly?
for counter := 0; counter < 16; counter++ {
shell, err := StartShell(shellpath)
shell, err := StartShell(shellpath, false)
require.NoError(t, err, "Starting a shell should work")
require.NoError(t, shell.Exit(), "Exiting ad running shell should work")
}
}

func TestReturnCodes(t *testing.T) {
// Does the shell report return codes corrrectly?
shell, err := StartShell(shellpath)
shell, err := StartShell(shellpath, false)
require.NoError(t, err, "Starting a shell should work")
defer shell.Exit()
ctx := context.Background()
{
output, rc, err := shell.ExecuteCommand(ctx, "true", 0)
stdout, _, rc, err := shell.ExecuteCommand(ctx, "true", 0)
require.NoError(t, err, "The true command is a builtin and should always work")
require.Equal(t, 0, rc, "The exit code of true should always be zero")
require.Empty(t, output, "true does not say a word")
require.Empty(t, stdout, "true does not say a word")
}
{
output, rc, err := shell.ExecuteCommand(ctx, "false", 0)
stdout, _, rc, err := shell.ExecuteCommand(ctx, "false", 0)
require.NoError(t, err, "The false command is a builtin and should always work")
require.NotEqual(t, 0, rc, "The exit code of false should never be zero")
require.Empty(t, output, "false does not say a word")
require.Empty(t, stdout, "false does not say a word")
}
}

func TestCaptureOutput(t *testing.T) {
// Does the shell capture and return the lines printed by the command correctly?
shell, err := StartShell(shellpath)
shell, err := StartShell(shellpath, false)
require.NoError(t, err, "Starting a shell should work")
defer shell.Exit()
ctx := context.Background()
Expand All @@ -67,48 +67,90 @@ func TestCaptureOutput(t *testing.T) {
hello = "Hello"
world = "World"
)
output, rc, err := shell.ExecuteCommand(ctx, fmt.Sprintf("echo %s && echo %s", hello, world), 0)
stdout, _, rc, err := shell.ExecuteCommand(ctx, fmt.Sprintf("echo %s && echo %s", hello, world), 0)
require.NoError(t, err, "The echo command is a builtin and should always work")
require.Equal(t, 0, rc, "The exit code of echo should be zero")
require.Len(t, output, 2, "echo was called twice")
require.Equal(t, output[0], hello, "you had one job, echo")
require.Equal(t, output[1], world, "actually, two")
require.Len(t, stdout, 2, "echo was called twice")
require.Equal(t, stdout[0], hello, "you had one job, echo")
require.Equal(t, stdout[1], world, "actually, two")
}
}

func TestTimeout(t *testing.T) {
// Does the timeout work correctly?
shell, err := StartShell(shellpath)
shell, err := StartShell(shellpath, false)
require.NoError(t, err, "Starting a shell should work")
defer shell.Kill() // Use Kill since shell may be in inconsistent state after timeout
ctx := context.Background()

// Command that completes within timeout should succeed
output, rc, err := shell.ExecuteCommand(ctx, "echo quick", 5*time.Second)
stdout, _, rc, err := shell.ExecuteCommand(ctx, "echo quick", 5*time.Second)
require.NoError(t, err, "Fast command should not timeout")
require.Equal(t, 0, rc)
require.Equal(t, []string{"quick"}, output)
require.Equal(t, []string{"quick"}, stdout)
}

func TestTimeoutExpires(t *testing.T) {
// Does timeout trigger correctly for slow commands?
shell, err := StartShell(shellpath)
shell, err := StartShell(shellpath, false)
require.NoError(t, err, "Starting a shell should work")
defer shell.Kill()
ctx := context.Background()

// Command that takes longer than timeout should fail
start := time.Now()
_, _, err = shell.ExecuteCommand(ctx, "sleep 10", 100*time.Millisecond)
_, _, _, err = shell.ExecuteCommand(ctx, "sleep 10", 100*time.Millisecond)
elapsed := time.Since(start)

require.ErrorIs(t, err, ErrTimeout, "Slow command should timeout")
require.Less(t, elapsed, 1*time.Second, "Timeout should trigger quickly, not wait for command")
}

func TestCaptureStderr(t *testing.T) {
// Does the shell capture stderr separately from stdout?
shell, err := StartShell(shellpath, false)
require.NoError(t, err, "Starting a shell should work")
defer shell.Exit()
ctx := context.Background()

stdout, stderr, rc, err := shell.ExecuteCommand(ctx, "echo out && echo err >&2", 0)
require.NoError(t, err)
require.Equal(t, 0, rc)
require.Equal(t, []string{"out"}, stdout, "stdout should contain only the stdout line")
require.Equal(t, []string{"err"}, stderr, "stderr should contain only the stderr line")
}

func TestMergeStderr(t *testing.T) {
// Does --merge-stderr combine stderr into stdout?
shell, err := StartShell(shellpath, true)
require.NoError(t, err, "Starting a shell should work")
defer shell.Exit()
ctx := context.Background()

stdout, stderr, rc, err := shell.ExecuteCommand(ctx, "echo out && echo err >&2", 0)
require.NoError(t, err)
require.Equal(t, 0, rc)
require.Contains(t, stdout, "out", "stdout should contain the stdout line")
require.Contains(t, stdout, "err", "merged stderr should appear in stdout")
require.Empty(t, stderr, "stderr slice should be empty when merging")
}

func TestStderrDoesNotPollutestdout(t *testing.T) {
// Stderr output must not bleed into stdout when captured separately.
shell, err := StartShell(shellpath, false)
require.NoError(t, err, "Starting a shell should work")
defer shell.Exit()
ctx := context.Background()

stdout, _, rc, err := shell.ExecuteCommand(ctx, "echo only-stdout && echo only-stderr >&2", 0)
require.NoError(t, err)
require.Equal(t, 0, rc)
require.Equal(t, []string{"only-stdout"}, stdout, "stderr must not appear in stdout")
}

func TestContextCancellation(t *testing.T) {
// Does context cancellation work correctly?
shell, err := StartShell(shellpath)
shell, err := StartShell(shellpath, false)
require.NoError(t, err, "Starting a shell should work")
defer shell.Kill()

Expand All @@ -118,6 +160,6 @@ func TestContextCancellation(t *testing.T) {
cancel()

// Command should fail with ErrCancelled
_, _, err = shell.ExecuteCommand(ctx, "sleep 10", 0)
_, _, _, err = shell.ExecuteCommand(ctx, "sleep 10", 0)
require.ErrorIs(t, err, ErrCancelled, "Command should be cancelled")
}
7 changes: 5 additions & 2 deletions pkg/tokenizer/interaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@ type Interaction struct {
ResultCode int
// Comment contains an explanation of the ResultCode after execution
Comment string
// Output contains the output of the interaction after it has been executed as individual lines
// Output contains the stdout of the interaction after it has been executed as individual lines
Output []string
// ErrorOutput contains the stderr of the interaction after it has been executed as individual lines
ErrorOutput []string
}

// Describe returns a human-readable description of the interaction
Expand Down Expand Up @@ -162,8 +164,9 @@ func (interaction *Interaction) Execute(ctx context.Context, sh *shell.Shell, gl
}

// execute the command in the shell
output, rc, err := sh.ExecuteCommand(ctx, interaction.Cmd, timeout)
output, errOutput, rc, err := sh.ExecuteCommand(ctx, interaction.Cmd, timeout)
interaction.Output = output
interaction.ErrorOutput = errOutput
// compare the results
if err == shell.ErrCancelled {
interaction.ResultCode = ResultCancelled
Expand Down
Loading