Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support PgVector #2500

Merged
merged 15 commits into from
Feb 18, 2025
2 changes: 1 addition & 1 deletion .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ jobs:
tls: [native-tls]
services:
postgres:
image: postgres:${{ matrix.version }}
image: pgvector/pgvector:pg${{ matrix.version }}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does that mean we can't use the official image?

env:
POSTGRES_HOST: 127.0.0.1
POSTGRES_USER: root
Expand Down
6 changes: 4 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ serde = { version = "1.0", default-features = false }
serde_json = { version = "1.0", default-features = false, optional = true }
sqlx = { version = "0.8.2", default-features = false, optional = true }
uuid = { version = "1", default-features = false, optional = true }
pgvector = { version = "~0.4", default-features = false, optional = true }
ouroboros = { version = "0.18", default-features = false }
url = { version = "2.2", default-features = false }
thiserror = { version = "1", default-features = false }
Expand All @@ -54,7 +55,7 @@ tokio = { version = "1.6", features = ["full"] }
actix-rt = { version = "2.2.0" }
maplit = { version = "1" }
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
sea-orm = { path = ".", features = ["mock", "debug-print", "tests-cfg", "postgres-array", "sea-orm-internal"] }
sea-orm = { path = ".", features = ["mock", "debug-print", "tests-cfg", "postgres-array", "postgres-vector", "sea-orm-internal"] }
pretty_assertions = { version = "0.7" }
time = { version = "0.3.36", features = ["macros"] }
uuid = { version = "1", features = ["v4"] }
Expand All @@ -76,13 +77,14 @@ default = [
macros = ["sea-orm-macros/derive"]
mock = []
proxy = ["serde_json", "serde/derive"]
with-json = ["serde_json", "sea-query/with-json", "chrono?/serde", "rust_decimal?/serde", "bigdecimal?/serde", "uuid?/serde", "time?/serde", "sea-query-binder?/with-json", "sqlx?/json"]
with-json = ["serde_json", "sea-query/with-json", "chrono?/serde", "rust_decimal?/serde", "bigdecimal?/serde", "uuid?/serde", "time?/serde", "pgvector?/serde", "sea-query-binder?/with-json", "sqlx?/json"]
with-chrono = ["chrono", "sea-query/with-chrono", "sea-query-binder?/with-chrono", "sqlx?/chrono"]
with-rust_decimal = ["rust_decimal", "sea-query/with-rust_decimal", "sea-query-binder?/with-rust_decimal", "sqlx?/rust_decimal"]
with-bigdecimal = ["bigdecimal", "sea-query/with-bigdecimal", "sea-query-binder?/with-bigdecimal", "sqlx?/bigdecimal"]
with-uuid = ["uuid", "sea-query/with-uuid", "sea-query-binder?/with-uuid", "sqlx?/uuid"]
with-time = ["time", "sea-query/with-time", "sea-query-binder?/with-time", "sqlx?/time"]
postgres-array = ["sea-query/postgres-array", "sea-query-binder?/postgres-array", "sea-orm-macros/postgres-array"]
postgres-vector = ["pgvector", "sea-query/postgres-vector", "sea-query-binder?/postgres-vector"]
json-array = ["postgres-array"] # this does not actually enable sqlx-postgres, but only a few traits to support array in sea-query
sea-orm-internal = []
sqlx-dep = []
Expand Down
7 changes: 7 additions & 0 deletions build-tools/docker-create.sh
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ docker stop "mysql-5.7"

# Setup PostgreSQL

docker run \
--name "postgres-vector-14" \
--env POSTGRES_USER="root" \
--env POSTGRES_PASSWORD="root" \
-d -p 5432:5432 pgvector/pgvector:pg14
docker stop "postgres-vector-14"

docker run \
--name "postgres-13" \
--env POSTGRES_USER="root" \
Expand Down
1 change: 1 addition & 0 deletions sea-orm-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ smol = "1.2.5"
default = ["codegen", "cli", "runtime-async-std-native-tls", "async-std"]
codegen = ["sea-schema/sqlx-all", "sea-orm-codegen"]
cli = ["clap", "dotenvy"]
postgres-vector = ["sea-schema/postgres-vector"]
runtime-actix = ["sqlx/runtime-tokio", "sea-schema/runtime-actix"]
runtime-async-std = ["sqlx/runtime-async-std", "sea-schema/runtime-async-std"]
runtime-tokio = ["sqlx/runtime-tokio", "sea-schema/runtime-tokio"]
Expand Down
1 change: 1 addition & 0 deletions sea-orm-codegen/src/entity/base_entity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ impl Entity {
match col_type {
ColumnType::Float | ColumnType::Double => true,
ColumnType::Array(col_type) => is_floats(col_type),
ColumnType::Vector(_) => true,
_ => false,
}
}
Expand Down
5 changes: 5 additions & 0 deletions sea-orm-codegen/src/entity/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ impl Column {
ColumnType::Array(column_type) => {
format!("Vec<{}>", write_rs_type(column_type, date_time_crate))
}
ColumnType::Vector(_) => "::pgvector::Vector".to_owned(),
ColumnType::Bit(None | Some(1)) => "bool".to_owned(),
ColumnType::Bit(_) | ColumnType::VarBit(_) => "Vec<u8>".to_owned(),
ColumnType::Year => "i32".to_owned(),
Expand Down Expand Up @@ -180,6 +181,10 @@ impl Column {
let column_type = write_col_def(column_type);
quote! { ColumnType::Array(RcOrArc::new(#column_type)) }
}
ColumnType::Vector(size) => match size {
Some(size) => quote! { ColumnType::Vector(Some(#size)) },
None => quote! { ColumnType::Vector(None) },
},
#[allow(unreachable_patterns)]
_ => unimplemented!(),
}
Expand Down
3 changes: 3 additions & 0 deletions src/entity/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,6 @@ pub use bigdecimal::BigDecimal;

#[cfg(feature = "with-uuid")]
pub use uuid::Uuid;

#[cfg(feature = "postgres-vector")]
pub use pgvector::Vector as PgVector;
36 changes: 36 additions & 0 deletions src/executor/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -953,6 +953,42 @@ mod postgres_array {
}
}
}

#[cfg(feature = "postgres-vector")]
impl TryGetable for pgvector::Vector {
Comment on lines +957 to +958
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we nesting this under postgres-array?

#[allow(unused_variables)]
fn try_get_by<I: ColIdx>(res: &QueryResult, idx: I) -> Result<Self, TryGetError> {
match &res.row {
#[cfg(feature = "sqlx-mysql")]
QueryResultRow::SqlxMySql(_) => {
Err(type_err("Vector unsupported by sqlx-mysql").into())
}
#[cfg(feature = "sqlx-postgres")]
QueryResultRow::SqlxPostgres(row) => row
.try_get::<Option<pgvector::Vector>, _>(idx.as_sqlx_postgres_index())
.map_err(|e| sqlx_error_to_query_err(e).into())
.and_then(|opt| opt.ok_or_else(|| err_null_idx_col(idx))),
#[cfg(feature = "sqlx-sqlite")]
QueryResultRow::SqlxSqlite(_) => {
Err(type_err("Vector unsupported by sqlx-sqlite").into())
}
#[cfg(feature = "mock")]
QueryResultRow::Mock(row) => row.try_get::<pgvector::Vector, _>(idx).map_err(|e| {
debug_print!("{:#?}", e.to_string());
err_null_idx_col(idx)
}),
#[cfg(feature = "proxy")]
QueryResultRow::Proxy(row) => {
row.try_get::<pgvector::Vector, _>(idx).map_err(|e| {
debug_print!("{:#?}", e.to_string());
err_null_idx_col(idx)
})
}
#[allow(unreachable_patterns)]
_ => unreachable!(),
}
}
}
}

// TryGetableMany //
Expand Down
2 changes: 2 additions & 0 deletions src/query/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ impl FromQueryResult for JsonValue {
try_get_type!(String, col);
#[cfg(feature = "postgres-array")]
try_get_type!(Vec<String>, col);
#[cfg(feature = "postgres-vector")]
try_get_type!(pgvector::Vector, col);
#[cfg(feature = "with-uuid")]
try_get_type!(uuid::Uuid, col);
#[cfg(all(feature = "with-uuid", feature = "postgres-array"))]
Expand Down
15 changes: 15 additions & 0 deletions tests/common/features/embedding.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
use super::sea_orm_active_enums::*;
use sea_orm::entity::prelude::*;

#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
#[sea_orm(table_name = "embedding")]
pub struct Model {
#[sea_orm(primary_key, auto_increment = false)]
pub id: i32,
pub embedding: PgVector,
}

#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}

impl ActiveModelBehavior for ActiveModel {}
2 changes: 2 additions & 0 deletions tests/common/features/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub mod collection_expanded;
pub mod custom_active_model;
pub mod dyn_table_name_lazy_static;
pub mod edit_log;
pub mod embedding;
pub mod event_trigger;
pub mod insert_default;
pub mod json_struct;
Expand Down Expand Up @@ -40,6 +41,7 @@ pub use collection::Entity as Collection;
pub use collection_expanded::Entity as CollectionExpanded;
pub use dyn_table_name_lazy_static::Entity as DynTableNameLazyStatic;
pub use edit_log::Entity as EditLog;
pub use embedding::Entity as Embedding;
pub use event_trigger::Entity as EventTrigger;
pub use insert_default::Entity as InsertDefault;
pub use json_struct::Entity as JsonStruct;
Expand Down
26 changes: 26 additions & 0 deletions tests/common/features/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ pub async fn create_tables(db: &DatabaseConnection) -> Result<(), DbErr> {
create_collection_table(db).await?;
create_event_trigger_table(db).await?;
create_categories_table(db).await?;
create_embedding_table(db).await?;
}

Ok(())
Expand Down Expand Up @@ -601,6 +602,31 @@ pub async fn create_categories_table(db: &DbConn) -> Result<ExecResult, DbErr> {
create_table(db, &create_table_stmt, Categories).await
}

pub async fn create_embedding_table(db: &DbConn) -> Result<ExecResult, DbErr> {
db.execute(sea_orm::Statement::from_string(
db.get_database_backend(),
"CREATE EXTENSION IF NOT EXISTS vector",
))
.await?;

let create_table_stmt = sea_query::Table::create()
.table(embedding::Entity.table_ref())
.col(
ColumnDef::new(embedding::Column::Id)
.integer()
.not_null()
.primary_key(),
)
.col(
ColumnDef::new(embedding::Column::Embedding)
.vector(None)
.not_null(),
)
.to_owned();

create_table(db, &create_table_stmt, Embedding).await
}

pub async fn create_binary_table(db: &DbConn) -> Result<ExecResult, DbErr> {
let create_table_stmt = sea_query::Table::create()
.table(binary::Entity.table_ref())
Expand Down
125 changes: 125 additions & 0 deletions tests/embedding_tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
#![allow(unused_imports, dead_code)]

pub mod common;

pub use common::{features::*, setup::*, TestContext};
use pretty_assertions::assert_eq;
use sea_orm::{
entity::prelude::*, entity::*, DatabaseConnection, DerivePartialModel, FromQueryResult,
};
use serde_json::json;

#[sea_orm_macros::test]
#[cfg(all(feature = "sqlx-postgres", feature = "postgres-vector"))]
async fn main() -> Result<(), DbErr> {
let ctx = TestContext::new("embedding_tests").await;
create_tables(&ctx.db).await?;
insert_embedding(&ctx.db).await?;
update_embedding(&ctx.db).await?;
select_embedding(&ctx.db).await?;
ctx.delete().await;

Ok(())
}

pub async fn insert_embedding(db: &DatabaseConnection) -> Result<(), DbErr> {
use embedding::*;

assert_eq!(
Model {
id: 1,
embedding: PgVector::from(vec![1.]),
}
.into_active_model()
.insert(db)
.await?,
Model {
id: 1,
embedding: PgVector::from(vec![1.]),
}
);

assert_eq!(
Model {
id: 2,
embedding: PgVector::from(vec![1., 2.]),
}
.into_active_model()
.insert(db)
.await?,
Model {
id: 2,
embedding: PgVector::from(vec![1., 2.]),
}
);

assert_eq!(
Model {
id: 3,
embedding: PgVector::from(vec![1., 2., 3.]),
}
.into_active_model()
.insert(db)
.await?,
Model {
id: 3,
embedding: PgVector::from(vec![1., 2., 3.]),
}
);

assert_eq!(
Entity::find_by_id(3).into_json().one(db).await?,
Some(json!({
"id": 3,
"embedding": [1., 2., 3.],
}))
);

Ok(())
}

pub async fn update_embedding(db: &DatabaseConnection) -> Result<(), DbErr> {
use embedding::*;

let model = Entity::find_by_id(1).one(db).await?.unwrap();

ActiveModel {
embedding: Set(PgVector::from(vec![10.])),
..model.into_active_model()
}
.update(db)
.await?;

ActiveModel {
id: Unchanged(3),
embedding: Set(PgVector::from(vec![10., 20., 30.])),
}
.update(db)
.await?;

Ok(())
}

pub async fn select_embedding(db: &DatabaseConnection) -> Result<(), DbErr> {
use embedding::*;

#[derive(DerivePartialModel, FromQueryResult, Debug, PartialEq)]
#[sea_orm(entity = "Entity")]
struct PartialSelectResult {
embedding: PgVector,
}

let result = Entity::find_by_id(1)
.into_partial_model::<PartialSelectResult>()
.one(db)
.await?;

assert_eq!(
result,
Some(PartialSelectResult {
embedding: PgVector::from(vec![10.]),
})
);

Ok(())
}
Loading