-
Notifications
You must be signed in to change notification settings - Fork 102
/
Copy pathembeddings.rs
60 lines (51 loc) · 1.69 KB
/
embeddings.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
use serde::{Deserialize, Serialize};
use crate::Ollama;
use super::options::GenerationOptions;
impl Ollama {
/// Generate embeddings from a model
/// * `model_name` - Name of model to generate embeddings from
/// * `prompt` - Prompt to generate embeddings for
pub async fn generate_embeddings(
&self,
model_name: String,
prompt: String,
options: Option<GenerationOptions>,
) -> crate::error::Result<GenerateEmbeddingsResponse> {
let request = GenerateEmbeddingsRequest {
model_name,
prompt,
options,
};
let uri = format!("{}/api/embeddings", self.uri());
let serialized = serde_json::to_string(&request).map_err(|e| e.to_string())?;
let res = self
.reqwest_client
.post(uri)
.body(serialized)
.send()
.await
.map_err(|e| e.to_string())?;
if !res.status().is_success() {
return Err(res.text().await.unwrap_or_else(|e| e.to_string()).into());
}
let res = res.bytes().await.map_err(|e| e.to_string())?;
let res = serde_json::from_slice::<GenerateEmbeddingsResponse>(&res)
.map_err(|e| e.to_string())?;
Ok(res)
}
}
/// An embeddings generation request to Ollama.
#[derive(Debug, Serialize)]
struct GenerateEmbeddingsRequest {
#[serde(rename = "model")]
model_name: String,
prompt: String,
options: Option<GenerationOptions>,
}
/// An embeddings generation response from Ollama.
#[derive(Debug, Deserialize, Clone)]
pub struct GenerateEmbeddingsResponse {
#[serde(rename = "embedding")]
#[allow(dead_code)]
pub embeddings: Vec<f64>,
}