diff --git a/runner/internal/common/types/types.go b/runner/internal/common/types/types.go index 057c0248ca..ac27558561 100644 --- a/runner/internal/common/types/types.go +++ b/runner/internal/common/types/types.go @@ -3,12 +3,13 @@ package types type TerminationReason string const ( - TerminationReasonExecutorError TerminationReason = "executor_error" - TerminationReasonCreatingContainerError TerminationReason = "creating_container_error" - TerminationReasonContainerExitedWithError TerminationReason = "container_exited_with_error" - TerminationReasonDoneByRunner TerminationReason = "done_by_runner" - TerminationReasonTerminatedByUser TerminationReason = "terminated_by_user" - TerminationReasonTerminatedByServer TerminationReason = "terminated_by_server" - TerminationReasonMaxDurationExceeded TerminationReason = "max_duration_exceeded" - TerminationReasonLogQuotaExceeded TerminationReason = "log_quota_exceeded" + TerminationReasonExecutorError TerminationReason = "executor_error" + TerminationReasonCreatingContainerError TerminationReason = "creating_container_error" + TerminationReasonContainerExitedWithError TerminationReason = "container_exited_with_error" + TerminationReasonDoneByRunner TerminationReason = "done_by_runner" + TerminationReasonTerminatedByUser TerminationReason = "terminated_by_user" + TerminationReasonTerminatedByServer TerminationReason = "terminated_by_server" + TerminationReasonMaxDurationExceeded TerminationReason = "max_duration_exceeded" + TerminationReasonLogQuotaExceeded TerminationReason = "log_quota_exceeded" + TerminationReasonDataTransferQuotaExceeded TerminationReason = "data_transfer_quota_exceeded" ) diff --git a/runner/internal/runner/api/http.go b/runner/internal/runner/api/http.go index 34220acc6e..800da516f2 100644 --- a/runner/internal/runner/api/http.go +++ b/runner/internal/runner/api/http.go @@ -194,6 +194,24 @@ func (s *Server) stopPostHandler(w http.ResponseWriter, r *http.Request) (interf return nil, nil } +func (s *Server) terminatePostHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { + var body schemas.TerminateBody + if err := api.DecodeJSONBody(w, r, &body, true); err != nil { + return nil, err + } + ctx := r.Context() + log.Error(ctx, "Terminate requested", "reason", body.Reason, "message", body.Message) + // No executor.Lock() needed — SetJobStateWithTerminationReason acquires its own lock. + // Using the external lock would deadlock with io.Copy holding it during job execution. + s.executor.SetJobStateWithTerminationReason( + ctx, + schemas.JobStateFailed, + body.Reason, + body.Message, + ) + return nil, nil +} + func isMaxBytesError(err error) bool { var maxBytesError *http.MaxBytesError return errors.As(err, &maxBytesError) diff --git a/runner/internal/runner/api/server.go b/runner/internal/runner/api/server.go index 11b76d887e..4783612326 100644 --- a/runner/internal/runner/api/server.go +++ b/runner/internal/runner/api/server.go @@ -68,6 +68,7 @@ func NewServer(ctx context.Context, address string, version string, ex executor. r.AddHandler("POST", "/api/run", s.runPostHandler) r.AddHandler("GET", "/api/pull", s.pullGetHandler) r.AddHandler("POST", "/api/stop", s.stopPostHandler) + r.AddHandler("POST", "/api/terminate", s.terminatePostHandler) r.AddHandler("GET", "/logs_ws", s.logsWsGetHandler) return s, nil } diff --git a/runner/internal/runner/schemas/schemas.go b/runner/internal/runner/schemas/schemas.go index 47706228cd..a3f83fee0a 100644 --- a/runner/internal/runner/schemas/schemas.go +++ b/runner/internal/runner/schemas/schemas.go @@ -39,6 +39,11 @@ type SubmitBody struct { LogQuotaHour int `json:"log_quota_hour"` // bytes per hour, 0 = unlimited } +type TerminateBody struct { + Reason types.TerminationReason `json:"reason"` + Message string `json:"message"` +} + type PullResponse struct { JobStates []JobStateEvent `json:"job_states"` JobLogs []LogEvent `json:"job_logs"` diff --git a/runner/internal/shim/docker.go b/runner/internal/shim/docker.go index 6acfb27a51..019ec7e36e 100644 --- a/runner/internal/shim/docker.go +++ b/runner/internal/shim/docker.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "io" + "net/http" "os" "os/exec" "os/user" @@ -37,6 +38,7 @@ import ( "github.com/dstackai/dstack/runner/internal/common/types" "github.com/dstackai/dstack/runner/internal/shim/backends" "github.com/dstackai/dstack/runner/internal/shim/host" + "github.com/dstackai/dstack/runner/internal/shim/netmeter" ) // TODO: Allow for configuration via cli arguments or environment variables. @@ -380,7 +382,8 @@ func (d *DockerRunner) Run(ctx context.Context, taskID string) error { if err := d.tasks.Update(task); err != nil { return fmt.Errorf("%w: failed to update task %s: %w", ErrInternal, task.ID, err) } - err = d.waitContainer(ctx, &task) + + err = d.waitContainerWithQuota(ctx, &task, cfg) } if err != nil { log.Error(ctx, "failed to run container", "err", err) @@ -910,6 +913,49 @@ func (d *DockerRunner) waitContainer(ctx context.Context, task *Task) error { return nil } +// waitContainerWithQuota waits for the container to finish, optionally enforcing +// a data transfer quota. If the quota is exceeded, it notifies the runner +// (so the server reads the termination reason via /api/pull) and stops the container. +func (d *DockerRunner) waitContainerWithQuota(ctx context.Context, task *Task, cfg TaskConfig) error { + if cfg.DataTransferQuota <= 0 { + return d.waitContainer(ctx, task) + } + + nm := netmeter.New(task.ID, cfg.DataTransferQuota) + if err := nm.Start(ctx); err != nil { + errMessage := fmt.Sprintf("data transfer quota configured but metering unavailable: %s", err) + log.Error(ctx, errMessage) + task.SetStatusTerminated(string(types.TerminationReasonExecutorError), errMessage) + return fmt.Errorf("data transfer meter: %w", err) + } + defer nm.Stop() + + waitDone := make(chan error, 1) + go func() { waitDone <- d.waitContainer(ctx, task) }() + + select { + case err := <-waitDone: + return err + case <-nm.Exceeded(): + log.Error(ctx, "Data transfer quota exceeded", "task", task.ID, "quota", cfg.DataTransferQuota) + terminateMsg := fmt.Sprintf("Outbound data transfer exceeded quota of %d bytes", cfg.DataTransferQuota) + if err := terminateRunner(ctx, d.dockerParams.RunnerHTTPPort(), + types.TerminationReasonDataTransferQuotaExceeded, terminateMsg); err != nil { + log.Error(ctx, "failed to notify runner of termination", "err", err) + } + stopTimeout := 10 + stopOpts := container.StopOptions{Timeout: &stopTimeout} + if err := d.client.ContainerStop(ctx, task.containerID, stopOpts); err != nil { + log.Error(ctx, "failed to stop container after quota exceeded", "err", err) + } + <-waitDone + // The runner already set the job state with the termination reason. + // The server will read it via /api/pull. + task.SetStatusTerminated(string(types.TerminationReasonDoneByRunner), "") + return nil + } +} + func encodeRegistryAuth(username string, password string) (string, error) { if username == "" && password == "" { return "", nil @@ -1180,6 +1226,31 @@ func getContainerLastLogs(ctx context.Context, client docker.APIClient, containe return lines, nil } +// terminateRunner calls the runner's /api/terminate endpoint to set the job termination state. +// This allows the server to read the termination reason via /api/pull before the container dies. +func terminateRunner(ctx context.Context, runnerPort int, reason types.TerminationReason, message string) error { + url := fmt.Sprintf("http://localhost:%d/api/terminate", runnerPort) + body := fmt.Sprintf(`{"reason":%q,"message":%q}`, reason, message) + // 5s is generous for a localhost HTTP call; if the runner doesn't respond in time, + // we proceed with stopping the container anyway (the server will handle the termination). + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + req, err := http.NewRequestWithContext(ctx, "POST", url, strings.NewReader(body)) + if err != nil { + return fmt.Errorf("create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + resp, err := http.DefaultClient.Do(req) + if err != nil { + return fmt.Errorf("request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("unexpected status: %d", resp.StatusCode) + } + return nil +} + /* DockerParameters interface implementation for CLIArgs */ func (c *CLIArgs) DockerPrivileged() bool { @@ -1228,6 +1299,10 @@ func (c *CLIArgs) DockerPorts() []int { return []int{c.Runner.HTTPPort, c.Runner.SSHPort} } +func (c *CLIArgs) RunnerHTTPPort() int { + return c.Runner.HTTPPort +} + func (c *CLIArgs) MakeRunnerDir(name string) (string, error) { runnerTemp := filepath.Join(c.Shim.HomeDir, "runners", name) if err := os.MkdirAll(runnerTemp, 0o755); err != nil { diff --git a/runner/internal/shim/docker_test.go b/runner/internal/shim/docker_test.go index 18f8c31fca..3723f53e3a 100644 --- a/runner/internal/shim/docker_test.go +++ b/runner/internal/shim/docker_test.go @@ -123,6 +123,10 @@ func (c *dockerParametersMock) DockerPorts() []int { return []int{} } +func (c *dockerParametersMock) RunnerHTTPPort() int { + return 10999 +} + func (c *dockerParametersMock) DockerMounts(string) ([]mount.Mount, error) { return nil, nil } diff --git a/runner/internal/shim/models.go b/runner/internal/shim/models.go index d50fe6e297..40e5cb8a15 100644 --- a/runner/internal/shim/models.go +++ b/runner/internal/shim/models.go @@ -9,6 +9,7 @@ type DockerParameters interface { DockerShellCommands(authorizedKeys []string, runnerHttpAddress string) []string DockerMounts(string) ([]mount.Mount, error) DockerPorts() []int + RunnerHTTPPort() int MakeRunnerDir(name string) (string, error) DockerPJRTDevice() string } @@ -97,10 +98,11 @@ type TaskConfig struct { InstanceMounts []InstanceMountPoint `json:"instance_mounts"` // GPUDevices allows the server to set gpu devices instead of relying on the runner default logic. // E.g. passing nvidia devices directly instead of using nvidia-container-toolkit. - GPUDevices []GPUDevice `json:"gpu_devices"` - HostSshUser string `json:"host_ssh_user"` - HostSshKeys []string `json:"host_ssh_keys"` - ContainerSshKeys []string `json:"container_ssh_keys"` + GPUDevices []GPUDevice `json:"gpu_devices"` + HostSshUser string `json:"host_ssh_user"` + HostSshKeys []string `json:"host_ssh_keys"` + ContainerSshKeys []string `json:"container_ssh_keys"` + DataTransferQuota int64 `json:"data_transfer_quota"` // total bytes for job lifetime; 0 = unlimited } type TaskListItem struct { diff --git a/runner/internal/shim/netmeter/netmeter.go b/runner/internal/shim/netmeter/netmeter.go new file mode 100644 index 0000000000..bf438b9466 --- /dev/null +++ b/runner/internal/shim/netmeter/netmeter.go @@ -0,0 +1,264 @@ +package netmeter + +import ( + "bytes" + "context" + "fmt" + "os/exec" + "strconv" + "strings" + "sync" + "time" + + "github.com/dstackai/dstack/runner/internal/common/log" +) + +const ( + pollInterval = 10 * time.Second + chainPrefix = "dstack-nm-" +) + +// NetMeter monitors outbound data transfer using iptables byte counters. +// It excludes private/VPC traffic and counts only external (billable) bytes. +// When cumulative bytes exceed the configured quota, the Exceeded() channel is closed. +type NetMeter struct { + quota int64 // total bytes for job lifetime + chainName string // unique iptables chain name + + exceeded chan struct{} + exceededOnce sync.Once + stopCh chan struct{} + stopped chan struct{} +} + +// New creates a new NetMeter with the given quota in bytes. +func New(taskID string, quota int64) *NetMeter { + // Use first 8 chars of task ID for chain name uniqueness + suffix := taskID + if len(suffix) > 8 { + suffix = suffix[:8] + } + return &NetMeter{ + quota: quota, + chainName: chainPrefix + suffix, + exceeded: make(chan struct{}), + stopCh: make(chan struct{}), + stopped: make(chan struct{}), + } +} + +// Start sets up iptables rules and begins polling byte counters. +func (m *NetMeter) Start(ctx context.Context) error { + if err := checkIptables(); err != nil { + return fmt.Errorf("iptables not available: %w", err) + } + + if err := m.setupChain(ctx); err != nil { + return fmt.Errorf("setup iptables chain: %w", err) + } + + go m.pollLoop(ctx) + return nil +} + +// Stop signals the poll loop to stop and cleans up iptables rules. +func (m *NetMeter) Stop() { + close(m.stopCh) + <-m.stopped +} + +// Exceeded returns a channel that is closed when the quota is exceeded. +func (m *NetMeter) Exceeded() <-chan struct{} { + return m.exceeded +} + +func checkIptables() error { + _, err := exec.LookPath("iptables") + return err +} + +func (m *NetMeter) setupChain(ctx context.Context) error { + // Create the chain + if err := iptables(ctx, "-N", m.chainName); err != nil { + return fmt.Errorf("create chain: %w", err) + } + + // Add exclusion rules for private/internal traffic (these RETURN without counting) + privateCIDRs := []struct { + cidr string + comment string + }{ + {"10.0.0.0/8", "VPC/private"}, + {"172.16.0.0/12", "VPC/private"}, + {"192.168.0.0/16", "VPC/private"}, + {"169.254.0.0/16", "link-local/metadata"}, + {"127.0.0.0/8", "loopback"}, + } + for _, p := range privateCIDRs { + if err := iptables(ctx, "-A", m.chainName, "-d", p.cidr, "-j", "RETURN"); err != nil { + m.cleanup(ctx) + return fmt.Errorf("add exclusion rule for %s: %w", p.comment, err) + } + } + + // Add catch-all counting rule (counts all remaining = external/billable bytes) + if err := iptables(ctx, "-A", m.chainName, "-j", "RETURN"); err != nil { + m.cleanup(ctx) + return fmt.Errorf("add counting rule: %w", err) + } + + // Insert jump from OUTPUT chain (catches host-mode Docker and host processes) + if err := iptables(ctx, "-I", "OUTPUT", "-j", m.chainName); err != nil { + m.cleanup(ctx) + return fmt.Errorf("insert OUTPUT jump: %w", err) + } + + // Insert jump from FORWARD chain (catches bridge-mode Docker traffic) + if err := iptables(ctx, "-I", "FORWARD", "-j", m.chainName); err != nil { + m.cleanup(ctx) + return fmt.Errorf("insert FORWARD jump: %w", err) + } + + return nil +} + +func (m *NetMeter) cleanup(ctx context.Context) { + // Remove jumps from OUTPUT and FORWARD (ignore errors — may not exist if setup failed partway) + _ = iptables(ctx, "-D", "OUTPUT", "-j", m.chainName) + _ = iptables(ctx, "-D", "FORWARD", "-j", m.chainName) + // Flush and delete chain + _ = iptables(ctx, "-F", m.chainName) + _ = iptables(ctx, "-X", m.chainName) +} + +func (m *NetMeter) pollLoop(ctx context.Context) { + defer close(m.stopped) + defer m.cleanup(ctx) + + ticker := time.NewTicker(pollInterval) + defer ticker.Stop() + + for { + select { + case <-m.stopCh: + return + case <-ticker.C: + bytes, err := m.readCounter(ctx) + if err != nil { + log.Error(ctx, "failed to read network counter", "chain", m.chainName, "err", err) + continue + } + if bytes > m.quota { + log.Error(ctx, "data transfer quota exceeded", + "chain", m.chainName, "bytes", bytes, "quota", m.quota) + m.exceededOnce.Do(func() { close(m.exceeded) }) + return + } + } + } +} + +// readCounter reads the cumulative byte count from the catch-all rule (last rule in chain). +func (m *NetMeter) readCounter(ctx context.Context) (int64, error) { + output, err := iptablesOutput(ctx, "-L", m.chainName, "-v", "-x", "-n") + if err != nil { + return 0, err + } + return parseByteCounter(output, m.chainName) +} + +// parseByteCounter extracts the byte count from the last rule (catch-all counting rule) +// in the iptables -L -v -x -n output. +// +// Example output: +// +// Chain dstack-nm-abcd1234 (1 references) +// pkts bytes target prot opt in out source destination +// 0 0 RETURN all -- * * 0.0.0.0/0 10.0.0.0/8 +// 0 0 RETURN all -- * * 0.0.0.0/0 172.16.0.0/12 +// 0 0 RETURN all -- * * 0.0.0.0/0 192.168.0.0/16 +// 0 0 RETURN all -- * * 0.0.0.0/0 169.254.0.0/16 +// 0 0 RETURN all -- * * 0.0.0.0/0 127.0.0.0/8 +// 123 456789 RETURN all -- * * 0.0.0.0/0 0.0.0.0/0 +// +// The last rule (destination 0.0.0.0/0) is the catch-all; its bytes field is what we want. +func parseByteCounter(output string, chainName string) (int64, error) { + lines := strings.Split(strings.TrimSpace(output), "\n") + + // Find lines that are rule entries (skip header lines) + var lastRuleLine string + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if trimmed == "" { + continue + } + // Skip "Chain ..." and column header lines + if strings.HasPrefix(trimmed, "Chain ") { + continue + } + if strings.HasPrefix(trimmed, "pkts") { + continue + } + lastRuleLine = trimmed + } + + if lastRuleLine == "" { + return 0, fmt.Errorf("no rules found in chain %s", chainName) + } + + // Parse the bytes field (second field in the line) + fields := strings.Fields(lastRuleLine) + if len(fields) < 2 { + return 0, fmt.Errorf("unexpected rule format: %q", lastRuleLine) + } + + byteCount, err := strconv.ParseInt(fields[1], 10, 64) + if err != nil { + return 0, fmt.Errorf("parse byte count %q: %w", fields[1], err) + } + + return byteCount, nil +} + +func iptables(ctx context.Context, args ...string) error { + cmd := exec.CommandContext(ctx, "iptables", args...) + var stderr bytes.Buffer + cmd.Stderr = &stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf("iptables %s: %s: %w", strings.Join(args, " "), stderr.String(), err) + } + return nil +} + +func iptablesOutput(ctx context.Context, args ...string) (string, error) { + cmd := exec.CommandContext(ctx, "iptables", args...) + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + if err := cmd.Run(); err != nil { + return "", fmt.Errorf("iptables %s: %s: %w", strings.Join(args, " "), stderr.String(), err) + } + return stdout.String(), nil +} + +// CleanupOrphanedChains removes any leftover dstack-nm-* chains from previous runs. +// Call this on shim startup. +func CleanupOrphanedChains(ctx context.Context) { + output, err := iptablesOutput(ctx, "-L", "-n") + if err != nil { + return + } + for _, line := range strings.Split(output, "\n") { + if strings.HasPrefix(line, "Chain "+chainPrefix) { + fields := strings.Fields(line) + if len(fields) >= 2 { + chainName := fields[1] + log.Info(ctx, "cleaning up orphaned data transfer meter chain", "chain", chainName) + _ = iptables(ctx, "-D", "OUTPUT", "-j", chainName) + _ = iptables(ctx, "-D", "FORWARD", "-j", chainName) + _ = iptables(ctx, "-F", chainName) + _ = iptables(ctx, "-X", chainName) + } + } + } +} diff --git a/runner/internal/shim/netmeter/netmeter_test.go b/runner/internal/shim/netmeter/netmeter_test.go new file mode 100644 index 0000000000..055843e930 --- /dev/null +++ b/runner/internal/shim/netmeter/netmeter_test.go @@ -0,0 +1,98 @@ +package netmeter + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseByteCounter(t *testing.T) { + tests := []struct { + name string + output string + chain string + expected int64 + expectErr bool + }{ + { + name: "typical output with traffic", + output: `Chain dstack-nm-abcd1234 (1 references) + pkts bytes target prot opt in out source destination + 0 0 RETURN all -- * * 0.0.0.0/0 10.0.0.0/8 + 0 0 RETURN all -- * * 0.0.0.0/0 172.16.0.0/12 + 0 0 RETURN all -- * * 0.0.0.0/0 192.168.0.0/16 + 0 0 RETURN all -- * * 0.0.0.0/0 169.254.0.0/16 + 0 0 RETURN all -- * * 0.0.0.0/0 127.0.0.0/8 + 123 456789 RETURN all -- * * 0.0.0.0/0 0.0.0.0/0 +`, + chain: "dstack-nm-abcd1234", + expected: 456789, + }, + { + name: "zero traffic", + output: `Chain dstack-nm-abcd1234 (1 references) + pkts bytes target prot opt in out source destination + 0 0 RETURN all -- * * 0.0.0.0/0 10.0.0.0/8 + 0 0 RETURN all -- * * 0.0.0.0/0 172.16.0.0/12 + 0 0 RETURN all -- * * 0.0.0.0/0 192.168.0.0/16 + 0 0 RETURN all -- * * 0.0.0.0/0 169.254.0.0/16 + 0 0 RETURN all -- * * 0.0.0.0/0 127.0.0.0/8 + 0 0 RETURN all -- * * 0.0.0.0/0 0.0.0.0/0 +`, + chain: "dstack-nm-abcd1234", + expected: 0, + }, + { + name: "large byte count", + output: `Chain dstack-nm-test1234 (1 references) + pkts bytes target prot opt in out source destination + 10000 5000000 RETURN all -- * * 0.0.0.0/0 10.0.0.0/8 + 0 0 RETURN all -- * * 0.0.0.0/0 172.16.0.0/12 + 0 0 RETURN all -- * * 0.0.0.0/0 192.168.0.0/16 + 0 0 RETURN all -- * * 0.0.0.0/0 169.254.0.0/16 + 0 0 RETURN all -- * * 0.0.0.0/0 127.0.0.0/8 + 500000 107374182400 RETURN all -- * * 0.0.0.0/0 0.0.0.0/0 +`, + chain: "dstack-nm-test1234", + expected: 107374182400, // ~100 GB + }, + { + name: "empty output", + output: "", + chain: "dstack-nm-abcd1234", + expectErr: true, + }, + { + name: "only headers no rules", + output: `Chain dstack-nm-abcd1234 (1 references) + pkts bytes target prot opt in out source destination +`, + chain: "dstack-nm-abcd1234", + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := parseByteCounter(tt.output, tt.chain) + if tt.expectErr { + assert.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} + +func TestNew(t *testing.T) { + nm := New("abcdefghijklmnop", 1000000) + assert.Equal(t, int64(1000000), nm.quota) + assert.Equal(t, "dstack-nm-abcdefgh", nm.chainName) +} + +func TestNew_ShortID(t *testing.T) { + nm := New("abc", 500) + assert.Equal(t, "dstack-nm-abc", nm.chainName) +} diff --git a/src/dstack/_internal/core/models/runs.py b/src/dstack/_internal/core/models/runs.py index fdb7b58cd2..eba3a12653 100644 --- a/src/dstack/_internal/core/models/runs.py +++ b/src/dstack/_internal/core/models/runs.py @@ -152,6 +152,7 @@ class JobTerminationReason(str, Enum): EXECUTOR_ERROR = "executor_error" MAX_DURATION_EXCEEDED = "max_duration_exceeded" LOG_QUOTA_EXCEEDED = "log_quota_exceeded" + DATA_TRANSFER_QUOTA_EXCEEDED = "data_transfer_quota_exceeded" def to_status(self) -> JobStatus: mapping = { @@ -175,6 +176,7 @@ def to_status(self) -> JobStatus: self.EXECUTOR_ERROR: JobStatus.FAILED, self.MAX_DURATION_EXCEEDED: JobStatus.TERMINATED, self.LOG_QUOTA_EXCEEDED: JobStatus.FAILED, + self.DATA_TRANSFER_QUOTA_EXCEEDED: JobStatus.FAILED, } return mapping[self] @@ -208,6 +210,7 @@ def to_error(self) -> Optional[str]: JobTerminationReason.EXECUTOR_ERROR: "executor error", JobTerminationReason.MAX_DURATION_EXCEEDED: "max duration exceeded", JobTerminationReason.LOG_QUOTA_EXCEEDED: "log quota exceeded", + JobTerminationReason.DATA_TRANSFER_QUOTA_EXCEEDED: "data transfer quota exceeded", } return error_mapping.get(self) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py index 590b9907bd..af379c3b78 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py @@ -34,6 +34,7 @@ RunStatus, ) from dstack._internal.core.models.volumes import InstanceMountPoint, Volume, VolumeMountPoint +from dstack._internal.server import settings as server_settings from dstack._internal.server.background.pipeline_tasks.base import ( Fetcher, Heartbeater, @@ -1104,6 +1105,9 @@ def _process_provisioning_with_shim( memory = None network_mode = NetworkMode.HOST image_name = resolve_provisioning_image_name(job_spec, jpd) + data_transfer_quota = 0 + if jpd.backend == BackendType.AWS: + data_transfer_quota = server_settings.SERVER_DATA_TRANSFER_QUOTA_PER_JOB_AWS if shim_client.is_api_v2_supported(): shim_client.submit_task( task_id=job_model.id, @@ -1126,6 +1130,7 @@ def _process_provisioning_with_shim( host_ssh_keys=[ssh_key] if ssh_key else [], container_ssh_keys=public_keys, instance_id=jpd.instance_id, + data_transfer_quota=data_transfer_quota, ) else: submitted = shim_client.submit( diff --git a/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py index 7adcabdf7d..1bd20a0ea9 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py @@ -37,6 +37,7 @@ RunStatus, ) from dstack._internal.core.models.volumes import InstanceMountPoint, Volume, VolumeMountPoint +from dstack._internal.server import settings as server_settings from dstack._internal.server.background.scheduled_tasks.common import get_provisioning_timeout from dstack._internal.server.db import get_db, get_session_ctx from dstack._internal.server.models import ( @@ -785,6 +786,9 @@ def _process_provisioning_with_shim( memory = None network_mode = NetworkMode.HOST image_name = resolve_provisioning_image_name(job_spec, jpd) + data_transfer_quota = 0 + if jpd.backend == BackendType.AWS: + data_transfer_quota = server_settings.SERVER_DATA_TRANSFER_QUOTA_PER_JOB_AWS if shim_client.is_api_v2_supported(): shim_client.submit_task( task_id=job_model.id, @@ -807,6 +811,7 @@ def _process_provisioning_with_shim( host_ssh_keys=[ssh_key] if ssh_key else [], container_ssh_keys=public_keys, instance_id=jpd.instance_id, + data_transfer_quota=data_transfer_quota, ) else: submitted = shim_client.submit( diff --git a/src/dstack/_internal/server/schemas/runner.py b/src/dstack/_internal/server/schemas/runner.py index 549ff7914e..5fadde06de 100644 --- a/src/dstack/_internal/server/schemas/runner.py +++ b/src/dstack/_internal/server/schemas/runner.py @@ -245,6 +245,7 @@ class TaskSubmitRequest(CoreModel): host_ssh_user: str host_ssh_keys: list[str] container_ssh_keys: list[str] + data_transfer_quota: int = 0 # total bytes; 0 = unlimited class TaskTerminateRequest(CoreModel): diff --git a/src/dstack/_internal/server/services/runner/client.py b/src/dstack/_internal/server/services/runner/client.py index 4b78eefeea..c3500e4a9c 100644 --- a/src/dstack/_internal/server/services/runner/client.py +++ b/src/dstack/_internal/server/services/runner/client.py @@ -416,6 +416,7 @@ def submit_task( host_ssh_keys: list[str], container_ssh_keys: list[str], instance_id: str, + data_transfer_quota: int = 0, ) -> None: if not self.is_api_v2_supported(): raise ShimAPIVersionError() @@ -439,6 +440,7 @@ def submit_task( host_ssh_user=host_ssh_user, host_ssh_keys=host_ssh_keys, container_ssh_keys=container_ssh_keys, + data_transfer_quota=data_transfer_quota, ) self._request("POST", "/api/tasks", body, raise_for_status=True) diff --git a/src/dstack/_internal/server/settings.py b/src/dstack/_internal/server/settings.py index 01216cff31..e46d3f1fc4 100644 --- a/src/dstack/_internal/server/settings.py +++ b/src/dstack/_internal/server/settings.py @@ -138,6 +138,20 @@ os.getenv("DSTACK_SERVER_LOG_QUOTA_PER_JOB_HOUR", 50 * 1024 * 1024) # 50 MB ) +# Per-job data transfer quota for AWS backend: maximum total outbound bytes to external IPs. +# 0 = unlimited. Only applied to instances running on AWS. +# Limitations: +# - Meters all outbound traffic to non-private IPs (excludes 10.0.0.0/8, 172.16.0.0/12, +# 192.168.0.0/16, 169.254.0.0/16). This covers inter-region and internet egress. +# - Does not differentiate by destination region — the same quota applies regardless of +# whether traffic goes to another AWS region ($0.01-0.02/GB) or the internet ($0.09/GB). +# - Only effective on Linux instances with iptables available. +# Task fails with executor_error on systems without iptables if quota is set. +# To add support for other backends, add DSTACK_SERVER_DATA_TRANSFER_QUOTA_PER_JOB_GCP, etc. +SERVER_DATA_TRANSFER_QUOTA_PER_JOB_AWS = int( + os.getenv("DSTACK_SERVER_DATA_TRANSFER_QUOTA_PER_JOB_AWS", 0) # disabled by default +) + # Development settings SQL_ECHO_ENABLED = os.getenv("DSTACK_SQL_ECHO_ENABLED") is not None diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py b/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py index a52924a552..eb0cb6b709 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py @@ -573,6 +573,7 @@ async def test_provisioning_shim_with_volumes( host_ssh_keys=["user_ssh_key"], container_ssh_keys=[project_ssh_pub_key, "user_ssh_key"], instance_id=job_provisioning_data.instance_id, + data_transfer_quota=0, ) await session.refresh(job) assert job.status == JobStatus.PULLING diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py b/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py index 66b38f331f..862565c648 100644 --- a/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py @@ -389,6 +389,7 @@ async def test_provisioning_shim_with_volumes( host_ssh_keys=["user_ssh_key"], container_ssh_keys=[project_ssh_pub_key, "user_ssh_key"], instance_id=job_provisioning_data.instance_id, + data_transfer_quota=0, ) await session.refresh(job) assert job.status == JobStatus.PULLING diff --git a/src/tests/_internal/server/services/runner/test_client.py b/src/tests/_internal/server/services/runner/test_client.py index 588c231a19..dbbece0a58 100644 --- a/src/tests/_internal/server/services/runner/test_client.py +++ b/src/tests/_internal/server/services/runner/test_client.py @@ -455,6 +455,7 @@ def test_submit_task(self, client: ShimClient, adapter: requests_mock.Adapter): "host_ssh_user": "dstack", "host_ssh_keys": ["host_key"], "container_ssh_keys": ["project_key", "user_key"], + "data_transfer_quota": 0, } self.assert_request(adapter, 1, "POST", "/api/tasks", expected_request)