[multi-vector] Multi-vectors with support for per-row metadata #785
[multi-vector] Multi-vectors with support for per-row metadata #785arkrishn94 wants to merge 10 commits intomainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
Adds a new multi-vector matrix representation whose rows are canonical meta::slice::Slice views (vector elements + per-row metadata), and wires it into MinMax’s asymmetric (full-query vs quantized-doc) distance path.
Changes:
- Introduces
SliceMatRepr<T, M>+Repr/ReprMut/ReprOwned+ constructors for owned/borrowed matrices backed by canonical[u8]layout. - Extends MinMax multi-vector MaxSim/Chamfer to accept arbitrary query
Repr(enabling full-precision query rows with metadata). - Fixes
MatRef::rows()iterator lifetime and addsSliceMut::from_canonical_mut_unchecked.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| diskann-quantization/src/multi_vector/slice.rs | New SliceMatRepr matrix representation for canonical slice rows (vector + per-row metadata), with extensive tests. |
| diskann-quantization/src/multi_vector/mod.rs | Exposes SliceMatRepr and related errors; adds module wiring and docs table update. |
| diskann-quantization/src/multi_vector/matrix.rs | Adjusts MatRef::rows() return lifetime to match the underlying MatRef. |
| diskann-quantization/src/minmax/multi/mod.rs | Updates docs/example to show full-query matrix + asymmetric distances. |
| diskann-quantization/src/minmax/multi/max_sim.rs | Generalizes MinMax MaxSim/Chamfer to generic query Repr; adds FullQueryMatRef alias and multi-vector CompressInto for full-query rows. |
| diskann-quantization/src/minmax/mod.rs | Re-exports FullQueryMatRef. |
| diskann-quantization/src/meta/slice.rs | Adds SliceMut::from_canonical_mut_unchecked and reuses it from the checked constructor. |
Comments suppressed due to low confidence (1)
diskann-quantization/src/minmax/multi/max_sim.rs:67
MinMaxKernel::max_sim_kerneldoesn’t special-case an emptydoc: whendoc.num_vectors() == 0,min_distanceremainsf32::MAXand the callback is invoked with that value for every query row. This makesMaxSim/Chamferbehave very differently from the standardSimpleKernel(which returns early and leaves scores unchanged / chamfer=0). Consider adding an early return whendochas zero rows (or defining/ documenting the intended semantics and adjusting callers/tests accordingly).
for (i, q_ref) in query.rows().enumerate() {
// Find min distance (IP returns negated, so min = max similarity)
let mut min_distance = f32::MAX;
for d_ref in doc.rows() {
let dist = MinMaxIP::evaluate(q_ref, d_ref)?;
min_distance = min_distance.min(dist);
}
f(i, min_distance);
}
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| std::ptr::copy_nonoverlapping(v.ptr.as_ptr(), buffer.as_mut_ptr(), total); | ||
| } | ||
|
|
||
| // SAFETY: `buffer` has the correct length and alignment from above cheks. |
There was a problem hiding this comment.
Typo in comment: "cheks" → "checks".
| // SAFETY: `buffer` has the correct length and alignment from above cheks. | |
| // SAFETY: `buffer` has the correct length and alignment from above checks. |
| /// Returns the cached byte stride per row. | ||
| const fn stride(ncols: usize) -> usize { | ||
| slice::SliceRef::<T, M>::canonical_bytes(ncols) | ||
| } |
There was a problem hiding this comment.
SliceMatRepr uses SliceRef::canonical_bytes(ncols) as the row stride, but canonical_bytes does not include any trailing padding to ensure the next row starts at SliceRef::canonical_align(). For metadata with higher alignment than T (e.g. #[repr(align(8))]), rows after the first can become misaligned, and SliceRef::from_canonical_unchecked will then create misaligned references (UB). Consider defining a padded row stride (e.g. canonical_bytes(ncols).next_multiple_of(canonical_align)) for pointer stepping, while still passing a canonical_bytes-sized subslice to SliceRef/SliceMut for each row; update total_bytes/layout/check_slice accordingly and document the padding.
| pub fn new(nrows: usize, ncols: usize) -> Result<Self, SliceMatReprError> { | ||
| let stride = Self::stride(ncols); | ||
|
|
||
| // Check that total bytes don't overflow or exceed isize::MAX. | ||
| let total = nrows | ||
| .checked_mul(stride) | ||
| .ok_or(SliceMatReprError::Overflow { |
There was a problem hiding this comment.
SliceMatRepr::new relies on SliceRef::canonical_bytes(ncols) for stride, but canonical_bytes computes size_of::<T>() * ncols without checked arithmetic and can overflow usize for large ncols, potentially letting new succeed with an incorrect (wrapped) stride and leading to OOB/UB later. It would be safer to compute the stride with checked operations (or explicitly validate ncols against usize::MAX / size_of::<T>() and addend overflow) before the nrows * stride check.
| fn check_slice(&self, slice: &[u8]) -> Result<(), SliceMatError> { | ||
| let expected = self.total_bytes(); | ||
| if slice.len() != expected { | ||
| return Err(SliceMatError::LengthMismatch { | ||
| expected, | ||
| found: slice.len(), | ||
| }); | ||
| } | ||
|
|
||
| let align = Self::alignment().raw(); | ||
| if !(slice.as_ptr() as usize).is_multiple_of(align) { | ||
| return Err(SliceMatError::NotAligned { expected: align }); | ||
| } |
There was a problem hiding this comment.
check_slice enforces base-pointer alignment even when expected == 0. For empty matrices (e.g. nrows == 0), callers may reasonably pass an empty &[u8]/&mut [u8] that has a dangling pointer not meeting canonical_align(), causing MatRef::new/MatMut::new to fail even though no memory will be dereferenced. Consider treating expected == 0 as a special case and skipping the alignment check (or accepting any pointer) so zero-sized matrices can be constructed from empty slices.
| /// | ||
| /// This invariant is checked in debug builds and will panic if not satisfied. | ||
| pub unsafe fn from_canonical_mut_unchecked(data: &'a mut [u8], dim: usize) -> Self { | ||
| debug_assert_eq!(data.len(), Self::canonical_bytes(dim)); |
There was a problem hiding this comment.
from_canonical_mut_unchecked docs state that both length and alignment invariants are checked in debug builds, but the implementation only debug_assert_eq!s the length. Consider adding a debug_assert! on data.as_ptr() alignment to Self::canonical_align() (or adjusting the docs) so the safety contract matches behavior and misaligned callers are caught earlier.
| debug_assert_eq!(data.len(), Self::canonical_bytes(dim)); | |
| let expected_align = Self::canonical_align().raw(); | |
| let expected_len = Self::canonical_bytes(dim); | |
| debug_assert!((data.as_ptr() as usize).is_multiple_of(expected_align)); | |
| debug_assert_eq!(data.len(), expected_len); |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #785 +/- ##
==========================================
+ Coverage 89.00% 89.05% +0.04%
==========================================
Files 428 429 +1
Lines 78417 79194 +777
==========================================
+ Hits 69795 70525 +730
- Misses 8622 8669 +47
Flags with carried forward coverage won't be shown. Click here to find out more.
🚀 New features to boost your workflow:
|
This change adds support for multi-vectors where every vector is a full-precision-like vector with some metadata. It builds on the
Sliceimplementation that supports the owned, referenced and mutable referenced versions of these vectors.Primarily this is done by introducing a new struct
SliceMatRepr<T, M>(defined for any primitive datatypeT: Podand metadataM : Pod) and providing implementations ofReprand its owned and mut variants for it. This is the bulk of the change in this PR.This matrix type has been used to support computing max sim and chamfer distance between full precision queries and minmax quantized vectors.
Other changes -
from_canonical_mut_uncheckedtoSliceMutfor constructing from raw data.rowsmethod ofMatRefto ensure the caller is guaranteed that the row yielded by the iterator lives as long as the underlyingMatRef.