-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathgenerator.go
More file actions
110 lines (95 loc) · 2.52 KB
/
generator.go
File metadata and controls
110 lines (95 loc) · 2.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
package main
import (
"bytes"
"embed"
"fmt"
"log"
"os"
"sort"
"strings"
"text/template"
)
//go:embed *.go.tmpl
var templatesFS embed.FS
var templates *template.Template
func init() {
tmplFuncs := template.FuncMap{
"join": func(sep string, s []string) string { return strings.Join(s, sep) },
"add": func(x ...int) (y int) {
for _, v := range x {
y += v
}
return
},
"n_tabs": func(n int) string { return strings.Repeat("\t", n) },
"snake_case_to_camel_case": snakeCaseToCamelCase,
"exists": func(m map[string]any, key string) bool { _, ok := m[key]; return ok },
}
templates = template.New("file.go.tmpl")
templates.Funcs(tmplFuncs)
if _, err := templates.ParseFS(templatesFS, "*.go.tmpl"); err != nil {
log.Fatalf("ParseFS: %v", err)
}
templates.Funcs(tmplFuncs)
}
func generate(packageName string, entities []entity) (newFileName string, err error) {
newFileName = "gentity.gen.go"
imports := make(map[string]string)
for _, entity := range entities {
for _, field := range entity.Fields {
importsByField(&entity, field, imports)
}
if len(entity.JsonFields) > 0 {
imports["encoding/json"] = "encoding/json"
}
}
sort.Slice(entities, func(i, j int) bool {
return entities[i].GoName < entities[j].GoName
})
var buf bytes.Buffer
if err = templates.Execute(&buf, struct {
PackageName string
Entities []entity
Imports map[string]string
}{packageName, entities, imports}); err != nil {
err = fmt.Errorf("execute template failed: %v", err)
return
}
var outFile *os.File
outFile, err = os.Create(newFileName)
if err != nil {
err = fmt.Errorf("create file: %v", err)
return
}
defer outFile.Close()
if _, err = outFile.WriteString(buf.String()); err != nil {
err = fmt.Errorf("failed to write generated file %s: %v", newFileName, err)
return
}
return
}
func importsByField(entity *entity, field *field, imports map[string]string) {
// Import field type need only if it used arguments of methods.
// This is one case: getters.
if len(field.InIndexes) == 0 && !field.InPrimaryKey {
return
}
t := strings.Split(field.GoType, ".")
if len(t) == 1 {
return
}
pkgAlias := t[0]
if len(pkgAlias) > 2 && pkgAlias[:2] == "[]" {
pkgAlias = pkgAlias[2:]
}
if len(pkgAlias) > 1 && pkgAlias[0] == '*' {
pkgAlias = pkgAlias[1:]
}
if pkgAlias == "pgtype" {
imports["pgtype"] = "github.com/jackc/pgx/v5/pgtype"
} else if imp, ok := entity.Imports[pkgAlias]; ok {
imports[pkgAlias] = imp
} else {
imports[pkgAlias] = pkgAlias
}
}