Skip to content
12 changes: 10 additions & 2 deletions src/uu/stty/src/flags.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ use nix::sys::termios::{
SpecialCharacterIndices as S,
};

#[derive(Debug)]
#[cfg_attr(test, derive(PartialEq))]
pub enum BaudType {
Input,
Output,
Both,
}

#[derive(Debug)]
#[cfg_attr(test, derive(PartialEq))]
pub enum AllFlags<'a> {
Expand All @@ -38,7 +46,7 @@ pub enum AllFlags<'a> {
target_os = "netbsd",
target_os = "openbsd"
))]
Baud(u32),
Baud(u32, BaudType),
#[cfg(not(any(
target_os = "freebsd",
target_os = "dragonfly",
Expand All @@ -47,7 +55,7 @@ pub enum AllFlags<'a> {
target_os = "netbsd",
target_os = "openbsd"
)))]
Baud(BaudRate),
Baud(BaudRate, BaudType),
ControlFlags((&'a Flag<C>, bool)),
InputFlags((&'a Flag<I>, bool)),
LocalFlags((&'a Flag<L>, bool)),
Expand Down
117 changes: 60 additions & 57 deletions src/uu/stty/src/stty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
// spell-checker:ignore isig icanon iexten echoe crterase echok echonl noflsh xcase tostop echoprt prterase echoctl ctlecho echoke crtkill flusho extproc
// spell-checker:ignore lnext rprnt susp swtch vdiscard veof veol verase vintr vkill vlnext vquit vreprint vstart vstop vsusp vswtc vwerase werase
// spell-checker:ignore sigquit sigtstp
// spell-checker:ignore cbreak decctlq evenp litout oddp tcsadrain exta extb NCCS
// spell-checker:ignore cbreak decctlq evenp litout oddp tcsadrain exta extb NCCS cfsetispeed
// spell-checker:ignore notaflag notacombo notabaud

mod flags;
Expand All @@ -21,7 +21,7 @@ use clap::{Arg, ArgAction, ArgMatches, Command};
use nix::libc::{O_NONBLOCK, TIOCGWINSZ, TIOCSWINSZ, c_ushort};
use nix::sys::termios::{
ControlFlags, InputFlags, LocalFlags, OutputFlags, SetArg, SpecialCharacterIndices as S,
Termios, cfgetospeed, cfsetospeed, tcgetattr, tcsetattr,
Termios, cfgetospeed, cfsetispeed, cfsetospeed, tcgetattr, tcsetattr,
};
use nix::{ioctl_read_bad, ioctl_write_ptr_bad};
use std::cmp::Ordering;
Expand Down Expand Up @@ -274,19 +274,24 @@ fn stty(opts: &Options) -> UResult<()> {
let mut args_iter = args.iter();
while let Some(&arg) = args_iter.next() {
match arg {
"ispeed" | "ospeed" => match args_iter.next() {
"ispeed" => match args_iter.next() {
Some(speed) => {
if let Some(baud_flag) = string_to_baud(speed) {
if let Some(baud_flag) = string_to_baud(speed, flags::BaudType::Input) {
valid_args.push(ArgOptions::Flags(baud_flag));
} else {
return Err(USimpleError::new(
1,
translate!(
"stty-error-invalid-speed",
"arg" => *arg,
"speed" => *speed,
),
));
return invalid_speed(arg, speed);
}
}
None => {
return missing_arg(arg);
}
},
"ospeed" => match args_iter.next() {
Some(speed) => {
if let Some(baud_flag) = string_to_baud(speed, flags::BaudType::Output) {
valid_args.push(ArgOptions::Flags(baud_flag));
} else {
return invalid_speed(arg, speed);
}
}
None => {
Expand Down Expand Up @@ -383,12 +388,12 @@ fn stty(opts: &Options) -> UResult<()> {
return missing_arg(arg);
}
// baud rate
} else if let Some(baud_flag) = string_to_baud(arg) {
} else if let Some(baud_flag) = string_to_baud(arg, flags::BaudType::Both) {
valid_args.push(ArgOptions::Flags(baud_flag));
// non control char flag
} else if let Some(flag) = string_to_flag(arg) {
let remove_group = match flag {
AllFlags::Baud(_) => false,
AllFlags::Baud(_, _) => false,
AllFlags::ControlFlags((flag, remove)) => {
check_flag_group(flag, remove)
}
Expand Down Expand Up @@ -417,7 +422,7 @@ fn stty(opts: &Options) -> UResult<()> {
for arg in &valid_args {
match arg {
ArgOptions::Mapping(mapping) => apply_char_mapping(&mut termios, mapping),
ArgOptions::Flags(flag) => apply_setting(&mut termios, flag),
ArgOptions::Flags(flag) => apply_setting(&mut termios, flag)?,
ArgOptions::Special(setting) => {
apply_special_setting(&mut termios, setting, opts.file.as_raw_fd())?;
}
Expand Down Expand Up @@ -468,6 +473,17 @@ fn invalid_integer_arg<T>(arg: &str) -> Result<T, Box<dyn UError>> {
))
}

fn invalid_speed<T>(arg: &str, speed: &str) -> Result<T, Box<dyn UError>> {
Err(UUsageError::new(
1,
translate!(
"stty-error-invalid-speed",
"arg" => arg,
"speed" => speed,
),
))
}

/// GNU uses different error messages if values overflow or underflow a u8,
/// this function returns the appropriate error message in the case of overflow or underflow, or u8 on success
fn parse_u8_or_err(arg: &str) -> Result<u8, String> {
Expand Down Expand Up @@ -719,7 +735,7 @@ fn parse_baud_with_rounding(normalized: &str) -> Option<u32> {
Some(value)
}

fn string_to_baud(arg: &str) -> Option<AllFlags<'_>> {
fn string_to_baud(arg: &str, baud_type: flags::BaudType) -> Option<AllFlags<'_>> {
// Reject invalid formats
if arg != arg.trim_end()
|| arg.trim().starts_with('-')
Expand All @@ -744,7 +760,7 @@ fn string_to_baud(arg: &str) -> Option<AllFlags<'_>> {
target_os = "netbsd",
target_os = "openbsd"
))]
return Some(AllFlags::Baud(value));
return Some(AllFlags::Baud(value, baud_type));

#[cfg(not(any(
target_os = "freebsd",
Expand All @@ -757,7 +773,7 @@ fn string_to_baud(arg: &str) -> Option<AllFlags<'_>> {
{
for (text, baud_rate) in BAUD_RATES {
if text.parse::<u32>().ok() == Some(value) {
return Some(AllFlags::Baud(*baud_rate));
return Some(AllFlags::Baud(*baud_rate, baud_type));
}
}
None
Expand Down Expand Up @@ -940,9 +956,9 @@ fn print_flags<T: TermiosFlag>(
}

/// Apply a single setting
fn apply_setting(termios: &mut Termios, setting: &AllFlags) {
fn apply_setting(termios: &mut Termios, setting: &AllFlags) -> nix::Result<()> {
match setting {
AllFlags::Baud(_) => apply_baud_rate_flag(termios, setting),
AllFlags::Baud(_, _) => apply_baud_rate_flag(termios, setting)?,
AllFlags::ControlFlags((setting, disable)) => {
setting.flag.apply(termios, !disable);
}
Expand All @@ -956,34 +972,21 @@ fn apply_setting(termios: &mut Termios, setting: &AllFlags) {
setting.flag.apply(termios, !disable);
}
}
Ok(())
}

fn apply_baud_rate_flag(termios: &mut Termios, input: &AllFlags) {
// BSDs use a u32 for the baud rate, so any decimal number applies.
#[cfg(any(
target_os = "freebsd",
target_os = "dragonfly",
target_os = "ios",
target_os = "macos",
target_os = "netbsd",
target_os = "openbsd"
))]
if let AllFlags::Baud(n) = input {
cfsetospeed(termios, *n).expect("Failed to set baud rate");
}

// Other platforms use an enum.
#[cfg(not(any(
target_os = "freebsd",
target_os = "dragonfly",
target_os = "ios",
target_os = "macos",
target_os = "netbsd",
target_os = "openbsd"
)))]
if let AllFlags::Baud(br) = input {
cfsetospeed(termios, *br).expect("Failed to set baud rate");
fn apply_baud_rate_flag(termios: &mut Termios, input: &AllFlags) -> nix::Result<()> {
if let AllFlags::Baud(rate, baud_type) = input {
match baud_type {
flags::BaudType::Input => cfsetispeed(termios, *rate)?,
flags::BaudType::Output => cfsetospeed(termios, *rate)?,
flags::BaudType::Both => {
cfsetispeed(termios, *rate)?;
cfsetospeed(termios, *rate)?;
}
}
}
Ok(())
}

fn apply_char_mapping(termios: &mut Termios, mapping: &(S, u8)) {
Expand Down Expand Up @@ -1446,10 +1449,10 @@ mod tests {
target_os = "openbsd"
)))]
{
assert!(string_to_baud("9600").is_some());
assert!(string_to_baud("115200").is_some());
assert!(string_to_baud("38400").is_some());
assert!(string_to_baud("19200").is_some());
assert!(string_to_baud("9600", flags::BaudType::Both).is_some());
assert!(string_to_baud("115200", flags::BaudType::Both).is_some());
assert!(string_to_baud("38400", flags::BaudType::Both).is_some());
assert!(string_to_baud("19200", flags::BaudType::Both).is_some());
}

#[cfg(any(
Expand All @@ -1461,10 +1464,10 @@ mod tests {
target_os = "openbsd"
))]
{
assert!(string_to_baud("9600").is_some());
assert!(string_to_baud("115200").is_some());
assert!(string_to_baud("1000000").is_some());
assert!(string_to_baud("0").is_some());
assert!(string_to_baud("9600", flags::BaudType::Both).is_some());
assert!(string_to_baud("115200", flags::BaudType::Both).is_some());
assert!(string_to_baud("1000000", flags::BaudType::Both).is_some());
assert!(string_to_baud("0", flags::BaudType::Both).is_some());
}
}

Expand All @@ -1479,10 +1482,10 @@ mod tests {
target_os = "openbsd"
)))]
{
assert_eq!(string_to_baud("995"), None);
assert_eq!(string_to_baud("invalid"), None);
assert_eq!(string_to_baud(""), None);
assert_eq!(string_to_baud("abc"), None);
assert_eq!(string_to_baud("995", flags::BaudType::Both), None);
assert_eq!(string_to_baud("invalid", flags::BaudType::Both), None);
assert_eq!(string_to_baud("", flags::BaudType::Both), None);
assert_eq!(string_to_baud("abc", flags::BaudType::Both), None);
}
}

Expand Down
65 changes: 65 additions & 0 deletions tests/by-util/test_stty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1627,6 +1627,71 @@ fn test_stty_uses_stdin() {
.stdout_contains("columns 100");
}

#[test]
#[cfg(unix)]
fn test_ispeed_ospeed_valid_speeds() {
let (path, _controller, _replica) = pty_path();
let (_at, ts) = at_and_ts!();

// Test various valid baud rates for both ispeed and ospeed
let test_cases = [
("ispeed", "50"),
("ispeed", "9600"),
("ispeed", "19200"),
("ospeed", "1200"),
("ospeed", "9600"),
("ospeed", "38400"),
];

for (arg, speed) in test_cases {
let result = ts.ucmd().args(&["--file", &path, arg, speed]).run();
let exp_result = unwrap_or_return!(expected_result(&ts, &["--file", &path, arg, speed]));
let normalized_stderr = normalize_stderr(result.stderr_str());

result
.stdout_is(exp_result.stdout_str())
.code_is(exp_result.code());
assert_eq!(normalized_stderr, exp_result.stderr_str());
}
}

#[test]
#[cfg(all(
unix,
not(any(
target_os = "freebsd",
target_os = "dragonfly",
target_os = "ios",
target_os = "macos",
target_os = "netbsd",
target_os = "openbsd"
))
))]
#[ignore = "Issue: #9547"]
fn test_ispeed_ospeed_invalid_speeds() {
let (path, _controller, _replica) = pty_path();
let (_at, ts) = at_and_ts!();

// Test invalid speed values (non-standard baud rates)
let test_cases = [
("ispeed", "12345"),
("ospeed", "99999"),
("ispeed", "abc"),
("ospeed", "xyz"),
];

for (arg, speed) in test_cases {
let result = ts.ucmd().args(&["--file", &path, arg, speed]).run();
let exp_result = unwrap_or_return!(expected_result(&ts, &["--file", &path, arg, speed]));
let normalized_stderr = normalize_stderr(result.stderr_str());

result
.stdout_is(exp_result.stdout_str())
.code_is(exp_result.code());
assert_eq!(normalized_stderr, exp_result.stderr_str());
}
}

#[test]
#[cfg(unix)]
fn test_columns_env_wrapping() {
Expand Down
Loading