Skip to content

Commit

Permalink
Support PgVector (#2500)
Browse files Browse the repository at this point in the history
* change sea-query to git

* add column type

* change to fork

* vector

* add feature flag

* trygetable for vector

* fix eq for vector

* add size to vector

* fix: pgvector version

* Update Cargo.toml

* Support PgVector

* Use `pgvector/pgvector` docker image

* Apply suggestions from code review

* Fixup

* Update tests/embedding_tests.rs

---------

Co-authored-by: Leon Camus <[email protected]>
Co-authored-by: Leon Camus <[email protected]>
Co-authored-by: Chris Tsang <[email protected]>
  • Loading branch information
4 people authored Feb 18, 2025
1 parent f5dab25 commit ce69458
Show file tree
Hide file tree
Showing 13 changed files with 252 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ jobs:
tls: [native-tls]
services:
postgres:
image: postgres:${{ matrix.version }}
image: pgvector/pgvector:pg${{ matrix.version }}
env:
POSTGRES_HOST: 127.0.0.1
POSTGRES_USER: root
Expand Down
6 changes: 4 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ serde = { version = "1.0", default-features = false }
serde_json = { version = "1.0", default-features = false, optional = true }
sqlx = { version = "0.8.2", default-features = false, optional = true }
uuid = { version = "1", default-features = false, optional = true }
pgvector = { version = "~0.4", default-features = false, optional = true }
ouroboros = { version = "0.18", default-features = false }
url = { version = "2.2", default-features = false }
thiserror = { version = "1", default-features = false }
Expand All @@ -54,7 +55,7 @@ tokio = { version = "1.6", features = ["full"] }
actix-rt = { version = "2.2.0" }
maplit = { version = "1" }
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
sea-orm = { path = ".", features = ["mock", "debug-print", "tests-cfg", "postgres-array", "sea-orm-internal"] }
sea-orm = { path = ".", features = ["mock", "debug-print", "tests-cfg", "postgres-array", "postgres-vector", "sea-orm-internal"] }
pretty_assertions = { version = "0.7" }
time = { version = "0.3.36", features = ["macros"] }
uuid = { version = "1", features = ["v4"] }
Expand All @@ -76,13 +77,14 @@ default = [
macros = ["sea-orm-macros/derive"]
mock = []
proxy = ["serde_json", "serde/derive"]
with-json = ["serde_json", "sea-query/with-json", "chrono?/serde", "rust_decimal?/serde", "bigdecimal?/serde", "uuid?/serde", "time?/serde", "sea-query-binder?/with-json", "sqlx?/json"]
with-json = ["serde_json", "sea-query/with-json", "chrono?/serde", "rust_decimal?/serde", "bigdecimal?/serde", "uuid?/serde", "time?/serde", "pgvector?/serde", "sea-query-binder?/with-json", "sqlx?/json"]
with-chrono = ["chrono", "sea-query/with-chrono", "sea-query-binder?/with-chrono", "sqlx?/chrono"]
with-rust_decimal = ["rust_decimal", "sea-query/with-rust_decimal", "sea-query-binder?/with-rust_decimal", "sqlx?/rust_decimal"]
with-bigdecimal = ["bigdecimal", "sea-query/with-bigdecimal", "sea-query-binder?/with-bigdecimal", "sqlx?/bigdecimal"]
with-uuid = ["uuid", "sea-query/with-uuid", "sea-query-binder?/with-uuid", "sqlx?/uuid"]
with-time = ["time", "sea-query/with-time", "sea-query-binder?/with-time", "sqlx?/time"]
postgres-array = ["sea-query/postgres-array", "sea-query-binder?/postgres-array", "sea-orm-macros/postgres-array"]
postgres-vector = ["pgvector", "sea-query/postgres-vector", "sea-query-binder?/postgres-vector"]
json-array = ["postgres-array"] # this does not actually enable sqlx-postgres, but only a few traits to support array in sea-query
sea-orm-internal = []
sqlx-dep = []
Expand Down
7 changes: 7 additions & 0 deletions build-tools/docker-create.sh
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ docker stop "mysql-5.7"

# Setup PostgreSQL

docker run \
--name "postgres-vector-14" \
--env POSTGRES_USER="root" \
--env POSTGRES_PASSWORD="root" \
-d -p 5432:5432 pgvector/pgvector:pg14
docker stop "postgres-vector-14"

docker run \
--name "postgres-13" \
--env POSTGRES_USER="root" \
Expand Down
1 change: 1 addition & 0 deletions sea-orm-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ smol = "1.2.5"
default = ["codegen", "cli", "runtime-async-std-native-tls", "async-std"]
codegen = ["sea-schema/sqlx-all", "sea-orm-codegen"]
cli = ["clap", "dotenvy"]
postgres-vector = ["sea-schema/postgres-vector"]
runtime-actix = ["sqlx/runtime-tokio", "sea-schema/runtime-actix"]
runtime-async-std = ["sqlx/runtime-async-std", "sea-schema/runtime-async-std"]
runtime-tokio = ["sqlx/runtime-tokio", "sea-schema/runtime-tokio"]
Expand Down
1 change: 1 addition & 0 deletions sea-orm-codegen/src/entity/base_entity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ impl Entity {
match col_type {
ColumnType::Float | ColumnType::Double => true,
ColumnType::Array(col_type) => is_floats(col_type),
ColumnType::Vector(_) => true,
_ => false,
}
}
Expand Down
5 changes: 5 additions & 0 deletions sea-orm-codegen/src/entity/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ impl Column {
ColumnType::Array(column_type) => {
format!("Vec<{}>", write_rs_type(column_type, date_time_crate))
}
ColumnType::Vector(_) => "::pgvector::Vector".to_owned(),
ColumnType::Bit(None | Some(1)) => "bool".to_owned(),
ColumnType::Bit(_) | ColumnType::VarBit(_) => "Vec<u8>".to_owned(),
ColumnType::Year => "i32".to_owned(),
Expand Down Expand Up @@ -180,6 +181,10 @@ impl Column {
let column_type = write_col_def(column_type);
quote! { ColumnType::Array(RcOrArc::new(#column_type)) }
}
ColumnType::Vector(size) => match size {
Some(size) => quote! { ColumnType::Vector(Some(#size)) },
None => quote! { ColumnType::Vector(None) },
},
#[allow(unreachable_patterns)]
_ => unimplemented!(),
}
Expand Down
3 changes: 3 additions & 0 deletions src/entity/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,6 @@ pub use bigdecimal::BigDecimal;

#[cfg(feature = "with-uuid")]
pub use uuid::Uuid;

#[cfg(feature = "postgres-vector")]
pub use pgvector::Vector as PgVector;
36 changes: 36 additions & 0 deletions src/executor/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -953,6 +953,42 @@ mod postgres_array {
}
}
}

#[cfg(feature = "postgres-vector")]
impl TryGetable for pgvector::Vector {
#[allow(unused_variables)]
fn try_get_by<I: ColIdx>(res: &QueryResult, idx: I) -> Result<Self, TryGetError> {
match &res.row {
#[cfg(feature = "sqlx-mysql")]
QueryResultRow::SqlxMySql(_) => {
Err(type_err("Vector unsupported by sqlx-mysql").into())
}
#[cfg(feature = "sqlx-postgres")]
QueryResultRow::SqlxPostgres(row) => row
.try_get::<Option<pgvector::Vector>, _>(idx.as_sqlx_postgres_index())
.map_err(|e| sqlx_error_to_query_err(e).into())
.and_then(|opt| opt.ok_or_else(|| err_null_idx_col(idx))),
#[cfg(feature = "sqlx-sqlite")]
QueryResultRow::SqlxSqlite(_) => {
Err(type_err("Vector unsupported by sqlx-sqlite").into())
}
#[cfg(feature = "mock")]
QueryResultRow::Mock(row) => row.try_get::<pgvector::Vector, _>(idx).map_err(|e| {
debug_print!("{:#?}", e.to_string());
err_null_idx_col(idx)
}),
#[cfg(feature = "proxy")]
QueryResultRow::Proxy(row) => {
row.try_get::<pgvector::Vector, _>(idx).map_err(|e| {
debug_print!("{:#?}", e.to_string());
err_null_idx_col(idx)
})
}
#[allow(unreachable_patterns)]
_ => unreachable!(),
}
}
}
}

// TryGetableMany //
Expand Down
2 changes: 2 additions & 0 deletions src/query/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ impl FromQueryResult for JsonValue {
try_get_type!(String, col);
#[cfg(feature = "postgres-array")]
try_get_type!(Vec<String>, col);
#[cfg(feature = "postgres-vector")]
try_get_type!(pgvector::Vector, col);
#[cfg(feature = "with-uuid")]
try_get_type!(uuid::Uuid, col);
#[cfg(all(feature = "with-uuid", feature = "postgres-array"))]
Expand Down
15 changes: 15 additions & 0 deletions tests/common/features/embedding.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
use super::sea_orm_active_enums::*;
use sea_orm::entity::prelude::*;

#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
#[sea_orm(table_name = "embedding")]
pub struct Model {
#[sea_orm(primary_key, auto_increment = false)]
pub id: i32,
pub embedding: PgVector,
}

#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}

impl ActiveModelBehavior for ActiveModel {}
2 changes: 2 additions & 0 deletions tests/common/features/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub mod collection_expanded;
pub mod custom_active_model;
pub mod dyn_table_name_lazy_static;
pub mod edit_log;
pub mod embedding;
pub mod event_trigger;
pub mod insert_default;
pub mod json_struct;
Expand Down Expand Up @@ -40,6 +41,7 @@ pub use collection::Entity as Collection;
pub use collection_expanded::Entity as CollectionExpanded;
pub use dyn_table_name_lazy_static::Entity as DynTableNameLazyStatic;
pub use edit_log::Entity as EditLog;
pub use embedding::Entity as Embedding;
pub use event_trigger::Entity as EventTrigger;
pub use insert_default::Entity as InsertDefault;
pub use json_struct::Entity as JsonStruct;
Expand Down
26 changes: 26 additions & 0 deletions tests/common/features/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ pub async fn create_tables(db: &DatabaseConnection) -> Result<(), DbErr> {
create_collection_table(db).await?;
create_event_trigger_table(db).await?;
create_categories_table(db).await?;
create_embedding_table(db).await?;
}

Ok(())
Expand Down Expand Up @@ -601,6 +602,31 @@ pub async fn create_categories_table(db: &DbConn) -> Result<ExecResult, DbErr> {
create_table(db, &create_table_stmt, Categories).await
}

pub async fn create_embedding_table(db: &DbConn) -> Result<ExecResult, DbErr> {
db.execute(sea_orm::Statement::from_string(
db.get_database_backend(),
"CREATE EXTENSION IF NOT EXISTS vector",
))
.await?;

let create_table_stmt = sea_query::Table::create()
.table(embedding::Entity.table_ref())
.col(
ColumnDef::new(embedding::Column::Id)
.integer()
.not_null()
.primary_key(),
)
.col(
ColumnDef::new(embedding::Column::Embedding)
.vector(None)
.not_null(),
)
.to_owned();

create_table(db, &create_table_stmt, Embedding).await
}

pub async fn create_binary_table(db: &DbConn) -> Result<ExecResult, DbErr> {
let create_table_stmt = sea_query::Table::create()
.table(binary::Entity.table_ref())
Expand Down
149 changes: 149 additions & 0 deletions tests/embedding_tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
#![allow(unused_imports, dead_code)]

pub mod common;

pub use common::{features::*, setup::*, TestContext};
use pretty_assertions::assert_eq;
use sea_orm::{
entity::prelude::*, entity::*, DatabaseConnection, DerivePartialModel, FromQueryResult,
};
use serde_json::json;

#[sea_orm_macros::test]
#[cfg(all(feature = "sqlx-postgres", feature = "postgres-vector"))]
async fn main() -> Result<(), DbErr> {
let ctx = TestContext::new("embedding_tests").await;
create_tables(&ctx.db).await?;
insert_embedding(&ctx.db).await?;
update_embedding(&ctx.db).await?;
select_embedding(&ctx.db).await?;
ctx.delete().await;

Ok(())
}

pub async fn insert_embedding(db: &DatabaseConnection) -> Result<(), DbErr> {
use embedding::*;

assert_eq!(
Model {
id: 1,
embedding: PgVector::from(vec![1.]),
}
.into_active_model()
.insert(db)
.await?,
Model {
id: 1,
embedding: PgVector::from(vec![1.]),
}
);

assert_eq!(
Model {
id: 2,
embedding: PgVector::from(vec![1., 2.]),
}
.into_active_model()
.insert(db)
.await?,
Model {
id: 2,
embedding: PgVector::from(vec![1., 2.]),
}
);

assert_eq!(
Model {
id: 3,
embedding: PgVector::from(vec![1., 2., 3.]),
}
.into_active_model()
.insert(db)
.await?,
Model {
id: 3,
embedding: PgVector::from(vec![1., 2., 3.]),
}
);

assert_eq!(
Entity::find_by_id(3).into_json().one(db).await?,
Some(json!({
"id": 3,
"embedding": [1., 2., 3.],
}))
);

Ok(())
}

pub async fn update_embedding(db: &DatabaseConnection) -> Result<(), DbErr> {
use embedding::*;

let model = Entity::find_by_id(1).one(db).await?.unwrap();

ActiveModel {
embedding: Set(PgVector::from(vec![10.])),
..model.into_active_model()
}
.update(db)
.await?;

ActiveModel {
id: Unchanged(3),
embedding: Set(PgVector::from(vec![10., 20., 30.])),
}
.update(db)
.await?;

Ok(())
}

pub async fn select_embedding(db: &DatabaseConnection) -> Result<(), DbErr> {
use embedding::*;

#[derive(DerivePartialModel, FromQueryResult, Debug, PartialEq)]
#[sea_orm(entity = "Entity")]
struct PartialSelectResult {
embedding: PgVector,
}

let result = Entity::find_by_id(1)
.into_partial_model::<PartialSelectResult>()
.one(db)
.await?;

assert_eq!(
result,
Some(PartialSelectResult {
embedding: PgVector::from(vec![10.]),
})
);

let result = Entity::find_by_id(2)
.into_partial_model::<PartialSelectResult>()
.one(db)
.await?;

assert_eq!(
result,
Some(PartialSelectResult {
embedding: PgVector::from(vec![1., 2.]),
})
);

let result = Entity::find_by_id(3)
.into_partial_model::<PartialSelectResult>()
.one(db)
.await?;

assert_eq!(
result,
Some(PartialSelectResult {
embedding: PgVector::from(vec![10., 20., 30.]),
})
);

Ok(())
}

0 comments on commit ce69458

Please sign in to comment.