diff --git a/frontend/src/pages/Runs/Details/RunDetails/ConnectToRunWithDevEnvConfiguration/index.tsx b/frontend/src/pages/Runs/Details/RunDetails/ConnectToRunWithDevEnvConfiguration/index.tsx index af63e9c67b..54d03c388d 100644 --- a/frontend/src/pages/Runs/Details/RunDetails/ConnectToRunWithDevEnvConfiguration/index.tsx +++ b/frontend/src/pages/Runs/Details/RunDetails/ConnectToRunWithDevEnvConfiguration/index.tsx @@ -54,7 +54,9 @@ export const ConnectToRunWithDevEnvConfiguration: FC<{ run: IRun }> = ({ run }) const [sshCommand, copySSHCommand] = getSSHCommand(run); const configuration = run.run_spec.configuration as TDevEnvironmentConfiguration; - const openInIDEUrl = `${configuration.ide}://vscode-remote/ssh-remote+${run.run_spec.run_name}/${run.run_spec.working_dir || 'workflow'}`; + const latestSubmission = run.jobs[0]?.job_submissions?.slice(-1)[0]; + const workingDir = latestSubmission?.job_runtime_data?.working_dir ?? '/'; + const openInIDEUrl = `${configuration.ide}://vscode-remote/ssh-remote+${run.run_spec.run_name}${workingDir}`; const ideDisplayName = getIDEDisplayName(configuration.ide); const [configCliCommand, copyCliCommand] = useConfigProjectCliCommand({ projectName: run.project_name }); diff --git a/frontend/src/types/run.d.ts b/frontend/src/types/run.d.ts index 3eac746218..928a022804 100644 --- a/frontend/src/types/run.d.ts +++ b/frontend/src/types/run.d.ts @@ -293,9 +293,15 @@ declare interface IJobProvisioningData { backend_data?: string; } +declare interface IJobRuntimeData { + working_dir?: string | null; + username?: string | null; +} + declare interface IJobSubmission { id: string; job_provisioning_data?: IJobProvisioningData | null; + job_runtime_data?: IJobRuntimeData | null; error_code?: TJobErrorCode | null; submission_num: number; status: TJobStatus; diff --git a/runner/internal/executor/base.go b/runner/internal/executor/base.go index 99e32250cb..fac1266fb0 100644 --- a/runner/internal/executor/base.go +++ b/runner/internal/executor/base.go @@ -12,6 +12,7 @@ type Executor interface { GetHistory(timestamp int64) *schemas.PullResponse GetJobWsLogsHistory() []schemas.LogEvent GetRunnerState() string + GetJobInfo(ctx context.Context) (username string, workingDir string, err error) Run(ctx context.Context) error SetJob(job schemas.SubmitBody) SetJobState(ctx context.Context, state types.JobState) diff --git a/runner/internal/executor/executor.go b/runner/internal/executor/executor.go index 61e18ee3e9..311eddaa10 100644 --- a/runner/internal/executor/executor.go +++ b/runner/internal/executor/executor.go @@ -21,6 +21,7 @@ import ( "github.com/creack/pty" "github.com/dstackai/ansistrip" "github.com/prometheus/procfs" + "github.com/sirupsen/logrus" "golang.org/x/sys/unix" "github.com/dstackai/dstack/runner/consts" @@ -61,6 +62,10 @@ type RunExecutor struct { fileArchiveDir string repoBlobDir string + runnerLogFile *os.File + runnerLogStripper *ansistrip.Writer + runnerLogger *logrus.Entry + run schemas.Run jobSpec schemas.JobSpec jobSubmission schemas.JobSubmission @@ -136,14 +141,26 @@ func NewRunExecutor(tempDir string, dstackDir string, currentUser linuxuser.User }, nil } +// GetJobInfo must be called after SetJob +func (ex *RunExecutor) GetJobInfo(ctx context.Context) (string, string, error) { + // preRun() sets ex.jobUser and ex.jobWorkingDir + if err := ex.preRun(ctx); err != nil { + return "", "", err + } + return ex.jobUser.Username, ex.jobWorkingDir, nil +} + // Run must be called after SetJob and WriteRepoBlob func (ex *RunExecutor) Run(ctx context.Context) (err error) { - runnerLogFile, err := log.CreateAppendFile(filepath.Join(ex.tempDir, consts.RunnerLogFileName)) - if err != nil { - ex.SetJobState(ctx, types.JobStateFailed) - return fmt.Errorf("create runner log file: %w", err) + // If jobStateHistory is not empty, either Run() has already been called or + // preRun() has already been called via GetJobInfo() and failed + if len(ex.jobStateHistory) > 0 { + return errors.New("already running or finished") + } + if err := ex.preRun(ctx); err != nil { + return err } - defer func() { _ = runnerLogFile.Close() }() + defer ex.postRun(ctx) jobLogFile, err := log.CreateAppendFile(filepath.Join(ex.tempDir, consts.RunnerJobLogFileName)) if err != nil { @@ -153,7 +170,7 @@ func (ex *RunExecutor) Run(ctx context.Context) (err error) { defer func() { _ = jobLogFile.Close() }() defer func() { - // recover goes after runnerLogFile.Close() to keep the log + // recover goes after postRun(), which closes runnerLogFile, to keep the log if r := recover(); r != nil { log.Error(ctx, "Executor PANIC", "err", r) ex.SetJobState(ctx, types.JobStateFailed) @@ -171,21 +188,8 @@ func (ex *RunExecutor) Run(ctx context.Context) (err error) { } }() - stripper := ansistrip.NewWriter(ex.runnerLogs, AnsiStripFlushInterval, AnsiStripMaxDelay, MaxBufferSize) - defer func() { _ = stripper.Close() }() - logger := io.MultiWriter(runnerLogFile, os.Stdout, stripper) - ctx = log.WithLogger(ctx, log.NewEntry(logger, int(log.DefaultEntry.Logger.Level))) // todo loglevel - log.Info(ctx, "Run job", "log_level", log.GetLogger(ctx).Logger.Level.String()) - - if err := ex.setJobUser(ctx); err != nil { - ex.SetJobStateWithTerminationReason( - ctx, - types.JobStateFailed, - types.TerminationReasonExecutorError, - fmt.Sprintf("Failed to set job user (%s)", err), - ) - return fmt.Errorf("set job user: %w", err) - } + ctx = log.WithLogger(ctx, ex.runnerLogger) + log.Info(ctx, "Run job") // setJobUser sets User.HomeDir to "/" if the original home dir is not set or not accessible, // in that case we skip home dir provisioning @@ -204,16 +208,6 @@ func (ex *RunExecutor) Run(ctx context.Context) (err error) { } } - if err := ex.setJobWorkingDir(ctx); err != nil { - ex.SetJobStateWithTerminationReason( - ctx, - types.JobStateFailed, - types.TerminationReasonExecutorError, - fmt.Sprintf("Failed to set job working dir (%s)", err), - ) - return fmt.Errorf("set job working dir: %w", err) - } - if err := ex.setupRepo(ctx); err != nil { ex.SetJobStateWithTerminationReason( ctx, @@ -336,6 +330,66 @@ func (ex *RunExecutor) SetRunnerState(state string) { ex.state = state } +// preRun performs actions that were once part of Run() but were moved to a separate function +// to implement GetJobInfo() +// preRun must not execute long-running operations, as GetJobInfo() is called synchronously +// in the /api/run method +func (ex *RunExecutor) preRun(ctx context.Context) error { + // Already called once + if ex.runnerLogFile != nil { + return nil + } + + // logging is required for the subsequent setJob{User,WorkingDir} calls + runnerLogFile, err := log.CreateAppendFile(filepath.Join(ex.tempDir, consts.RunnerLogFileName)) + if err != nil { + ex.SetJobState(ctx, types.JobStateFailed) + return fmt.Errorf("create runner log file: %w", err) + } + ex.runnerLogFile = runnerLogFile + ex.runnerLogStripper = ansistrip.NewWriter(ex.runnerLogs, AnsiStripFlushInterval, AnsiStripMaxDelay, MaxBufferSize) + runnerLogWriter := io.MultiWriter(ex.runnerLogFile, os.Stdout, ex.runnerLogStripper) + runnerLogLevel := log.DefaultEntry.Logger.Level + ex.runnerLogger = log.NewEntry(runnerLogWriter, int(runnerLogLevel)) + ctx = log.WithLogger(ctx, ex.runnerLogger) + log.Info(ctx, "Logging configured", "log_level", runnerLogLevel.String()) + + // jobUser and jobWorkingDir are required for GetJobInfo() + if err := ex.setJobUser(ctx); err != nil { + ex.SetJobStateWithTerminationReason( + ctx, + types.JobStateFailed, + types.TerminationReasonExecutorError, + fmt.Sprintf("Failed to set job user (%s)", err), + ) + return fmt.Errorf("set job user: %w", err) + } + if err := ex.setJobWorkingDir(ctx); err != nil { + ex.SetJobStateWithTerminationReason( + ctx, + types.JobStateFailed, + types.TerminationReasonExecutorError, + fmt.Sprintf("Failed to set job working dir (%s)", err), + ) + return fmt.Errorf("set job working dir: %w", err) + } + + return nil +} + +func (ex *RunExecutor) postRun(ctx context.Context) { + if ex.runnerLogFile != nil { + if err := ex.runnerLogFile.Close(); err != nil { + log.Error(ctx, "Failed to close runnerLogFile", "err", err) + } + } + if ex.runnerLogStripper != nil { + if err := ex.runnerLogStripper.Close(); err != nil { + log.Error(ctx, "Failed to close runnerLogStripper", "err", err) + } + } +} + // setJobWorkingDir must be called from Run after setJobUser func (ex *RunExecutor) setJobWorkingDir(ctx context.Context) error { var err error diff --git a/runner/internal/runner/api/http.go b/runner/internal/runner/api/http.go index 87eb96e0af..4d1c7daf54 100644 --- a/runner/internal/runner/api/http.go +++ b/runner/internal/runner/api/http.go @@ -146,18 +146,27 @@ func (s *Server) uploadCodePostHandler(w http.ResponseWriter, r *http.Request) ( func (s *Server) runPostHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { s.executor.Lock() - defer s.executor.Unlock() if s.executor.GetRunnerState() != executor.WaitRun { + s.executor.Unlock() return nil, &api.Error{Status: http.StatusConflict} } + s.executor.SetRunnerState(executor.ServeLogs) + s.executor.Unlock() var runCtx context.Context runCtx, s.cancelRun = context.WithCancel(context.Background()) + username, workingDir, err := s.executor.GetJobInfo(runCtx) go func() { _ = s.executor.Run(runCtx) // INFO: all errors are handled inside the Run() s.jobBarrierCh <- nil // notify server that job finished }() - s.executor.SetRunnerState(executor.ServeLogs) + + if err == nil { + return &schemas.JobInfoResponse{ + Username: username, + WorkingDir: workingDir, + }, nil + } return nil, nil } diff --git a/runner/internal/schemas/schemas.go b/runner/internal/schemas/schemas.go index 152637decc..10ab62ea95 100644 --- a/runner/internal/schemas/schemas.go +++ b/runner/internal/schemas/schemas.go @@ -35,7 +35,11 @@ type PullResponse struct { LastUpdated int64 `json:"last_updated"` NoConnectionsSecs int64 `json:"no_connections_secs"` HasMore bool `json:"has_more"` - // todo Result +} + +type JobInfoResponse struct { + WorkingDir string `json:"working_dir"` + Username string `json:"username"` } type Run struct { diff --git a/src/dstack/_internal/core/models/runs.py b/src/dstack/_internal/core/models/runs.py index 88c2f38f5e..558b07e26e 100644 --- a/src/dstack/_internal/core/models/runs.py +++ b/src/dstack/_internal/core/models/runs.py @@ -352,6 +352,10 @@ class JobRuntimeData(CoreModel): volume_names: Optional[list[str]] = None # None for backward compatibility # Virtual shared offer offer: Optional[InstanceOfferWithAvailability] = None # None for backward compatibility + # Resolved working directory and OS username reported by the runner. + # None if the runner hasn't reported them yet or if it's an old runner. + working_dir: Optional[str] = None + username: Optional[str] = None class ClusterInfo(CoreModel): diff --git a/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py index 9d3bd04c3b..5916c9054a 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py @@ -1116,7 +1116,13 @@ def _submit_job_to_runner( logger.debug("%s: uploading code", fmt(job_model)) runner_client.upload_code(code) logger.debug("%s: starting job", fmt(job_model)) - runner_client.run_job() + job_info = runner_client.run_job() + if job_info is not None: + jrd = get_job_runtime_data(job_model) + if jrd is not None: + jrd.working_dir = job_info.working_dir + jrd.username = job_info.username + job_model.job_runtime_data = jrd.json() switch_job_status(session, job_model, JobStatus.RUNNING) # do not log here, because the runner will send a new status diff --git a/src/dstack/_internal/server/schemas/runner.py b/src/dstack/_internal/server/schemas/runner.py index 12ff6c6825..89649ddda6 100644 --- a/src/dstack/_internal/server/schemas/runner.py +++ b/src/dstack/_internal/server/schemas/runner.py @@ -46,6 +46,11 @@ class PullResponse(CoreModel): no_connections_secs: Optional[int] = None # Optional for compatibility with old runners +class JobInfoResponse(CoreModel): + working_dir: str + username: str + + class SubmitBody(CoreModel): run: Annotated[ Run, diff --git a/src/dstack/_internal/server/services/runner/client.py b/src/dstack/_internal/server/services/runner/client.py index c83a42b744..c31726e76a 100644 --- a/src/dstack/_internal/server/services/runner/client.py +++ b/src/dstack/_internal/server/services/runner/client.py @@ -24,6 +24,7 @@ GPUDevice, HealthcheckResponse, InstanceHealthResponse, + JobInfoResponse, LegacyPullResponse, LegacyStopBody, LegacySubmitBody, @@ -124,9 +125,13 @@ def upload_code(self, file: Union[BinaryIO, bytes]): ) resp.raise_for_status() - def run_job(self): + def run_job(self) -> Optional[JobInfoResponse]: resp = requests.post(self._url("/api/run"), timeout=REQUEST_TIMEOUT) resp.raise_for_status() + if not _is_json_response(resp): + # Old runner or runner failed to get job info + return None + return JobInfoResponse.__response__.parse_obj(resp.json()) def pull(self, timestamp: int) -> PullResponse: resp = requests.get( @@ -617,6 +622,13 @@ def _memory_to_bytes(memory: Optional[Memory]) -> int: return int(memory * 1024**3) +def _is_json_response(response: requests.Response) -> bool: + content_type = response.headers.get("content-type") + if not content_type: + return False + return content_type.split(";", maxsplit=1)[0].strip() == "application/json" + + _TaskID = Union[uuid.UUID, str] _Version = tuple[int, int, int] diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index cca5212576..6bff65dea3 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -456,6 +456,8 @@ def get_job_runtime_data( ports: Optional[dict[int, int]] = None, offer: Optional[InstanceOfferWithAvailability] = None, volume_names: Optional[list[str]] = None, + working_dir: Optional[str] = None, + username: Optional[str] = None, ) -> JobRuntimeData: return JobRuntimeData( network_mode=NetworkMode(network_mode), @@ -465,6 +467,8 @@ def get_job_runtime_data( ports=ports, offer=offer, volume_names=volume_names, + working_dir=working_dir, + username=username, ) diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index 72d31189f2..675a88d292 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -352,9 +352,9 @@ def attach( if runtime_data is not None and runtime_data.ports is not None: container_ssh_port = runtime_data.ports.get(container_ssh_port, container_ssh_port) - # TODO: get login name from runner in case it's not specified in the run configuration - # (i.e. the default image user is used, and it is not root) - if job.job_spec.user is not None and job.job_spec.user.username is not None: + if runtime_data is not None and runtime_data.username is not None: + container_user = runtime_data.username + elif job.job_spec.user is not None and job.job_spec.user.username is not None: container_user = job.job_spec.user.username else: container_user = "root" diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py b/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py index 0d748f4e91..aad8615bf3 100644 --- a/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py @@ -44,6 +44,7 @@ from dstack._internal.server.models import JobModel from dstack._internal.server.schemas.runner import ( HealthcheckResponse, + JobInfoResponse, JobStateEvent, PortMapping, PullResponse, @@ -188,6 +189,7 @@ async def test_runs_provisioning_job(self, test_db, session: AsyncSession): run=run, status=JobStatus.PROVISIONING, job_provisioning_data=job_provisioning_data, + job_runtime_data=get_job_runtime_data(), instance=instance, instance_assigned=True, ) @@ -201,6 +203,9 @@ async def test_runs_provisioning_job(self, test_db, session: AsyncSession): runner_client_mock.healthcheck.return_value = HealthcheckResponse( service="dstack-runner", version="0.0.1.dev2" ) + runner_client_mock.run_job.return_value = JobInfoResponse( + working_dir="/dstack/run", username="dstack" + ) await process_running_jobs() SSHTunnelMock.assert_called_once() runner_client_mock.healthcheck.assert_called_once() @@ -210,6 +215,9 @@ async def test_runs_provisioning_job(self, test_db, session: AsyncSession): await session.refresh(job) assert job is not None assert job.status == JobStatus.RUNNING + jrd = JobRuntimeData.__response__.parse_raw(job.job_runtime_data) + assert jrd.working_dir == "/dstack/run" + assert jrd.username == "dstack" @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) @@ -416,6 +424,9 @@ async def test_pulling_shim( PortMapping(container=10022, host=32771), PortMapping(container=10999, host=32772), ] + runner_client_mock.run_job.return_value = JobInfoResponse( + working_dir="/dstack/run", username="dstack" + ) await process_running_jobs() @@ -428,10 +439,13 @@ async def test_pulling_shim( await session.refresh(job) assert job is not None assert job.status == JobStatus.RUNNING - assert JobRuntimeData.__response__.parse_raw(job.job_runtime_data).ports == { + jrd = JobRuntimeData.__response__.parse_raw(job.job_runtime_data) + assert jrd.ports == { 10022: 32771, 10999: 32772, } + assert jrd.working_dir == "/dstack/run" + assert jrd.username == "dstack" @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)