Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions sea-orm-codegen/src/entity/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,33 @@ impl Column {
col_type.map(|ty| quote! { column_type = #ty })
}

pub fn get_ts_type_attrs(
&self,
model_extra_derives: &TokenStream,
model_extra_attributes: &TokenStream,
) -> Option<TokenStream> {
if !matches!(self.col_type, ColumnType::Vector(_)) {
return None;
}

let mut attrs = Vec::new();
let tokens = format!("{}{}", model_extra_derives, model_extra_attributes)
.replace(|c: char| c.is_whitespace(), "");

if tokens.contains("ts_rs::TS") || tokens.contains("ts(export)") {
attrs.push(quote! { #[ts(type = "number[]")] });
}
if tokens.contains("specta::Type") || tokens.contains("specta(export)") {
attrs.push(quote! { #[specta(type = "number[]")] });
}

if attrs.is_empty() {
None
} else {
Some(quote! { #(#attrs)* })
}
}

pub fn get_def(&self) -> TokenStream {
fn write_col_def(col_type: &ColumnType) -> TokenStream {
match col_type {
Expand Down Expand Up @@ -369,9 +396,38 @@ mod tests {
make_col!("date_time", ColumnType::DateTime),
make_col!("timestamp", ColumnType::Timestamp),
make_col!("timestamp_tz", ColumnType::TimestampWithTimeZone),
make_col!("embedding", ColumnType::Vector(None)),
]
}

#[test]
fn test_get_ts_type_attrs() {
let col = Column {
name: "embedding".to_owned(),
col_type: ColumnType::Vector(None),
auto_increment: false,
not_null: false,
unique: false,
unique_key: None,
};

let ts_attr = col
.get_ts_type_attrs(
&quote! { ts_rs::TS },
&TokenStream::new(),
)
.expect("Expected ts attribute");
assert_eq!(ts_attr.to_string(), "# [ts (type = \"number[]\")]");

let specta_attr = col
.get_ts_type_attrs(
&quote! { specta::Type },
&TokenStream::new(),
)
.expect("Expected specta attribute");
assert_eq!(specta_attr.to_string(), "# [specta (type = \"number[]\")]");
}

#[test]
fn test_get_name_snake_case() {
let columns = setup();
Expand Down
49 changes: 49 additions & 0 deletions sea-orm-codegen/src/entity/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2896,6 +2896,55 @@ mod tests {
Ok(())
}

#[test]
fn test_gen_with_ts_vector_support() -> io::Result<()> {
let entity = Entity {
table_name: "document".to_owned(),
columns: vec![
Column {
name: "id".to_owned(),
col_type: ColumnType::Integer,
auto_increment: true,
not_null: true,
unique: false,
unique_key: None,
},
Column {
name: "embedding".to_owned(),
col_type: ColumnType::Vector(None),
auto_increment: false,
not_null: false,
unique: false,
unique_key: None,
},
],
relations: vec![],
conjunct_relations: vec![],
primary_keys: vec![PrimaryKey {
name: "id".to_owned(),
}],
};

let generated = generated_to_string(EntityWriter::gen_compact_code_blocks(
&entity,
&WithSerde::None,
&default_column_option(),
&None,
false,
false,
&bonus_derive(["ts_rs::TS"]),
&TokenStream::new(),
&TokenStream::new(),
false,
true,
));

assert!(generated.contains("# [ts (type = \"number[]\")]"));
assert!(generated.contains("pub embedding : Option < PgVector >"));

Ok(())
}

#[test]
fn test_gen_import_active_enum() -> io::Result<()> {
let entities = vec![
Expand Down
5 changes: 5 additions & 0 deletions sea-orm-codegen/src/entity/writer/compact.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,18 @@ impl EntityWriter {
}
ts = quote! { #[sea_orm(#ts)] };
}
let ts_type_attribute = col.get_ts_type_attrs(
model_extra_derives,
model_extra_attributes,
);
let serde_attribute = col.get_serde_attribute(
is_primary_key,
serde_skip_deserializing_primary_key,
serde_skip_hidden_column,
);
ts = quote! {
#ts
#ts_type_attribute
#serde_attribute
};
ts
Expand Down
5 changes: 5 additions & 0 deletions sea-orm-codegen/src/entity/writer/dense.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,18 @@ impl EntityWriter {
}
ts = quote! { #[sea_orm(#ts)] };
}
let ts_type_attribute = col.get_ts_type_attrs(
model_extra_derives,
model_extra_attributes,
);
let serde_attribute = col.get_serde_attribute(
is_primary_key,
serde_skip_deserializing_primary_key,
serde_skip_hidden_column,
);
ts = quote! {
#ts
#ts_type_attribute
#serde_attribute
};
ts
Expand Down
26 changes: 21 additions & 5 deletions sea-orm-codegen/src/entity/writer/expanded.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,34 @@ impl EntityWriter {
let column_names_snake_case = entity.get_column_names_snake_case();
let column_rs_types = entity.get_column_rs_types(column_option);
let if_eq_needed = entity.get_eq_needed();
let serde_attributes = entity.get_column_serde_attributes(
serde_skip_deserializing_primary_key,
serde_skip_hidden_column,
);
let column_attributes: Vec<TokenStream> = entity
.columns
.iter()
.map(|col| {
let is_primary_key = entity.primary_keys.iter().any(|pk| pk.name == col.name);
let ts_type_attribute = col.get_ts_type_attrs(
model_extra_derives,
model_extra_attributes,
);
let serde_attribute = col.get_serde_attribute(
is_primary_key,
serde_skip_deserializing_primary_key,
serde_skip_hidden_column,
);
quote! {
#ts_type_attribute
#serde_attribute
}
})
.collect();
let extra_derive = with_serde.extra_derive();

quote! {
#[derive(Clone, Debug, PartialEq, DeriveModel, DeriveActiveModel #if_eq_needed #extra_derive #model_extra_derives)]
#model_extra_attributes
pub struct Model {
#(
#serde_attributes
#column_attributes
pub #column_names_snake_case: #column_rs_types,
)*
}
Expand Down
12 changes: 10 additions & 2 deletions sea-orm-codegen/src/entity/writer/frontend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,19 @@ impl EntityWriter {
.iter()
.map(|col| {
let is_primary_key = primary_keys.contains(&col.name);
col.get_serde_attribute(
let ts_type_attribute = col.get_ts_type_attrs(
model_extra_derives,
model_extra_attributes,
);
let serde_attribute = col.get_serde_attribute(
is_primary_key,
serde_skip_deserializing_primary_key,
serde_skip_hidden_column,
)
);
quote! {
#ts_type_attribute
#serde_attribute
}
})
.collect();
let extra_derive = with_serde.extra_derive();
Expand Down