Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -459,10 +459,13 @@ mod tests {
}

#[test]
fn arithmetic_scalar(){
fn arithmetic_scalar() {
let qs = "56";
let res = arithmetic(qs.as_bytes());
assert!(res.is_err());
assert_eq!(nom::Err::Error(nom::error::Error::new(qs.as_bytes(), ErrorKind::Tag)), res.err().unwrap());
assert_eq!(
nom::Err::Error(nom::error::Error::new(qs.as_bytes(), ErrorKind::Tag)),
res.err().unwrap()
);
}
}
172 changes: 127 additions & 45 deletions src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,19 @@ use nom::branch::alt;
use nom::character::complete::{alphanumeric1, digit1, line_ending, multispace0, multispace1};
use nom::character::is_alphanumeric;
use nom::combinator::{map, not, peek};
use nom::{IResult, InputLength, Parser};
use nom::{Err, IResult, InputLength, Parser};
use std::fmt::{self, Display};
use std::str;
use std::str::FromStr;

use arithmetic::{arithmetic_expression, ArithmeticExpression};
use case::case_when_column;
use column::{Column, FunctionArgument, FunctionArguments, FunctionExpression};
use insert::InsertDataValue;
use keywords::{escape_if_keyword, sql_keyword};
use nom::bytes::complete::{is_not, tag, tag_no_case, take, take_until, take_while1};
use nom::combinator::opt;
use nom::error::{ErrorKind, ParseError};
use nom::error::{Error, ErrorKind, ParseError};
use nom::multi::{fold_many0, many0, many1, separated_list0};
use nom::sequence::{delimited, pair, preceded, separated_pair, terminated, tuple};
use table::Table;
Expand Down Expand Up @@ -354,6 +355,33 @@ impl Display for FieldValueExpression {
}
}

#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
pub enum FieldAssignmentValue {
Col(Column),
Expression(FieldValueExpression),
}

impl Display for FieldAssignmentValue {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Self::Col(ref col) => write!(f, "{}", col),
Self::Expression(ref expr) => write!(f, "{}", expr),
}
}
}

impl From<Column> for FieldAssignmentValue {
fn from(c: Column) -> Self {
Self::Col(c)
}
}

impl From<FieldValueExpression> for FieldAssignmentValue {
fn from(fve: FieldValueExpression) -> Self {
Self::Expression(fve)
}
}

#[inline]
pub fn is_sql_identifier(chr: u8) -> bool {
is_alphanumeric(chr) || chr == '_' as u8 || chr == '@' as u8
Expand Down Expand Up @@ -388,7 +416,7 @@ where
let (inp, _) = first.parse(inp)?;
let (inp, o2) = second.parse(inp)?;
third.parse(inp).map(|(i, _)| (i, o2))
},
}
}
}
}
Expand Down Expand Up @@ -641,7 +669,8 @@ pub fn function_argument_parser(i: &[u8]) -> IResult<&[u8], FunctionArgument> {
// present.
pub fn function_arguments(i: &[u8]) -> IResult<&[u8], (FunctionArgument, bool)> {
let distinct_parser = opt(tuple((tag_no_case("distinct"), multispace1)));
let (remaining_input, (distinct, args)) = tuple((distinct_parser, function_argument_parser))(i)?;
let (remaining_input, (distinct, args)) =
tuple((distinct_parser, function_argument_parser))(i)?;
Ok((remaining_input, (args, distinct.is_some())))
}

Expand Down Expand Up @@ -695,12 +724,25 @@ pub fn column_function(i: &[u8]) -> IResult<&[u8], FunctionExpression> {
FunctionExpression::GroupConcat(FunctionArgument::Column(col.clone()), sep)
},
),
map(tuple((sql_identifier, multispace0, tag("("), separated_list0(tag(","), delimited(multispace0, function_argument_parser, multispace0)), tag(")"))), |tuple| {
let (name, _, _, arguments, _) = tuple;
FunctionExpression::Generic(
str::from_utf8(name).unwrap().to_string(),
FunctionArguments::from(arguments))
})
map(
tuple((
sql_identifier,
multispace0,
tag("("),
separated_list0(
tag(","),
delimited(multispace0, function_argument_parser, multispace0),
),
tag(")"),
)),
|tuple| {
let (name, _, _, arguments, _) = tuple;
FunctionExpression::Generic(
str::from_utf8(name).unwrap().to_string(),
FunctionArguments::from(arguments),
)
},
),
))(i)
}

Expand Down Expand Up @@ -764,11 +806,20 @@ pub fn column_identifier(i: &[u8]) -> IResult<&[u8], Column> {

// Parses a SQL identifier (alphanumeric1 and "_").
pub fn sql_identifier(i: &[u8]) -> IResult<&[u8], &[u8]> {
alt((
let (i, si) = alt((
preceded(not(peek(sql_keyword)), take_while1(is_sql_identifier)),
delimited(tag("`"), take_while1(is_sql_identifier), tag("`")),
delimited(tag("["), take_while1(is_sql_identifier), tag("]")),
))(i)
))(i)?;

if str::from_utf8(si).unwrap_or("0").parse::<usize>().is_ok() {
return Err(Err::Error(Error {
input: i,
code: ErrorKind::IsA,
}));
}

Ok((i, si))
}

// Parse an unsigned integer.
Expand Down Expand Up @@ -822,21 +873,23 @@ pub fn as_alias(i: &[u8]) -> IResult<&[u8], &str> {
)(i)
}

fn field_value_expr(i: &[u8]) -> IResult<&[u8], FieldValueExpression> {
fn field_value_expr(i: &[u8]) -> IResult<&[u8], FieldAssignmentValue> {
alt((
map(arithmetic_expression, |ae| {
FieldValueExpression::Arithmetic(ae).into()
}),
map(column_identifier, |c| c.into()),
map(literal, |l| {
FieldValueExpression::Literal(LiteralExpression {
value: l.into(),
alias: None,
})
}),
map(arithmetic_expression, |ae| {
FieldValueExpression::Arithmetic(ae)
.into()
}),
))(i)
}

fn assignment_expr(i: &[u8]) -> IResult<&[u8], (Column, FieldValueExpression)> {
fn assignment_expr(i: &[u8]) -> IResult<&[u8], (Column, FieldAssignmentValue)> {
separated_pair(
column_identifier_no_alias,
delimited(multispace0, tag("="), multispace0),
Expand All @@ -858,7 +911,7 @@ where
delimited(multispace0, tag("="), multispace0)(i)
}

pub fn assignment_expr_list(i: &[u8]) -> IResult<&[u8], Vec<(Column, FieldValueExpression)>> {
pub fn assignment_expr_list(i: &[u8]) -> IResult<&[u8], Vec<(Column, FieldAssignmentValue)>> {
many1(terminated(assignment_expr, opt(ws_sep_comma)))(i)
}

Expand Down Expand Up @@ -1018,25 +1071,48 @@ pub fn value_list(i: &[u8]) -> IResult<&[u8], Vec<Literal>> {
many0(delimited(multispace0, literal, opt(ws_sep_comma)))(i)
}

pub fn insert_data_value_list(i: &[u8]) -> IResult<&[u8], Vec<InsertDataValue>> {
many0(delimited(multispace0, insert_data_value, opt(ws_sep_comma)))(i)
}

pub fn insert_data_value(i: &[u8]) -> IResult<&[u8], InsertDataValue> {
alt((
map(
tuple((
tag_no_case("DEFAULT"),
tag("("),
multispace0,
column_identifier_no_alias,
multispace0,
tag(")"),
)),
|(_, _, _, c, _, _)| InsertDataValue::ColumnDefault(c),
),
map(tag_no_case("DEFAULT"), |_| InsertDataValue::Default),
map(literal, |l| InsertDataValue::Literal(l)),
))(i)
}

// Parse a reference to a named schema.table, with an optional alias
pub fn schema_table_reference(i: &[u8]) -> IResult<&[u8], Table> {
map(
tuple((
opt(pair(sql_identifier, tag("."))),
sql_identifier,
opt(as_alias)
)),
|tup| Table {
name: String::from(str::from_utf8(tup.1).unwrap()),
alias: match tup.2 {
Some(a) => Some(String::from(a)),
None => None,
},
schema: match tup.0 {
Some((schema, _)) => Some(String::from(str::from_utf8(schema).unwrap())),
None => None,
tuple((
opt(pair(sql_identifier, tag("."))),
sql_identifier,
opt(as_alias),
)),
|tup| Table {
name: String::from(str::from_utf8(tup.1).unwrap()),
alias: match tup.2 {
Some(a) => Some(String::from(a)),
None => None,
},
schema: match tup.0 {
Some((schema, _)) => Some(String::from(str::from_utf8(schema).unwrap())),
None => None,
},
},
})(i)
)(i)
}

// Parse a reference to a named table, with an optional alias
Expand All @@ -1047,7 +1123,7 @@ pub fn table_reference(i: &[u8]) -> IResult<&[u8], Table> {
Some(a) => Some(String::from(a)),
None => None,
},
schema: None,
schema: None,
})(i)
}

Expand Down Expand Up @@ -1137,25 +1213,31 @@ mod tests {
name: String::from("max(addr_id)"),
alias: None,
table: None,
function: Some(Box::new(FunctionExpression::Max(
FunctionArgument::Column(Column::from("addr_id")),
))),
function: Some(Box::new(FunctionExpression::Max(FunctionArgument::Column(
Column::from("addr_id"),
)))),
};
assert_eq!(res.unwrap().1, expected);
}

#[test]
fn simple_generic_function() {
let qlist = ["coalesce(a,b,c)".as_bytes(), "coalesce (a,b,c)".as_bytes(), "coalesce(a ,b,c)".as_bytes(), "coalesce(a, b,c)".as_bytes()];
let qlist = [
"coalesce(a,b,c)".as_bytes(),
"coalesce (a,b,c)".as_bytes(),
"coalesce(a ,b,c)".as_bytes(),
"coalesce(a, b,c)".as_bytes(),
];
for q in qlist.iter() {
let res = column_function(q);
let expected = FunctionExpression::Generic("coalesce".to_string(),
FunctionArguments::from(
vec!(
FunctionArgument::Column(Column::from("a")),
FunctionArgument::Column(Column::from("b")),
FunctionArgument::Column(Column::from("c"))
)));
let expected = FunctionExpression::Generic(
"coalesce".to_string(),
FunctionArguments::from(vec![
FunctionArgument::Column(Column::from("a")),
FunctionArgument::Column(Column::from("b")),
FunctionArgument::Column(Column::from("c")),
]),
);
assert_eq!(res, Ok((&b""[..], expected)));
}
}
Expand Down
Loading