diff --git a/python/fast_plaid/search/fast_plaid.py b/python/fast_plaid/search/fast_plaid.py index 8fdfd0e..57ea819 100644 --- a/python/fast_plaid/search/fast_plaid.py +++ b/python/fast_plaid/search/fast_plaid.py @@ -115,7 +115,7 @@ def compute_kmeans( # noqa: PLR0913 def search_on_device( # noqa: PLR0913 device: str, - queries_embeddings: torch.Tensor, + queries_embeddings: list[torch.Tensor], batch_size: int, n_full_scores: int, top_k: int, @@ -154,6 +154,18 @@ def search_on_device( # noqa: PLR0913 ] +def cleanup_embeddings( + embeddings: list[torch.Tensor] | torch.Tensor, +) -> list[torch.Tensor]: + """Convert embeddings to a list and remove extra dimensions.""" + if isinstance(embeddings, torch.Tensor): + embeddings = [embeddings[i] for i in range(embeddings.shape[0])] + return [ + embedding.squeeze(0) if embedding.dim() == 3 else embedding + for embedding in embeddings + ] + + class FastPlaid: """A class for creating and searching a FastPlaid index. @@ -288,15 +300,7 @@ def create( # noqa: PLR0913 Optional list of dictionaries containing metadata for each document. """ - if isinstance(documents_embeddings, torch.Tensor): - documents_embeddings = [ - documents_embeddings[i] for i in range(documents_embeddings.shape[0]) - ] - - documents_embeddings = [ - embedding.squeeze(0) if embedding.dim() == 3 else embedding - for embedding in documents_embeddings - ] + documents_embeddings = cleanup_embeddings(documents_embeddings) num_docs = len(documents_embeddings) self._prepare_index_directory(index_path=self.index) @@ -473,17 +477,8 @@ def search( # noqa: PLR0913, C901, PLR0912, PLR0915 corresponding inner list. """ - if isinstance(queries_embeddings, list): - queries_embeddings = torch.nn.utils.rnn.pad_sequence( - sequences=[ - embedding[0] if embedding.dim() == 3 else embedding - for embedding in queries_embeddings - ], - batch_first=True, - padding_value=0.0, - ) - - num_queries = queries_embeddings.shape[0] + queries_embeddings = cleanup_embeddings(queries_embeddings) + num_queries = len(queries_embeddings) if subset is not None: if isinstance(subset, int): @@ -529,16 +524,15 @@ def search( # noqa: PLR0913, C901, PLR0912, PLR0915 num_cpus = len(self.devices) # Use torch.chunk to split the tensor into num_cpus - queries_embeddings_splits = torch.chunk( - input=queries_embeddings, - chunks=num_cpus, - dim=0, - ) + queries_embeddings_splits = [ + queries_embeddings[i : i + num_cpus] + for i in range(0, num_queries, num_cpus) + ] # Filter out empty chunks that torch.chunk might create # if num_queries < num_cpus non_empty_splits = [ - split for split in queries_embeddings_splits if split.shape[0] > 0 + split for split in queries_embeddings_splits if len(split) > 0 ] num_splits = len(non_empty_splits) @@ -548,7 +542,7 @@ def search( # noqa: PLR0913, C901, PLR0912, PLR0915 if subset is not None: current_idx = 0 for split in non_empty_splits: - size = split.shape[0] + size = len(split) subset_splits.append(subset[current_idx : current_idx + size]) # type: ignore current_idx += size @@ -600,16 +594,16 @@ def search( # noqa: PLR0913, C901, PLR0912, PLR0915 subset=subset, # type: ignore ) - queries_embeddings_splits = torch.split( - tensor=queries_embeddings, - split_size_or_sections=len(self.devices), - ) + queries_embeddings_splits = [ + queries_embeddings[i : i + len(self.devices)] + for i in range(0, num_queries, len(self.devices)) + ] num_splits = len(queries_embeddings_splits) if subset is not None: current_idx = 0 for split in queries_embeddings_splits: - size = split.shape[0] + size = len(split) subset_splits.append(subset[current_idx : current_idx + size]) # type: ignore current_idx += size else: diff --git a/rust/lib.rs b/rust/lib.rs index 748a83e..91f54e9 100644 --- a/rust/lib.rs +++ b/rust/lib.rs @@ -376,7 +376,7 @@ fn load_and_search( index: String, torch_path: String, device: String, - queries_embeddings: PyTensor, + queries_embeddings: Vec, search_parameters: &SearchParameters, show_progress: bool, preload_index: bool, @@ -397,9 +397,14 @@ fn load_and_search( Ok(Arc::new(loaded_index)) }?; + let queries_embeddings: Vec<_> = queries_embeddings + .into_iter() + .map(|tensor| tensor.to_kind(Kind::Half)) + .collect(); + // Perform the search let results = search_many( - &queries_embeddings.to_kind(Kind::Half), + &queries_embeddings, &index, search_parameters, device_tch, diff --git a/rust/search/search.rs b/rust/search/search.rs index 36051ed..b59e049 100644 --- a/rust/search/search.rs +++ b/rust/search/search.rs @@ -1,8 +1,8 @@ -use anyhow::{anyhow, bail, Result}; +use anyhow::{anyhow, Result}; use indicatif::{ProgressBar, ProgressIterator}; use pyo3::prelude::*; use serde::Serialize; -use tch::{Device, IndexOp, Kind, Tensor}; +use tch::{Device, Kind, Tensor}; use crate::search::load::LoadedIndex; use crate::search::padding::direct_pad_sequences; @@ -165,22 +165,21 @@ impl SearchParameters { /// A `Result` with a `Vec`. Individual search failures result in an empty /// `QueryResult` for that specific query, ensuring the operation doesn't halt. pub fn search_many( - queries: &Tensor, + queries: &Vec, index: &LoadedIndex, params: &SearchParameters, device: Device, show_progress: bool, subset: Option>>, ) -> Result> { - let [num_queries, _, query_dim] = queries.size()[..] else { - bail!( - "Expected a 3D tensor for queries, but got shape {:?}", - queries.size() - ); - }; + let num_queries = queries.len(); + if num_queries == 0 { + return Ok(Vec::new()); + } + let query_dim = queries[0].size()[queries[0].dim() - 1]; - let search_closure = |query_index| { - let query_embedding = queries.i(query_index).to(device); + let search_closure = |query_index: usize| { + let query_embedding = &queries[query_index].to(device); // Handle the per-query subset list let query_subset = subset.as_ref().and_then(|s| s.get(query_index as usize));