diff --git a/nested_set.go b/nested_set.go index bd1a5f2..6474e56 100644 --- a/nested_set.go +++ b/nested_set.go @@ -323,6 +323,65 @@ func Rebuild(db *gorm.DB, source interface{}, doUpdate bool) (affectedCount int, return } + +// RebuildBatched rebuild nodes as any nestedset which in the scope +// ```nestedset.RebuildBatched(db, &node, true, 1000)``` will rebuild [&node] as nestedset +func RebuildBatched(db *gorm.DB, source interface{}, doUpdate bool, batchSize int) (affectedCount int, err error) { + tx, target, err := parseNode(db, source) + if err != nil { + return + } + err = tx.Transaction(func(tx *gorm.DB) (err error) { + allItems := []*nestedItem{} + err = tx.Clauses(clause.Locking{Strength: "UPDATE"}). + Where(formatSQL("", target)). + Order(formatSQL(":parent_id ASC NULLS FIRST, :lft ASC", target)). + Find(&allItems). + Error + + if err != nil { + return + } + initTree(allItems).rebuild() + + var itemsToUpdate []*nestedItem + for _, item := range allItems { + if item.IsChanged { + affectedCount += 1 + if doUpdate { + itemsToUpdate = append(itemsToUpdate, item) + } + } + } + if doUpdate && len(itemsToUpdate) > 0 { + err = batchUpdate(tx, []string{"lft", "rgt", "depth", "children_count"}, target.DbNames, itemsToUpdate, batchSize) + if err != nil { + return + } + } + return nil + }) + return +} + +// batchUpdate performs a batched upsert (update on conflict) for the given columns and items. +func batchUpdate(db *gorm.DB, columns []string, dbNames map[string]string, items []*nestedItem, batchSize int) error { + if len(items) == 0 { + return nil + } + + assignmentMap := map[string]interface{}{} + for _, column := range columns { + column = dbNames[column] + assignmentMap[column] = gorm.Expr("EXCLUDED." + column) + } + + return db.Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: dbNames["id"]}}, + DoUpdates: clause.Assignments(assignmentMap), + }).CreateInBatches(items, batchSize).Error +} + func moveIsValid(node, to nestedItem) error { validLft, validRgt := node.Lft, node.Rgt if (to.Lft >= validLft && to.Lft <= validRgt) || (to.Rgt >= validLft && to.Rgt <= validRgt) { diff --git a/nested_set_test.go b/nested_set_test.go index 7b188bd..7c785f3 100644 --- a/nested_set_test.go +++ b/nested_set_test.go @@ -344,6 +344,164 @@ func TestRebuild(t *testing.T) { assertNodeEqual(t, lilysDresses, 4, 5, 1, 0, lilysClothing.ID) } +func TestRebuildBatched(t *testing.T) { + const batchSize = 5 + initData() + affectedCount, err := RebuildBatched(db, clothing, true, batchSize) + assert.NoError(t, err) + assert.Equal(t, 0, affectedCount) + reloadCategories() + + assertNodeEqual(t, clothing, 1, 22, 0, 2, 0) + assertNodeEqual(t, mens, 2, 9, 1, 1, clothing.ID) + assertNodeEqual(t, suits, 3, 8, 2, 2, mens.ID) + assertNodeEqual(t, slacks, 4, 5, 3, 0, suits.ID) + assertNodeEqual(t, jackets, 6, 7, 3, 0, suits.ID) + assertNodeEqual(t, womens, 10, 21, 1, 3, clothing.ID) + assertNodeEqual(t, dresses, 11, 16, 2, 2, womens.ID) + assertNodeEqual(t, eveningGowns, 12, 13, 3, 0, dresses.ID) + assertNodeEqual(t, sunDresses, 14, 15, 3, 0, dresses.ID) + assertNodeEqual(t, skirts, 17, 18, 2, 0, womens.ID) + assertNodeEqual(t, blouses, 19, 20, 2, 0, womens.ID) + + sunDresses.Rgt = 123 + sunDresses.Lft = 12 + sunDresses.Depth = 1 + sunDresses.ChildrenCount = 100 + err = db.Updates(&sunDresses).Error + assert.NoError(t, err) + reloadCategories() + assertNodeEqual(t, sunDresses, 12, 123, 1, 100, dresses.ID) + + affectedCount, err = RebuildBatched(db, clothing, true, batchSize) + assert.NoError(t, err) + assert.Equal(t, 2, affectedCount) + reloadCategories() + + assertNodeEqual(t, clothing, 1, 22, 0, 2, 0) + assertNodeEqual(t, mens, 2, 9, 1, 1, clothing.ID) + assertNodeEqual(t, suits, 3, 8, 2, 2, mens.ID) + assertNodeEqual(t, slacks, 4, 5, 3, 0, suits.ID) + assertNodeEqual(t, jackets, 6, 7, 3, 0, suits.ID) + assertNodeEqual(t, womens, 10, 21, 1, 3, clothing.ID) + assertNodeEqual(t, dresses, 11, 16, 2, 2, womens.ID) + assertNodeEqual(t, eveningGowns, 14, 15, 3, 0, dresses.ID) + assertNodeEqual(t, sunDresses, 12, 13, 3, 0, dresses.ID) + assertNodeEqual(t, skirts, 17, 18, 2, 0, womens.ID) + assertNodeEqual(t, blouses, 19, 20, 2, 0, womens.ID) + + affectedCount, err = RebuildBatched(db, clothing, true, batchSize) + assert.NoError(t, err) + assert.Equal(t, 0, affectedCount) + reloadCategories() + + assertNodeEqual(t, clothing, 1, 22, 0, 2, 0) + assertNodeEqual(t, mens, 2, 9, 1, 1, clothing.ID) + assertNodeEqual(t, suits, 3, 8, 2, 2, mens.ID) + assertNodeEqual(t, slacks, 4, 5, 3, 0, suits.ID) + assertNodeEqual(t, jackets, 6, 7, 3, 0, suits.ID) + assertNodeEqual(t, womens, 10, 21, 1, 3, clothing.ID) + assertNodeEqual(t, dresses, 11, 16, 2, 2, womens.ID) + assertNodeEqual(t, eveningGowns, 14, 15, 3, 0, dresses.ID) + assertNodeEqual(t, sunDresses, 12, 13, 3, 0, dresses.ID) + assertNodeEqual(t, skirts, 17, 18, 2, 0, womens.ID) + assertNodeEqual(t, blouses, 19, 20, 2, 0, womens.ID) + + hat := *CategoryFactory.MustCreateWithOption(map[string]interface{}{ + "Title": "Hat", + "ParentID": sql.NullInt64{Valid: false}, + }).(*Category) + + affectedCount, err = RebuildBatched(db, clothing, false, batchSize) + assert.NoError(t, err) + assert.Equal(t, 1, affectedCount) + + affectedCount, err = RebuildBatched(db, clothing, true, batchSize) + assert.NoError(t, err) + assert.Equal(t, 1, affectedCount) + reloadCategories() + hat, _ = findNode(db, hat.ID) + + assertNodeEqual(t, clothing, 1, 22, 0, 2, 0) + assertNodeEqual(t, mens, 2, 9, 1, 1, clothing.ID) + assertNodeEqual(t, suits, 3, 8, 2, 2, mens.ID) + assertNodeEqual(t, slacks, 4, 5, 3, 0, suits.ID) + assertNodeEqual(t, jackets, 6, 7, 3, 0, suits.ID) + assertNodeEqual(t, womens, 10, 21, 1, 3, clothing.ID) + assertNodeEqual(t, dresses, 11, 16, 2, 2, womens.ID) + assertNodeEqual(t, eveningGowns, 14, 15, 3, 0, dresses.ID) + assertNodeEqual(t, sunDresses, 12, 13, 3, 0, dresses.ID) + assertNodeEqual(t, skirts, 17, 18, 2, 0, womens.ID) + assertNodeEqual(t, blouses, 19, 20, 2, 0, womens.ID) + assertNodeEqual(t, hat, 23, 24, 0, 0, 0) + + jacksClothing := *CategoryFactory.MustCreateWithOption(map[string]interface{}{ + "Title": "Jack's Clothing", + "ParentID": sql.NullInt64{Valid: false}, + "UserType": "User", + "UserID": 8686, + }).(*Category) + jacksSuits := *CategoryFactory.MustCreateWithOption(map[string]interface{}{ + "Title": "Jack's Suits", + "ParentID": sql.NullInt64{Valid: true, Int64: jacksClothing.ID}, + "UserType": "User", + "UserID": 8686, + }).(*Category) + jacksHat := *CategoryFactory.MustCreateWithOption(map[string]interface{}{ + "Title": "Jack's Hat", + "UserType": "User", + "UserID": 8686, + "ParentID": sql.NullInt64{Valid: false}, + }).(*Category) + jacksSlacks := *CategoryFactory.MustCreateWithOption(map[string]interface{}{ + "Title": "Jack's Slacks", + "ParentID": sql.NullInt64{Valid: true, Int64: jacksClothing.ID}, + "UserType": "User", + "UserID": 8686, + }).(*Category) + + lilysHat := *CategoryFactory.MustCreateWithOption(map[string]interface{}{ + "Title": "Lily's Hat", + "UserType": "User", + "UserID": 6666, + "ParentID": sql.NullInt64{Valid: false}, + }).(*Category) + lilysClothing := *CategoryFactory.MustCreateWithOption(map[string]interface{}{ + "Title": "Lily's Clothing", + "ParentID": sql.NullInt64{Valid: false}, + "UserType": "User", + "UserID": 6666, + }).(*Category) + lilysDresses := *CategoryFactory.MustCreateWithOption(map[string]interface{}{ + "Title": "Lily's Dresses", + "ParentID": sql.NullInt64{Valid: true, Int64: lilysClothing.ID}, + "UserType": "User", + "UserID": 6666, + }).(*Category) + + affectedCount, err = RebuildBatched(db, jacksSuits, true, batchSize) + assert.NoError(t, err) + assert.Equal(t, 4, affectedCount) + affectedCount, err = RebuildBatched(db, lilysHat, true, batchSize) + assert.NoError(t, err) + assert.Equal(t, 3, affectedCount) + jacksClothing, _ = findNode(db, jacksClothing.ID) + jacksSuits, _ = findNode(db, jacksSuits.ID) + jacksSlacks, _ = findNode(db, jacksSlacks.ID) + jacksHat, _ = findNode(db, jacksHat.ID) + lilysHat, _ = findNode(db, lilysHat.ID) + lilysClothing, _ = findNode(db, lilysClothing.ID) + lilysDresses, _ = findNode(db, lilysDresses.ID) + + assertNodeEqual(t, jacksClothing, 1, 6, 0, 2, 0) + assertNodeEqual(t, jacksSuits, 2, 3, 1, 0, jacksClothing.ID) + assertNodeEqual(t, jacksSlacks, 4, 5, 1, 0, jacksClothing.ID) + assertNodeEqual(t, jacksHat, 7, 8, 0, 0, 0) + assertNodeEqual(t, lilysHat, 1, 2, 0, 0, 0) + assertNodeEqual(t, lilysClothing, 3, 6, 0, 1, 0) + assertNodeEqual(t, lilysDresses, 4, 5, 1, 0, lilysClothing.ID) +} + func TestMoveToLeft(t *testing.T) { // case 1 initData()