Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,77 @@ mod tests {
}
}

#[test]
fn test_new_checkpoint_creates_fresh_record() -> ANNResult<()> {
let temp_dir = tempdir()?;
let index_prefix = temp_dir
.path()
.join("fresh_index")
.to_str()
.unwrap()
.to_string();
// Two managers with the same prefix+identifier should see the same checkpoint state
let manager_a = CheckpointRecordManagerWithFileStorage::new(&index_prefix, 42);
let manager_b = CheckpointRecordManagerWithFileStorage::new(&index_prefix, 42);
assert_eq!(
manager_a.get_resumption_point(WorkStage::Start)?,
manager_b.get_resumption_point(WorkStage::Start)?
);
// A different identifier should be independent
let manager_c = CheckpointRecordManagerWithFileStorage::new(&index_prefix, 99);
assert!(!manager_c.has_completed()?);
Ok(())
}

#[test]
fn test_has_completed_false_when_no_file() -> ANNResult<()> {
let temp_dir = tempdir()?;
let index_prefix = temp_dir
.path()
.join("nonexistent_index")
.to_str()
.unwrap()
.to_string();
let manager = CheckpointRecordManagerWithFileStorage::new(&index_prefix, 999);
assert!(!manager.has_completed()?);
Ok(())
}

#[test]
fn test_mark_as_invalid() -> ANNResult<()> {
let temp_dir = tempdir()?;
let index_prefix = temp_dir
.path()
.join("test_invalid")
.to_str()
.unwrap()
.to_string();
let identifier = 77;

let mut manager = CheckpointRecordManagerWithFileStorage::new(&index_prefix, identifier);
// Advance to a later stage with some progress
manager.update(Progress::Completed, WorkStage::QuantizeFPV)?;
manager.update(Progress::Processed(42), WorkStage::InMemIndexBuild)?;

// Verify we can resume from progress=42
let manager2 = CheckpointRecordManagerWithFileStorage::new(&index_prefix, identifier);
assert_eq!(
manager2.get_resumption_point(WorkStage::QuantizeFPV)?,
Some(42)
);

// Mark as invalid - progress resets to 0 (is_valid=false => progress read as 0)
let mut manager3 = CheckpointRecordManagerWithFileStorage::new(&index_prefix, identifier);
manager3.mark_as_invalid()?;
assert_eq!(
manager3.get_resumption_point(WorkStage::QuantizeFPV)?,
Some(0)
);

clean_checkpoint_file(&index_prefix, identifier);
Ok(())
}

#[test]
fn test_checkpoint_manager_interruption_and_resumption() -> ANNResult<()> {
let temp_dir = tempdir()?;
Expand Down
144 changes: 144 additions & 0 deletions diskann-disk/src/build/chunking/continuation/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,150 @@ mod tests {
}
}

/// A tracker that returns Stop after `stop_after` Continue grants.
#[derive(Clone)]
struct StopAfterTracker {
count: std::sync::Arc<std::sync::Mutex<usize>>,
stop_after: usize,
}

impl ContinuationTrackerTrait for StopAfterTracker {
fn get_continuation_grant(&self) -> ContinuationGrant {
let mut count = self.count.lock().unwrap();
if *count >= self.stop_after {
ContinuationGrant::Stop
} else {
*count += 1;
ContinuationGrant::Continue
}
}
}

#[test]
fn test_process_while_resource_is_available_stops_early() {
let tracker = StopAfterTracker {
count: std::sync::Arc::new(std::sync::Mutex::new(0)),
stop_after: 3,
};
let items = vec![10, 20, 30, 40, 50];
let mut processed = Vec::new();

let result = process_while_resource_is_available(
|item| {
processed.push(item);
Ok::<(), TestError>(())
},
items.into_iter(),
Box::new(tracker),
);

assert!(result.is_ok());
match result.unwrap() {
Progress::Processed(idx) => {
assert_eq!(idx, 3); // stopped before processing item at index 3
assert_eq!(processed, vec![10, 20, 30]);
}
_ => panic!("Expected Processed"),
}
}

/// A tracker that yields once (with a tiny duration), then continues.
#[derive(Clone)]
struct YieldOnceThenContinueTracker {
yielded: std::sync::Arc<std::sync::Mutex<bool>>,
}

impl ContinuationTrackerTrait for YieldOnceThenContinueTracker {
fn get_continuation_grant(&self) -> ContinuationGrant {
let mut yielded = self.yielded.lock().unwrap();
if !*yielded {
*yielded = true;
ContinuationGrant::Yield(std::time::Duration::ZERO)
} else {
Comment thread
arrayka marked this conversation as resolved.
// After yielding once, always continue
ContinuationGrant::Continue
}
}
}

#[test]
fn test_process_while_resource_is_available_yield_then_continue() {
let tracker = YieldOnceThenContinueTracker {
yielded: std::sync::Arc::new(std::sync::Mutex::new(false)),
};
let items = vec![1, 2];
let mut processed = Vec::new();

let result = process_while_resource_is_available(
|item| {
processed.push(item);
Ok::<(), TestError>(())
},
items.into_iter(),
Box::new(tracker),
);

assert!(result.is_ok());
// After yielding, it should have continued and processed all items
match result.unwrap() {
Progress::Completed => assert_eq!(processed, vec![1, 2]),
_ => panic!("Expected Completed"),
}
}

#[test]
fn test_process_while_resource_is_available_action_error() {
let checker = Box::new(NaiveContinuationTracker::default());
let items = vec![1, 2, 3];

let result = process_while_resource_is_available(
|item| {
if item == 2 {
Err(TestError)
} else {
Ok(())
}
},
items.into_iter(),
checker,
);

assert!(result.is_err());
}

#[tokio::test]
async fn test_process_while_resource_is_available_async_stops_early() {
let tracker = StopAfterTracker {
count: std::sync::Arc::new(std::sync::Mutex::new(0)),
stop_after: 2,
};
let items = vec![1, 2, 3, 4, 5];
let processed = std::sync::Arc::new(tokio::sync::Mutex::new(Vec::new()));

let result = process_while_resource_is_available_async(
|item| {
let processed = processed.clone();
async move {
processed.lock().await.push(item);
Ok::<(), TestError>(())
}
},
items.into_iter(),
Box::new(tracker),
)
.await;

assert!(result.is_ok());
match result.unwrap() {
Progress::Processed(idx) => {
assert_eq!(idx, 2);
let processed = processed.lock().await;
assert_eq!(*processed, vec![1, 2]);
}
_ => panic!("Expected Processed"),
}
}

#[tokio::test]
async fn test_process_while_resource_is_available_async_completes() {
let checker = Box::new(NaiveContinuationTracker::default());
Expand Down
26 changes: 26 additions & 0 deletions diskann-disk/src/search/pq/pq_scratch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,4 +128,30 @@ mod tests {
assert_eq!(pq_scratch.query_scratch[i], query[i]);
});
}

#[test]
fn test_pq_scratch_set_rejects_short_query() {
let dim = 16;
let mut pq_scratch = PQScratch::new(64, dim, 4, 256).unwrap();

// Query shorter than dim should fail
let short_query: Vec<f32> = (1..dim).map(|i| i as f32).collect(); // dim-1 elements
let err = pq_scratch.set(&short_query).unwrap_err();
assert_eq!(err.kind(), diskann::ANNErrorKind::DimensionMismatchError);
assert!(err.to_string().contains("expected query of length"));
}

#[test]
fn test_pq_scratch_set_accepts_oversized_query() {
let dim = 8;
let mut pq_scratch = PQScratch::new(64, dim, 4, 256).unwrap();

// Query longer than dim should succeed (only first `dim` elements used)
let long_query: Vec<f32> = (1..=dim + 10).map(|i| i as f32).collect();
pq_scratch.set(&long_query).unwrap();

for (i, &val) in long_query.iter().enumerate().take(dim) {
assert_eq!(pq_scratch.query_scratch[i], val);
}
}
Comment on lines +144 to +156

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this expected behavior? Shouldn't the scratch fail for an incorrectly sized query?

@arrayka arrayka Jun 26, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is the expected behavior, according to the documentation of pq_scratch.set():

    /// Copy the first `dim` elements of `query` into `query_scratch`.
    ///
    /// `query` must already be in full-precision `f32` representation; quantized
    /// inputs (e.g. `MinMaxElement`) should be decoded via `VectorRepr::as_f32`
    /// at the caller boundary before invoking this method.
    ///
    /// Accepts oversized `query` (only the first `dim` elements are used) for
    /// backwards compatibility with callers that hold alignment-padded buffers.
    /// Returns `DimensionMismatchError` if `query.len() < query_scratch.len()`.
    pub fn set(&mut self, query: &[f32]) -> ANNResult<()>

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I followed the chain of calls to set() and new() and I don't think supporting larger input query dimension is actually needed here (the scratch length is derived from the dimension of the fp vectors in the PQ table, which is the same as the query dimension).

Can we fix this since we've found it? The reason I'm asking is cause this is actually incorrect behavior; I'm not actually sure how this ended up getting supported (It's probably my fault!). The f32 dimension is not larger than the dimension of minmax vectors, it's actually smaller :)

}
44 changes: 44 additions & 0 deletions diskann-disk/src/search/provider/disk_sector_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -373,4 +373,48 @@ mod disk_sector_graph_test {
let data = &graph;
assert_eq!(data.len(), 512);
}

#[test]
fn test_reconfigure_grows_buffer() {
let reader = AlignedFileReaderFactory::new(test_index_path())
.build()
.unwrap();
let mut graph = test_initialize_disk_sector_graph(2, 1, reader);
assert_eq!(graph.max_n_batch_sector_read, 4);

// Reconfigure to larger batch — buffer must grow beyond initial 512 bytes
graph.reconfigure(16).unwrap();
assert_eq!(graph.max_n_batch_sector_read, 16);
assert_eq!(graph.sectors_data.len(), 16 * 64);
}

#[test]
fn test_reconfigure_noop_for_smaller_size() {
let reader = AlignedFileReaderFactory::new(test_index_path())
.build()
.unwrap();
let mut graph = test_initialize_disk_sector_graph(2, 1, reader);
let original_len = graph.sectors_data.len();

// Reconfigure with same or smaller size should be a no-op
graph.reconfigure(4).unwrap();
assert_eq!(graph.max_n_batch_sector_read, 4);
assert_eq!(graph.sectors_data.len(), original_len);

graph.reconfigure(2).unwrap();
assert_eq!(graph.max_n_batch_sector_read, 4);
assert_eq!(graph.sectors_data.len(), original_len);
}

#[test]
fn test_new_disk_sector_graph_zero_block_size_defaults() {
let metadata = GraphMetadata::new(1000, 32, 500, 32, 2, 20, 50, 1024, 256);
// block_size = 0 should fall back to DEFAULT_DISK_SECTOR_LEN regardless of version
let header = GraphHeader::new(metadata, 0, GraphLayoutVersion::new(1, 0));
let reader = AlignedFileReaderFactory::new(test_index_path())
.build()
.unwrap();
let graph = DiskSectorGraph::new(reader, &header, 2).unwrap();
assert_eq!(graph.block_size, DEFAULT_DISK_SECTOR_LEN);
}
}
36 changes: 36 additions & 0 deletions diskann-disk/src/storage/quant/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -598,4 +598,40 @@ mod generator_tests {

Ok(())
}

#[test]
fn test_validate_params_missing_compressed_file() -> ANNResult<()> {
let storage_provider = VirtualStorageProvider::new_memory();
storage_provider
.filesystem()
.create_dir("/test_data")
.expect("Could not create test directory");

let data_path = "/test_data/data.bin";
let compressed_path = "/test_data/compressed.bin";
let num_points = 100;
let dim = 8;
let output_dim = 4u32;

// Create source data
let data = create_test_data(num_points, dim);
let view = MatrixView::try_from(data.as_slice(), num_points, dim).unwrap();
write_bin(view, &mut storage_provider.create_for_write(data_path)?)?;

// Don't create compressed file but set offset > 0
let context = GeneratorContext::new(10, compressed_path.to_string());
let generator = QuantDataGenerator::<f32, DummyCompressor>::new(
data_path.to_string(),
context,
&output_dim,
)
.unwrap();

let err = generator
.validate_params(num_points, &storage_provider)
.unwrap_err();
assert_eq!(err.kind(), diskann::ANNErrorKind::FileNotFoundError);
assert!(err.to_string().contains("expected compressed file"));
Ok(())
}
}
Loading
Loading