Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 55 additions & 56 deletions cmd/gen-source-variants/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ import (
"path/filepath"
"testing"

"gotest.tools/v3/assert"
"gotest.tools/v3/assert/cmp"
"github.com/stretchr/testify/assert"
)

func TestExtractSourceFields(t *testing.T) {
Expand Down Expand Up @@ -49,13 +48,13 @@ type NonSourceType struct{}
`

err := os.WriteFile(specFile, []byte(specContent), 0644)
assert.NilError(t, err)
assert.NoError(t, err)

// Change to temp directory so the parser can find spec.go
t.Chdir(tempDir)

fields, err := extractSourceFields()
assert.NilError(t, err)
assert.NoError(t, err)

// Verify we found the expected source fields
expectedFields := []SourceField{
Expand All @@ -67,7 +66,7 @@ type NonSourceType struct{}
{Name: "Inline", TypeName: "SourceInline"},
}

assert.DeepEqual(t, fields, expectedFields)
assert.Equal(t, expectedFields, fields)
}

func TestExtractSourceFields_NoSourceStruct(t *testing.T) {
Expand All @@ -82,12 +81,12 @@ type OtherStruct struct {
`

err := os.WriteFile(specFile, []byte(specContent), 0644)
assert.NilError(t, err)
assert.NoError(t, err)

t.Chdir(tempDir)

fields, err := extractSourceFields()
assert.NilError(t, err)
assert.NoError(t, err)
assert.Equal(t, len(fields), 0)
}

Expand All @@ -104,13 +103,13 @@ type Source struct {
`

err := os.WriteFile(specFile, []byte(specContent), 0644)
assert.NilError(t, err)
assert.NoError(t, err)

t.Chdir(tempDir)

_, err = extractSourceFields()
assert.Check(t, err != nil)
assert.Check(t, cmp.Contains(err.Error(), "failed to parse"))
assert.True(t, err != nil)
assert.Contains(t, err.Error(), "failed to parse")
}

func TestGenerateCode(t *testing.T) {
Expand All @@ -121,68 +120,68 @@ func TestGenerateCode(t *testing.T) {
}

code, err := generateCode(fields)
assert.NilError(t, err)
assert.NoError(t, err)

// Verify the generated code is valid Go
fset := token.NewFileSet()
_, err = parser.ParseFile(fset, "", code, parser.ParseComments)
assert.NilError(t, err)
assert.NoError(t, err)

codeStr := string(code)

// Check for expected function signatures
assert.Check(t, cmp.Contains(codeStr, "func (s *Source) validateSourceVariants() error"))
assert.Check(t, cmp.Contains(codeStr, "func (s *Source) toInterface() source"))
assert.Contains(t, codeStr, "func (s *Source) validateSourceVariants() error")
assert.Contains(t, codeStr, "func (s *Source) toInterface() source")

// Check for package declaration and imports
assert.Check(t, cmp.Contains(codeStr, "package dalec"))
assert.Check(t, cmp.Contains(codeStr, `import (`))
assert.Check(t, cmp.Contains(codeStr, `"fmt"`))
assert.Contains(t, codeStr, "package dalec")
assert.Contains(t, codeStr, `import (`)
assert.Contains(t, codeStr, `"fmt"`)

// Check for generated comment
assert.Check(t, cmp.Contains(codeStr, "Code generated by cmd/gen-source-variants. DO NOT EDIT."))
assert.Contains(t, codeStr, "Code generated by cmd/gen-source-variants. DO NOT EDIT.")

// Check that all fields are present in validation
for _, field := range fields {
expectedCheck := "if s." + field.Name + " != nil {"
assert.Check(t, cmp.Contains(codeStr, expectedCheck))
assert.Contains(t, codeStr, expectedCheck)
}

// Check that all fields are present in toInterface
for _, field := range fields {
expectedCase := "case s." + field.Name + " != nil:"
expectedReturn := "return s." + field.Name
assert.Check(t, cmp.Contains(codeStr, expectedCase))
assert.Check(t, cmp.Contains(codeStr, expectedReturn))
assert.Contains(t, codeStr, expectedCase)
assert.Contains(t, codeStr, expectedReturn)
}

// Check validation logic
assert.Check(t, cmp.Contains(codeStr, "count := 0"))
assert.Check(t, cmp.Contains(codeStr, "count++"))
assert.Check(t, cmp.Contains(codeStr, `return fmt.Errorf("no non-nil source variant")`))
assert.Check(t, cmp.Contains(codeStr, `return fmt.Errorf("more than one source variant defined")`))
assert.Contains(t, codeStr, "count := 0")
assert.Contains(t, codeStr, "count++")
assert.Contains(t, codeStr, `return fmt.Errorf("no non-nil source variant")`)
assert.Contains(t, codeStr, `return fmt.Errorf("more than one source variant defined")`)

// Check toInterface logic
assert.Check(t, cmp.Contains(codeStr, "panic(errNoSourceVariant)"))
assert.Contains(t, codeStr, "panic(errNoSourceVariant)")
}

func TestGenerateCode_EmptyFields(t *testing.T) {
fields := []SourceField{}

code, err := generateCode(fields)
assert.NilError(t, err)
assert.NoError(t, err)

// Verify the generated code is still valid Go
fset := token.NewFileSet()
_, err = parser.ParseFile(fset, "", code, parser.ParseComments)
assert.NilError(t, err)
assert.NoError(t, err)

codeStr := string(code)

// Should still have the basic structure
assert.Check(t, cmp.Contains(codeStr, "func (s *Source) validateSourceVariants() error"))
assert.Check(t, cmp.Contains(codeStr, "func (s *Source) toInterface() source"))
assert.Check(t, cmp.Contains(codeStr, "count := 0"))
assert.Contains(t, codeStr, "func (s *Source) validateSourceVariants() error")
assert.Contains(t, codeStr, "func (s *Source) toInterface() source")
assert.Contains(t, codeStr, "count := 0")
}

func TestGenerateCode_SingleField(t *testing.T) {
Expand All @@ -191,14 +190,14 @@ func TestGenerateCode_SingleField(t *testing.T) {
}

code, err := generateCode(fields)
assert.NilError(t, err)
assert.NoError(t, err)

codeStr := string(code)

// Check for the single field
assert.Check(t, cmp.Contains(codeStr, "if s.Git != nil {"))
assert.Check(t, cmp.Contains(codeStr, "case s.Git != nil:"))
assert.Check(t, cmp.Contains(codeStr, "return s.Git"))
assert.Contains(t, codeStr, "if s.Git != nil {")
assert.Contains(t, codeStr, "case s.Git != nil:")
assert.Contains(t, codeStr, "return s.Git")
}

func TestGenerateCode_FormattingPreserved(t *testing.T) {
Expand All @@ -208,14 +207,14 @@ func TestGenerateCode_FormattingPreserved(t *testing.T) {
}

code, err := generateCode(fields)
assert.NilError(t, err)
assert.NoError(t, err)

// Verify that the code is properly formatted by re-formatting it
reformatted, err := format.Source(code)
assert.NilError(t, err)
assert.NoError(t, err)

// Should be identical since generateCode already formats
assert.Check(t, bytes.Equal(code, reformatted))
assert.True(t, bytes.Equal(code, reformatted))
}

func TestMain_Integration(t *testing.T) {
Expand All @@ -238,7 +237,7 @@ type SourceHTTP struct{}
`

err := os.WriteFile(specFile, []byte(specContent), 0644)
assert.NilError(t, err)
assert.NoError(t, err)

t.Chdir(tempDir)

Expand All @@ -260,21 +259,21 @@ type SourceHTTP struct{}

// Verify output file was created
_, err = os.Stat(outputFile)
assert.NilError(t, err)
assert.NoError(t, err)

// Verify output file contents
content, err := os.ReadFile(outputFile)
assert.NilError(t, err)
assert.NoError(t, err)

contentStr := string(content)
assert.Check(t, cmp.Contains(contentStr, "package dalec"))
assert.Check(t, cmp.Contains(contentStr, "func (s *Source) validateSourceVariants() error"))
assert.Check(t, cmp.Contains(contentStr, "func (s *Source) toInterface() source"))
assert.Contains(t, contentStr, "package dalec")
assert.Contains(t, contentStr, "func (s *Source) validateSourceVariants() error")
assert.Contains(t, contentStr, "func (s *Source) toInterface() source")

// Verify it's valid Go
fset := token.NewFileSet()
_, err = parser.ParseFile(fset, "", content, parser.ParseComments)
assert.NilError(t, err)
assert.NoError(t, err)
}

func TestMain_InvalidArgs(t *testing.T) {
Expand All @@ -290,7 +289,7 @@ func TestMain_InvalidArgs(t *testing.T) {

// We know from reading main() that it checks len(os.Args) != 2
// Since we can't easily test os.Exit, we'll verify the condition
assert.Check(t, len(os.Args) != 2, "Invalid args should trigger exit condition")
assert.True(t, len(os.Args) != 2, "Invalid args should trigger exit condition")
}

func TestFieldSorting(t *testing.T) {
Expand All @@ -302,22 +301,22 @@ func TestFieldSorting(t *testing.T) {
}

code, err := generateCode(fields)
assert.NilError(t, err)
assert.NoError(t, err)

codeStr := string(code)

// Check that all fields are present
assert.Check(t, cmp.Contains(codeStr, "AField"))
assert.Check(t, cmp.Contains(codeStr, "MField"))
assert.Check(t, cmp.Contains(codeStr, "ZField"))
assert.Contains(t, codeStr, "AField")
assert.Contains(t, codeStr, "MField")
assert.Contains(t, codeStr, "ZField")

// Check that the validation logic contains all fields
assert.Check(t, cmp.Contains(codeStr, "if s.AField != nil {"))
assert.Check(t, cmp.Contains(codeStr, "if s.MField != nil {"))
assert.Check(t, cmp.Contains(codeStr, "if s.ZField != nil {"))
assert.Contains(t, codeStr, "if s.AField != nil {")
assert.Contains(t, codeStr, "if s.MField != nil {")
assert.Contains(t, codeStr, "if s.ZField != nil {")

// Check that the toInterface logic contains all fields
assert.Check(t, cmp.Contains(codeStr, "case s.AField != nil:"))
assert.Check(t, cmp.Contains(codeStr, "case s.MField != nil:"))
assert.Check(t, cmp.Contains(codeStr, "case s.ZField != nil:"))
assert.Contains(t, codeStr, "case s.AField != nil:")
assert.Contains(t, codeStr, "case s.MField != nil:")
assert.Contains(t, codeStr, "case s.ZField != nil:")
}
Loading