From 887241362ebbefd2ec0fa3c2e124e195e4f3f7cd Mon Sep 17 00:00:00 2001 From: Noah Treuhaft Date: Fri, 22 May 2026 13:44:25 -0400 Subject: [PATCH] ztest: refactor a little --- ztest/ztest.go | 68 +++++++++++++++++++++++--------------------------- 1 file changed, 31 insertions(+), 37 deletions(-) diff --git a/ztest/ztest.go b/ztest/ztest.go index 1d68f0bc6..94bef4165 100644 --- a/ztest/ztest.go +++ b/ztest/ztest.go @@ -348,43 +348,35 @@ func (z *ZTest) ShouldSkip(path string) string { } func (z *ZTest) RunScript(ctx context.Context, shellPath, testDir string, tempDir func() string) error { + return z.run(func(r exec.Runtime) error { + return z.runScript(ctx, shellPath, testDir, tempDir(), r) + }) +} + +func (z *ZTest) RunInternal(ctx context.Context) error { + return z.run(func(r exec.Runtime) error { + return z.diffInternal(z.runInternal(ctx, r)) + }) +} + +func (z *ZTest) run(fn func(exec.Runtime) error) error { if err := z.check(); err != nil { return fmt.Errorf("bad yaml format: %w", err) } - serr := runsh(ctx, shellPath, testDir, tempDir(), z) + serr := fn(exec.RuntimeSAM) if !z.Vector { return serr } if serr != nil { serr = fmt.Errorf("=== sequence ===\n%w", serr) } - verr := runsh(ctx, shellPath, testDir, tempDir(), z, "SUPER_RUNTIME=vam") + verr := fn(exec.RuntimeVAM) if verr != nil { verr = fmt.Errorf("=== vector ===\n%w", verr) } return errors.Join(serr, verr) } -func (z *ZTest) RunInternal(ctx context.Context) error { - if err := z.check(); err != nil { - return fmt.Errorf("bad yaml format: %w", err) - } - outputFlags := append([]string{"-f=sup", "-pretty=0"}, strings.Fields(z.OutputFlags)...) - inputFlags := strings.Fields(z.InputFlags) - if z.Vector { - verr := z.diffInternal(runInternal(ctx, z.SPQ, z.Input, outputFlags, inputFlags, true)) - if verr != nil { - verr = fmt.Errorf("=== vector ===\n%w", verr) - } - serr := z.diffInternal(runInternal(ctx, z.SPQ, z.Input, outputFlags, inputFlags, false)) - if serr != nil { - serr = fmt.Errorf("=== sequence ===\n%w", serr) - } - return errors.Join(verr, serr) - } - return z.diffInternal(runInternal(ctx, z.SPQ, z.Input, outputFlags, inputFlags, false)) -} - func (z *ZTest) diffInternal(out string, err error) error { var outDiffErr, errDiffErr error if z.Output != out { @@ -436,9 +428,9 @@ func diffErr(name, expected, actual string) error { return fmt.Errorf("expected and actual %s differ:\n%s", name, diff) } -func runsh(ctx context.Context, path, testDir, tempDir string, zt *ZTest, extraEnv ...string) error { +func (z *ZTest) runScript(ctx context.Context, path, testDir, tempDir string, r exec.Runtime) error { var stdin io.Reader - for _, f := range zt.Inputs { + for _, f := range z.Inputs { b, _, err := f.load(testDir) if err != nil { return err @@ -451,12 +443,16 @@ func runsh(ctx context.Context, path, testDir, tempDir string, zt *ZTest, extraE return err } } - stdout, stderr, err := RunShell(ctx, tempDir, path, zt.Script, stdin, zt.Env, extraEnv) + var extraEnv []string + if r == exec.RuntimeVAM { + extraEnv = []string{"SUPER_RUNTIME=vam"} + } + stdout, stderr, err := RunShell(ctx, tempDir, path, z.Script, stdin, z.Env, extraEnv) if err != nil { return fmt.Errorf("script failed: %w\n=== stdout ===\n%s=== stderr ===\n%s", err, stdout, stderr) } - for _, f := range zt.Outputs { + for _, f := range z.Outputs { var actual string switch f.Name { case "stdout": @@ -484,20 +480,20 @@ func runsh(ctx context.Context, path, testDir, tempDir string, zt *ZTest, extraE return nil } -// runInternal runs query over input and returns the output. input -// may be in any format recognized by "super -i auto" and may be gzip-compressed. -// outputFlags may contain any flags accepted by cli/outputflags.Flags. -func runInternal(ctx context.Context, query string, input *string, outputFlags, inputFlags []string, vector bool) (string, error) { - ast, err := parser.ParseText(query) +func (z *ZTest) runInternal(ctx context.Context, r exec.Runtime) (string, error) { + ast, err := parser.ParseText(z.SPQ) if err != nil { return "", err } + args := []string{"-f=sup", "-pretty=0"} + args = append(args, strings.Fields(z.OutputFlags)...) + args = append(args, strings.Fields(z.InputFlags)...) var fs flag.FlagSet var inflags inputflags.Flags var outflags outputflags.Flags inflags.SetFlags(&fs, true) outflags.SetFlags(&fs) - if err := fs.Parse(append(inputFlags, outputFlags...)); err != nil { + if err := fs.Parse(args); err != nil { return "", err } if err := inflags.Init(); err != nil { @@ -507,17 +503,15 @@ func runInternal(ctx context.Context, query string, input *string, outputFlags, return "", err } eng := storage.NewInternalEngine() - if input != nil { + if i := z.Input; i != nil { ast.PrependFileScan([]string{"stdio:stdin"}) - eng.AddReader("stdio:stdin", strings.NewReader(*input)) + eng.AddReader("stdio:stdin", strings.NewReader(*i)) } env := exec.NewEnvironment(eng, nil) env.Dynamic = inflags.Dynamic env.ReaderOpts = inflags.ReaderOpts + env.Runtime = r env.SampleSize = inflags.SampleSize - if vector { - env.Runtime = exec.RuntimeVAM - } q, err := runtime.CompileQuery(ctx, super.NewContext(), compiler.NewCompilerWithEnv(env), ast, nil) if err != nil { return "", err