Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 31 additions & 37 deletions ztest/ztest.go
Original file line number Diff line number Diff line change
Expand Up @@ -348,43 +348,35 @@ func (z *ZTest) ShouldSkip(path string) string {
}

func (z *ZTest) RunScript(ctx context.Context, shellPath, testDir string, tempDir func() string) error {
return z.run(func(r exec.Runtime) error {
return z.runScript(ctx, shellPath, testDir, tempDir(), r)
})
}

func (z *ZTest) RunInternal(ctx context.Context) error {
return z.run(func(r exec.Runtime) error {
return z.diffInternal(z.runInternal(ctx, r))
})
}

func (z *ZTest) run(fn func(exec.Runtime) error) error {
if err := z.check(); err != nil {
return fmt.Errorf("bad yaml format: %w", err)
}
serr := runsh(ctx, shellPath, testDir, tempDir(), z)
serr := fn(exec.RuntimeSAM)
if !z.Vector {
return serr
}
if serr != nil {
serr = fmt.Errorf("=== sequence ===\n%w", serr)
}
verr := runsh(ctx, shellPath, testDir, tempDir(), z, "SUPER_RUNTIME=vam")
verr := fn(exec.RuntimeVAM)
if verr != nil {
verr = fmt.Errorf("=== vector ===\n%w", verr)
}
return errors.Join(serr, verr)
}

func (z *ZTest) RunInternal(ctx context.Context) error {
if err := z.check(); err != nil {
return fmt.Errorf("bad yaml format: %w", err)
}
outputFlags := append([]string{"-f=sup", "-pretty=0"}, strings.Fields(z.OutputFlags)...)
inputFlags := strings.Fields(z.InputFlags)
if z.Vector {
verr := z.diffInternal(runInternal(ctx, z.SPQ, z.Input, outputFlags, inputFlags, true))
if verr != nil {
verr = fmt.Errorf("=== vector ===\n%w", verr)
}
serr := z.diffInternal(runInternal(ctx, z.SPQ, z.Input, outputFlags, inputFlags, false))
if serr != nil {
serr = fmt.Errorf("=== sequence ===\n%w", serr)
}
return errors.Join(verr, serr)
}
return z.diffInternal(runInternal(ctx, z.SPQ, z.Input, outputFlags, inputFlags, false))
}

func (z *ZTest) diffInternal(out string, err error) error {
var outDiffErr, errDiffErr error
if z.Output != out {
Expand Down Expand Up @@ -436,9 +428,9 @@ func diffErr(name, expected, actual string) error {
return fmt.Errorf("expected and actual %s differ:\n%s", name, diff)
}

func runsh(ctx context.Context, path, testDir, tempDir string, zt *ZTest, extraEnv ...string) error {
func (z *ZTest) runScript(ctx context.Context, path, testDir, tempDir string, r exec.Runtime) error {
var stdin io.Reader
for _, f := range zt.Inputs {
for _, f := range z.Inputs {
b, _, err := f.load(testDir)
if err != nil {
return err
Expand All @@ -451,12 +443,16 @@ func runsh(ctx context.Context, path, testDir, tempDir string, zt *ZTest, extraE
return err
}
}
stdout, stderr, err := RunShell(ctx, tempDir, path, zt.Script, stdin, zt.Env, extraEnv)
var extraEnv []string
if r == exec.RuntimeVAM {
extraEnv = []string{"SUPER_RUNTIME=vam"}
}
stdout, stderr, err := RunShell(ctx, tempDir, path, z.Script, stdin, z.Env, extraEnv)
if err != nil {
return fmt.Errorf("script failed: %w\n=== stdout ===\n%s=== stderr ===\n%s",
err, stdout, stderr)
}
for _, f := range zt.Outputs {
for _, f := range z.Outputs {
var actual string
switch f.Name {
case "stdout":
Expand Down Expand Up @@ -484,20 +480,20 @@ func runsh(ctx context.Context, path, testDir, tempDir string, zt *ZTest, extraE
return nil
}

// runInternal runs query over input and returns the output. input
// may be in any format recognized by "super -i auto" and may be gzip-compressed.
// outputFlags may contain any flags accepted by cli/outputflags.Flags.
func runInternal(ctx context.Context, query string, input *string, outputFlags, inputFlags []string, vector bool) (string, error) {
ast, err := parser.ParseText(query)
func (z *ZTest) runInternal(ctx context.Context, r exec.Runtime) (string, error) {
ast, err := parser.ParseText(z.SPQ)
if err != nil {
return "", err
}
args := []string{"-f=sup", "-pretty=0"}
args = append(args, strings.Fields(z.OutputFlags)...)
args = append(args, strings.Fields(z.InputFlags)...)
var fs flag.FlagSet
var inflags inputflags.Flags
var outflags outputflags.Flags
inflags.SetFlags(&fs, true)
outflags.SetFlags(&fs)
if err := fs.Parse(append(inputFlags, outputFlags...)); err != nil {
if err := fs.Parse(args); err != nil {
return "", err
}
if err := inflags.Init(); err != nil {
Expand All @@ -507,17 +503,15 @@ func runInternal(ctx context.Context, query string, input *string, outputFlags,
return "", err
}
eng := storage.NewInternalEngine()
if input != nil {
if i := z.Input; i != nil {
ast.PrependFileScan([]string{"stdio:stdin"})
eng.AddReader("stdio:stdin", strings.NewReader(*input))
eng.AddReader("stdio:stdin", strings.NewReader(*i))
}
env := exec.NewEnvironment(eng, nil)
env.Dynamic = inflags.Dynamic
env.ReaderOpts = inflags.ReaderOpts
env.Runtime = r
env.SampleSize = inflags.SampleSize
if vector {
env.Runtime = exec.RuntimeVAM
}
q, err := runtime.CompileQuery(ctx, super.NewContext(), compiler.NewCompilerWithEnv(env), ast, nil)
if err != nil {
return "", err
Expand Down
Loading