Skip to content

Commit

Permalink
Support array datatype in PostgreSQL (#1132)
Browse files Browse the repository at this point in the history
* PostgreSQL array (draft)

* Fixup

* Fixup

* Fixup

* Fixup

* Fixup

* Refactoring

* generate entity for Postgres array fields

* Add tests

* Update Cargo.toml

Co-authored-by: Chris Tsang <[email protected]>
  • Loading branch information
billy1624 and tyt2y3 authored Oct 23, 2022
1 parent 2757190 commit b5b9790
Show file tree
Hide file tree
Showing 16 changed files with 764 additions and 178 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ actix-rt = { version = "2.2.0" }
maplit = { version = "^1" }
rust_decimal_macros = { version = "^1" }
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
sea-orm = { path = ".", features = ["mock", "debug-print", "tests-cfg"] }
sea-orm = { path = ".", features = ["mock", "debug-print", "tests-cfg", "postgres-array"] }
pretty_assertions = { version = "^0.7" }
time = { version = "^0.3", features = ["macros"] }

Expand All @@ -76,6 +76,7 @@ with-chrono = ["chrono", "sea-query/with-chrono", "sea-query-binder?/with-chrono
with-rust_decimal = ["rust_decimal", "sea-query/with-rust_decimal", "sea-query-binder?/with-rust_decimal", "sqlx?/decimal"]
with-uuid = ["uuid", "sea-query/with-uuid", "sea-query-binder?/with-uuid", "sqlx?/uuid"]
with-time = ["time", "sea-query/with-time", "sea-query-binder?/with-time", "sqlx?/time"]
postgres-array = ["sea-query/postgres-array", "sea-query-binder?/postgres-array"]
sqlx-dep = []
sqlx-all = ["sqlx-mysql", "sqlx-postgres", "sqlx-sqlite"]
sqlx-mysql = ["sqlx-dep", "sea-query-binder/sqlx-mysql", "sqlx/mysql"]
Expand Down
2 changes: 1 addition & 1 deletion sea-orm-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ clap = { version = "^3.2", features = ["env", "derive"] }
dotenvy = { version = "^0.15", optional = true }
async-std = { version = "^1.9", features = [ "attributes", "tokio1" ], optional = true }
sea-orm-codegen = { version = "^0.10.0", path = "../sea-orm-codegen", optional = true }
sea-schema = { version = "^0.10.0" }
sea-schema = { version = "^0.10.1" }
sqlx = { version = "^0.6", default-features = false, features = [ "mysql", "postgres" ], optional = true }
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
tracing = { version = "0.1" }
Expand Down
218 changes: 117 additions & 101 deletions sea-orm-codegen/src/entity/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,53 +28,59 @@ impl Column {
}

pub fn get_rs_type(&self, date_time_crate: &DateTimeCrate) -> TokenStream {
#[allow(unreachable_patterns)]
let ident: TokenStream = match &self.col_type {
ColumnType::Char(_)
| ColumnType::String(_)
| ColumnType::Text
| ColumnType::Custom(_) => "String".to_owned(),
ColumnType::TinyInteger(_) => "i8".to_owned(),
ColumnType::SmallInteger(_) => "i16".to_owned(),
ColumnType::Integer(_) => "i32".to_owned(),
ColumnType::BigInteger(_) => "i64".to_owned(),
ColumnType::TinyUnsigned(_) => "u8".to_owned(),
ColumnType::SmallUnsigned(_) => "u16".to_owned(),
ColumnType::Unsigned(_) => "u32".to_owned(),
ColumnType::BigUnsigned(_) => "u64".to_owned(),
ColumnType::Float(_) => "f32".to_owned(),
ColumnType::Double(_) => "f64".to_owned(),
ColumnType::Json | ColumnType::JsonBinary => "Json".to_owned(),
ColumnType::Date => match date_time_crate {
DateTimeCrate::Chrono => "Date".to_owned(),
DateTimeCrate::Time => "TimeDate".to_owned(),
},
ColumnType::Time(_) => match date_time_crate {
DateTimeCrate::Chrono => "Time".to_owned(),
DateTimeCrate::Time => "TimeTime".to_owned(),
},
ColumnType::DateTime(_) => match date_time_crate {
DateTimeCrate::Chrono => "DateTime".to_owned(),
DateTimeCrate::Time => "TimeDateTime".to_owned(),
},
ColumnType::Timestamp(_) => match date_time_crate {
DateTimeCrate::Chrono => "DateTimeUtc".to_owned(),
// ColumnType::Timpestamp(_) => time::PrimitiveDateTime: https://docs.rs/sqlx/0.3.5/sqlx/postgres/types/index.html#time
DateTimeCrate::Time => "TimeDateTime".to_owned(),
},
ColumnType::TimestampWithTimeZone(_) => match date_time_crate {
DateTimeCrate::Chrono => "DateTimeWithTimeZone".to_owned(),
DateTimeCrate::Time => "TimeDateTimeWithTimeZone".to_owned(),
},
ColumnType::Decimal(_) | ColumnType::Money(_) => "Decimal".to_owned(),
ColumnType::Uuid => "Uuid".to_owned(),
ColumnType::Binary(_) | ColumnType::VarBinary(_) => "Vec<u8>".to_owned(),
ColumnType::Boolean => "bool".to_owned(),
ColumnType::Enum { name, .. } => name.to_string().to_camel_case(),
_ => unimplemented!(),
fn write_rs_type(col_type: &ColumnType, date_time_crate: &DateTimeCrate) -> String {
#[allow(unreachable_patterns)]
match col_type {
ColumnType::Char(_)
| ColumnType::String(_)
| ColumnType::Text
| ColumnType::Custom(_) => "String".to_owned(),
ColumnType::TinyInteger(_) => "i8".to_owned(),
ColumnType::SmallInteger(_) => "i16".to_owned(),
ColumnType::Integer(_) => "i32".to_owned(),
ColumnType::BigInteger(_) => "i64".to_owned(),
ColumnType::TinyUnsigned(_) => "u8".to_owned(),
ColumnType::SmallUnsigned(_) => "u16".to_owned(),
ColumnType::Unsigned(_) => "u32".to_owned(),
ColumnType::BigUnsigned(_) => "u64".to_owned(),
ColumnType::Float(_) => "f32".to_owned(),
ColumnType::Double(_) => "f64".to_owned(),
ColumnType::Json | ColumnType::JsonBinary => "Json".to_owned(),
ColumnType::Date => match date_time_crate {
DateTimeCrate::Chrono => "Date".to_owned(),
DateTimeCrate::Time => "TimeDate".to_owned(),
},
ColumnType::Time(_) => match date_time_crate {
DateTimeCrate::Chrono => "Time".to_owned(),
DateTimeCrate::Time => "TimeTime".to_owned(),
},
ColumnType::DateTime(_) => match date_time_crate {
DateTimeCrate::Chrono => "DateTime".to_owned(),
DateTimeCrate::Time => "TimeDateTime".to_owned(),
},
ColumnType::Timestamp(_) => match date_time_crate {
DateTimeCrate::Chrono => "DateTimeUtc".to_owned(),
// ColumnType::Timpestamp(_) => time::PrimitiveDateTime: https://docs.rs/sqlx/0.3.5/sqlx/postgres/types/index.html#time
DateTimeCrate::Time => "TimeDateTime".to_owned(),
},
ColumnType::TimestampWithTimeZone(_) => match date_time_crate {
DateTimeCrate::Chrono => "DateTimeWithTimeZone".to_owned(),
DateTimeCrate::Time => "TimeDateTimeWithTimeZone".to_owned(),
},
ColumnType::Decimal(_) | ColumnType::Money(_) => "Decimal".to_owned(),
ColumnType::Uuid => "Uuid".to_owned(),
ColumnType::Binary(_) | ColumnType::VarBinary(_) => "Vec<u8>".to_owned(),
ColumnType::Boolean => "bool".to_owned(),
ColumnType::Enum { name, .. } => name.to_string().to_camel_case(),
ColumnType::Array(column_type) => {
format!("Vec<{}>", write_rs_type(column_type, date_time_crate))
}
_ => unimplemented!(),
}
}
.parse()
.unwrap();
let ident: TokenStream = write_rs_type(&self.col_type, date_time_crate)
.parse()
.unwrap();
match self.not_null {
true => quote! { #ident },
false => quote! { Option<#ident> },
Expand All @@ -97,62 +103,72 @@ impl Column {
}

pub fn get_def(&self) -> TokenStream {
let mut col_def = match &self.col_type {
ColumnType::Char(s) => match s {
Some(s) => quote! { ColumnType::Char(Some(#s)).def() },
None => quote! { ColumnType::Char(None).def() },
},
ColumnType::String(s) => match s {
Some(s) => quote! { ColumnType::String(Some(#s)).def() },
None => quote! { ColumnType::String(None).def() },
},
ColumnType::Text => quote! { ColumnType::Text.def() },
ColumnType::TinyInteger(_) => quote! { ColumnType::TinyInteger.def() },
ColumnType::SmallInteger(_) => quote! { ColumnType::SmallInteger.def() },
ColumnType::Integer(_) => quote! { ColumnType::Integer.def() },
ColumnType::BigInteger(_) => quote! { ColumnType::BigInteger.def() },
ColumnType::TinyUnsigned(_) => quote! { ColumnType::TinyUnsigned.def() },
ColumnType::SmallUnsigned(_) => quote! { ColumnType::SmallUnsigned.def() },
ColumnType::Unsigned(_) => quote! { ColumnType::Unsigned.def() },
ColumnType::BigUnsigned(_) => quote! { ColumnType::BigUnsigned.def() },
ColumnType::Float(_) => quote! { ColumnType::Float.def() },
ColumnType::Double(_) => quote! { ColumnType::Double.def() },
ColumnType::Decimal(s) => match s {
Some((s1, s2)) => quote! { ColumnType::Decimal(Some((#s1, #s2))).def() },
None => quote! { ColumnType::Decimal(None).def() },
},
ColumnType::DateTime(_) => quote! { ColumnType::DateTime.def() },
ColumnType::Timestamp(_) => quote! { ColumnType::Timestamp.def() },
ColumnType::TimestampWithTimeZone(_) => {
quote! { ColumnType::TimestampWithTimeZone.def() }
}
ColumnType::Time(_) => quote! { ColumnType::Time.def() },
ColumnType::Date => quote! { ColumnType::Date.def() },
ColumnType::Binary(BlobSize::Blob(_)) | ColumnType::VarBinary(_) => {
quote! { ColumnType::Binary.def() }
}
ColumnType::Binary(BlobSize::Tiny) => quote! { ColumnType::TinyBinary.def() },
ColumnType::Binary(BlobSize::Medium) => quote! { ColumnType::MediumBinary.def() },
ColumnType::Binary(BlobSize::Long) => quote! { ColumnType::LongBinary.def() },
ColumnType::Boolean => quote! { ColumnType::Boolean.def() },
ColumnType::Money(s) => match s {
Some((s1, s2)) => quote! { ColumnType::Money(Some((#s1, #s2))).def() },
None => quote! { ColumnType::Money(None).def() },
},
ColumnType::Json => quote! { ColumnType::Json.def() },
ColumnType::JsonBinary => quote! { ColumnType::JsonBinary.def() },
ColumnType::Uuid => quote! { ColumnType::Uuid.def() },
ColumnType::Custom(s) => {
let s = s.to_string();
quote! { ColumnType::Custom(#s.to_owned()).def() }
}
ColumnType::Enum { name, .. } => {
let enum_ident = format_ident!("{}", name.to_string().to_camel_case());
quote! { #enum_ident::db_type() }
fn write_col_def(col_type: &ColumnType) -> TokenStream {
match col_type {
ColumnType::Char(s) => match s {
Some(s) => quote! { ColumnType::Char(Some(#s)) },
None => quote! { ColumnType::Char(None) },
},
ColumnType::String(s) => match s {
Some(s) => quote! { ColumnType::String(Some(#s)) },
None => quote! { ColumnType::String(None) },
},
ColumnType::Text => quote! { ColumnType::Text },
ColumnType::TinyInteger(_) => quote! { ColumnType::TinyInteger },
ColumnType::SmallInteger(_) => quote! { ColumnType::SmallInteger },
ColumnType::Integer(_) => quote! { ColumnType::Integer },
ColumnType::BigInteger(_) => quote! { ColumnType::BigInteger },
ColumnType::TinyUnsigned(_) => quote! { ColumnType::TinyUnsigned },
ColumnType::SmallUnsigned(_) => quote! { ColumnType::SmallUnsigned },
ColumnType::Unsigned(_) => quote! { ColumnType::Unsigned },
ColumnType::BigUnsigned(_) => quote! { ColumnType::BigUnsigned },
ColumnType::Float(_) => quote! { ColumnType::Float },
ColumnType::Double(_) => quote! { ColumnType::Double },
ColumnType::Decimal(s) => match s {
Some((s1, s2)) => quote! { ColumnType::Decimal(Some((#s1, #s2))) },
None => quote! { ColumnType::Decimal(None) },
},
ColumnType::DateTime(_) => quote! { ColumnType::DateTime },
ColumnType::Timestamp(_) => quote! { ColumnType::Timestamp },
ColumnType::TimestampWithTimeZone(_) => {
quote! { ColumnType::TimestampWithTimeZone }
}
ColumnType::Time(_) => quote! { ColumnType::Time },
ColumnType::Date => quote! { ColumnType::Date },
ColumnType::Binary(BlobSize::Blob(_)) | ColumnType::VarBinary(_) => {
quote! { ColumnType::Binary }
}
ColumnType::Binary(BlobSize::Tiny) => quote! { ColumnType::TinyBinary },
ColumnType::Binary(BlobSize::Medium) => quote! { ColumnType::MediumBinary },
ColumnType::Binary(BlobSize::Long) => quote! { ColumnType::LongBinary },
ColumnType::Boolean => quote! { ColumnType::Boolean },
ColumnType::Money(s) => match s {
Some((s1, s2)) => quote! { ColumnType::Money(Some((#s1, #s2))) },
None => quote! { ColumnType::Money(None) },
},
ColumnType::Json => quote! { ColumnType::Json },
ColumnType::JsonBinary => quote! { ColumnType::JsonBinary },
ColumnType::Uuid => quote! { ColumnType::Uuid },
ColumnType::Custom(s) => {
let s = s.to_string();
quote! { ColumnType::Custom(#s.to_owned()) }
}
ColumnType::Enum { name, .. } => {
let enum_ident = format_ident!("{}", name.to_string().to_camel_case());
quote! { #enum_ident::db_type() }
}
ColumnType::Array(column_type) => {
let column_type = write_col_def(column_type);
quote! { ColumnType::Array(sea_orm::sea_query::SeaRc::new(#column_type)) }
}
#[allow(unreachable_patterns)]
_ => unimplemented!(),
}
#[allow(unreachable_patterns)]
_ => unimplemented!(),
};
}
let mut col_def = write_col_def(&self.col_type);
col_def.extend(quote! {
.def()
});
if !self.not_null {
col_def.extend(quote! {
.null()
Expand Down
49 changes: 44 additions & 5 deletions sea-orm-codegen/src/entity/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ mod tests {
};
use pretty_assertions::assert_eq;
use proc_macro2::TokenStream;
use sea_query::{ColumnType, ForeignKeyAction};
use sea_query::{ColumnType, ForeignKeyAction, SeaRc};
use std::io::{self, BufRead, BufReader, Read};

fn setup() -> Vec<Entity> {
Expand Down Expand Up @@ -1120,6 +1120,41 @@ mod tests {
name: "id".to_owned(),
}],
},
Entity {
table_name: "collection".to_owned(),
columns: vec![
Column {
name: "id".to_owned(),
col_type: ColumnType::Integer(Some(11)),
auto_increment: true,
not_null: true,
unique: false,
},
Column {
name: "integers".to_owned(),
col_type: ColumnType::Array(SeaRc::new(Box::new(ColumnType::Integer(
None,
)))),
auto_increment: false,
not_null: true,
unique: false,
},
Column {
name: "integers_opt".to_owned(),
col_type: ColumnType::Array(SeaRc::new(Box::new(ColumnType::Integer(
None,
)))),
auto_increment: false,
not_null: false,
unique: false,
},
],
relations: vec![],
conjunct_relations: vec![],
primary_keys: vec![PrimaryKey {
name: "id".to_owned(),
}],
},
]
}

Expand All @@ -1144,7 +1179,7 @@ mod tests {
#[test]
fn test_gen_expanded_code_blocks() -> io::Result<()> {
let entities = setup();
const ENTITY_FILES: [&str; 8] = [
const ENTITY_FILES: [&str; 9] = [
include_str!("../../tests/expanded/cake.rs"),
include_str!("../../tests/expanded/cake_filling.rs"),
include_str!("../../tests/expanded/filling.rs"),
Expand All @@ -1153,8 +1188,9 @@ mod tests {
include_str!("../../tests/expanded/rust_keyword.rs"),
include_str!("../../tests/expanded/cake_with_float.rs"),
include_str!("../../tests/expanded/cake_with_double.rs"),
include_str!("../../tests/expanded/collection.rs"),
];
const ENTITY_FILES_WITH_SCHEMA_NAME: [&str; 8] = [
const ENTITY_FILES_WITH_SCHEMA_NAME: [&str; 9] = [
include_str!("../../tests/expanded_with_schema_name/cake.rs"),
include_str!("../../tests/expanded_with_schema_name/cake_filling.rs"),
include_str!("../../tests/expanded_with_schema_name/filling.rs"),
Expand All @@ -1163,6 +1199,7 @@ mod tests {
include_str!("../../tests/expanded_with_schema_name/rust_keyword.rs"),
include_str!("../../tests/expanded_with_schema_name/cake_with_float.rs"),
include_str!("../../tests/expanded_with_schema_name/cake_with_double.rs"),
include_str!("../../tests/expanded_with_schema_name/collection.rs"),
];

assert_eq!(entities.len(), ENTITY_FILES.len());
Expand Down Expand Up @@ -1224,7 +1261,7 @@ mod tests {
#[test]
fn test_gen_compact_code_blocks() -> io::Result<()> {
let entities = setup();
const ENTITY_FILES: [&str; 8] = [
const ENTITY_FILES: [&str; 9] = [
include_str!("../../tests/compact/cake.rs"),
include_str!("../../tests/compact/cake_filling.rs"),
include_str!("../../tests/compact/filling.rs"),
Expand All @@ -1233,8 +1270,9 @@ mod tests {
include_str!("../../tests/compact/rust_keyword.rs"),
include_str!("../../tests/compact/cake_with_float.rs"),
include_str!("../../tests/compact/cake_with_double.rs"),
include_str!("../../tests/compact/collection.rs"),
];
const ENTITY_FILES_WITH_SCHEMA_NAME: [&str; 8] = [
const ENTITY_FILES_WITH_SCHEMA_NAME: [&str; 9] = [
include_str!("../../tests/compact_with_schema_name/cake.rs"),
include_str!("../../tests/compact_with_schema_name/cake_filling.rs"),
include_str!("../../tests/compact_with_schema_name/filling.rs"),
Expand All @@ -1243,6 +1281,7 @@ mod tests {
include_str!("../../tests/compact_with_schema_name/rust_keyword.rs"),
include_str!("../../tests/compact_with_schema_name/cake_with_float.rs"),
include_str!("../../tests/compact_with_schema_name/cake_with_double.rs"),
include_str!("../../tests/compact_with_schema_name/collection.rs"),
];

assert_eq!(entities.len(), ENTITY_FILES.len());
Expand Down
17 changes: 17 additions & 0 deletions sea-orm-codegen/tests/compact/collection.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
//! SeaORM Entity. Generated by sea-orm-codegen 0.10.0
use sea_orm::entity::prelude::*;

#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq)]
#[sea_orm(table_name = "collection")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: i32,
pub integers: Vec<i32> ,
pub integers_opt: Option<Vec<i32> > ,
}

#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}

impl ActiveModelBehavior for ActiveModel {}
Loading

0 comments on commit b5b9790

Please sign in to comment.