diff --git a/pkg/blockcontroller/blockcontroller.go b/pkg/blockcontroller/blockcontroller.go index 524a66c10b..48ba01bb70 100644 --- a/pkg/blockcontroller/blockcontroller.go +++ b/pkg/blockcontroller/blockcontroller.go @@ -19,6 +19,7 @@ import ( "github.com/wavetermdev/waveterm/pkg/jobcontroller" "github.com/wavetermdev/waveterm/pkg/remote" "github.com/wavetermdev/waveterm/pkg/remote/conncontroller" + "github.com/wavetermdev/waveterm/pkg/util/ds" "github.com/wavetermdev/waveterm/pkg/util/shellutil" "github.com/wavetermdev/waveterm/pkg/wavebase" "github.com/wavetermdev/waveterm/pkg/waveobj" @@ -75,10 +76,17 @@ type Controller interface { // Registry for all controllers var ( - controllerRegistry = make(map[string]Controller) - registryLock sync.RWMutex + controllerRegistry = make(map[string]Controller) + registryLock sync.RWMutex + blockResyncMutexMap = ds.MakeSyncMap[*sync.Mutex]() ) +func getBlockResyncMutex(blockId string) *sync.Mutex { + return blockResyncMutexMap.GetOrCreate(blockId, func() *sync.Mutex { + return &sync.Mutex{} + }) +} + // Registry operations func getController(blockId string) Controller { registryLock.RLock() @@ -145,6 +153,10 @@ func ResyncController(ctx context.Context, tabId string, blockId string, rtOpts return fmt.Errorf("invalid tabId or blockId passed to ResyncController") } + mu := getBlockResyncMutex(blockId) + mu.Lock() + defer mu.Unlock() + blockData, err := wstore.DBMustGet[*waveobj.Block](ctx, blockId) if err != nil { return fmt.Errorf("error getting block: %w", err) diff --git a/pkg/blockcontroller/durableshellcontroller.go b/pkg/blockcontroller/durableshellcontroller.go index a208a3df75..d3481b172d 100644 --- a/pkg/blockcontroller/durableshellcontroller.go +++ b/pkg/blockcontroller/durableshellcontroller.go @@ -163,7 +163,7 @@ func (dsc *DurableShellController) Start(ctx context.Context, blockMeta waveobj. if jobId == "" { log.Printf("block %q starting new durable shell\n", dsc.BlockId) - newJobId, err := dsc.startNewJob(ctx, blockMeta, dsc.ConnName) + newJobId, err := dsc.startNewJob(ctx, blockMeta, dsc.ConnName, rtOpts) if err != nil { return fmt.Errorf("failed to start new job: %w", err) } @@ -218,11 +218,14 @@ func (dsc *DurableShellController) SendInput(inputUnion *BlockInputUnion) error return jobcontroller.SendInput(context.Background(), data) } -func (dsc *DurableShellController) startNewJob(ctx context.Context, blockMeta waveobj.MetaMapType, connName string) (string, error) { +func (dsc *DurableShellController) startNewJob(ctx context.Context, blockMeta waveobj.MetaMapType, connName string, rtOpts *waveobj.RuntimeOpts) (string, error) { termSize := waveobj.TermSize{ Rows: shellutil.DefaultTermRows, Cols: shellutil.DefaultTermCols, } + if rtOpts != nil && rtOpts.TermSize.Rows > 0 && rtOpts.TermSize.Cols > 0 { + termSize = rtOpts.TermSize + } cmdStr := blockMeta.GetString(waveobj.MetaKey_Cmd, "") cwd := blockMeta.GetString(waveobj.MetaKey_CmdCwd, "") opts, err := remote.ParseOpts(connName) diff --git a/pkg/util/ds/syncmap.go b/pkg/util/ds/syncmap.go index a422343ac5..99b4095efc 100644 --- a/pkg/util/ds/syncmap.go +++ b/pkg/util/ds/syncmap.go @@ -62,3 +62,14 @@ func (sm *SyncMap[T]) TestAndSet(key string, newValue T, testFn func(T, bool) bo } return false } + +func (sm *SyncMap[T]) GetOrCreate(key string, createFn func() T) T { + sm.lock.Lock() + defer sm.lock.Unlock() + if v, ok := sm.m[key]; ok { + return v + } + v := createFn() + sm.m[key] = v + return v +}