diff --git a/Cargo.lock b/Cargo.lock index 2d97e92..e7d1e8c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -609,7 +609,7 @@ dependencies = [ [[package]] name = "daft" -version = "0.4.1-alpha0" +version = "0.5.0-alpha0" dependencies = [ "anyhow", "aws-config", @@ -617,6 +617,8 @@ dependencies = [ "aws-sdk-sts", "clap", "comfy-table", + "open", + "regex", "rstest", "serde", "serde_yaml", @@ -1100,6 +1102,25 @@ dependencies = [ "hashbrown", ] +[[package]] +name = "is-docker" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "928bae27f42bc99b60d9ac7334e3a21d10ad8f1835a4e12ec3ec0464765ed1b3" +dependencies = [ + "once_cell", +] + +[[package]] +name = "is-wsl" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "173609498df190136aa7dea1a91db051746d339e18476eed5ca40521f02d7aa5" +dependencies = [ + "is-docker", + "once_cell", +] + [[package]] name = "is_terminal_polyfill" version = "1.70.1" @@ -1236,6 +1257,17 @@ version = "1.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" +[[package]] +name = "open" +version = "5.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2483562e62ea94312f3576a7aca397306df7990b8d89033e18766744377ef95" +dependencies = [ + "is-wsl", + "libc", + "pathdiff", +] + [[package]] name = "openssl-probe" version = "0.1.5" @@ -1271,6 +1303,12 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "pathdiff" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df94ce210e5bc13cb6651479fa48d14f601d9858cfe0467f43ae157023b938d3" + [[package]] name = "percent-encoding" version = "2.3.1" diff --git a/Cargo.toml b/Cargo.toml index e485945..98e4e29 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "daft" -version = "0.4.1-alpha0" +version = "0.5.0-alpha0" edition = "2021" description = "A simple launcher for spinning up and managing Ray clusters for Daft" license = "LICENSE" @@ -13,6 +13,8 @@ serde_yaml = "0.9" tempdir = "0.3" toml = "0.8" comfy-table = "7.1" +regex = "1.11.1" +open = "5.3.2" [dependencies.anyhow] version = "1.0" diff --git a/src/sql.py b/assets/sql.py similarity index 100% rename from src/sql.py rename to assets/sql.py diff --git a/src/template_byoc.toml b/assets/template-byoc.toml similarity index 55% rename from src/template_byoc.toml rename to assets/template-byoc.toml index 38f1564..504b6c2 100644 --- a/src/template_byoc.toml +++ b/assets/template-byoc.toml @@ -1,8 +1,10 @@ -# This is a template configuration file for daft-launcher with a BYOC provider +# This is a template configuration file for daft-launcher with Kubernetes provider [setup] name = "my-daft-cluster" -version = "" +requires = "" +python-version = "" +ray-version = "" [setup.byoc] namespace = "default" # Optional, defaults to "default" diff --git a/src/template_provisioned.toml b/assets/template-provisioned.toml similarity index 61% rename from src/template_provisioned.toml rename to assets/template-provisioned.toml index 45cc7cb..3d59c14 100644 --- a/src/template_provisioned.toml +++ b/assets/template-provisioned.toml @@ -1,10 +1,12 @@ -# This is a template configuration file for daft-launcher with a provisioned provider +# This is a template configuration file for daft-launcher with AWS provider [setup] name = "my-daft-cluster" -version = "" +requires = "" +python-version = "" +ray-version = "" -# Provisioned (AWS) configuration +# AWS-specific configuration [setup.provisioned] region = "us-west-2" number-of-workers = 4 @@ -14,6 +16,7 @@ instance-type = "i3.2xlarge" image-id = "ami-04dd23e62ed049936" iam-instance-profile-name = "YourInstanceProfileName" # Optional dependencies = [] # Optional additional Python packages to install +run = [] # Optional commands to run during cluster-node initialization # Job definitions [[job]] diff --git a/src/main.rs b/src/main.rs index 3c378dd..9feb9b4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,21 @@ +macro_rules! asset { + ( + $($path_segment:literal),* + $(,)? + ) => { + include_str!( + concat! { + env!("CARGO_MANIFEST_DIR"), + "/assets", + $( + "/", + $path_segment, + )* + } + ) + }; +} + macro_rules! not_available_for_byoc { ($command:literal) => { anyhow::bail!(concat!( @@ -9,6 +27,8 @@ macro_rules! not_available_for_byoc { } mod ssh; +#[cfg(test)] +mod tests; use std::{ collections::HashMap, @@ -21,9 +41,6 @@ use std::{ time::Duration, }; -#[cfg(test)] -mod tests; - #[cfg(not(test))] use anyhow::bail; use aws_config::{BehaviorVersion, Region}; @@ -32,6 +49,7 @@ use clap::{Parser, Subcommand, ValueEnum}; use comfy_table::{ modifiers, presets, Attribute, Cell, CellAlignment, Color, ContentArrangement, Table, }; +use regex::Regex; use serde::{Deserialize, Serialize}; use tempdir::TempDir; use tokio::{ @@ -132,13 +150,17 @@ struct Init { #[arg(default_value = ".daft.toml")] path: PathBuf, - /// The provider to use - either 'provisioned' (default) to auto-generate a cluster or 'byoc' for existing Kubernetes clusters + /// The provider to use - either 'aws' (default) to auto-generate a cluster + /// or 'k8s' for existing Kubernetes clusters #[arg(long, default_value_t = DaftProvider::Provisioned)] provider: DaftProvider, } #[derive(Debug, Parser, Clone, PartialEq, Eq)] struct List { + /// A regex to filter for the Ray clusters which match the given name. + regex: Option, + /// The region which to list all the available clusters for. #[arg(long)] region: Option, @@ -170,6 +192,10 @@ struct Connect { #[arg(long, default_value = "8265")] port: u16, + /// Prevent the dashboard from opening automatically. + #[arg(long)] + no_dashboard: bool, + #[clap(flatten)] config_path: ConfigPath, } @@ -203,7 +229,11 @@ struct DaftConfig { struct DaftSetup { name: StrRef, #[serde(deserialize_with = "parse_requirement")] - version: Requirement, + requires: Requirement, + #[serde(deserialize_with = "parse_python_version")] + python_version: Versioning, + #[serde(deserialize_with = "parse_ray_version")] + ray_version: Versioning, #[serde(flatten)] provider_config: ProviderConfig, } @@ -232,6 +262,8 @@ struct AwsConfig { iam_instance_profile_name: Option, #[serde(default)] dependencies: Vec, + #[serde(default)] + run: Vec, } #[derive(Debug, Deserialize, Clone, PartialEq, Eq)] @@ -340,6 +372,34 @@ where } } +fn parse_python_version<'de, D>(deserializer: D) -> Result +where + D: serde::Deserializer<'de>, +{ + let raw: StrRef = Deserialize::deserialize(deserializer)?; + let requested_py_version = raw + .parse::() + .map_err(serde::de::Error::custom)?; + let minimum_py_requirement = ">=3.9" + .parse::() + .expect("Parsing a static, constant version should always succeed"); + + if minimum_py_requirement.matches(&requested_py_version) { + Ok(requested_py_version) + } else { + Err(serde::de::Error::custom(format!("The minimum supported python version is {minimum_py_requirement}, but your configuration file requested python version {requested_py_version}"))) + } +} + +fn parse_ray_version<'de, D>(deserializer: D) -> Result +where + D: serde::Deserializer<'de>, +{ + let raw: StrRef = Deserialize::deserialize(deserializer)?; + let version = raw.parse().map_err(serde::de::Error::custom)?; + Ok(version) +} + #[derive(Debug, ValueEnum, Clone, PartialEq, Eq)] enum DaftProvider { Provisioned, @@ -355,7 +415,6 @@ impl ToString for DaftProvider { .to_string() } } - #[derive(Debug, Clone, PartialEq, Eq)] struct DaftJob { command: StrRef, @@ -500,28 +559,12 @@ fn convert( ] .into_iter() .collect(), - setup_commands: { - let mut commands = vec![ - "curl -LsSf https://astral.sh/uv/install.sh | sh".into(), - "uv python install 3.12".into(), - "uv python pin 3.12".into(), - "uv venv".into(), - "echo 'source $HOME/.venv/bin/activate' >> ~/.bashrc".into(), - "source ~/.bashrc".into(), - "uv pip install boto3 pip py-spy deltalake getdaft ray[default]".into(), - ]; - if !aws_config.dependencies.is_empty() { - let deps = aws_config - .dependencies - .iter() - .map(|dep| format!(r#""{dep}""#)) - .collect::>() - .join(" "); - let deps = format!("uv pip install {deps}").into(); - commands.push(deps); - } - commands - }, + setup_commands: generate_setup_commands( + daft_config.setup.python_version.clone(), + daft_config.setup.ray_version.clone(), + &aws_config.dependencies, + &aws_config.run, + ), }) } @@ -604,6 +647,15 @@ enum NodeType { Worker, } +impl NodeType { + pub fn as_str(self) -> &'static str { + match self { + Self::Head => "head", + Self::Worker => "worker", + } + } +} + impl FromStr for NodeType { type Err = anyhow::Error; @@ -625,7 +677,7 @@ async fn get_ray_clusters_from_aws(region: StrRef) -> anyhow::Result anyhow::Result, + head: bool, + running: bool, +) -> anyhow::Result { let mut table = Table::default(); table .load_preset(presets::UTF8_FULL) @@ -685,11 +742,16 @@ fn print_instances(instances: &[AwsInstance], head: bool, running: bool) { .set_alignment(CellAlignment::Center) .add_attribute(Attribute::Bold) })); + let regex = regex.as_deref().map(Regex::new).transpose()?; for instance in instances.iter().filter(|instance| { if head && instance.node_type != NodeType::Head { return false; } else if running && instance.state != Some(InstanceStateName::Running) { return false; + } else if let Some(regex) = regex.as_ref() { + if !regex.is_match(&instance.regular_name) { + return false; + }; }; true }) { @@ -716,12 +778,13 @@ fn print_instances(instances: &[AwsInstance], head: bool, running: bool) { .map_or("n/a".into(), ToString::to_string); table.add_row(vec![ Cell::new(instance.regular_name.to_string()).fg(Color::Cyan), - Cell::new(&*instance.instance_id), + Cell::new(instance.instance_id.as_ref()), + Cell::new(instance.node_type.as_str()), status, Cell::new(ipv4), ]); } - println!("{table}"); + Ok(table) } async fn assert_is_logged_in_with_aws() -> anyhow::Result<()> { @@ -846,6 +909,71 @@ async fn main() -> anyhow::Result<()> { DaftLauncher::parse().run().await } +fn generate_setup_commands( + python_version: Versioning, + ray_version: Versioning, + dependencies: &[StrRef], + run: &[StrRef], +) -> Vec { + let mut commands = vec![ + "curl -LsSf https://astral.sh/uv/install.sh | sh".into(), + format!("uv python install {python_version}").into(), + format!("uv python pin {python_version}").into(), + "uv venv".into(), + "echo 'source $HOME/.venv/bin/activate' >> ~/.bashrc".into(), + "source ~/.bashrc".into(), + format!( + r#"uv pip install boto3 pip py-spy deltalake getdaft "ray[default]=={ray_version}""# + ) + .into(), + ]; + + if !dependencies.is_empty() { + let deps = dependencies + .iter() + .map(|dep| format!(r#""{dep}""#)) + .collect::>() + .join(" "); + let deps = format!("uv pip install {deps}").into(); + commands.push(deps); + } + + commands.extend(run.iter().cloned()); + + commands +} + +async fn get_version_from_env(bin: &str, prefix: &str) -> anyhow::Result { + let output = Command::new(bin) + .arg("--version") + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn()? + .wait_with_output() + .await?; + + if output.status.success() { + let version = String::from_utf8(output.stdout)? + .strip_prefix(prefix) + .ok_or_else(|| anyhow::anyhow!("Could not parse {bin} version"))? + .trim() + .parse()?; + Ok(version) + } else { + Err(anyhow::anyhow!("Failed to find {bin} executable")) + } +} + +async fn get_python_version_from_env() -> anyhow::Result { + let python_version = get_version_from_env("python", "Python ").await?; + Ok(python_version) +} + +async fn get_ray_version_from_env() -> anyhow::Result { + let python_version = get_version_from_env("ray", "ray, version ").await?; + Ok(python_version) +} + impl DaftLauncher { async fn run(&self) -> anyhow::Result<()> { match &self.sub_command { @@ -866,10 +994,18 @@ impl ConfigCommand { bail!("The path {path:?} already exists; the path given must point to a new location on your filesystem"); } let contents = match provider { - DaftProvider::Byoc => include_str!("template_byoc.toml"), - DaftProvider::Provisioned => include_str!("template_provisioned.toml"), + DaftProvider::Byoc => asset!("template-byoc.toml"), + DaftProvider::Provisioned => asset!("template-provisioned.toml"), } - .replace("", concat!("=", env!("CARGO_PKG_VERSION"))); + .replace("", concat!("=", env!("CARGO_PKG_VERSION"))) + .replace( + "", + get_python_version_from_env().await?.to_string().as_str(), + ) + .replace( + "", + get_ray_version_from_env().await?.to_string().as_str(), + ); fs::write(path, contents).await?; } ConfigCommand::Check(ConfigPath { config }) => { @@ -921,7 +1057,7 @@ impl JobCommand { JobCommand::Sql(Sql { sql, config_path }) => { let daft_config = read_daft_config(&config_path.config).await?; let (temp_sql_dir, sql_path) = create_temp_file("sql.py")?; - fs::write(sql_path, include_str!("sql.py")).await?; + fs::write(sql_path, asset!("sql.py")).await?; let working_dir = temp_sql_dir.path(); let command_segments = vec!["python", "sql.py", sql.as_ref()]; @@ -995,11 +1131,13 @@ impl ProvisionedCommand { ProviderConfig::Byoc(..) => not_available_for_byoc!("kill"), } } - ProvisionedCommand::List(List { - config_path, - region, + &ProvisionedCommand::List(List { + ref config_path, + ref regex, + ref region, head, running, + .. }) => { let daft_config = read_daft_config(&config_path.config).await?; match &daft_config.setup.provider_config { @@ -1008,16 +1146,19 @@ impl ProvisionedCommand { let region = region.as_ref().unwrap_or_else(|| &aws_config.region); let instances = get_ray_clusters_from_aws(region.clone()).await?; - print_instances(&instances, *head, *running); + let table = format_table(&instances, regex.as_deref(), head, running)?; + println!("{table}"); } ProviderConfig::Byoc(..) => not_available_for_byoc!("list"), } } &ProvisionedCommand::Connect(Connect { port, + no_dashboard, ref config_path, }) => { let daft_config = read_daft_config(&config_path.config).await?; + let open_dashboard = !no_dashboard; match &daft_config.setup.provider_config { ProviderConfig::Provisioned(aws_config) => { assert_is_logged_in_with_aws().await?; @@ -1025,10 +1166,14 @@ impl ProvisionedCommand { let ray_config = convert(&daft_config, None)?; let (_temp_dir, ray_path) = create_temp_ray_file()?; write_ray_config(&ray_config, &ray_path).await?; - let _ = ssh::ssh_portforward(ray_path, aws_config, Some(port)) - .await? - .wait_with_output() - .await?; + + let child = ssh::ssh_portforward(ray_path, aws_config, Some(port)).await?; + + if open_dashboard { + open::that("http://localhost:8265")?; + }; + + child.wait_with_output().await?; } ProviderConfig::Byoc(..) => not_available_for_byoc!("connect"), } diff --git a/src/tests.rs b/src/tests.rs index 8aa3e14..a184062 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -75,6 +75,43 @@ async fn test_check(#[case] provider: DaftProvider) { .unwrap(); } +#[rstest::rstest] +#[case("3.9".parse().unwrap(), "2.34".parse().unwrap(), vec![], vec![], vec![ + "curl -LsSf https://astral.sh/uv/install.sh | sh".into(), + "uv python install 3.9".into(), + "uv python pin 3.9".into(), + "uv venv".into(), + "echo 'source $HOME/.venv/bin/activate' >> ~/.bashrc".into(), + "source ~/.bashrc".into(), + r#"uv pip install boto3 pip py-spy deltalake getdaft "ray[default]==2.34""#.into(), +])] +#[case("3.9".parse().unwrap(), "2.34".parse().unwrap(), vec!["requests==0.0.0".into()], vec![r#"echo "Hello, world!""#.into()], vec![ + "curl -LsSf https://astral.sh/uv/install.sh | sh".into(), + "uv python install 3.9".into(), + "uv python pin 3.9".into(), + "uv venv".into(), + "echo 'source $HOME/.venv/bin/activate' >> ~/.bashrc".into(), + "source ~/.bashrc".into(), + r#"uv pip install boto3 pip py-spy deltalake getdaft "ray[default]==2.34""#.into(), + r#"uv pip install "requests==0.0.0""#.into(), + r#"echo "Hello, world!""#.into(), +])] +fn test_generate_setup_commands( + #[case] python_version: Versioning, + #[case] ray_version: Versioning, + #[case] dependencies: Vec, + #[case] run: Vec, + #[case] expected: Vec, +) { + let actual = generate_setup_commands( + python_version, + ray_version, + dependencies.as_slice(), + run.as_slice(), + ); + assert_eq!(actual, expected); +} + /// This tests the core conversion functionality, from a `DaftConfig` to a /// `RayConfig`. /// @@ -106,7 +143,9 @@ pub fn simple_config() -> (DaftConfig, Option, RayConfig) { let daft_config = DaftConfig { setup: DaftSetup { name: test_name.clone(), - version: "=1.2.3".parse().unwrap(), + requires: "=1.2.3".parse().unwrap(), + python_version: "3.12".parse().unwrap(), + ray_version: "2.34".parse().unwrap(), provider_config: ProviderConfig::Provisioned(AwsConfig { region: test_name.clone(), number_of_workers, @@ -116,6 +155,7 @@ pub fn simple_config() -> (DaftConfig, Option, RayConfig) { image_id: test_name.clone(), iam_instance_profile_name: Some(test_name.clone()), dependencies: vec![], + run: vec![r#"echo "Hello, world!""#.into()], }), }, jobs: HashMap::default(), @@ -168,7 +208,8 @@ pub fn simple_config() -> (DaftConfig, Option, RayConfig) { "uv venv".into(), "echo 'source $HOME/.venv/bin/activate' >> ~/.bashrc".into(), "source ~/.bashrc".into(), - "uv pip install boto3 pip py-spy deltalake getdaft ray[default]".into(), + r#"uv pip install boto3 pip py-spy deltalake getdaft "ray[default]==2.34""#.into(), + r#"echo "Hello, world!""#.into(), ], };